From a3e30199e18cad0b7fbaf776e584140c75030ef6 Mon Sep 17 00:00:00 2001 From: Michael Brewer Date: Mon, 7 Sep 2020 01:22:24 -0700 Subject: [PATCH 01/30] feat(trigger): class wrapper for events Put together an intial set of common lambda trigger events --- .../utilities/trigger/__init__.py | 15 ++ .../trigger/cloud_watch_logs_event.py | 66 +++++++ .../trigger/dynamo_db_stream_event.py | 181 ++++++++++++++++++ .../utilities/trigger/s3_event.py | 133 +++++++++++++ .../utilities/trigger/ses_event.py | 153 +++++++++++++++ .../utilities/trigger/sns_event.py | 87 +++++++++ .../utilities/trigger/sqs_event.py | 105 ++++++++++ tests/events/cloudWatchLogEvent.json | 5 + tests/events/dynamoStreamEvent.json | 64 +++++++ tests/events/s3Event.json | 38 ++++ tests/events/sesEvent.json | 100 ++++++++++ tests/events/snsEvent.json | 31 +++ tests/events/sqsEvent.json | 42 ++++ .../functional/test_lambda_trigger_events.py | 167 ++++++++++++++++ 14 files changed, 1187 insertions(+) create mode 100644 aws_lambda_powertools/utilities/trigger/__init__.py create mode 100644 aws_lambda_powertools/utilities/trigger/cloud_watch_logs_event.py create mode 100644 aws_lambda_powertools/utilities/trigger/dynamo_db_stream_event.py create mode 100644 aws_lambda_powertools/utilities/trigger/s3_event.py create mode 100644 aws_lambda_powertools/utilities/trigger/ses_event.py create mode 100644 aws_lambda_powertools/utilities/trigger/sns_event.py create mode 100644 aws_lambda_powertools/utilities/trigger/sqs_event.py create mode 100644 tests/events/cloudWatchLogEvent.json create mode 100644 tests/events/dynamoStreamEvent.json create mode 100644 tests/events/s3Event.json create mode 100644 tests/events/sesEvent.json create mode 100644 tests/events/snsEvent.json create mode 100644 tests/events/sqsEvent.json create mode 100644 tests/functional/test_lambda_trigger_events.py diff --git a/aws_lambda_powertools/utilities/trigger/__init__.py b/aws_lambda_powertools/utilities/trigger/__init__.py new file mode 100644 index 00000000000..4f6d16231b6 --- /dev/null +++ b/aws_lambda_powertools/utilities/trigger/__init__.py @@ -0,0 +1,15 @@ +from .cloud_watch_logs_event import CloudWatchLogsEvent +from .dynamo_db_stream_event import DynamoDBStreamEvent +from .s3_event import S3Event +from .ses_event import SESEvent +from .sns_event import SNSEvent +from .sqs_event import SQSEvent + +__all__ = [ + "CloudWatchLogsEvent", + "DynamoDBStreamEvent", + "S3Event", + "SESEvent", + "SNSEvent", + "SQSEvent", +] diff --git a/aws_lambda_powertools/utilities/trigger/cloud_watch_logs_event.py b/aws_lambda_powertools/utilities/trigger/cloud_watch_logs_event.py new file mode 100644 index 00000000000..7fb7182bd55 --- /dev/null +++ b/aws_lambda_powertools/utilities/trigger/cloud_watch_logs_event.py @@ -0,0 +1,66 @@ +import base64 +import json +import zlib +from typing import Dict, List, Optional + + +class CloudWatchLogsLogEvent(dict): + @property + def id(self) -> str: # noqa: A003 + return self["id"] + + @property + def timestamp(self) -> int: + return self["timestamp"] + + @property + def message(self) -> str: + return self["message"] + + @property + def extracted_fields(self) -> Optional[Dict[str, str]]: + return self.get("extractedFields") + + +class CloudWatchLogsDecodedData(dict): + @property + def owner(self) -> str: + return self["owner"] + + @property + def log_group(self) -> str: + return self["logGroup"] + + @property + def log_stream(self) -> str: + return self["logStream"] + + @property + def subscription_filters(self) -> List[str]: + return self["subscriptionFilters"] + + @property + def message_type(self) -> str: + return self["messageType"] + + @property + def log_events(self) -> List[CloudWatchLogsLogEvent]: + return [CloudWatchLogsLogEvent(i) for i in self["logEvents"]] + + +class CloudWatchLogsEventData(dict): + @property + def data(self) -> str: + return self["data"] + + +class CloudWatchLogsEvent(dict): + @property + def aws_logs(self) -> CloudWatchLogsEventData: + return CloudWatchLogsEventData(self["awslogs"]) + + def cloud_watch_logs_decoded_data(self) -> CloudWatchLogsDecodedData: + """Gzip and parse data""" + payload = base64.b64decode(self.aws_logs.data) + decoded: dict = json.loads(zlib.decompress(payload, zlib.MAX_WBITS | 32).decode("UTF-8")) + return CloudWatchLogsDecodedData(decoded) diff --git a/aws_lambda_powertools/utilities/trigger/dynamo_db_stream_event.py b/aws_lambda_powertools/utilities/trigger/dynamo_db_stream_event.py new file mode 100644 index 00000000000..6f084b09f9a --- /dev/null +++ b/aws_lambda_powertools/utilities/trigger/dynamo_db_stream_event.py @@ -0,0 +1,181 @@ +from __future__ import annotations + +from enum import Enum +from typing import Dict, Iterator, List, Optional + + +def _attribute_value(values: dict, key: str) -> Optional[Dict[str, AttributeValue]]: + item: dict = values.get(key) + return None if item is None else {k: AttributeValue(v) for k, v in item.items()} + + +class AttributeValue(dict): + """Represents the data for an attribute + + Documentation: https://docs.aws.amazon.com/amazondynamodb/latest/APIReference/API_streams_AttributeValue.html + """ + + @property + def b_value(self) -> Optional[str]: + """An attribute of type Base64-encoded binary data object""" + return self.get("B") + + @property + def bs_value(self) -> Optional[List[str]]: + """An attribute of type Array of Base64-encoded binary data objects""" + return self.get("BS") + + @property + def bool_value(self) -> Optional[bool]: + """An attribute of type Boolean""" + item = self.get("bool") + return None if item is None else bool(item) + + @property + def list_value(self) -> Optional[List[AttributeValue]]: + """An attribute of type Array of AttributeValue objects""" + item = self.get("L") + return None if item is None else [AttributeValue(i) for i in item] + + @property + def map_value(self) -> Optional[Dict[str, AttributeValue]]: + """An attribute of type String to AttributeValue object map""" + return _attribute_value(self, "M") + + @property + def n_value(self) -> Optional[str]: + """An attribute of type Number""" + return self.get("N") + + @property + def ns_value(self) -> Optional[List[str]]: + """An attribute of type Number Set""" + return self.get("NS") + + @property + def null_value(self) -> Optional[bool]: + """An attribute of type Null.""" + item = self.get("NULL") + return None if item is None else bool(item) + + @property + def s_value(self) -> Optional[str]: + """An attribute of type String""" + return self.get("S") + + @property + def ss_value(self) -> Optional[List[str]]: + """An attribute of type Array of strings""" + return self.get("SS") + + +class StreamViewType(Enum): + """The type of data from the modified DynamoDB item that was captured in this stream record""" + + KEYS_ONLY = 0 # only the key attributes of the modified item + NEW_IMAGE = 1 # the entire item, as it appeared after it was modified. + OLD_IMAGE = 2 # the entire item, as it appeared before it was modified. + NEW_AND_OLD_IMAGES = 3 # both the new and the old item images of the item. + + +class StreamRecord(dict): + @property + def approximate_creation_date_time(self) -> Optional[int]: + """The approximate date and time when the stream record was created, in UNIX epoch time format.""" + item = self.get("ApproximateCreationDateTime") + return None if item is None else int(item) + + @property + def keys(self) -> Optional[Dict[str, AttributeValue]]: + """The primary key attribute(s) for the DynamoDB item that was modified.""" + return _attribute_value(self, "Keys") + + @property + def new_image(self) -> Optional[Dict[str, AttributeValue]]: + """The item in the DynamoDB table as it appeared after it was modified.""" + return _attribute_value(self, "NewImage") + + @property + def old_image(self) -> Optional[Dict[str, AttributeValue]]: + """The item in the DynamoDB table as it appeared before it was modified.""" + return _attribute_value(self, "OldImage") + + @property + def sequence_number(self) -> Optional[str]: + """The sequence number of the stream record.""" + return self.get("SequenceNumber") + + @property + def size_bytes(self) -> Optional[int]: + """The size of the stream record, in bytes.""" + item = self.get("SizeBytes") + return None if item is None else int(item) + + @property + def stream_view_type(self) -> Optional[StreamViewType]: + """The type of data from the modified DynamoDB item that was captured in this stream record""" + item = self.get("StreamViewType") + return None if item is None else StreamViewType[str(item)] + + +class DynamoDBRecordEventName(Enum): + INSERT = 0 # a new item was added to the table + MODIFY = 1 # one or more of an existing item's attributes were modified + REMOVE = 2 # the item was deleted from the table + + +class DynamoDBRecord(dict): + """A description of a unique event within a stream""" + + @property + def aws_region(self) -> Optional[str]: + """The region in which the GetRecords request was received""" + return self.get("awsRegion") + + @property + def dynamodb(self) -> Optional[StreamRecord]: + """The main body of the stream record, containing all of the DynamoDB-specific fields.""" + item = self.get("dynamodb") + return None if item is None else StreamRecord(item) + + @property + def event_id(self) -> Optional[str]: + """A globally unique identifier for the event that was recorded in this stream record.""" + return self.get("eventID") + + @property + def event_name(self) -> Optional[DynamoDBRecordEventName]: + """The type of data modification that was performed on the DynamoDB table""" + item = self.get("eventName") + return None if item is None else DynamoDBRecordEventName[item] + + @property + def event_source(self) -> Optional[str]: + """The AWS service from which the stream record originated. For DynamoDB Streams, this is aws:dynamodb.""" + return self.get("eventSource") + + @property + def event_source_arn(self) -> Optional[str]: + return self.get("eventSourceARN") + + @property + def event_version(self) -> Optional[str]: + """The version number of the stream record format.""" + return self.get("eventVersion") + + @property + def user_identity(self) -> Optional[dict]: + """Contains details about the type of identity that made the request""" + return self.get("userIdentity") + + +class DynamoDBStreamEvent(dict): + """Dynamo DB Stream Event + + Documentation: https://docs.aws.amazon.com/lambda/latest/dg/with-ddb.html + """ + + @property + def records(self) -> Iterator[DynamoDBRecord]: + for record in self["Records"]: + yield DynamoDBRecord(record) diff --git a/aws_lambda_powertools/utilities/trigger/s3_event.py b/aws_lambda_powertools/utilities/trigger/s3_event.py new file mode 100644 index 00000000000..0fe61a92221 --- /dev/null +++ b/aws_lambda_powertools/utilities/trigger/s3_event.py @@ -0,0 +1,133 @@ +from typing import Dict, Iterator, Optional + + +class S3Identity(dict): + @property + def principal_id(self) -> str: + return self["principalId"] + + +class S3RequestParameters(dict): + @property + def source_ip_address(self) -> str: + return self["sourceIPAddress"] + + +class S3Bucket(dict): + @property + def name(self) -> str: + return self["name"] + + @property + def owner_identity(self) -> S3Identity: + return S3Identity(self["ownerIdentity"]) + + @property + def arn(self) -> str: + return self["arn"] + + +class S3Object(dict): + @property + def key(self) -> str: + return self["key"] + + @property + def size(self) -> int: + return int(self["size"]) + + @property + def etag(self) -> str: + return self["eTag"] + + @property + def version_id(self) -> Optional[str]: + return self.get("versionId") + + @property + def sequencer(self) -> str: + return self["sequencer"] + + +class S3Message(dict): + @property + def s3_schema_version(self) -> str: + return self["s3SchemaVersion"] + + @property + def configuration_id(self) -> str: + return self["configurationId"] + + @property + def bucket(self) -> S3Bucket: + return S3Bucket(self["bucket"]) + + @property + def object(self) -> S3Object: # noqa: A003 + return S3Object(self["object"]) + + +class S3EventRecordGlacierRestoreEventData(dict): + @property + def lifecycle_restoration_expiry_time(self) -> str: + return self["lifecycleRestorationExpiryTime"] + + @property + def lifecycle_restore_storage_class(self) -> str: + return self["lifecycleRestoreStorageClass"] + + +class S3EventRecordGlacierEventData(dict): + @property + def restore_event_data(self) -> S3EventRecordGlacierRestoreEventData: + return S3EventRecordGlacierRestoreEventData(self["restoreEventData"]) + + +class S3EventRecord(dict): + @property + def event_version(self) -> str: + return self["eventVersion"] + + @property + def event_source(self) -> str: + return self["eventSource"] + + @property + def aws_region(self) -> str: + return self["awsRegion"] + + @property + def event_time(self) -> str: + return self["eventTime"] + + @property + def event_name(self) -> str: + return self["eventName"] + + @property + def user_identity(self) -> S3Identity: + return S3Identity(self["userIdentity"]) + + @property + def request_parameters(self) -> S3RequestParameters: + return S3RequestParameters(self["requestParameters"]) + + @property + def response_elements(self) -> Dict[str, str]: + return self["responseElements"] + + @property + def s3(self) -> S3Message: + return S3Message(self["s3"]) + + @property + def glacier_event_data(self) -> Optional[S3EventRecordGlacierEventData]: + item = self.get("glacierEventData") + return None if item is None else S3EventRecordGlacierEventData(item) + + +class S3Event(dict): + @property + def records(self) -> Iterator[S3EventRecord]: + for record in self["Records"]: + yield S3EventRecord(record) diff --git a/aws_lambda_powertools/utilities/trigger/ses_event.py b/aws_lambda_powertools/utilities/trigger/ses_event.py new file mode 100644 index 00000000000..ff994b1a93c --- /dev/null +++ b/aws_lambda_powertools/utilities/trigger/ses_event.py @@ -0,0 +1,153 @@ +from typing import Iterator, List + + +class SESMailHeader(dict): + @property + def name(self) -> str: + return self["name"] + + @property + def value(self) -> str: + return self["value"] + + +class SESMailCommonHeaders(dict): + @property + def return_path(self) -> str: + return self["returnPath"] + + @property + def form(self) -> List[str]: + return self["form"] + + @property + def date(self) -> List[str]: + return self["date"] + + @property + def to(self) -> List[str]: + return self["to"] + + @property + def message_id(self) -> str: + return str(self["messageId"]) + + @property + def subject(self) -> str: + return str(self["subject"]) + + +class SESMail(dict): + @property + def timestamp(self) -> str: + return self["timestamp"] + + @property + def source(self) -> str: + return self["source"] + + @property + def message_id(self) -> str: + return self["messageId"] + + @property + def destination(self) -> List[str]: + return self["destination"] + + @property + def headers_truncated(self) -> bool: + return bool(self["headersTruncated"]) + + @property + def headers(self) -> Iterator[SESMailHeader]: + for header in self["headers"]: + yield SESMailHeader(header) + + @property + def common_headers(self) -> SESMailCommonHeaders: + return SESMailCommonHeaders(self["commonHeaders"]) + + +class SESReceiptStatus(dict): + @property + def status(self) -> str: + return str(self["status"]) + + +class SESReceiptAction(dict): + @property + def type(self) -> str: # noqa: A003 + return self["type"] + + @property + def function_arn(self) -> str: + return self["functionArn"] + + @property + def invocation_type(self) -> str: + return self["invocationType"] + + +class SESReceipt(dict): + @property + def timestamp(self) -> str: + return self["timestamp"] + + @property + def processing_time_millis(self) -> int: + return int(self["processingTimeMillis"]) + + @property + def recipients(self) -> Iterator[str]: + return self["recipients"] + + @property + def spam_verdict(self) -> SESReceiptStatus: + return SESReceiptStatus(self["spamVerdict"]) + + @property + def virus_verdict(self) -> SESReceiptStatus: + return SESReceiptStatus(self["virusVerdict"]) + + @property + def spf_verdict(self) -> SESReceiptStatus: + return SESReceiptStatus(self["spfVerdict"]) + + @property + def dmarc_verdict(self) -> SESReceiptStatus: + return SESReceiptStatus(self["dmarcVerdict"]) + + @property + def action(self) -> SESReceiptAction: + return SESReceiptAction(self["action"]) + + +class SESMessage(dict): + @property + def mail(self) -> SESMail: + return SESMail(self["mail"]) + + @property + def receipt(self) -> SESReceipt: + return SESReceipt(self["receipt"]) + + +class SESEventRecord(dict): + @property + def event_source(self) -> str: + return self["eventSource"] + + @property + def event_version(self) -> str: + return self["eventVersion"] + + @property + def ses(self) -> SESMessage: + return SESMessage(self["ses"]) + + +class SESEvent(dict): + @property + def records(self) -> Iterator[SESEventRecord]: + for record in self["Records"]: + yield SESEventRecord(record) diff --git a/aws_lambda_powertools/utilities/trigger/sns_event.py b/aws_lambda_powertools/utilities/trigger/sns_event.py new file mode 100644 index 00000000000..c02d768fc0c --- /dev/null +++ b/aws_lambda_powertools/utilities/trigger/sns_event.py @@ -0,0 +1,87 @@ +from typing import Dict, Iterator + + +class SNSMessageAttribute(dict): + @property + def type(self) -> str: # noqa: A003 + return self["Type"] + + @property + def value(self) -> str: + return self["Value"] + + +class SNSMessage(dict): + @property + def signature_version(self) -> str: + return self["SignatureVersion"] + + @property + def timestamp(self) -> str: + return self["Timestamp"] + + @property + def signature(self) -> str: + return self["Signature"] + + @property + def signing_cert_url(self) -> str: + return self["SigningCertUrl"] + + @property + def message_id(self) -> str: + return self["MessageId"] + + @property + def message(self) -> str: + return self["Message"] + + @property + def message_attributes(self) -> Dict[str, SNSMessageAttribute]: + return {k: SNSMessageAttribute(v) for (k, v) in self["MessageAttributes"].items()} + + @property + def type(self) -> str: # noqa: A003 + return self["Type"] + + @property + def unsubscribe_url(self) -> str: + return self["UnsubscribeUrl"] + + @property + def topic_arn(self) -> str: + return self["TopicArn"] + + @property + def subject(self) -> str: + return self["Subject"] + + +class SNSEventRecord(dict): + @property + def event_version(self) -> str: + return self["EventVersion"] + + @property + def event_subscription_arn(self) -> str: + return self["EventSubscriptionArn"] + + @property + def event_source(self) -> str: + return self["EventSource"] + + @property + def sns(self) -> SNSMessage: + return SNSMessage(self["Sns"]) + + +class SNSEvent(dict): + """SNS Event + + Documentation: https://docs.aws.amazon.com/lambda/latest/dg/with-sns.html + """ + + @property + def records(self) -> Iterator[SNSEventRecord]: + for record in self["Records"]: + yield SNSEventRecord(record) diff --git a/aws_lambda_powertools/utilities/trigger/sqs_event.py b/aws_lambda_powertools/utilities/trigger/sqs_event.py new file mode 100644 index 00000000000..f9de8ea5dbb --- /dev/null +++ b/aws_lambda_powertools/utilities/trigger/sqs_event.py @@ -0,0 +1,105 @@ +from typing import Dict, Iterator, Optional + + +class SQSRecordAttributes(dict): + @property + def aws_trace_header(self) -> Optional[str]: + return self.get("AWSTraceHeader") + + @property + def approximate_receive_count(self) -> str: + return self["ApproximateReceiveCount"] + + @property + def sent_timestamp(self) -> str: + return self["SentTimestamp"] + + @property + def sender_id(self) -> str: + return self["SenderId"] + + @property + def approximate_first_receive_timestamp(self) -> str: + return self["ApproximateFirstReceiveTimestamp"] + + @property + def sequence_number(self) -> Optional[str]: + return self.get("SequenceNumber") + + @property + def message_group_id(self) -> Optional[str]: + return self.get("MessageGroupId") + + @property + def message_deduplication_id(self) -> Optional[str]: + return self.get("MessageDeduplicationId") + + +class SQSMessageAttribute(dict): + @property + def string_value(self) -> Optional[str]: + return self["stringValue"] + + @property + def binary_value(self) -> Optional[str]: + return self["binaryValue"] + + @property + def data_type(self) -> str: + return self["dataType"] + + +class SQSMessageAttributes(Dict[str, SQSMessageAttribute]): + def __getitem__(self, item) -> Optional[SQSMessageAttribute]: + item = super(SQSMessageAttributes, self).get(item) + return None if item is None else SQSMessageAttribute(item) + + +class SQSRecord(dict): + @property + def message_id(self) -> str: + return self["messageId"] + + @property + def receipt_handle(self) -> str: + return self["receiptHandle"] + + @property + def body(self) -> str: + return self["body"] + + @property + def attributes(self) -> SQSRecordAttributes: + return SQSRecordAttributes(self["attributes"]) + + @property + def message_attributes(self) -> SQSMessageAttributes: + return SQSMessageAttributes(self["messageAttributes"]) + + @property + def md5_of_body(self) -> str: + return self["md5OfBody"] + + @property + def event_source(self) -> str: + return self["eventSource"] + + @property + def event_source_arn(self) -> str: + return self["eventSourceARN"] + + @property + def aws_region(self) -> str: + return self["awsRegion"] + + +class SQSEvent(dict): + """SQS Event + + Documentation: https://docs.aws.amazon.com/lambda/latest/dg/with-sqs.html + """ + + @property + def records(self) -> Iterator[SQSRecord]: + for record in self["Records"]: + yield SQSRecord(record) diff --git a/tests/events/cloudWatchLogEvent.json b/tests/events/cloudWatchLogEvent.json new file mode 100644 index 00000000000..aa184c1d013 --- /dev/null +++ b/tests/events/cloudWatchLogEvent.json @@ -0,0 +1,5 @@ +{ + "awslogs": { + "data": "H4sIAAAAAAAAAHWPwQqCQBCGX0Xm7EFtK+smZBEUgXoLCdMhFtKV3akI8d0bLYmibvPPN3wz00CJxmQnTO41whwWQRIctmEcB6sQbFC3CjW3XW8kxpOpP+OC22d1Wml1qZkQGtoMsScxaczKN3plG8zlaHIta5KqWsozoTYw3/djzwhpLwivWFGHGpAFe7DL68JlBUk+l7KSN7tCOEJ4M3/qOI49vMHj+zCKdlFqLaU2ZHV2a4Ct/an0/ivdX8oYc1UVX860fQDQiMdxRQEAAA==" + } +} diff --git a/tests/events/dynamoStreamEvent.json b/tests/events/dynamoStreamEvent.json new file mode 100644 index 00000000000..12c535b005e --- /dev/null +++ b/tests/events/dynamoStreamEvent.json @@ -0,0 +1,64 @@ +{ + "Records": [ + { + "eventID": "1", + "eventVersion": "1.0", + "dynamodb": { + "Keys": { + "Id": { + "N": "101" + } + }, + "NewImage": { + "Message": { + "S": "New item!" + }, + "Id": { + "N": "101" + } + }, + "StreamViewType": "NEW_AND_OLD_IMAGES", + "SequenceNumber": "111", + "SizeBytes": 26 + }, + "awsRegion": "us-west-2", + "eventName": "INSERT", + "eventSourceARN": "eventsource_arn", + "eventSource": "aws:dynamodb" + }, + { + "eventID": "2", + "eventVersion": "1.0", + "dynamodb": { + "OldImage": { + "Message": { + "S": "New item!" + }, + "Id": { + "N": "101" + } + }, + "SequenceNumber": "222", + "Keys": { + "Id": { + "N": "101" + } + }, + "SizeBytes": 59, + "NewImage": { + "Message": { + "S": "This item has changed" + }, + "Id": { + "N": "101" + } + }, + "StreamViewType": "NEW_AND_OLD_IMAGES" + }, + "awsRegion": "us-west-2", + "eventName": "MODIFY", + "eventSourceARN": "source_arn", + "eventSource": "aws:dynamodb" + } + ] +} diff --git a/tests/events/s3Event.json b/tests/events/s3Event.json new file mode 100644 index 00000000000..4558dc3c9e1 --- /dev/null +++ b/tests/events/s3Event.json @@ -0,0 +1,38 @@ +{ + "Records": [ + { + "eventVersion": "2.1", + "eventSource": "aws:s3", + "awsRegion": "us-east-2", + "eventTime": "2019-09-03T19:37:27.192Z", + "eventName": "ObjectCreated:Put", + "userIdentity": { + "principalId": "AWS:AIDAINPONIXQXHT3IKHL2" + }, + "requestParameters": { + "sourceIPAddress": "205.255.255.255" + }, + "responseElements": { + "x-amz-request-id": "D82B88E5F771F645", + "x-amz-id-2": "vlR7PnpV2Ce81l0PRw6jlUpck7Jo5ZsQjryTjKlc5aLWGVHPZLj5NeC6qMa0emYBDXOo6QBU0Wo=" + }, + "s3": { + "s3SchemaVersion": "1.0", + "configurationId": "828aa6fc-f7b5-4305-8584-487c791949c1", + "bucket": { + "name": "lambda-artifacts-deafc19498e3f2df", + "ownerIdentity": { + "principalId": "A3I5XTEXAMAI3E" + }, + "arn": "arn:aws:s3:::lambda-artifacts-deafc19498e3f2df" + }, + "object": { + "key": "b21b84d653bb07b05b1e6b33684dc11b", + "size": 1305107, + "eTag": "b21b84d653bb07b05b1e6b33684dc11b", + "sequencer": "0C0F6F405D6ED209E1" + } + } + } + ] +} diff --git a/tests/events/sesEvent.json b/tests/events/sesEvent.json new file mode 100644 index 00000000000..5a5afd5bab7 --- /dev/null +++ b/tests/events/sesEvent.json @@ -0,0 +1,100 @@ +{ + "Records": [ + { + "eventVersion": "1.0", + "ses": { + "mail": { + "commonHeaders": { + "from": [ + "Jane Doe " + ], + "to": [ + "johndoe@example.com" + ], + "returnPath": "janedoe@example.com", + "messageId": "<0123456789example.com>", + "date": "Wed, 7 Oct 2015 12:34:56 -0700", + "subject": "Test Subject" + }, + "source": "janedoe@example.com", + "timestamp": "1970-01-01T00:00:00.000Z", + "destination": [ + "johndoe@example.com" + ], + "headers": [ + { + "name": "Return-Path", + "value": "" + }, + { + "name": "Received", + "value": "from mailer.example.com (mailer.example.com [203.0.113.1]) by ..." + }, + { + "name": "DKIM-Signature", + "value": "v=1; a=rsa-sha256; c=relaxed/relaxed; d=example.com; s=example; ..." + }, + { + "name": "MIME-Version", + "value": "1.0" + }, + { + "name": "From", + "value": "Jane Doe " + }, + { + "name": "Date", + "value": "Wed, 7 Oct 2015 12:34:56 -0700" + }, + { + "name": "Message-ID", + "value": "<0123456789example.com>" + }, + { + "name": "Subject", + "value": "Test Subject" + }, + { + "name": "To", + "value": "johndoe@example.com" + }, + { + "name": "Content-Type", + "value": "text/plain; charset=UTF-8" + } + ], + "headersTruncated": false, + "messageId": "o3vrnil0e2ic28tr" + }, + "receipt": { + "recipients": [ + "johndoe@example.com" + ], + "timestamp": "1970-01-01T00:00:00.000Z", + "spamVerdict": { + "status": "PASS" + }, + "dkimVerdict": { + "status": "PASS" + }, + "processingTimeMillis": 574, + "action": { + "type": "Lambda", + "invocationType": "Event", + "functionArn": "arn:aws:lambda:us-west-2:012345678912:function:Example" + }, + "dmarcVerdict": { + "status": "PASS" + }, + "spfVerdict": { + "status": "PASS" + }, + "virusVerdict": { + "status": "PASS" + } + } + }, + "eventSource": "aws:ses" + } + ] +} diff --git a/tests/events/snsEvent.json b/tests/events/snsEvent.json new file mode 100644 index 00000000000..b351dfd1418 --- /dev/null +++ b/tests/events/snsEvent.json @@ -0,0 +1,31 @@ +{ + "Records": [ + { + "EventVersion": "1.0", + "EventSubscriptionArn": "arn:aws:sns:us-east-2:123456789012:sns-la ...", + "EventSource": "aws:sns", + "Sns": { + "SignatureVersion": "1", + "Timestamp": "2019-01-02T12:45:07.000Z", + "Signature": "tcc6faL2yUC6dgZdmrwh1Y4cGa/ebXEkAi6RibDsvpi+tE/1+82j...65r==", + "SigningCertUrl": "https://sns.us-east-2.amazonaws.com/SimpleNotificat ...", + "MessageId": "95df01b4-ee98-5cb9-9903-4c221d41eb5e", + "Message": "Hello from SNS!", + "MessageAttributes": { + "Test": { + "Type": "String", + "Value": "TestString" + }, + "TestBinary": { + "Type": "Binary", + "Value": "TestBinary" + } + }, + "Type": "Notification", + "UnsubscribeUrl": "https://sns.us-east-2.amazonaws.com/?Action=Unsubscri ...", + "TopicArn": "arn:aws:sns:us-east-2:123456789012:sns-lambda", + "Subject": "TestInvoke" + } + } + ] +} diff --git a/tests/events/sqsEvent.json b/tests/events/sqsEvent.json new file mode 100644 index 00000000000..7201068d60c --- /dev/null +++ b/tests/events/sqsEvent.json @@ -0,0 +1,42 @@ +{ + "Records": [ + { + "messageId": "059f36b4-87a3-44ab-83d2-661975830a7d", + "receiptHandle": "AQEBwJnKyrHigUMZj6rYigCgxlaS3SLy0a...", + "body": "Test message.", + "attributes": { + "ApproximateReceiveCount": "1", + "SentTimestamp": "1545082649183", + "SenderId": "AIDAIENQZJOLO23YVJ4VO", + "ApproximateFirstReceiveTimestamp": "1545082649185" + }, + "messageAttributes": { + "testAttr": { + "stringValue": "100", + "binaryValue": "base64Str", + "dataType": "Number" + } + }, + "md5OfBody": "e4e68fb7bd0e697a0ae8f1bb342846b3", + "eventSource": "aws:sqs", + "eventSourceARN": "arn:aws:sqs:us-east-2:123456789012:my-queue", + "awsRegion": "us-east-2" + }, + { + "messageId": "2e1424d4-f796-459a-8184-9c92662be6da", + "receiptHandle": "AQEBzWwaftRI0KuVm4tP+/7q1rGgNqicHq...", + "body": "Test message.", + "attributes": { + "ApproximateReceiveCount": "1", + "SentTimestamp": "1545082650636", + "SenderId": "AIDAIENQZJOLO23YVJ4VO", + "ApproximateFirstReceiveTimestamp": "1545082650649" + }, + "messageAttributes": {}, + "md5OfBody": "e4e68fb7bd0e697a0ae8f1bb342846b3", + "eventSource": "aws:sqs", + "eventSourceARN": "arn:aws:sqs:us-east-2:123456789012:my-queue", + "awsRegion": "us-east-2" + } + ] +} diff --git a/tests/functional/test_lambda_trigger_events.py b/tests/functional/test_lambda_trigger_events.py new file mode 100644 index 00000000000..1b71f305650 --- /dev/null +++ b/tests/functional/test_lambda_trigger_events.py @@ -0,0 +1,167 @@ +import json +import os + +from aws_lambda_powertools.utilities.trigger import CloudWatchLogsEvent, S3Event, SESEvent, SNSEvent, SQSEvent +from aws_lambda_powertools.utilities.trigger.dynamo_db_stream_event import ( + DynamoDBRecordEventName, + DynamoDBStreamEvent, + StreamViewType, +) + + +def load_event(file_name: str) -> dict: + full_file_name = os.path.dirname(os.path.realpath(__file__)) + "/../events/" + file_name + with open(full_file_name) as fp: + return json.load(fp) + + +def test_cloud_watch_trigger_event(): + event = CloudWatchLogsEvent(load_event("cloudWatchLogEvent.json")) + + decoded_data = event.cloud_watch_logs_decoded_data() + log_events = decoded_data.log_events + log_event = log_events[0] + + assert decoded_data.owner == "123456789123" + assert decoded_data.log_group == "testLogGroup" + assert decoded_data.log_stream == "testLogStream" + assert decoded_data.subscription_filters == ["testFilter"] + assert decoded_data.message_type == "DATA_MESSAGE" + + assert log_event.id == "eventId1" + assert log_event.timestamp == 1440442987000 + assert log_event.message == "[ERROR] First test message" + assert log_event.extracted_fields is None + + +def test_dynamo_db_stream_trigger_event(): + event = DynamoDBStreamEvent(load_event("dynamoStreamEvent.json")) + + records = list(event.records) + record = records[0] + assert record.aws_region == "us-west-2" + dynamodb = record.dynamodb + assert dynamodb is not None + assert dynamodb.approximate_creation_date_time is None + keys = dynamodb.keys + assert keys is not None + id_key = keys["Id"] + assert id_key.b_value is None + assert id_key.bs_value is None + assert id_key.bool_value is None + assert id_key.list_value is None + assert id_key.map_value is None + assert id_key.n_value == "101" + assert id_key.ns_value is None + assert id_key.null_value is None + assert id_key.s_value is None + assert id_key.ss_value is None + message_key = dynamodb.new_image["Message"] + assert message_key is not None + assert message_key.s_value == "New item!" + assert dynamodb.old_image is None + assert dynamodb.sequence_number == "111" + assert dynamodb.size_bytes == 26 + assert dynamodb.stream_view_type == StreamViewType.NEW_AND_OLD_IMAGES + assert record.event_id == "1" + assert record.event_name is DynamoDBRecordEventName.INSERT + assert record.event_source == "aws:dynamodb" + assert record.event_source_arn == "eventsource_arn" + assert record.event_version == "1.0" + assert record.user_identity is None + + +def test_s3_trigger_event(): + event = S3Event(load_event("s3Event.json")) + records = list(event.records) + assert len(records) == 1 + record = records[0] + assert record.event_version == "2.1" + assert record.event_source == "aws:s3" + assert record.aws_region == "us-east-2" + assert record.event_time == "2019-09-03T19:37:27.192Z" + assert record.event_name == "ObjectCreated:Put" + user_identity = record.user_identity + assert user_identity.principal_id == "AWS:AIDAINPONIXQXHT3IKHL2" + request_parameters = record.request_parameters + assert request_parameters.source_ip_address == "205.255.255.255" + assert record.response_elements["x-amz-request-id"] == "D82B88E5F771F645" + s3 = record.s3 + assert s3.s3_schema_version == "1.0" + assert s3.configuration_id == "828aa6fc-f7b5-4305-8584-487c791949c1" + bucket = s3.bucket + assert bucket.name == "lambda-artifacts-deafc19498e3f2df" + assert bucket.owner_identity.principal_id == "A3I5XTEXAMAI3E" + assert bucket.arn == "arn:aws:s3:::lambda-artifacts-deafc19498e3f2df" + assert s3.object.key == "b21b84d653bb07b05b1e6b33684dc11b" + assert s3.object.size == 1305107 + assert s3.object.etag == "b21b84d653bb07b05b1e6b33684dc11b" + assert s3.object.version_id is None + assert s3.object.sequencer == "0C0F6F405D6ED209E1" + assert record.glacier_event_data is None + + +def test_ses_trigger_event(): + event = SESEvent(load_event("sesEvent.json")) + + records = list(event.records) + record = records[0] + print(record) + assert record.event_source == "aws:ses" + + +def test_sns_trigger_event(): + event = SNSEvent(load_event("snsEvent.json")) + records = list(event.records) + assert len(records) == 1 + record = records[0] + assert record.event_version == "1.0" + assert record.event_subscription_arn == "arn:aws:sns:us-east-2:123456789012:sns-la ..." + assert record.event_source == "aws:sns" + sns = record.sns + assert sns.signature_version == "1" + assert sns.timestamp == "2019-01-02T12:45:07.000Z" + assert sns.signature == "tcc6faL2yUC6dgZdmrwh1Y4cGa/ebXEkAi6RibDsvpi+tE/1+82j...65r==" + assert sns.signing_cert_url == "https://sns.us-east-2.amazonaws.com/SimpleNotificat ..." + assert sns.message_id == "95df01b4-ee98-5cb9-9903-4c221d41eb5e" + assert sns.message == "Hello from SNS!" + message_attributes = sns.message_attributes + test_message_attribute = message_attributes["Test"] + assert test_message_attribute.type == "String" + assert test_message_attribute.value == "TestString" + assert sns.type == "Notification" + assert sns.unsubscribe_url == "https://sns.us-east-2.amazonaws.com/?Action=Unsubscri ..." + assert sns.topic_arn == "arn:aws:sns:us-east-2:123456789012:sns-lambda" + assert sns.subject == "TestInvoke" + + +def test_seq_trigger_event(): + event = SQSEvent(load_event("sqsEvent.json")) + + records = list(event.records) + record = records[0] + attributes = record.attributes + message_attributes = record.message_attributes + test_attr = message_attributes["testAttr"] + + assert len(records) == 2 + assert record.message_id == "059f36b4-87a3-44ab-83d2-661975830a7d" + assert record.receipt_handle == "AQEBwJnKyrHigUMZj6rYigCgxlaS3SLy0a..." + assert record.body == "Test message." + assert attributes.aws_trace_header is None + assert attributes.approximate_receive_count == "1" + assert attributes.sent_timestamp == "1545082649183" + assert attributes.sender_id == "AIDAIENQZJOLO23YVJ4VO" + assert attributes.approximate_first_receive_timestamp == "1545082649185" + assert attributes.sequence_number is None + assert attributes.message_group_id is None + assert attributes.message_deduplication_id is None + assert message_attributes["NotFound"] is None + assert message_attributes.get("NotFound") is None + assert test_attr.string_value == "100" + assert test_attr.binary_value == "base64Str" + assert test_attr.data_type == "Number" + assert record.md5_of_body == "e4e68fb7bd0e697a0ae8f1bb342846b3" + assert record.event_source == "aws:sqs" + assert record.event_source_arn == "arn:aws:sqs:us-east-2:123456789012:my-queue" + assert record.aws_region == "us-east-2" From 2c50acb4f1a402d4af13505015adab54e24c87b3 Mon Sep 17 00:00:00 2001 From: Michael Brewer Date: Mon, 7 Sep 2020 01:34:49 -0700 Subject: [PATCH 02/30] build: fix build for python 3.6 --- .../utilities/trigger/dynamo_db_stream_event.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/aws_lambda_powertools/utilities/trigger/dynamo_db_stream_event.py b/aws_lambda_powertools/utilities/trigger/dynamo_db_stream_event.py index 6f084b09f9a..7dc8cf73915 100644 --- a/aws_lambda_powertools/utilities/trigger/dynamo_db_stream_event.py +++ b/aws_lambda_powertools/utilities/trigger/dynamo_db_stream_event.py @@ -1,14 +1,7 @@ -from __future__ import annotations - from enum import Enum from typing import Dict, Iterator, List, Optional -def _attribute_value(values: dict, key: str) -> Optional[Dict[str, AttributeValue]]: - item: dict = values.get(key) - return None if item is None else {k: AttributeValue(v) for k, v in item.items()} - - class AttributeValue(dict): """Represents the data for an attribute @@ -32,13 +25,13 @@ def bool_value(self) -> Optional[bool]: return None if item is None else bool(item) @property - def list_value(self) -> Optional[List[AttributeValue]]: + def list_value(self) -> Optional[List["AttributeValue"]]: """An attribute of type Array of AttributeValue objects""" item = self.get("L") return None if item is None else [AttributeValue(i) for i in item] @property - def map_value(self) -> Optional[Dict[str, AttributeValue]]: + def map_value(self) -> Optional[Dict[str, "AttributeValue"]]: """An attribute of type String to AttributeValue object map""" return _attribute_value(self, "M") @@ -69,6 +62,11 @@ def ss_value(self) -> Optional[List[str]]: return self.get("SS") +def _attribute_value(values: dict, key: str) -> Optional[Dict[str, AttributeValue]]: + item: dict = values.get(key) + return None if item is None else {k: AttributeValue(v) for k, v in item.items()} + + class StreamViewType(Enum): """The type of data from the modified DynamoDB item that was captured in this stream record""" From bdf4385c4e0546b5eefaa39406171ba18fc90384 Mon Sep 17 00:00:00 2001 From: Michael Brewer Date: Mon, 7 Sep 2020 01:51:20 -0700 Subject: [PATCH 03/30] fix(linters): make the python linters happy --- .../utilities/trigger/cloud_watch_logs_event.py | 6 +++++- .../utilities/trigger/s3_event.py | 3 ++- .../utilities/trigger/ses_event.py | 3 ++- .../utilities/trigger/sns_event.py | 6 ++++-- tests/functional/test_lambda_trigger_events.py | 16 ++++++++-------- 5 files changed, 21 insertions(+), 13 deletions(-) diff --git a/aws_lambda_powertools/utilities/trigger/cloud_watch_logs_event.py b/aws_lambda_powertools/utilities/trigger/cloud_watch_logs_event.py index 7fb7182bd55..63d534c57db 100644 --- a/aws_lambda_powertools/utilities/trigger/cloud_watch_logs_event.py +++ b/aws_lambda_powertools/utilities/trigger/cloud_watch_logs_event.py @@ -6,19 +6,23 @@ class CloudWatchLogsLogEvent(dict): @property - def id(self) -> str: # noqa: A003 + def log_event_id(self) -> str: + """Get the `id` property""" return self["id"] @property def timestamp(self) -> int: + """Get the `timestamp` property""" return self["timestamp"] @property def message(self) -> str: + """Get the `message` property""" return self["message"] @property def extracted_fields(self) -> Optional[Dict[str, str]]: + """Get the `extractedFields` property""" return self.get("extractedFields") diff --git a/aws_lambda_powertools/utilities/trigger/s3_event.py b/aws_lambda_powertools/utilities/trigger/s3_event.py index 0fe61a92221..b0099452690 100644 --- a/aws_lambda_powertools/utilities/trigger/s3_event.py +++ b/aws_lambda_powertools/utilities/trigger/s3_event.py @@ -63,7 +63,8 @@ def bucket(self) -> S3Bucket: return S3Bucket(self["bucket"]) @property - def object(self) -> S3Object: # noqa: A003 + def s3_object(self) -> S3Object: + """Get the `object` property as an S3Object""" return S3Object(self["object"]) diff --git a/aws_lambda_powertools/utilities/trigger/ses_event.py b/aws_lambda_powertools/utilities/trigger/ses_event.py index ff994b1a93c..375d919f9f1 100644 --- a/aws_lambda_powertools/utilities/trigger/ses_event.py +++ b/aws_lambda_powertools/utilities/trigger/ses_event.py @@ -76,7 +76,8 @@ def status(self) -> str: class SESReceiptAction(dict): @property - def type(self) -> str: # noqa: A003 + def action_type(self) -> str: + """Get the `type` property""" return self["type"] @property diff --git a/aws_lambda_powertools/utilities/trigger/sns_event.py b/aws_lambda_powertools/utilities/trigger/sns_event.py index c02d768fc0c..bbf3078587a 100644 --- a/aws_lambda_powertools/utilities/trigger/sns_event.py +++ b/aws_lambda_powertools/utilities/trigger/sns_event.py @@ -3,7 +3,8 @@ class SNSMessageAttribute(dict): @property - def type(self) -> str: # noqa: A003 + def attribute_type(self) -> str: + """Get the `type` property""" return self["Type"] @property @@ -41,7 +42,8 @@ def message_attributes(self) -> Dict[str, SNSMessageAttribute]: return {k: SNSMessageAttribute(v) for (k, v) in self["MessageAttributes"].items()} @property - def type(self) -> str: # noqa: A003 + def message_type(self) -> str: + """Get the `type` property""" return self["Type"] @property diff --git a/tests/functional/test_lambda_trigger_events.py b/tests/functional/test_lambda_trigger_events.py index 1b71f305650..e0c2475f957 100644 --- a/tests/functional/test_lambda_trigger_events.py +++ b/tests/functional/test_lambda_trigger_events.py @@ -28,7 +28,7 @@ def test_cloud_watch_trigger_event(): assert decoded_data.subscription_filters == ["testFilter"] assert decoded_data.message_type == "DATA_MESSAGE" - assert log_event.id == "eventId1" + assert log_event.log_event_id == "eventId1" assert log_event.timestamp == 1440442987000 assert log_event.message == "[ERROR] First test message" assert log_event.extracted_fields is None @@ -93,11 +93,11 @@ def test_s3_trigger_event(): assert bucket.name == "lambda-artifacts-deafc19498e3f2df" assert bucket.owner_identity.principal_id == "A3I5XTEXAMAI3E" assert bucket.arn == "arn:aws:s3:::lambda-artifacts-deafc19498e3f2df" - assert s3.object.key == "b21b84d653bb07b05b1e6b33684dc11b" - assert s3.object.size == 1305107 - assert s3.object.etag == "b21b84d653bb07b05b1e6b33684dc11b" - assert s3.object.version_id is None - assert s3.object.sequencer == "0C0F6F405D6ED209E1" + assert s3.s3_object.key == "b21b84d653bb07b05b1e6b33684dc11b" + assert s3.s3_object.size == 1305107 + assert s3.s3_object.etag == "b21b84d653bb07b05b1e6b33684dc11b" + assert s3.s3_object.version_id is None + assert s3.s3_object.sequencer == "0C0F6F405D6ED209E1" assert record.glacier_event_data is None @@ -127,9 +127,9 @@ def test_sns_trigger_event(): assert sns.message == "Hello from SNS!" message_attributes = sns.message_attributes test_message_attribute = message_attributes["Test"] - assert test_message_attribute.type == "String" + assert test_message_attribute.attribute_type == "String" assert test_message_attribute.value == "TestString" - assert sns.type == "Notification" + assert sns.message_type == "Notification" assert sns.unsubscribe_url == "https://sns.us-east-2.amazonaws.com/?Action=Unsubscri ..." assert sns.topic_arn == "arn:aws:sns:us-east-2:123456789012:sns-lambda" assert sns.subject == "TestInvoke" From 112635de462c6ec4552a3a309191eaa669842010 Mon Sep 17 00:00:00 2001 From: Michael Brewer Date: Mon, 7 Sep 2020 08:36:38 -0700 Subject: [PATCH 04/30] tests: add missing test cases --- .../utilities/trigger/ses_event.py | 9 ++- .../functional/test_lambda_trigger_events.py | 72 ++++++++++++++++++- 2 files changed, 77 insertions(+), 4 deletions(-) diff --git a/aws_lambda_powertools/utilities/trigger/ses_event.py b/aws_lambda_powertools/utilities/trigger/ses_event.py index 375d919f9f1..5bcfb310d37 100644 --- a/aws_lambda_powertools/utilities/trigger/ses_event.py +++ b/aws_lambda_powertools/utilities/trigger/ses_event.py @@ -17,8 +17,10 @@ def return_path(self) -> str: return self["returnPath"] @property - def form(self) -> List[str]: - return self["form"] + def from_header(self) -> List[str]: + """Get the `from` common header as a list""" + # Note: this conflicts with existing python builtins + return self["from"] @property def date(self) -> List[str]: @@ -78,6 +80,7 @@ class SESReceiptAction(dict): @property def action_type(self) -> str: """Get the `type` property""" + # Note: this conflicts with existing python builtins return self["type"] @property @@ -99,7 +102,7 @@ def processing_time_millis(self) -> int: return int(self["processingTimeMillis"]) @property - def recipients(self) -> Iterator[str]: + def recipients(self) -> List[str]: return self["recipients"] @property diff --git a/tests/functional/test_lambda_trigger_events.py b/tests/functional/test_lambda_trigger_events.py index e0c2475f957..ed61c281110 100644 --- a/tests/functional/test_lambda_trigger_events.py +++ b/tests/functional/test_lambda_trigger_events.py @@ -3,6 +3,7 @@ from aws_lambda_powertools.utilities.trigger import CloudWatchLogsEvent, S3Event, SESEvent, SNSEvent, SQSEvent from aws_lambda_powertools.utilities.trigger.dynamo_db_stream_event import ( + AttributeValue, DynamoDBRecordEventName, DynamoDBStreamEvent, StreamViewType, @@ -71,6 +72,24 @@ def test_dynamo_db_stream_trigger_event(): assert record.user_identity is None +def test_dynamo_attribute_value_list_value(): + example_attribute_value = {"L": [{"S": "Cookies"}, {"S": "Coffee"}, {"N": "3.14159"}]} + attribute_value = AttributeValue(example_attribute_value) + list_value = attribute_value.list_value + assert list_value is not None + item = list_value[0] + assert item.s_value == "Cookies" + + +def test_dynamo_attribute_value_map_value(): + example_attribute_value = {"M": {"Name": {"S": "Joe"}, "Age": {"N": "35"}}} + attribute_value = AttributeValue(example_attribute_value) + map_value = attribute_value.map_value + assert map_value is not None + item = map_value["Name"] + assert item.s_value == "Joe" + + def test_s3_trigger_event(): event = S3Event(load_event("s3Event.json")) records = list(event.records) @@ -101,13 +120,64 @@ def test_s3_trigger_event(): assert record.glacier_event_data is None +def test_s3_glacier_event(): + example_event = { + "Records": [ + { + "glacierEventData": { + "restoreEventData": { + "lifecycleRestorationExpiryTime": "1970-01-01T00:01:00.000Z", + "lifecycleRestoreStorageClass": "standard", + } + } + } + ] + } + event = S3Event(example_event) + record = next(event.records) + glacier_event_data = record.glacier_event_data + assert glacier_event_data is not None + assert glacier_event_data.restore_event_data.lifecycle_restoration_expiry_time == "1970-01-01T00:01:00.000Z" + assert glacier_event_data.restore_event_data.lifecycle_restore_storage_class == "standard" + + def test_ses_trigger_event(): event = SESEvent(load_event("sesEvent.json")) + expected_address = "johndoe@example.com" records = list(event.records) record = records[0] - print(record) assert record.event_source == "aws:ses" + assert record.event_version == "1.0" + mail = record.ses.mail + assert mail.timestamp == "1970-01-01T00:00:00.000Z" + assert mail.source == "janedoe@example.com" + assert mail.message_id == "o3vrnil0e2ic28tr" + assert mail.destination == [expected_address] + assert mail.headers_truncated is False + headers = list(mail.headers) + assert len(headers) == 10 + assert headers[0].name == "Return-Path" + assert headers[0].value == "" + common_headers = mail.common_headers + assert common_headers.return_path == "janedoe@example.com" + assert common_headers.from_header == common_headers["from"] + assert common_headers.date == "Wed, 7 Oct 2015 12:34:56 -0700" + assert common_headers.to == [expected_address] + assert common_headers.message_id == "<0123456789example.com>" + assert common_headers.subject == "Test Subject" + receipt = record.ses.receipt + assert receipt.timestamp == "1970-01-01T00:00:00.000Z" + assert receipt.processing_time_millis == 574 + assert receipt.recipients == [expected_address] + assert receipt.spam_verdict.status == "PASS" + assert receipt.virus_verdict.status == "PASS" + assert receipt.spf_verdict.status == "PASS" + assert receipt.dmarc_verdict.status == "PASS" + action = receipt.action + assert action.action_type == action["type"] + assert action.function_arn == action["functionArn"] + assert action.invocation_type == action["invocationType"] def test_sns_trigger_event(): From d9253f574298913d7643ccdb63ca8279aec4ce31 Mon Sep 17 00:00:00 2001 From: Michael Brewer Date: Mon, 7 Sep 2020 13:00:34 -0700 Subject: [PATCH 05/30] docs: Include some docstrings Tracked down some of the AWS docs to inline with the dict wrapper classes. --- .../trigger/cloud_watch_logs_event.py | 25 +++++++- .../trigger/dynamo_db_stream_event.py | 63 ++++++++++++++++--- .../utilities/trigger/s3_event.py | 29 +++++++++ .../utilities/trigger/ses_event.py | 52 ++++++++++++++- .../functional/test_lambda_trigger_events.py | 2 +- 5 files changed, 155 insertions(+), 16 deletions(-) diff --git a/aws_lambda_powertools/utilities/trigger/cloud_watch_logs_event.py b/aws_lambda_powertools/utilities/trigger/cloud_watch_logs_event.py index 63d534c57db..13d463c4322 100644 --- a/aws_lambda_powertools/utilities/trigger/cloud_watch_logs_event.py +++ b/aws_lambda_powertools/utilities/trigger/cloud_watch_logs_event.py @@ -7,7 +7,7 @@ class CloudWatchLogsLogEvent(dict): @property def log_event_id(self) -> str: - """Get the `id` property""" + """The ID property is a unique identifier for every log event.""" return self["id"] @property @@ -29,42 +29,61 @@ def extracted_fields(self) -> Optional[Dict[str, str]]: class CloudWatchLogsDecodedData(dict): @property def owner(self) -> str: + """The AWS Account ID of the originating log data.""" return self["owner"] @property def log_group(self) -> str: + """The log group name of the originating log data.""" return self["logGroup"] @property def log_stream(self) -> str: + """The log stream name of the originating log data.""" return self["logStream"] @property def subscription_filters(self) -> List[str]: + """The list of subscription filter names that matched with the originating log data.""" return self["subscriptionFilters"] @property def message_type(self) -> str: + """Data messages will use the "DATA_MESSAGE" type. + + Sometimes CloudWatch Logs may emit Kinesis records with a "CONTROL_MESSAGE" type, + mainly for checking if the destination is reachable. + """ return self["messageType"] @property def log_events(self) -> List[CloudWatchLogsLogEvent]: + """The actual log data, represented as an array of log event records. + + The ID property is a unique identifier for every log event. + """ return [CloudWatchLogsLogEvent(i) for i in self["logEvents"]] class CloudWatchLogsEventData(dict): @property def data(self) -> str: + """The value of the `data` field is a Base64 encoded ZIP archive.""" return self["data"] class CloudWatchLogsEvent(dict): + """CloudWatch Logs log stream event + + Documentation: https://docs.aws.amazon.com/lambda/latest/dg/services-cloudwatchlogs.html + """ + @property def aws_logs(self) -> CloudWatchLogsEventData: return CloudWatchLogsEventData(self["awslogs"]) - def cloud_watch_logs_decoded_data(self) -> CloudWatchLogsDecodedData: - """Gzip and parse data""" + def decode_cloud_watch_logs_data(self) -> CloudWatchLogsDecodedData: + """Gzip and parse json data""" payload = base64.b64decode(self.aws_logs.data) decoded: dict = json.loads(zlib.decompress(payload, zlib.MAX_WBITS | 32).decode("UTF-8")) return CloudWatchLogsDecodedData(decoded) diff --git a/aws_lambda_powertools/utilities/trigger/dynamo_db_stream_event.py b/aws_lambda_powertools/utilities/trigger/dynamo_db_stream_event.py index 7dc8cf73915..5abbc450405 100644 --- a/aws_lambda_powertools/utilities/trigger/dynamo_db_stream_event.py +++ b/aws_lambda_powertools/utilities/trigger/dynamo_db_stream_event.py @@ -10,55 +10,98 @@ class AttributeValue(dict): @property def b_value(self) -> Optional[str]: - """An attribute of type Base64-encoded binary data object""" + """An attribute of type Base64-encoded binary data object + + Example: + >>> {"B": "dGhpcyB0ZXh0IGlzIGJhc2U2NC1lbmNvZGVk"} + """ return self.get("B") @property def bs_value(self) -> Optional[List[str]]: - """An attribute of type Array of Base64-encoded binary data objects""" + """An attribute of type Array of Base64-encoded binary data objects + + Example: + >>> {"BS": ["U3Vubnk=", "UmFpbnk=", "U25vd3k="]} + """ return self.get("BS") @property def bool_value(self) -> Optional[bool]: - """An attribute of type Boolean""" + """An attribute of type Boolean + + Example: + >>> {"BOOL": True} + """ item = self.get("bool") return None if item is None else bool(item) @property def list_value(self) -> Optional[List["AttributeValue"]]: - """An attribute of type Array of AttributeValue objects""" + """An attribute of type Array of AttributeValue objects + + Example: + >>> {"L": [ {"S": "Cookies"} , {"S": "Coffee"}, {"N": "3.14159"}]} + """ item = self.get("L") return None if item is None else [AttributeValue(i) for i in item] @property def map_value(self) -> Optional[Dict[str, "AttributeValue"]]: - """An attribute of type String to AttributeValue object map""" + """An attribute of type String to AttributeValue object map + + Example: + >>> {"M": {"Name": {"S": "Joe"}, "Age": {"N": "35"}}} + """ return _attribute_value(self, "M") @property def n_value(self) -> Optional[str]: - """An attribute of type Number""" + """An attribute of type Number + + Numbers are sent across the network to DynamoDB as strings, to maximize compatibility across languages + and libraries. However, DynamoDB treats them as number type attributes for mathematical operations. + + Example: + >>> {"N": "123.45"} + """ return self.get("N") @property def ns_value(self) -> Optional[List[str]]: - """An attribute of type Number Set""" + """An attribute of type Number Set + + Example: + >>> {"NS": ["42.2", "-19", "7.5", "3.14"]} + """ return self.get("NS") @property def null_value(self) -> Optional[bool]: - """An attribute of type Null.""" + """An attribute of type Null. + + Example: + >>> {"NULL": True} + """ item = self.get("NULL") return None if item is None else bool(item) @property def s_value(self) -> Optional[str]: - """An attribute of type String""" + """An attribute of type String + + Example: + >>> {"S": "Hello"} + """ return self.get("S") @property def ss_value(self) -> Optional[List[str]]: - """An attribute of type Array of strings""" + """An attribute of type Array of strings + + Example: + >>> {"SS": ["Giraffe", "Hippo" ,"Zebra"]} + """ return self.get("SS") diff --git a/aws_lambda_powertools/utilities/trigger/s3_event.py b/aws_lambda_powertools/utilities/trigger/s3_event.py index b0099452690..99616ea78f8 100644 --- a/aws_lambda_powertools/utilities/trigger/s3_event.py +++ b/aws_lambda_powertools/utilities/trigger/s3_event.py @@ -30,22 +30,29 @@ def arn(self) -> str: class S3Object(dict): @property def key(self) -> str: + """Object key""" return self["key"] @property def size(self) -> int: + """Object byte size""" return int(self["size"]) @property def etag(self) -> str: + """object eTag""" return self["eTag"] @property def version_id(self) -> Optional[str]: + """Object version if bucket is versioning-enabled, otherwise null""" return self.get("versionId") @property def sequencer(self) -> str: + """A string representation of a hexadecimal value used to determine event sequence, + only used with PUTs and DELETEs + """ return self["sequencer"] @@ -56,6 +63,7 @@ def s3_schema_version(self) -> str: @property def configuration_id(self) -> str: + """ID found in the bucket notification configuration""" return self["configurationId"] @property @@ -81,16 +89,19 @@ def lifecycle_restore_storage_class(self) -> str: class S3EventRecordGlacierEventData(dict): @property def restore_event_data(self) -> S3EventRecordGlacierRestoreEventData: + """The restoreEventData key contains attributes related to your restore request.""" return S3EventRecordGlacierRestoreEventData(self["restoreEventData"]) class S3EventRecord(dict): @property def event_version(self) -> str: + """The eventVersion key value contains a major and minor version in the form ..""" return self["eventVersion"] @property def event_source(self) -> str: + """aws:s3""" return self["eventSource"] @property @@ -99,10 +110,13 @@ def aws_region(self) -> str: @property def event_time(self) -> str: + """The time, in ISO-8601 format, for example, 1970-01-01T00:00:00.000Z, when S3 finished + processing the request""" return self["eventTime"] @property def event_name(self) -> str: + """Event type""" return self["eventName"] @property @@ -115,6 +129,12 @@ def request_parameters(self) -> S3RequestParameters: @property def response_elements(self) -> Dict[str, str]: + """The responseElements key value is useful if you want to trace a request by following up with AWS Support. + + Both x-amz-request-id and x-amz-id-2 help Amazon S3 trace an individual request. These values are the same + as those that Amazon S3 returns in the response to the request that initiates the events, so they can be + used to match the event to the request. + """ return self["responseElements"] @property @@ -123,11 +143,20 @@ def s3(self) -> S3Message: @property def glacier_event_data(self) -> Optional[S3EventRecordGlacierEventData]: + """The glacierEventData key is only visible for s3:ObjectRestore:Completed events.""" item = self.get("glacierEventData") return None if item is None else S3EventRecordGlacierEventData(item) class S3Event(dict): + """S3 event notification + + Documentation: + https://docs.aws.amazon.com/lambda/latest/dg/with-s3.html + https://docs.aws.amazon.com/AmazonS3/latest/dev/NotificationHowTo.html + https://docs.aws.amazon.com/AmazonS3/latest/dev/notification-content-structure.html + """ + @property def records(self) -> Iterator[S3EventRecord]: for record in self["Records"]: diff --git a/aws_lambda_powertools/utilities/trigger/ses_event.py b/aws_lambda_powertools/utilities/trigger/ses_event.py index 5bcfb310d37..b798495e379 100644 --- a/aws_lambda_powertools/utilities/trigger/ses_event.py +++ b/aws_lambda_powertools/utilities/trigger/ses_event.py @@ -14,59 +14,78 @@ def value(self) -> str: class SESMailCommonHeaders(dict): @property def return_path(self) -> str: + """The values in the Return-Path header of the email.""" return self["returnPath"] @property def from_header(self) -> List[str]: - """Get the `from` common header as a list""" + """The values in the From header of the email.""" # Note: this conflicts with existing python builtins return self["from"] @property def date(self) -> List[str]: + """The date and time when Amazon SES received the message.""" return self["date"] @property def to(self) -> List[str]: + """The values in the To header of the email.""" return self["to"] @property def message_id(self) -> str: + """The ID of the original message.""" return str(self["messageId"]) @property def subject(self) -> str: + """The value of the Subject header for the email.""" return str(self["subject"]) class SESMail(dict): @property def timestamp(self) -> str: + """String that contains the time at which the email was received, in ISO8601 format.""" return self["timestamp"] @property def source(self) -> str: + """String that contains the email address (specifically, the envelope MAIL FROM address) + that the email was sent from.""" return self["source"] @property def message_id(self) -> str: + """String that contains the unique ID assigned to the email by Amazon SES. + + If the email was delivered to Amazon S3, the message ID is also the Amazon S3 object key that was + used to write the message to your Amazon S3 bucket.""" return self["messageId"] @property def destination(self) -> List[str]: + """A complete list of all recipient addresses (including To: and CC: recipients) + from the MIME headers of the incoming email.""" return self["destination"] @property def headers_truncated(self) -> bool: + """String that specifies whether the headers were truncated in the notification, which will happen + if the headers are larger than 10 KB. Possible values are true and false.""" return bool(self["headersTruncated"]) @property def headers(self) -> Iterator[SESMailHeader]: + """A list of Amazon SES headers and your custom headers. + Each header in the list has a name field and a value field""" for header in self["headers"]: yield SESMailHeader(header) @property def common_headers(self) -> SESMailCommonHeaders: + """A list of headers common to all emails. Each header in the list is composed of a name and a value.""" return SESMailCommonHeaders(self["commonHeaders"]) @@ -79,50 +98,69 @@ def status(self) -> str: class SESReceiptAction(dict): @property def action_type(self) -> str: - """Get the `type` property""" + """String that indicates the type of action that was executed. + + Possible values are S3, SNS, Bounce, Lambda, Stop, and WorkMail + """ # Note: this conflicts with existing python builtins return self["type"] @property def function_arn(self) -> str: + """String that contains the ARN of the Lambda function that was triggered. + Present only for the Lambda action type.""" return self["functionArn"] @property def invocation_type(self) -> str: + """String that contains the invocation type of the Lambda function. Possible values are RequestResponse + and Event. Present only for the Lambda action type.""" return self["invocationType"] class SESReceipt(dict): @property def timestamp(self) -> str: + """String that specifies the date and time at which the action was triggered, in ISO 8601 format.""" return self["timestamp"] @property def processing_time_millis(self) -> int: + """String that specifies the period, in milliseconds, from the time Amazon SES received the message + to the time it triggered the action.""" return int(self["processingTimeMillis"]) @property def recipients(self) -> List[str]: + """A list of recipients (specifically, the envelope RCPT TO addresses) that were matched by the + active receipt rule. The addresses listed here may differ from those listed by the destination + field in the mail object.""" return self["recipients"] @property def spam_verdict(self) -> SESReceiptStatus: + """Object that indicates whether the message is spam.""" return SESReceiptStatus(self["spamVerdict"]) @property def virus_verdict(self) -> SESReceiptStatus: + """Object that indicates whether the message contains a virus.""" return SESReceiptStatus(self["virusVerdict"]) @property def spf_verdict(self) -> SESReceiptStatus: + """Object that indicates whether the Sender Policy Framework (SPF) check passed.""" return SESReceiptStatus(self["spfVerdict"]) @property def dmarc_verdict(self) -> SESReceiptStatus: + """Object that indicates whether the Domain-based Message Authentication, + Reporting & Conformance (DMARC) check passed.""" return SESReceiptStatus(self["dmarcVerdict"]) @property def action(self) -> SESReceiptAction: + """Object that encapsulates information about the action that was executed.""" return SESReceiptAction(self["action"]) @@ -139,6 +177,7 @@ def receipt(self) -> SESReceipt: class SESEventRecord(dict): @property def event_source(self) -> str: + """event source will be: aws:ses""" return self["eventSource"] @property @@ -151,6 +190,15 @@ def ses(self) -> SESMessage: class SESEvent(dict): + """Amazon SES to receive message event trigger + + NOTE: There is a 30-second timeout on RequestResponse invocations. + + Documentation: + https://docs.aws.amazon.com/lambda/latest/dg/services-ses.html + https://docs.aws.amazon.com/ses/latest/DeveloperGuide/receiving-email-action-lambda.html + """ + @property def records(self) -> Iterator[SESEventRecord]: for record in self["Records"]: diff --git a/tests/functional/test_lambda_trigger_events.py b/tests/functional/test_lambda_trigger_events.py index ed61c281110..e35e42fe984 100644 --- a/tests/functional/test_lambda_trigger_events.py +++ b/tests/functional/test_lambda_trigger_events.py @@ -19,7 +19,7 @@ def load_event(file_name: str) -> dict: def test_cloud_watch_trigger_event(): event = CloudWatchLogsEvent(load_event("cloudWatchLogEvent.json")) - decoded_data = event.cloud_watch_logs_decoded_data() + decoded_data = event.decode_cloud_watch_logs_data() log_events = decoded_data.log_events log_event = log_events[0] From 057c947be8b7e111560cd91c750ec86cc7a49585 Mon Sep 17 00:00:00 2001 From: Michael Brewer Date: Mon, 7 Sep 2020 22:22:18 -0700 Subject: [PATCH 06/30] feat(trigger): initial cognito triggers --- .../utilities/trigger/__init__.py | 3 + .../trigger/cognito_user_pool_event.py | 139 ++++++++++++++++++ .../cognitoPostConfirmationTriggerEvent.json | 18 +++ .../events/cognitoPreSignUpTriggerEvent.json | 18 +++ .../functional/test_lambda_trigger_events.py | 44 +++++- 5 files changed, 221 insertions(+), 1 deletion(-) create mode 100644 aws_lambda_powertools/utilities/trigger/cognito_user_pool_event.py create mode 100644 tests/events/cognitoPostConfirmationTriggerEvent.json create mode 100644 tests/events/cognitoPreSignUpTriggerEvent.json diff --git a/aws_lambda_powertools/utilities/trigger/__init__.py b/aws_lambda_powertools/utilities/trigger/__init__.py index 4f6d16231b6..f12fccaa986 100644 --- a/aws_lambda_powertools/utilities/trigger/__init__.py +++ b/aws_lambda_powertools/utilities/trigger/__init__.py @@ -1,4 +1,5 @@ from .cloud_watch_logs_event import CloudWatchLogsEvent +from .cognito_user_pool_event import PostConfirmationTriggerEvent, PreSignUpTriggerEvent from .dynamo_db_stream_event import DynamoDBStreamEvent from .s3_event import S3Event from .ses_event import SESEvent @@ -7,6 +8,8 @@ __all__ = [ "CloudWatchLogsEvent", + "PreSignUpTriggerEvent", + "PostConfirmationTriggerEvent", "DynamoDBStreamEvent", "S3Event", "SESEvent", diff --git a/aws_lambda_powertools/utilities/trigger/cognito_user_pool_event.py b/aws_lambda_powertools/utilities/trigger/cognito_user_pool_event.py new file mode 100644 index 00000000000..2a94fc91c58 --- /dev/null +++ b/aws_lambda_powertools/utilities/trigger/cognito_user_pool_event.py @@ -0,0 +1,139 @@ +from typing import Dict, Optional + + +class CallerContext(dict): + @property + def aws_sdk_version(self) -> str: + """The AWS SDK version number.""" + return self["awsSdkVersion"] + + @property + def client_id(self) -> str: + """The ID of the client associated with the user pool.""" + return self["clientId"] + + +class BaseTriggerEvent(dict): + """Common attributes shared by all User Pool Lambda Trigger Events + + Documentation: + https://docs.aws.amazon.com/cognito/latest/developerguide/cognito-user-identity-pools-working-with-aws-lambda-triggers.html + """ + + @property + def version(self) -> str: + """The version number of your Lambda function.""" + return self["version"] + + @property + def region(self) -> str: + """The AWS Region, as an AWSRegion instance.""" + return self["region"] + + @property + def user_pool_id(self) -> str: + """The user pool ID for the user pool.""" + return self["userPoolId"] + + @property + def trigger_source(self) -> str: + """The name of the event that triggered the Lambda function.""" + return self["triggerSource"] + + @property + def user_name(self) -> str: + """The username of the current user.""" + return self["userName"] + + @property + def caller_context(self) -> CallerContext: + """The caller context""" + return CallerContext(self["callerContext"]) + + +class PreSignUpTriggerEventRequest(dict): + @property + def user_attributes(self) -> Dict[str, str]: + """One or more name-value pairs representing user attributes. The attribute names are the keys.""" + return self["userAttributes"] + + @property + def validation_data(self) -> Optional[Dict[str, str]]: + """One or more name-value pairs containing the validation data in the request to register a user.""" + return self.get("validationData") + + @property + def client_metadata(self) -> Optional[Dict[str, str]]: + """One or more key-value pairs that you can provide as custom input to the Lambda function + that you specify for the pre sign-up trigger.""" + return self.get("clientMetadata") + + +class PreSignUpTriggerEventResponse(dict): + @property + def auto_confirm_user(self) -> bool: + """Set to true to auto-confirm the user, or false otherwise.""" + return bool(self["response"]["autoConfirmUser"]) + + @auto_confirm_user.setter + def auto_confirm_user(self, value: bool): + self["response"]["autoConfirmUser"] = value + + @property + def auto_verify_email(self) -> bool: + """Set to true to set as verified the email of a user who is signing up, or false otherwise.""" + return bool(self["response"]["autoVerifyEmail"]) + + @auto_verify_email.setter + def auto_verify_email(self, value: bool): + self["response"]["autoVerifyEmail"] = value + + @property + def auto_verify_phone(self) -> bool: + """Set to true to set as verified the phone number of a user who is signing up, or false otherwise.""" + return bool(self["response"]["autoVerifyPhone"]) + + @auto_verify_phone.setter + def auto_verify_phone(self, value: bool): + self["response"]["autoVerifyPhone"] = value + + +class PreSignUpTriggerEvent(BaseTriggerEvent): + """Pre Sign-up Lambda Trigger + + Documentation: + https://docs.aws.amazon.com/cognito/latest/developerguide/user-pool-lambda-pre-sign-up.html + """ + + @property + def request(self) -> PreSignUpTriggerEventRequest: + return PreSignUpTriggerEventRequest(self["request"]) + + @property + def response(self) -> PreSignUpTriggerEventResponse: + return PreSignUpTriggerEventResponse(self) + + +class PostConfirmationTriggerEventRequest(dict): + @property + def user_attributes(self) -> Dict[str, str]: + """One or more name-value pairs representing user attributes. The attribute names are the keys.""" + return self["userAttributes"] + + @property + def client_metadata(self) -> Optional[Dict[str, str]]: + """One or more key-value pairs that you can provide as custom input to the Lambda function that you + specify for the post confirmation trigger.""" + return self.get("clientMetadata") + + +class PostConfirmationTriggerEvent(BaseTriggerEvent): + """Post Confirmation Lambda Trigger + + Documentation: + https://docs.aws.amazon.com/cognito/latest/developerguide/user-pool-lambda-post-confirmation.html + """ + + @property + def request(self) -> PostConfirmationTriggerEventRequest: + return PostConfirmationTriggerEventRequest(self["request"]) diff --git a/tests/events/cognitoPostConfirmationTriggerEvent.json b/tests/events/cognitoPostConfirmationTriggerEvent.json new file mode 100644 index 00000000000..e88f98150ca --- /dev/null +++ b/tests/events/cognitoPostConfirmationTriggerEvent.json @@ -0,0 +1,18 @@ +{ + "version": "string", + "triggerSource": "PostConfirmation_ConfirmSignUp", + "region": "us-east-1", + "userPoolId": "string", + "userName": "userName", + "callerContext": { + "awsSdkVersion": "awsSdkVersion", + "clientId": "clientId" + }, + "request": { + "userAttributes": { + "email": "user@example.com", + "email_verified": true + } + }, + "response": {} +} diff --git a/tests/events/cognitoPreSignUpTriggerEvent.json b/tests/events/cognitoPreSignUpTriggerEvent.json new file mode 100644 index 00000000000..feb4eba25dd --- /dev/null +++ b/tests/events/cognitoPreSignUpTriggerEvent.json @@ -0,0 +1,18 @@ +{ + "version": "string", + "triggerSource": "PreSignUp_SignUp", + "region": "us-east-1", + "userPoolId": "string", + "userName": "userName", + "callerContext": { + "awsSdkVersion": "awsSdkVersion", + "clientId": "clientId" + }, + "request": { + "userAttributes": { + "email": "user@example.com", + "phone_number": "+12065550100" + } + }, + "response": {} +} diff --git a/tests/functional/test_lambda_trigger_events.py b/tests/functional/test_lambda_trigger_events.py index e35e42fe984..24e99a53951 100644 --- a/tests/functional/test_lambda_trigger_events.py +++ b/tests/functional/test_lambda_trigger_events.py @@ -1,7 +1,15 @@ import json import os -from aws_lambda_powertools.utilities.trigger import CloudWatchLogsEvent, S3Event, SESEvent, SNSEvent, SQSEvent +from aws_lambda_powertools.utilities.trigger import ( + CloudWatchLogsEvent, + PostConfirmationTriggerEvent, + PreSignUpTriggerEvent, + S3Event, + SESEvent, + SNSEvent, + SQSEvent, +) from aws_lambda_powertools.utilities.trigger.dynamo_db_stream_event import ( AttributeValue, DynamoDBRecordEventName, @@ -35,6 +43,40 @@ def test_cloud_watch_trigger_event(): assert log_event.extracted_fields is None +def test_cognito_pre_signup_trigger_event(): + event = PreSignUpTriggerEvent(load_event("cognitoPreSignUpTriggerEvent.json")) + + assert event.version == "string" + assert event.trigger_source == "PreSignUp_SignUp" + assert event.region == "us-east-1" + assert event.user_pool_id == "string" + assert event.user_name == "userName" + caller_context = event.caller_context + assert caller_context.aws_sdk_version == "awsSdkVersion" + assert caller_context.client_id == "clientId" + + user_attributes = event.request.user_attributes + assert user_attributes["email"] == "user@example.com" + + assert event.request.validation_data is None + assert event.request.client_metadata is None + + event.response.auto_confirm_user = True + assert event.response.auto_confirm_user is True + event.response.auto_verify_phone = True + assert event.response.auto_verify_phone is True + event.response.auto_verify_email = True + assert event.response.auto_verify_email is True + + +def test_cognito_post_confirmation_trigger_event(): + event = PostConfirmationTriggerEvent(load_event("cognitoPostConfirmationTriggerEvent.json")) + + user_attributes = event.request.user_attributes + assert user_attributes["email"] == "user@example.com" + assert event.request.client_metadata is None + + def test_dynamo_db_stream_trigger_event(): event = DynamoDBStreamEvent(load_event("dynamoStreamEvent.json")) From 84b63a3a31ee9674b3218164ef0228bc6d1feb11 Mon Sep 17 00:00:00 2001 From: Michael Brewer Date: Tue, 8 Sep 2020 09:32:00 -0700 Subject: [PATCH 07/30] feat(trigger): add event_bridge_event --- .../utilities/trigger/__init__.py | 2 + .../utilities/trigger/event_bridge_event.py | 60 +++++++++++++++++++ tests/events/eventBridgeEvent.json | 16 +++++ .../functional/test_lambda_trigger_events.py | 15 +++++ 4 files changed, 93 insertions(+) create mode 100644 aws_lambda_powertools/utilities/trigger/event_bridge_event.py create mode 100644 tests/events/eventBridgeEvent.json diff --git a/aws_lambda_powertools/utilities/trigger/__init__.py b/aws_lambda_powertools/utilities/trigger/__init__.py index f12fccaa986..b934b662057 100644 --- a/aws_lambda_powertools/utilities/trigger/__init__.py +++ b/aws_lambda_powertools/utilities/trigger/__init__.py @@ -1,6 +1,7 @@ from .cloud_watch_logs_event import CloudWatchLogsEvent from .cognito_user_pool_event import PostConfirmationTriggerEvent, PreSignUpTriggerEvent from .dynamo_db_stream_event import DynamoDBStreamEvent +from .event_bridge_event import EventBridgeEvent from .s3_event import S3Event from .ses_event import SESEvent from .sns_event import SNSEvent @@ -11,6 +12,7 @@ "PreSignUpTriggerEvent", "PostConfirmationTriggerEvent", "DynamoDBStreamEvent", + "EventBridgeEvent", "S3Event", "SESEvent", "SNSEvent", diff --git a/aws_lambda_powertools/utilities/trigger/event_bridge_event.py b/aws_lambda_powertools/utilities/trigger/event_bridge_event.py new file mode 100644 index 00000000000..bc0520ae50d --- /dev/null +++ b/aws_lambda_powertools/utilities/trigger/event_bridge_event.py @@ -0,0 +1,60 @@ +from typing import Any, Dict, List + + +class EventBridgeEvent(dict): + """Amazon EventBridge Event + + Documentation: + https://docs.aws.amazon.com/eventbridge/latest/userguide/aws-events.html + """ + + @property + def event_id(self) -> str: + """A unique value is generated for every event. This can be helpful in tracing events as + they move through rules to targets, and are processed.""" + return self["id"] + + @property + def version(self) -> str: + """By default, this is set to 0 (zero) in all events.""" + return self["version"] + + @property + def account(self) -> str: + """The 12-digit number identifying an AWS account.""" + return self["account"] + + @property + def time(self) -> str: + """The event timestamp, which can be specified by the service originating the event. + + If the event spans a time interval, the service might choose to report the start time, so + this value can be noticeably before the time the event is actually received. + """ + return self["time"] + + @property + def region(self) -> str: + """Identifies the AWS region where the event originated.""" + return self["region"] + + @property + def resources(self) -> List[str]: + """This JSON array contains ARNs that identify resources that are involved in the event. + Inclusion of these ARNs is at the discretion of the service.""" + return self["resources"] + + @property + def source(self) -> str: + """Identifies the service that sourced the event. All events sourced from within AWS begin with "aws." """ + return self["source"] + + @property + def detail_type(self) -> str: + """Identifies, in combination with the source field, the fields and values that appear in the detail field.""" + return self["detail-type"] + + @property + def detail(self) -> Dict[str, Any]: + """A JSON object, whose content is at the discretion of the service originating the event. """ + return self["detail"] diff --git a/tests/events/eventBridgeEvent.json b/tests/events/eventBridgeEvent.json new file mode 100644 index 00000000000..e8d949001c9 --- /dev/null +++ b/tests/events/eventBridgeEvent.json @@ -0,0 +1,16 @@ +{ + "version": "0", + "id": "6a7e8feb-b491-4cf7-a9f1-bf3703467718", + "detail-type": "EC2 Instance State-change Notification", + "source": "aws.ec2", + "account": "111122223333", + "time": "2017-12-22T18:43:48Z", + "region": "us-west-1", + "resources": [ + "arn:aws:ec2:us-west-1:123456789012:instance/ i-1234567890abcdef0" + ], + "detail": { + "instance-id": " i-1234567890abcdef0", + "state": "terminated" + } +} diff --git a/tests/functional/test_lambda_trigger_events.py b/tests/functional/test_lambda_trigger_events.py index 24e99a53951..9f2b9b3a5c3 100644 --- a/tests/functional/test_lambda_trigger_events.py +++ b/tests/functional/test_lambda_trigger_events.py @@ -3,6 +3,7 @@ from aws_lambda_powertools.utilities.trigger import ( CloudWatchLogsEvent, + EventBridgeEvent, PostConfirmationTriggerEvent, PreSignUpTriggerEvent, S3Event, @@ -132,6 +133,20 @@ def test_dynamo_attribute_value_map_value(): assert item.s_value == "Joe" +def test_event_bridge_event(): + event = EventBridgeEvent(load_event("eventBridgeEvent.json")) + + assert event.event_id == event["id"] + assert event.version == event["version"] + assert event.account == event["account"] + assert event.time == event["time"] + assert event.region == event["region"] + assert event.resources == event["resources"] + assert event.source == event["source"] + assert event.detail_type == event["detail-type"] + assert event.detail == event["detail"] + + def test_s3_trigger_event(): event = S3Event(load_event("s3Event.json")) records = list(event.records) From c1d516bb4821c4fb4f22a1d19226a8fa3a82a2f4 Mon Sep 17 00:00:00 2001 From: Michael Brewer Date: Tue, 8 Sep 2020 10:14:08 -0700 Subject: [PATCH 08/30] feat(trigger): use consistent getter name For cases when the dict name conflicts with a python builtin we should use a consistent method name --- .../trigger/cloud_watch_logs_event.py | 3 +- .../utilities/trigger/event_bridge_event.py | 3 +- .../utilities/trigger/s3_event.py | 33 ++++++++++--------- .../utilities/trigger/ses_event.py | 8 ++--- .../utilities/trigger/sns_event.py | 6 ++-- .../functional/test_lambda_trigger_events.py | 22 ++++++------- 6 files changed, 40 insertions(+), 35 deletions(-) diff --git a/aws_lambda_powertools/utilities/trigger/cloud_watch_logs_event.py b/aws_lambda_powertools/utilities/trigger/cloud_watch_logs_event.py index 13d463c4322..10dc0332e18 100644 --- a/aws_lambda_powertools/utilities/trigger/cloud_watch_logs_event.py +++ b/aws_lambda_powertools/utilities/trigger/cloud_watch_logs_event.py @@ -6,8 +6,9 @@ class CloudWatchLogsLogEvent(dict): @property - def log_event_id(self) -> str: + def get_id(self) -> str: """The ID property is a unique identifier for every log event.""" + # Note: this name conflicts with existing python builtins return self["id"] @property diff --git a/aws_lambda_powertools/utilities/trigger/event_bridge_event.py b/aws_lambda_powertools/utilities/trigger/event_bridge_event.py index bc0520ae50d..74180bb2328 100644 --- a/aws_lambda_powertools/utilities/trigger/event_bridge_event.py +++ b/aws_lambda_powertools/utilities/trigger/event_bridge_event.py @@ -9,9 +9,10 @@ class EventBridgeEvent(dict): """ @property - def event_id(self) -> str: + def get_id(self) -> str: """A unique value is generated for every event. This can be helpful in tracing events as they move through rules to targets, and are processed.""" + # Note: this name conflicts with existing python builtins return self["id"] @property diff --git a/aws_lambda_powertools/utilities/trigger/s3_event.py b/aws_lambda_powertools/utilities/trigger/s3_event.py index 99616ea78f8..0b53e926e91 100644 --- a/aws_lambda_powertools/utilities/trigger/s3_event.py +++ b/aws_lambda_powertools/utilities/trigger/s3_event.py @@ -10,70 +10,71 @@ def principal_id(self) -> str: class S3RequestParameters(dict): @property def source_ip_address(self) -> str: - return self["sourceIPAddress"] + return self["requestParameters"]["sourceIPAddress"] class S3Bucket(dict): @property def name(self) -> str: - return self["name"] + return self["s3"]["bucket"]["name"] @property def owner_identity(self) -> S3Identity: - return S3Identity(self["ownerIdentity"]) + return S3Identity(self["s3"]["bucket"]["ownerIdentity"]) @property def arn(self) -> str: - return self["arn"] + return self["s3"]["bucket"]["arn"] class S3Object(dict): @property def key(self) -> str: """Object key""" - return self["key"] + return self["s3"]["object"]["key"] @property def size(self) -> int: """Object byte size""" - return int(self["size"]) + return int(self["s3"]["object"]["size"]) @property def etag(self) -> str: """object eTag""" - return self["eTag"] + return self["s3"]["object"]["eTag"] @property def version_id(self) -> Optional[str]: """Object version if bucket is versioning-enabled, otherwise null""" - return self.get("versionId") + return self["s3"]["object"].get("versionId") @property def sequencer(self) -> str: """A string representation of a hexadecimal value used to determine event sequence, only used with PUTs and DELETEs """ - return self["sequencer"] + return self["s3"]["object"]["sequencer"] class S3Message(dict): @property def s3_schema_version(self) -> str: - return self["s3SchemaVersion"] + return self["s3"]["s3SchemaVersion"] @property def configuration_id(self) -> str: """ID found in the bucket notification configuration""" - return self["configurationId"] + return self["s3"]["configurationId"] @property def bucket(self) -> S3Bucket: - return S3Bucket(self["bucket"]) + return S3Bucket(self) @property - def s3_object(self) -> S3Object: + def get_object(self) -> S3Object: """Get the `object` property as an S3Object""" - return S3Object(self["object"]) + # Note: this name conflicts with existing python builtins + return S3Object(self) class S3EventRecordGlacierRestoreEventData(dict): @@ -125,7 +126,7 @@ def user_identity(self) -> S3Identity: @property def request_parameters(self) -> S3RequestParameters: - return S3RequestParameters(self["requestParameters"]) + return S3RequestParameters(self) @property def response_elements(self) -> Dict[str, str]: @@ -139,7 +140,7 @@ def response_elements(self) -> Dict[str, str]: @property def s3(self) -> S3Message: - return S3Message(self["s3"]) + return S3Message(self) @property def glacier_event_data(self) -> Optional[S3EventRecordGlacierEventData]: diff --git a/aws_lambda_powertools/utilities/trigger/ses_event.py b/aws_lambda_powertools/utilities/trigger/ses_event.py index b798495e379..2052f62b4b1 100644 --- a/aws_lambda_powertools/utilities/trigger/ses_event.py +++ b/aws_lambda_powertools/utilities/trigger/ses_event.py @@ -18,9 +18,9 @@ def return_path(self) -> str: return self["returnPath"] @property - def from_header(self) -> List[str]: + def get_from(self) -> List[str]: """The values in the From header of the email.""" - # Note: this conflicts with existing python builtins + # Note: this name conflicts with existing python builtins return self["from"] @property @@ -97,12 +97,12 @@ def status(self) -> str: class SESReceiptAction(dict): @property - def action_type(self) -> str: + def get_type(self) -> str: """String that indicates the type of action that was executed. Possible values are S3, SNS, Bounce, Lambda, Stop, and WorkMail """ - # Note: this conflicts with existing python builtins + # Note: this name conflicts with existing python builtins return self["type"] @property diff --git a/aws_lambda_powertools/utilities/trigger/sns_event.py b/aws_lambda_powertools/utilities/trigger/sns_event.py index bbf3078587a..d797e8911a5 100644 --- a/aws_lambda_powertools/utilities/trigger/sns_event.py +++ b/aws_lambda_powertools/utilities/trigger/sns_event.py @@ -3,8 +3,9 @@ class SNSMessageAttribute(dict): @property - def attribute_type(self) -> str: + def get_type(self) -> str: """Get the `type` property""" + # Note: this name conflicts with existing python builtins return self["Type"] @property @@ -42,8 +43,9 @@ def message_attributes(self) -> Dict[str, SNSMessageAttribute]: return {k: SNSMessageAttribute(v) for (k, v) in self["MessageAttributes"].items()} @property - def message_type(self) -> str: + def get_type(self) -> str: """Get the `type` property""" + # Note: this name conflicts with existing python builtins return self["Type"] @property diff --git a/tests/functional/test_lambda_trigger_events.py b/tests/functional/test_lambda_trigger_events.py index 9f2b9b3a5c3..67a34999170 100644 --- a/tests/functional/test_lambda_trigger_events.py +++ b/tests/functional/test_lambda_trigger_events.py @@ -38,7 +38,7 @@ def test_cloud_watch_trigger_event(): assert decoded_data.subscription_filters == ["testFilter"] assert decoded_data.message_type == "DATA_MESSAGE" - assert log_event.log_event_id == "eventId1" + assert log_event.get_id == "eventId1" assert log_event.timestamp == 1440442987000 assert log_event.message == "[ERROR] First test message" assert log_event.extracted_fields is None @@ -136,7 +136,7 @@ def test_dynamo_attribute_value_map_value(): def test_event_bridge_event(): event = EventBridgeEvent(load_event("eventBridgeEvent.json")) - assert event.event_id == event["id"] + assert event.get_id == event["id"] assert event.version == event["version"] assert event.account == event["account"] assert event.time == event["time"] @@ -169,11 +169,11 @@ def test_s3_trigger_event(): assert bucket.name == "lambda-artifacts-deafc19498e3f2df" assert bucket.owner_identity.principal_id == "A3I5XTEXAMAI3E" assert bucket.arn == "arn:aws:s3:::lambda-artifacts-deafc19498e3f2df" - assert s3.s3_object.key == "b21b84d653bb07b05b1e6b33684dc11b" - assert s3.s3_object.size == 1305107 - assert s3.s3_object.etag == "b21b84d653bb07b05b1e6b33684dc11b" - assert s3.s3_object.version_id is None - assert s3.s3_object.sequencer == "0C0F6F405D6ED209E1" + assert s3.get_object.key == "b21b84d653bb07b05b1e6b33684dc11b" + assert s3.get_object.size == 1305107 + assert s3.get_object.etag == "b21b84d653bb07b05b1e6b33684dc11b" + assert s3.get_object.version_id is None + assert s3.get_object.sequencer == "0C0F6F405D6ED209E1" assert record.glacier_event_data is None @@ -218,7 +218,7 @@ def test_ses_trigger_event(): assert headers[0].value == "" common_headers = mail.common_headers assert common_headers.return_path == "janedoe@example.com" - assert common_headers.from_header == common_headers["from"] + assert common_headers.get_from == common_headers["from"] assert common_headers.date == "Wed, 7 Oct 2015 12:34:56 -0700" assert common_headers.to == [expected_address] assert common_headers.message_id == "<0123456789example.com>" @@ -232,7 +232,7 @@ def test_ses_trigger_event(): assert receipt.spf_verdict.status == "PASS" assert receipt.dmarc_verdict.status == "PASS" action = receipt.action - assert action.action_type == action["type"] + assert action.get_type == action["type"] assert action.function_arn == action["functionArn"] assert action.invocation_type == action["invocationType"] @@ -254,9 +254,9 @@ def test_sns_trigger_event(): assert sns.message == "Hello from SNS!" message_attributes = sns.message_attributes test_message_attribute = message_attributes["Test"] - assert test_message_attribute.attribute_type == "String" + assert test_message_attribute.get_type == "String" assert test_message_attribute.value == "TestString" - assert sns.message_type == "Notification" + assert sns.get_type == "Notification" assert sns.unsubscribe_url == "https://sns.us-east-2.amazonaws.com/?Action=Unsubscri ..." assert sns.topic_arn == "arn:aws:sns:us-east-2:123456789012:sns-lambda" assert sns.subject == "TestInvoke" From 2aea86cdfbbee9ed91a4d52c5597c948664af47c Mon Sep 17 00:00:00 2001 From: Michael Brewer Date: Tue, 8 Sep 2020 11:18:24 -0700 Subject: [PATCH 09/30] refactor: less copies passed around --- .../trigger/cloud_watch_logs_event.py | 4 ++-- .../utilities/trigger/ses_event.py | 6 ++--- .../utilities/trigger/sns_event.py | 24 +++++++++---------- .../utilities/trigger/sqs_event.py | 16 ++++++------- 4 files changed, 25 insertions(+), 25 deletions(-) diff --git a/aws_lambda_powertools/utilities/trigger/cloud_watch_logs_event.py b/aws_lambda_powertools/utilities/trigger/cloud_watch_logs_event.py index 10dc0332e18..a31dd4b790a 100644 --- a/aws_lambda_powertools/utilities/trigger/cloud_watch_logs_event.py +++ b/aws_lambda_powertools/utilities/trigger/cloud_watch_logs_event.py @@ -70,7 +70,7 @@ class CloudWatchLogsEventData(dict): @property def data(self) -> str: """The value of the `data` field is a Base64 encoded ZIP archive.""" - return self["data"] + return self["awslogs"]["data"] class CloudWatchLogsEvent(dict): @@ -81,7 +81,7 @@ class CloudWatchLogsEvent(dict): @property def aws_logs(self) -> CloudWatchLogsEventData: - return CloudWatchLogsEventData(self["awslogs"]) + return CloudWatchLogsEventData(self) def decode_cloud_watch_logs_data(self) -> CloudWatchLogsDecodedData: """Gzip and parse json data""" diff --git a/aws_lambda_powertools/utilities/trigger/ses_event.py b/aws_lambda_powertools/utilities/trigger/ses_event.py index 2052f62b4b1..b40032dd815 100644 --- a/aws_lambda_powertools/utilities/trigger/ses_event.py +++ b/aws_lambda_powertools/utilities/trigger/ses_event.py @@ -167,11 +167,11 @@ def action(self) -> SESReceiptAction: class SESMessage(dict): @property def mail(self) -> SESMail: - return SESMail(self["mail"]) + return SESMail(self["ses"]["mail"]) @property def receipt(self) -> SESReceipt: - return SESReceipt(self["receipt"]) + return SESReceipt(self["ses"]["receipt"]) class SESEventRecord(dict): @@ -186,7 +186,7 @@ def event_version(self) -> str: @property def ses(self) -> SESMessage: - return SESMessage(self["ses"]) + return SESMessage(self) class SESEvent(dict): diff --git a/aws_lambda_powertools/utilities/trigger/sns_event.py b/aws_lambda_powertools/utilities/trigger/sns_event.py index d797e8911a5..9f016ce0fbe 100644 --- a/aws_lambda_powertools/utilities/trigger/sns_event.py +++ b/aws_lambda_powertools/utilities/trigger/sns_event.py @@ -16,49 +16,49 @@ def value(self) -> str: class SNSMessage(dict): @property def signature_version(self) -> str: - return self["SignatureVersion"] + return self["Sns"]["SignatureVersion"] @property def timestamp(self) -> str: - return self["Timestamp"] + return self["Sns"]["Timestamp"] @property def signature(self) -> str: - return self["Signature"] + return self["Sns"]["Signature"] @property def signing_cert_url(self) -> str: - return self["SigningCertUrl"] + return self["Sns"]["SigningCertUrl"] @property def message_id(self) -> str: - return self["MessageId"] + return self["Sns"]["MessageId"] @property def message(self) -> str: - return self["Message"] + return self["Sns"]["Message"] @property def message_attributes(self) -> Dict[str, SNSMessageAttribute]: - return {k: SNSMessageAttribute(v) for (k, v) in self["MessageAttributes"].items()} + return {k: SNSMessageAttribute(v) for (k, v) in self["Sns"]["MessageAttributes"].items()} @property def get_type(self) -> str: """Get the `type` property""" # Note: this name conflicts with existing python builtins - return self["Type"] + return self["Sns"]["Type"] @property def unsubscribe_url(self) -> str: - return self["UnsubscribeUrl"] + return self["Sns"]["UnsubscribeUrl"] @property def topic_arn(self) -> str: - return self["TopicArn"] + return self["Sns"]["TopicArn"] @property def subject(self) -> str: - return self["Subject"] + return self["Sns"]["Subject"] class SNSEventRecord(dict): @@ -76,7 +76,7 @@ def event_source(self) -> str: @property def sns(self) -> SNSMessage: - return SNSMessage(self["Sns"]) + return SNSMessage(self) class SNSEvent(dict): diff --git a/aws_lambda_powertools/utilities/trigger/sqs_event.py b/aws_lambda_powertools/utilities/trigger/sqs_event.py index f9de8ea5dbb..ad7a15a61a7 100644 --- a/aws_lambda_powertools/utilities/trigger/sqs_event.py +++ b/aws_lambda_powertools/utilities/trigger/sqs_event.py @@ -4,27 +4,27 @@ class SQSRecordAttributes(dict): @property def aws_trace_header(self) -> Optional[str]: - return self.get("AWSTraceHeader") + return self["attributes"].get("AWSTraceHeader") @property def approximate_receive_count(self) -> str: - return self["ApproximateReceiveCount"] + return self["attributes"]["ApproximateReceiveCount"] @property def sent_timestamp(self) -> str: - return self["SentTimestamp"] + return self["attributes"]["SentTimestamp"] @property def sender_id(self) -> str: - return self["SenderId"] + return self["attributes"]["SenderId"] @property def approximate_first_receive_timestamp(self) -> str: - return self["ApproximateFirstReceiveTimestamp"] + return self["attributes"]["ApproximateFirstReceiveTimestamp"] @property def sequence_number(self) -> Optional[str]: - return self.get("SequenceNumber") + return self["attributes"].get("SequenceNumber") @property def message_group_id(self) -> Optional[str]: @@ -32,7 +32,7 @@ def message_group_id(self) -> Optional[str]: @property def message_deduplication_id(self) -> Optional[str]: - return self.get("MessageDeduplicationId") + return self["attributes"].get("MessageDeduplicationId") class SQSMessageAttribute(dict): @@ -70,7 +70,7 @@ def body(self) -> str: @property def attributes(self) -> SQSRecordAttributes: - return SQSRecordAttributes(self["attributes"]) + return SQSRecordAttributes(self) @property def message_attributes(self) -> SQSMessageAttributes: From 17906ab709df481dc7a9aeb6879aa4651831a58b Mon Sep 17 00:00:00 2001 From: Michael Brewer Date: Tue, 8 Sep 2020 15:08:33 -0700 Subject: [PATCH 10/30] feat(trigger): add UserMigrationTriggerEvent Add support for UserMigrationTriggerEvent and include better docstrings --- .../utilities/trigger/__init__.py | 3 +- .../trigger/cognito_user_pool_event.py | 151 ++++++++++++++++-- .../cognitoUserMigrationTriggerEvent.json | 15 ++ .../functional/test_lambda_trigger_events.py | 24 +++ 4 files changed, 178 insertions(+), 15 deletions(-) create mode 100644 tests/events/cognitoUserMigrationTriggerEvent.json diff --git a/aws_lambda_powertools/utilities/trigger/__init__.py b/aws_lambda_powertools/utilities/trigger/__init__.py index b934b662057..837f4456d84 100644 --- a/aws_lambda_powertools/utilities/trigger/__init__.py +++ b/aws_lambda_powertools/utilities/trigger/__init__.py @@ -1,5 +1,5 @@ from .cloud_watch_logs_event import CloudWatchLogsEvent -from .cognito_user_pool_event import PostConfirmationTriggerEvent, PreSignUpTriggerEvent +from .cognito_user_pool_event import PostConfirmationTriggerEvent, PreSignUpTriggerEvent, UserMigrationTriggerEvent from .dynamo_db_stream_event import DynamoDBStreamEvent from .event_bridge_event import EventBridgeEvent from .s3_event import S3Event @@ -11,6 +11,7 @@ "CloudWatchLogsEvent", "PreSignUpTriggerEvent", "PostConfirmationTriggerEvent", + "UserMigrationTriggerEvent", "DynamoDBStreamEvent", "EventBridgeEvent", "S3Event", diff --git a/aws_lambda_powertools/utilities/trigger/cognito_user_pool_event.py b/aws_lambda_powertools/utilities/trigger/cognito_user_pool_event.py index 2a94fc91c58..6b08f08152a 100644 --- a/aws_lambda_powertools/utilities/trigger/cognito_user_pool_event.py +++ b/aws_lambda_powertools/utilities/trigger/cognito_user_pool_event.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional +from typing import Dict, List, Optional class CallerContext(dict): @@ -17,7 +17,8 @@ class BaseTriggerEvent(dict): """Common attributes shared by all User Pool Lambda Trigger Events Documentation: - https://docs.aws.amazon.com/cognito/latest/developerguide/cognito-user-identity-pools-working-with-aws-lambda-triggers.html + ------------- + https://docs.aws.amazon.com/cognito/latest/developerguide/cognito-user-identity-pools-working-with-aws-lambda-triggers.html """ @property @@ -55,59 +56,68 @@ class PreSignUpTriggerEventRequest(dict): @property def user_attributes(self) -> Dict[str, str]: """One or more name-value pairs representing user attributes. The attribute names are the keys.""" - return self["userAttributes"] + return self["request"]["userAttributes"] @property def validation_data(self) -> Optional[Dict[str, str]]: """One or more name-value pairs containing the validation data in the request to register a user.""" - return self.get("validationData") + return self["request"].get("validationData") @property def client_metadata(self) -> Optional[Dict[str, str]]: """One or more key-value pairs that you can provide as custom input to the Lambda function that you specify for the pre sign-up trigger.""" - return self.get("clientMetadata") + return self["request"].get("clientMetadata") class PreSignUpTriggerEventResponse(dict): @property def auto_confirm_user(self) -> bool: - """Set to true to auto-confirm the user, or false otherwise.""" return bool(self["response"]["autoConfirmUser"]) @auto_confirm_user.setter def auto_confirm_user(self, value: bool): + """Set to true to auto-confirm the user, or false otherwise.""" self["response"]["autoConfirmUser"] = value @property def auto_verify_email(self) -> bool: - """Set to true to set as verified the email of a user who is signing up, or false otherwise.""" return bool(self["response"]["autoVerifyEmail"]) @auto_verify_email.setter def auto_verify_email(self, value: bool): + """Set to true to set as verified the email of a user who is signing up, or false otherwise.""" self["response"]["autoVerifyEmail"] = value @property def auto_verify_phone(self) -> bool: - """Set to true to set as verified the phone number of a user who is signing up, or false otherwise.""" return bool(self["response"]["autoVerifyPhone"]) @auto_verify_phone.setter def auto_verify_phone(self, value: bool): + """Set to true to set as verified the phone number of a user who is signing up, or false otherwise.""" self["response"]["autoVerifyPhone"] = value class PreSignUpTriggerEvent(BaseTriggerEvent): """Pre Sign-up Lambda Trigger + Notes: + ---- + `triggerSource` can be one of the following: + + - `PreSignUp_SignUp` Pre sign-up. + - `PreSignUp_AdminCreateUser` Pre sign-up when an admin creates a new user. + - `PreSignUp_ExternalProvider` Pre sign-up with external provider + Documentation: - https://docs.aws.amazon.com/cognito/latest/developerguide/user-pool-lambda-pre-sign-up.html + ------------- + - https://docs.aws.amazon.com/cognito/latest/developerguide/user-pool-lambda-pre-sign-up.html """ @property def request(self) -> PreSignUpTriggerEventRequest: - return PreSignUpTriggerEventRequest(self["request"]) + return PreSignUpTriggerEventRequest(self) @property def response(self) -> PreSignUpTriggerEventResponse: @@ -118,22 +128,135 @@ class PostConfirmationTriggerEventRequest(dict): @property def user_attributes(self) -> Dict[str, str]: """One or more name-value pairs representing user attributes. The attribute names are the keys.""" - return self["userAttributes"] + return self["request"]["userAttributes"] @property def client_metadata(self) -> Optional[Dict[str, str]]: """One or more key-value pairs that you can provide as custom input to the Lambda function that you specify for the post confirmation trigger.""" - return self.get("clientMetadata") + return self["request"].get("clientMetadata") class PostConfirmationTriggerEvent(BaseTriggerEvent): """Post Confirmation Lambda Trigger + Notes: + ---- + `triggerSource` can be one of the following: + + - `PostConfirmation_ConfirmSignUp` Post sign-up confirmation. + - `PostConfirmation_ConfirmForgotPassword` Post Forgot Password confirmation. + Documentation: - https://docs.aws.amazon.com/cognito/latest/developerguide/user-pool-lambda-post-confirmation.html + ------------- + - https://docs.aws.amazon.com/cognito/latest/developerguide/user-pool-lambda-post-confirmation.html """ @property def request(self) -> PostConfirmationTriggerEventRequest: - return PostConfirmationTriggerEventRequest(self["request"]) + return PostConfirmationTriggerEventRequest(self) + + +class UserMigrationTriggerEventRequest(dict): + @property + def password(self) -> str: + return self["request"]["password"] + + @property + def validation_data(self) -> Optional[Dict[str, str]]: + """One or more name-value pairs containing the validation data in the request to register a user.""" + return self["request"].get("validationData") + + @property + def client_metadata(self) -> Optional[Dict[str, str]]: + """One or more key-value pairs that you can provide as custom input to the Lambda function + that you specify for the pre sign-up trigger.""" + return self["request"].get("clientMetadata") + + +class UserMigrationTriggerEventResponse(dict): + @property + def user_attributes(self) -> Dict[str, str]: + return self["response"]["userAttributes"] + + @user_attributes.setter + def user_attributes(self, value: Dict[str, str]): + """It must contain one or more name-value pairs representing user attributes to be stored in the + user profile in your user pool. You can include both standard and custom user attributes. + Custom attributes require the custom: prefix to distinguish them from standard attributes.""" + self["response"]["userAttributes"] = value + + @property + def final_user_status(self) -> Optional[str]: + return self["response"].get("finalUserStatus") + + @final_user_status.setter + def final_user_status(self, value: str): + """During sign-in, this attribute can be set to CONFIRMED, or not set, to auto-confirm your users and + allow them to sign-in with their previous passwords. This is the simplest experience for the user. + + If this attribute is set to RESET_REQUIRED, the user is required to change his or her password immediately + after migration at the time of sign-in, and your client app needs to handle the PasswordResetRequiredException + during the authentication flow.""" + self["response"]["finalUserStatus"] = value + + @property + def message_action(self) -> Optional[str]: + return self["response"].get("messageAction") + + @message_action.setter + def message_action(self, value: str): + """This attribute can be set to "SUPPRESS" to suppress the welcome message usually sent by + Amazon Cognito to new users. If this attribute is not returned, the welcome message will be sent.""" + self["response"]["messageAction"] = value + + @property + def desired_delivery_mediums(self) -> Optional[List[str]]: + return self["response"].get("desiredDeliveryMediums") + + @desired_delivery_mediums.setter + def desired_delivery_mediums(self, value: List[str]): + """This attribute can be set to "EMAIL" to send the welcome message by email, or "SMS" to send the + welcome message by SMS. If this attribute is not returned, the welcome message will be sent by SMS.""" + self["response"]["desiredDeliveryMediums"] = value + + @property + def force_alias_creation(self) -> Optional[bool]: + return self["response"].get("forceAliasCreation") + + @force_alias_creation.setter + def force_alias_creation(self, value: bool): + """If this parameter is set to "true" and the phone number or email address specified in the UserAttributes + parameter already exists as an alias with a different user, the API call will migrate the alias from the + previous user to the newly created user. The previous user will no longer be able to log in using that alias. + + If this attribute is set to "false" and the alias exists, the user will not be migrated, and an error is + returned to the client app. + + If this attribute is not returned, it is assumed to be "false". + """ + self["response"]["forceAliasCreation"] = value + + +class UserMigrationTriggerEvent(BaseTriggerEvent): + """Migrate User Lambda Trigger + + Notes: + ---- + `triggerSource` can be one of the following: + + - `UserMigration_Authentication` User migration at the time of sign in. + - `UserMigration_ForgotPassword` User migration during forgot-password flow. + + Documentation: + ------------- + - https://docs.aws.amazon.com/cognito/latest/developerguide/user-pool-lambda-migrate-user.html + """ + + @property + def request(self) -> UserMigrationTriggerEventRequest: + return UserMigrationTriggerEventRequest(self) + + @property + def response(self) -> UserMigrationTriggerEventResponse: + return UserMigrationTriggerEventResponse(self) diff --git a/tests/events/cognitoUserMigrationTriggerEvent.json b/tests/events/cognitoUserMigrationTriggerEvent.json new file mode 100644 index 00000000000..2eae4e66189 --- /dev/null +++ b/tests/events/cognitoUserMigrationTriggerEvent.json @@ -0,0 +1,15 @@ +{ + "version": "string", + "triggerSource": "UserMigration_Authentication", + "region": "us-east-1", + "userPoolId": "string", + "userName": "userName", + "callerContext": { + "awsSdkVersion": "awsSdkVersion", + "clientId": "clientId" + }, + "request": { + "password": "password" + }, + "response": {} +} diff --git a/tests/functional/test_lambda_trigger_events.py b/tests/functional/test_lambda_trigger_events.py index 67a34999170..432276dba5d 100644 --- a/tests/functional/test_lambda_trigger_events.py +++ b/tests/functional/test_lambda_trigger_events.py @@ -1,5 +1,6 @@ import json import os +from secrets import compare_digest from aws_lambda_powertools.utilities.trigger import ( CloudWatchLogsEvent, @@ -10,6 +11,7 @@ SESEvent, SNSEvent, SQSEvent, + UserMigrationTriggerEvent, ) from aws_lambda_powertools.utilities.trigger.dynamo_db_stream_event import ( AttributeValue, @@ -78,6 +80,28 @@ def test_cognito_post_confirmation_trigger_event(): assert event.request.client_metadata is None +def test_cognito_user_migration_trigger_event(): + event = UserMigrationTriggerEvent(load_event("cognitoUserMigrationTriggerEvent.json")) + + assert compare_digest(event.request.password, event["request"]["password"]) + assert event.request.validation_data is None + assert event.request.client_metadata is None + + event.response.user_attributes = {"username": "username"} + assert event.response.user_attributes == event["response"]["userAttributes"] + assert event.response.user_attributes == {"username": "username"} + assert event.response.final_user_status is None + assert event.response.message_action is None + assert event.response.force_alias_creation is None + + event.response.final_user_status = "CONFIRMED" + assert event.response.final_user_status == "CONFIRMED" + event.response.message_action = "SUPPRESS" + assert event.response.message_action == "SUPPRESS" + event.response.force_alias_creation = True + assert event.response.force_alias_creation is True + + def test_dynamo_db_stream_trigger_event(): event = DynamoDBStreamEvent(load_event("dynamoStreamEvent.json")) From 1c69762c883ef90064e81fd6fb09be2b840f57cf Mon Sep 17 00:00:00 2001 From: Michael Brewer Date: Tue, 8 Sep 2020 17:16:27 -0700 Subject: [PATCH 11/30] feat(trigger): cognito custom message and pre-auth Add support for CustomMessageTriggerEvent and PreAuthenticationTriggerEvent --- .../utilities/trigger/__init__.py | 10 +- .../trigger/cognito_user_pool_event.py | 125 ++++++++++++++++++ .../cognitoCustomMessageTriggerEvent.json | 20 +++ .../cognitoPreAuthenticationTriggerEvent.json | 20 +++ .../functional/test_lambda_trigger_events.py | 31 +++++ 5 files changed, 205 insertions(+), 1 deletion(-) create mode 100644 tests/events/cognitoCustomMessageTriggerEvent.json create mode 100644 tests/events/cognitoPreAuthenticationTriggerEvent.json diff --git a/aws_lambda_powertools/utilities/trigger/__init__.py b/aws_lambda_powertools/utilities/trigger/__init__.py index 837f4456d84..924097a6cfd 100644 --- a/aws_lambda_powertools/utilities/trigger/__init__.py +++ b/aws_lambda_powertools/utilities/trigger/__init__.py @@ -1,5 +1,11 @@ from .cloud_watch_logs_event import CloudWatchLogsEvent -from .cognito_user_pool_event import PostConfirmationTriggerEvent, PreSignUpTriggerEvent, UserMigrationTriggerEvent +from .cognito_user_pool_event import ( + CustomMessageTriggerEvent, + PostConfirmationTriggerEvent, + PreAuthenticationTriggerEvent, + PreSignUpTriggerEvent, + UserMigrationTriggerEvent, +) from .dynamo_db_stream_event import DynamoDBStreamEvent from .event_bridge_event import EventBridgeEvent from .s3_event import S3Event @@ -12,6 +18,8 @@ "PreSignUpTriggerEvent", "PostConfirmationTriggerEvent", "UserMigrationTriggerEvent", + "CustomMessageTriggerEvent", + "PreAuthenticationTriggerEvent", "DynamoDBStreamEvent", "EventBridgeEvent", "S3Event", diff --git a/aws_lambda_powertools/utilities/trigger/cognito_user_pool_event.py b/aws_lambda_powertools/utilities/trigger/cognito_user_pool_event.py index 6b08f08152a..91c6fc19127 100644 --- a/aws_lambda_powertools/utilities/trigger/cognito_user_pool_event.py +++ b/aws_lambda_powertools/utilities/trigger/cognito_user_pool_event.py @@ -260,3 +260,128 @@ def request(self) -> UserMigrationTriggerEventRequest: @property def response(self) -> UserMigrationTriggerEventResponse: return UserMigrationTriggerEventResponse(self) + + +class CustomMessageTriggerEventRequest(dict): + @property + def code_parameter(self) -> str: + """A string for you to use as the placeholder for the verification code in the custom message.""" + return self["request"]["codeParameter"] + + @property + def username_parameter(self) -> str: + """The username parameter. It is a required request parameter for the admin create user flow.""" + return self["request"]["usernameParameter"] + + @property + def user_attributes(self) -> Dict[str, str]: + """One or more name-value pairs representing user attributes. The attribute names are the keys.""" + return self["request"]["userAttributes"] + + @property + def client_metadata(self) -> Optional[Dict[str, str]]: + """One or more key-value pairs that you can provide as custom input to the Lambda function + that you specify for the pre sign-up trigger.""" + return self["request"].get("clientMetadata") + + +class CustomMessageTriggerEventResponse(dict): + @property + def sms_message(self) -> str: + return self["response"]["smsMessage"] + + @sms_message.setter + def sms_message(self, value: str): + """The custom SMS message to be sent to your users. + Must include the codeParameter value received in the request.""" + self["response"]["smsMessage"] = value + + @property + def email_message(self) -> str: + return self["response"]["emailMessage"] + + @email_message.setter + def email_message(self, value: str): + """The custom email message to be sent to your users. + Must include the codeParameter value received in the request.""" + self["response"]["emailMessage"] = value + + @property + def email_subject(self) -> str: + return self["response"]["emailSubject"] + + @email_subject.setter + def email_subject(self, value: str): + """The subject line for the custom message.""" + self["response"]["emailSubject"] = value + + +class CustomMessageTriggerEvent(BaseTriggerEvent): + """Custom Message Lambda Trigger + + Notes: + ---- + `triggerSource` can be one of the following: + + - `CustomMessage_SignUp` To send the confirmation code post sign-up. + - `CustomMessage_AdminCreateUser` To send the temporary password to a new user. + - `CustomMessage_ResendCode` To resend the confirmation code to an existing user. + - `CustomMessage_ForgotPassword` To send the confirmation code for Forgot Password request. + - `CustomMessage_UpdateUserAttribute` When a user's email or phone number is changed, this trigger sends a + verification code automatically to the user. Cannot be used for other attributes. + - `CustomMessage_VerifyUserAttribute` This trigger sends a verification code to the user when they manually + request it for a new email or phone number. + - `CustomMessage_Authentication` To send MFA code during authentication. + + Documentation: + -------------- + - https://docs.aws.amazon.com/cognito/latest/developerguide/user-pool-lambda-custom-message.html + """ + + @property + def request(self) -> CustomMessageTriggerEventRequest: + return CustomMessageTriggerEventRequest(self) + + @property + def response(self) -> CustomMessageTriggerEventResponse: + return CustomMessageTriggerEventResponse(self) + + +class PreAuthenticationTriggerEventRequest(dict): + @property + def user_not_found(self) -> Optional[bool]: + """This boolean is populated when PreventUserExistenceErrors is set to ENABLED for your User Pool client.""" + return self["request"].get("userNotFound") + + @property + def user_attributes(self) -> Dict[str, str]: + """One or more name-value pairs representing user attributes.""" + return self["request"]["userAttributes"] + + @property + def validation_data(self) -> Optional[Dict[str, str]]: + """One or more key-value pairs containing the validation data in the user's sign-in request.""" + return self["request"].get("validationData") + + +class PreAuthenticationTriggerEvent(BaseTriggerEvent): + """Pre Authentication Lambda Trigger + + Amazon Cognito invokes this trigger when a user attempts to sign in, allowing custom validation + to accept or deny the authentication request. + + Notes: + ---- + `triggerSource` can be one of the following: + + - `PreAuthentication_Authentication` Pre authentication. + + Documentation: + -------------- + - https://docs.aws.amazon.com/cognito/latest/developerguide/user-pool-lambda-pre-authentication.html + """ + + @property + def request(self) -> PreAuthenticationTriggerEventRequest: + """Pre Authentication Request Parameters""" + return PreAuthenticationTriggerEventRequest(self) diff --git a/tests/events/cognitoCustomMessageTriggerEvent.json b/tests/events/cognitoCustomMessageTriggerEvent.json new file mode 100644 index 00000000000..8652c3bff40 --- /dev/null +++ b/tests/events/cognitoCustomMessageTriggerEvent.json @@ -0,0 +1,20 @@ +{ + "version": "1", + "triggerSource": "CustomMessage_AdminCreateUser", + "region": "region", + "userPoolId": "userPoolId", + "userName": "userName", + "callerContext": { + "awsSdk": "awsSdkVersion", + "clientId": "clientId" + }, + "request": { + "userAttributes": { + "phone_number_verified": false, + "email_verified": true + }, + "codeParameter": "####", + "usernameParameter": "username" + }, + "response": {} +} diff --git a/tests/events/cognitoPreAuthenticationTriggerEvent.json b/tests/events/cognitoPreAuthenticationTriggerEvent.json new file mode 100644 index 00000000000..75ff9ce34b3 --- /dev/null +++ b/tests/events/cognitoPreAuthenticationTriggerEvent.json @@ -0,0 +1,20 @@ +{ + "version": "1", + "region": "us-east-1", + "userPoolId": "us-east-1_example", + "userName": "UserName", + "callerContext": { + "awsSdkVersion": "awsSdkVersion", + "clientId": "clientId" + }, + "triggerSource": "PreAuthentication_Authentication", + "request": { + "userAttributes": { + "sub": "4A709A36-7D63-4785-829D-4198EF10EBDA", + "email_verified": "true", + "name": "First Last", + "email": "test@mail.com" + } + }, + "response": {} +} diff --git a/tests/functional/test_lambda_trigger_events.py b/tests/functional/test_lambda_trigger_events.py index 432276dba5d..16a1b280475 100644 --- a/tests/functional/test_lambda_trigger_events.py +++ b/tests/functional/test_lambda_trigger_events.py @@ -4,8 +4,10 @@ from aws_lambda_powertools.utilities.trigger import ( CloudWatchLogsEvent, + CustomMessageTriggerEvent, EventBridgeEvent, PostConfirmationTriggerEvent, + PreAuthenticationTriggerEvent, PreSignUpTriggerEvent, S3Event, SESEvent, @@ -93,6 +95,7 @@ def test_cognito_user_migration_trigger_event(): assert event.response.final_user_status is None assert event.response.message_action is None assert event.response.force_alias_creation is None + assert event.response.desired_delivery_mediums is None event.response.final_user_status = "CONFIRMED" assert event.response.final_user_status == "CONFIRMED" @@ -100,6 +103,34 @@ def test_cognito_user_migration_trigger_event(): assert event.response.message_action == "SUPPRESS" event.response.force_alias_creation = True assert event.response.force_alias_creation is True + event.response.desired_delivery_mediums = ["EMAIL"] + assert event.response.desired_delivery_mediums == ["EMAIL"] + + +def test_cognito_custom_message_trigger_event(): + event = CustomMessageTriggerEvent(load_event("cognitoCustomMessageTriggerEvent.json")) + + assert event.request.code_parameter == "####" + assert event.request.username_parameter == "username" + assert event.request.user_attributes["phone_number_verified"] is False + assert event.request.client_metadata is None + + event.response.sms_message = "sms" + assert event.response.sms_message == event["response"]["smsMessage"] + event.response.email_message = "email" + assert event.response.email_message == event["response"]["emailMessage"] + event.response.email_subject = "subject" + assert event.response.email_subject == event["response"]["emailSubject"] + + +def test_cognito_pre_authentication_trigger_event(): + event = PreAuthenticationTriggerEvent(load_event("cognitoPreAuthenticationTriggerEvent.json")) + + assert event.request.user_not_found is None + event["request"]["userNotFound"] = True + assert event.request.user_not_found is True + assert event.request.user_attributes["email"] == "test@mail.com" + assert event.request.validation_data is None def test_dynamo_db_stream_trigger_event(): From b9cdd6647b6767ddf93fed6be3186b89d945626e Mon Sep 17 00:00:00 2001 From: Michael Brewer Date: Tue, 8 Sep 2020 20:53:41 -0700 Subject: [PATCH 12/30] feat(trigger): cognito pre token and post auth Add support for PreTokenGenerationTriggerEvent and PostAuthenticationTriggerEvent --- .../utilities/trigger/__init__.py | 12 -- .../trigger/cognito_user_pool_event.py | 175 +++++++++++++++++- ...cognitoPostAuthenticationTriggerEvent.json | 18 ++ ...cognitoPreTokenGenerationTriggerEvent.json | 25 +++ .../functional/test_lambda_trigger_events.py | 58 +++++- 5 files changed, 269 insertions(+), 19 deletions(-) create mode 100644 tests/events/cognitoPostAuthenticationTriggerEvent.json create mode 100644 tests/events/cognitoPreTokenGenerationTriggerEvent.json diff --git a/aws_lambda_powertools/utilities/trigger/__init__.py b/aws_lambda_powertools/utilities/trigger/__init__.py index 924097a6cfd..1499b4a5367 100644 --- a/aws_lambda_powertools/utilities/trigger/__init__.py +++ b/aws_lambda_powertools/utilities/trigger/__init__.py @@ -1,11 +1,4 @@ from .cloud_watch_logs_event import CloudWatchLogsEvent -from .cognito_user_pool_event import ( - CustomMessageTriggerEvent, - PostConfirmationTriggerEvent, - PreAuthenticationTriggerEvent, - PreSignUpTriggerEvent, - UserMigrationTriggerEvent, -) from .dynamo_db_stream_event import DynamoDBStreamEvent from .event_bridge_event import EventBridgeEvent from .s3_event import S3Event @@ -15,11 +8,6 @@ __all__ = [ "CloudWatchLogsEvent", - "PreSignUpTriggerEvent", - "PostConfirmationTriggerEvent", - "UserMigrationTriggerEvent", - "CustomMessageTriggerEvent", - "PreAuthenticationTriggerEvent", "DynamoDBStreamEvent", "EventBridgeEvent", "S3Event", diff --git a/aws_lambda_powertools/utilities/trigger/cognito_user_pool_event.py b/aws_lambda_powertools/utilities/trigger/cognito_user_pool_event.py index 91c6fc19127..024a5367e58 100644 --- a/aws_lambda_powertools/utilities/trigger/cognito_user_pool_event.py +++ b/aws_lambda_powertools/utilities/trigger/cognito_user_pool_event.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional class CallerContext(dict): @@ -132,8 +132,8 @@ def user_attributes(self) -> Dict[str, str]: @property def client_metadata(self) -> Optional[Dict[str, str]]: - """One or more key-value pairs that you can provide as custom input to the Lambda function that you - specify for the post confirmation trigger.""" + """One or more key-value pairs that you can provide as custom input to the Lambda function + that you specify for the post confirmation trigger.""" return self["request"].get("clientMetadata") @@ -385,3 +385,172 @@ class PreAuthenticationTriggerEvent(BaseTriggerEvent): def request(self) -> PreAuthenticationTriggerEventRequest: """Pre Authentication Request Parameters""" return PreAuthenticationTriggerEventRequest(self) + + +class PostAuthenticationTriggerEventRequest(dict): + @property + def new_device_used(self) -> bool: + """This flag indicates if the user has signed in on a new device. + It is set only if the remembered devices value of the user pool is set to `Always` or User `Opt-In`.""" + return self["request"]["newDeviceUsed"] + + @property + def user_attributes(self) -> Dict[str, str]: + """One or more name-value pairs representing user attributes.""" + return self["request"]["userAttributes"] + + @property + def client_metadata(self) -> Optional[Dict[str, str]]: + """One or more key-value pairs that you can provide as custom input to the Lambda function + that you specify for the post authentication trigger.""" + return self["request"].get("clientMetadata") + + +class PostAuthenticationTriggerEvent(BaseTriggerEvent): + """Post Authentication Lambda Trigger + + Amazon Cognito invokes this trigger after signing in a user, allowing you to add custom logic after authentication. + + Notes: + ---- + `triggerSource` can be one of the following: + + - `PostAuthentication_Authentication` Post authentication. + + Documentation: + -------------- + - https://docs.aws.amazon.com/cognito/latest/developerguide/user-pool-lambda-post-authentication.html + """ + + @property + def request(self) -> PostAuthenticationTriggerEventRequest: + """Post Authentication Request Parameters""" + return PostAuthenticationTriggerEventRequest(self) + + +class GroupOverrideDetails(dict): + @property + def groups_to_override(self) -> Optional[List[str]]: + """A list of the group names that are associated with the user that the identity token is issued for.""" + return self.get("groupsToOverride") + + @property + def iam_roles_to_override(self) -> Optional[List[str]]: + """A list of the current IAM roles associated with these groups.""" + return self.get("iamRolesToOverride") + + @property + def preferred_role(self) -> Optional[str]: + """A string indicating the preferred IAM role.""" + return self.get("preferredRole") + + +class PreTokenGenerationTriggerEventRequest(dict): + @property + def group_configuration(self) -> GroupOverrideDetails: + """The input object containing the current group configuration""" + return GroupOverrideDetails(self["request"]["groupConfiguration"]) + + @property + def user_attributes(self) -> Dict[str, str]: + """One or more name-value pairs representing user attributes.""" + return self["request"]["userAttributes"] + + @property + def client_metadata(self) -> Optional[Dict[str, str]]: + """One or more key-value pairs that you can provide as custom input to the Lambda function + that you specify for the pre token generation trigger.""" + return self["request"].get("clientMetadata") + + +class ClaimsOverrideDetails(dict): + @property + def claims_to_add_or_override(self) -> Optional[Dict[str, str]]: + return self["response"]["claimsOverrideDetails"].get("claimsToAddOrOverride") + + @property + def claims_to_suppress(self) -> Optional[List[str]]: + return self["response"]["claimsOverrideDetails"].get("claimsToSuppress") + + @property + def group_configuration(self) -> Optional[GroupOverrideDetails]: + group_override_details = self["response"]["claimsOverrideDetails"].get("groupOverrideDetails") + return None if group_override_details is None else GroupOverrideDetails(group_override_details) + + @claims_to_add_or_override.setter + def claims_to_add_or_override(self, value: Dict[str, str]): + """A map of one or more key-value pairs of claims to add or override. + For group related claims, use groupOverrideDetails instead.""" + self["response"]["claimsOverrideDetails"]["claimsToAddOrOverride"] = value + + @claims_to_suppress.setter + def claims_to_suppress(self, value: List[str]): + """A list that contains claims to be suppressed from the identity token.""" + self["response"]["claimsOverrideDetails"]["claimsToSuppress"] = value + + @group_configuration.setter + def group_configuration(self, value: Dict[str, Any]): + """The output object containing the current group configuration. + + It includes groupsToOverride, iamRolesToOverride, and preferredRole. + + The groupOverrideDetails object is replaced with the one you provide. If you provide an empty or null + object in the response, then the groups are suppressed. To leave the existing group configuration + as is, copy the value of the request's groupConfiguration object to the groupOverrideDetails object + in the response, and pass it back to the service. + """ + self["response"]["claimsOverrideDetails"]["groupOverrideDetails"] = value + + def set_group_configuration_groups_to_override(self, value: List[str]): + """A list of the group names that are associated with the user that the identity token is issued for.""" + self["response"]["claimsOverrideDetails"].setdefault("groupOverrideDetails", {}) + self["response"]["claimsOverrideDetails"]["groupOverrideDetails"]["groupsToOverride"] = value + + def set_group_configuration_iam_roles_to_override(self, value: List[str]): + """A list of the current IAM roles associated with these groups.""" + self["response"]["claimsOverrideDetails"].setdefault("groupOverrideDetails", {}) + self["response"]["claimsOverrideDetails"]["groupOverrideDetails"]["iamRolesToOverride"] = value + + def set_group_configuration_preferred_role(self, value: str): + """A string indicating the preferred IAM role.""" + self["response"]["claimsOverrideDetails"].setdefault("groupOverrideDetails", {}) + self["response"]["claimsOverrideDetails"]["groupOverrideDetails"]["preferredRole"] = value + + +class PreTokenGenerationTriggerEventResponse(dict): + @property + def claims_override_details(self) -> ClaimsOverrideDetails: + self["response"].setdefault("claimsOverrideDetails", {}) + return ClaimsOverrideDetails(self) + + +class PreTokenGenerationTriggerEvent(BaseTriggerEvent): + """Pre Token Generation Lambda Trigger + + Amazon Cognito invokes this trigger before token generation allowing you to customize identity token claims. + + Notes: + ---- + `triggerSource` can be one of the following: + + - `TokenGeneration_HostedAuth` Called during authentication from the Amazon Cognito hosted UI sign-in page. + - `TokenGeneration_Authentication` Called after user authentication flows have completed. + - `TokenGeneration_NewPasswordChallenge` Called after the user is created by an admin. This flow is invoked + when the user has to change a temporary password. + - `TokenGeneration_AuthenticateDevice` Called at the end of the authentication of a user device. + - `TokenGeneration_RefreshTokens` Called when a user tries to refresh the identity and access tokens. + + Documentation: + -------------- + - https://docs.aws.amazon.com/cognito/latest/developerguide/user-pool-lambda-pre-token-generation.html + """ + + @property + def request(self) -> PreTokenGenerationTriggerEventRequest: + """Pre Token Generation Request Parameters""" + return PreTokenGenerationTriggerEventRequest(self) + + @property + def response(self) -> PreTokenGenerationTriggerEventResponse: + """Pre Token Generation Response Parameters""" + return PreTokenGenerationTriggerEventResponse(self) diff --git a/tests/events/cognitoPostAuthenticationTriggerEvent.json b/tests/events/cognitoPostAuthenticationTriggerEvent.json new file mode 100644 index 00000000000..3b1faa81bf9 --- /dev/null +++ b/tests/events/cognitoPostAuthenticationTriggerEvent.json @@ -0,0 +1,18 @@ +{ + "version": "1", + "region": "us-east-1", + "userPoolId": "us-east-1_example", + "userName": "UserName", + "callerContext": { + "awsSdkVersion": "awsSdkVersion", + "clientId": "clientId" + }, + "triggerSource": "PostAuthentication_Authentication", + "request": { + "newDeviceUsed": true, + "userAttributes": { + "email": "test@mail.com" + } + }, + "response": {} +} diff --git a/tests/events/cognitoPreTokenGenerationTriggerEvent.json b/tests/events/cognitoPreTokenGenerationTriggerEvent.json new file mode 100644 index 00000000000..f5ee69e0d2d --- /dev/null +++ b/tests/events/cognitoPreTokenGenerationTriggerEvent.json @@ -0,0 +1,25 @@ +{ + "version": "1", + "triggerSource": "TokenGeneration_Authentication", + "region": "us-west-2", + "userPoolId": "us-west-2_example", + "userName": "testqq", + "callerContext": { + "awsSdkVersion": "aws-sdk-unknown-unknown", + "clientId": "71ghuul37mresr7h373b704tua" + }, + "request": { + "userAttributes": { + "sub": "0b0a57c5-f013-426a-81a1-f8ffbfba21f0", + "email_verified": "true", + "cognito:user_status": "CONFIRMED", + "email": "test@mail.com" + }, + "groupConfiguration": { + "groupsToOverride": [], + "iamRolesToOverride": [], + "preferredRole": null + } + }, + "response": {} +} diff --git a/tests/functional/test_lambda_trigger_events.py b/tests/functional/test_lambda_trigger_events.py index 16a1b280475..a7f61970d25 100644 --- a/tests/functional/test_lambda_trigger_events.py +++ b/tests/functional/test_lambda_trigger_events.py @@ -4,15 +4,19 @@ from aws_lambda_powertools.utilities.trigger import ( CloudWatchLogsEvent, - CustomMessageTriggerEvent, EventBridgeEvent, - PostConfirmationTriggerEvent, - PreAuthenticationTriggerEvent, - PreSignUpTriggerEvent, S3Event, SESEvent, SNSEvent, SQSEvent, +) +from aws_lambda_powertools.utilities.trigger.cognito_user_pool_event import ( + CustomMessageTriggerEvent, + PostAuthenticationTriggerEvent, + PostConfirmationTriggerEvent, + PreAuthenticationTriggerEvent, + PreSignUpTriggerEvent, + PreTokenGenerationTriggerEvent, UserMigrationTriggerEvent, ) from aws_lambda_powertools.utilities.trigger.dynamo_db_stream_event import ( @@ -133,6 +137,52 @@ def test_cognito_pre_authentication_trigger_event(): assert event.request.validation_data is None +def test_cognito_post_authentication_trigger_event(): + event = PostAuthenticationTriggerEvent(load_event("cognitoPostAuthenticationTriggerEvent.json")) + + assert event.request.new_device_used is True + assert event.request.user_attributes["email"] == "test@mail.com" + assert event.request.client_metadata is None + + +def test_cognito_pre_token_generation_trigger_event(): + event = PreTokenGenerationTriggerEvent(load_event("cognitoPreTokenGenerationTriggerEvent.json")) + + group_configuration = event.request.group_configuration + assert group_configuration.groups_to_override == [] + assert group_configuration.iam_roles_to_override == [] + assert group_configuration.preferred_role is None + assert event.request.user_attributes["email"] == "test@mail.com" + assert event.request.client_metadata is None + + event["request"]["groupConfiguration"]["preferredRole"] = "temp" + group_configuration = event.request.group_configuration + assert group_configuration.preferred_role == "temp" + + claims_override_details = event.response.claims_override_details + assert claims_override_details.claims_to_add_or_override is None + assert claims_override_details.claims_to_suppress is None + assert claims_override_details.group_configuration is None + + claims_override_details.claims_to_add_or_override = {"test": "value"} + assert claims_override_details.claims_to_add_or_override["test"] == "value" + + claims_override_details.claims_to_suppress = ["email"] + assert claims_override_details.claims_to_suppress[0] == "email" + + claims_override_details.set_group_configuration_groups_to_override(["group-A", "group-B"]) + assert claims_override_details.group_configuration.groups_to_override == ["group-A", "group-B"] + + claims_override_details.set_group_configuration_iam_roles_to_override(["role"]) + assert claims_override_details.group_configuration.iam_roles_to_override == ["role"] + + claims_override_details.set_group_configuration_preferred_role("role_name") + assert claims_override_details.group_configuration.preferred_role == "role_name" + + claims_override_details.group_configuration = {} + assert claims_override_details.group_configuration == {} + + def test_dynamo_db_stream_trigger_event(): event = DynamoDBStreamEvent(load_event("dynamoStreamEvent.json")) From 87277bad078821cfbfc5700b90f0be29fd0b0104 Mon Sep 17 00:00:00 2001 From: Michael Brewer Date: Tue, 8 Sep 2020 23:39:01 -0700 Subject: [PATCH 13/30] tests(trigger): some extra checks --- .../trigger/cognito_user_pool_event.py | 50 +++++++++++-------- .../functional/test_lambda_trigger_events.py | 22 +++++--- 2 files changed, 46 insertions(+), 26 deletions(-) diff --git a/aws_lambda_powertools/utilities/trigger/cognito_user_pool_event.py b/aws_lambda_powertools/utilities/trigger/cognito_user_pool_event.py index 024a5367e58..249bdce25ce 100644 --- a/aws_lambda_powertools/utilities/trigger/cognito_user_pool_event.py +++ b/aws_lambda_powertools/utilities/trigger/cognito_user_pool_event.py @@ -1,16 +1,19 @@ from typing import Any, Dict, List, Optional -class CallerContext(dict): +class CallerContext: + def __init__(self, event: dict): + self._event = event + @property def aws_sdk_version(self) -> str: """The AWS SDK version number.""" - return self["awsSdkVersion"] + return self._event["callerContext"]["awsSdkVersion"] @property def client_id(self) -> str: """The ID of the client associated with the user pool.""" - return self["clientId"] + return self._event["callerContext"]["clientId"] class BaseTriggerEvent(dict): @@ -49,7 +52,7 @@ def user_name(self) -> str: @property def caller_context(self) -> CallerContext: """The caller context""" - return CallerContext(self["callerContext"]) + return CallerContext(self) class PreSignUpTriggerEventRequest(dict): @@ -463,30 +466,33 @@ def client_metadata(self) -> Optional[Dict[str, str]]: return self["request"].get("clientMetadata") -class ClaimsOverrideDetails(dict): +class ClaimsOverrideDetails: + def __init__(self, event: dict): + self._claims_override_details = event["response"]["claimsOverrideDetails"] + @property def claims_to_add_or_override(self) -> Optional[Dict[str, str]]: - return self["response"]["claimsOverrideDetails"].get("claimsToAddOrOverride") + return self._claims_override_details.get("claimsToAddOrOverride") @property def claims_to_suppress(self) -> Optional[List[str]]: - return self["response"]["claimsOverrideDetails"].get("claimsToSuppress") + return self._claims_override_details.get("claimsToSuppress") @property def group_configuration(self) -> Optional[GroupOverrideDetails]: - group_override_details = self["response"]["claimsOverrideDetails"].get("groupOverrideDetails") + group_override_details = self._claims_override_details.get("groupOverrideDetails") return None if group_override_details is None else GroupOverrideDetails(group_override_details) @claims_to_add_or_override.setter def claims_to_add_or_override(self, value: Dict[str, str]): """A map of one or more key-value pairs of claims to add or override. For group related claims, use groupOverrideDetails instead.""" - self["response"]["claimsOverrideDetails"]["claimsToAddOrOverride"] = value + self._claims_override_details["claimsToAddOrOverride"] = value @claims_to_suppress.setter def claims_to_suppress(self, value: List[str]): """A list that contains claims to be suppressed from the identity token.""" - self["response"]["claimsOverrideDetails"]["claimsToSuppress"] = value + self._claims_override_details["claimsToSuppress"] = value @group_configuration.setter def group_configuration(self, value: Dict[str, Any]): @@ -499,29 +505,33 @@ def group_configuration(self, value: Dict[str, Any]): as is, copy the value of the request's groupConfiguration object to the groupOverrideDetails object in the response, and pass it back to the service. """ - self["response"]["claimsOverrideDetails"]["groupOverrideDetails"] = value + self._claims_override_details["groupOverrideDetails"] = value def set_group_configuration_groups_to_override(self, value: List[str]): """A list of the group names that are associated with the user that the identity token is issued for.""" - self["response"]["claimsOverrideDetails"].setdefault("groupOverrideDetails", {}) - self["response"]["claimsOverrideDetails"]["groupOverrideDetails"]["groupsToOverride"] = value + self._claims_override_details.setdefault("groupOverrideDetails", {}) + self._claims_override_details["groupOverrideDetails"]["groupsToOverride"] = value def set_group_configuration_iam_roles_to_override(self, value: List[str]): """A list of the current IAM roles associated with these groups.""" - self["response"]["claimsOverrideDetails"].setdefault("groupOverrideDetails", {}) - self["response"]["claimsOverrideDetails"]["groupOverrideDetails"]["iamRolesToOverride"] = value + self._claims_override_details.setdefault("groupOverrideDetails", {}) + self._claims_override_details["groupOverrideDetails"]["iamRolesToOverride"] = value def set_group_configuration_preferred_role(self, value: str): """A string indicating the preferred IAM role.""" - self["response"]["claimsOverrideDetails"].setdefault("groupOverrideDetails", {}) - self["response"]["claimsOverrideDetails"]["groupOverrideDetails"]["preferredRole"] = value + self._claims_override_details.setdefault("groupOverrideDetails", {}) + self._claims_override_details["groupOverrideDetails"]["preferredRole"] = value + +class PreTokenGenerationTriggerEventResponse: + def __init__(self, event: dict): + self._event = event -class PreTokenGenerationTriggerEventResponse(dict): @property def claims_override_details(self) -> ClaimsOverrideDetails: - self["response"].setdefault("claimsOverrideDetails", {}) - return ClaimsOverrideDetails(self) + # Ensure we have a `claimsOverrideDetails` element + self._event["response"].setdefault("claimsOverrideDetails", {}) + return ClaimsOverrideDetails(self._event) class PreTokenGenerationTriggerEvent(BaseTriggerEvent): diff --git a/tests/functional/test_lambda_trigger_events.py b/tests/functional/test_lambda_trigger_events.py index a7f61970d25..30adac3d537 100644 --- a/tests/functional/test_lambda_trigger_events.py +++ b/tests/functional/test_lambda_trigger_events.py @@ -159,28 +159,38 @@ def test_cognito_pre_token_generation_trigger_event(): group_configuration = event.request.group_configuration assert group_configuration.preferred_role == "temp" + assert event["response"].get("claimsOverrideDetails") is None claims_override_details = event.response.claims_override_details + assert event["response"]["claimsOverrideDetails"] == {} + assert claims_override_details.claims_to_add_or_override is None assert claims_override_details.claims_to_suppress is None assert claims_override_details.group_configuration is None - claims_override_details.claims_to_add_or_override = {"test": "value"} + claims_override_details.group_configuration = {} + assert claims_override_details.group_configuration == {} + + expected_claims = {"test": "value"} + claims_override_details.claims_to_add_or_override = expected_claims assert claims_override_details.claims_to_add_or_override["test"] == "value" + assert event["response"]["claimsOverrideDetails"]["claimsToAddOrOverride"] == expected_claims claims_override_details.claims_to_suppress = ["email"] assert claims_override_details.claims_to_suppress[0] == "email" + assert event["response"]["claimsOverrideDetails"]["claimsToSuppress"] == ["email"] - claims_override_details.set_group_configuration_groups_to_override(["group-A", "group-B"]) - assert claims_override_details.group_configuration.groups_to_override == ["group-A", "group-B"] + expected_groups = ["group-A", "group-B"] + claims_override_details.set_group_configuration_groups_to_override(expected_groups) + assert claims_override_details.group_configuration.groups_to_override == expected_groups + assert event["response"]["claimsOverrideDetails"]["groupOverrideDetails"]["groupsToOverride"] == expected_groups claims_override_details.set_group_configuration_iam_roles_to_override(["role"]) assert claims_override_details.group_configuration.iam_roles_to_override == ["role"] + assert event["response"]["claimsOverrideDetails"]["groupOverrideDetails"]["iamRolesToOverride"] == ["role"] claims_override_details.set_group_configuration_preferred_role("role_name") assert claims_override_details.group_configuration.preferred_role == "role_name" - - claims_override_details.group_configuration = {} - assert claims_override_details.group_configuration == {} + assert event["response"]["claimsOverrideDetails"]["groupOverrideDetails"]["preferredRole"] == "role_name" def test_dynamo_db_stream_trigger_event(): From 92a3ca75743d0267b8505fec74de5f7342cd9255 Mon Sep 17 00:00:00 2001 From: Michael Brewer Date: Wed, 9 Sep 2020 10:29:55 -0700 Subject: [PATCH 14/30] chore(trigger): clean up and docs Changes: * Add some missing docs pull for various AWS Docs * Fix SQS mapping * Make docs consistent --- .../trigger/cloud_watch_logs_event.py | 20 ++-- .../trigger/dynamo_db_stream_event.py | 84 +++++++++------- .../utilities/trigger/event_bridge_event.py | 3 +- .../utilities/trigger/s3_event.py | 14 +-- .../utilities/trigger/ses_event.py | 5 +- .../utilities/trigger/sns_event.py | 25 ++++- .../utilities/trigger/sqs_event.py | 99 ++++++++++++++----- .../functional/test_lambda_trigger_events.py | 3 + 8 files changed, 170 insertions(+), 83 deletions(-) diff --git a/aws_lambda_powertools/utilities/trigger/cloud_watch_logs_event.py b/aws_lambda_powertools/utilities/trigger/cloud_watch_logs_event.py index a31dd4b790a..191c02a2c19 100644 --- a/aws_lambda_powertools/utilities/trigger/cloud_watch_logs_event.py +++ b/aws_lambda_powertools/utilities/trigger/cloud_watch_logs_event.py @@ -66,25 +66,23 @@ def log_events(self) -> List[CloudWatchLogsLogEvent]: return [CloudWatchLogsLogEvent(i) for i in self["logEvents"]] -class CloudWatchLogsEventData(dict): - @property - def data(self) -> str: - """The value of the `data` field is a Base64 encoded ZIP archive.""" - return self["awslogs"]["data"] - - class CloudWatchLogsEvent(dict): """CloudWatch Logs log stream event - Documentation: https://docs.aws.amazon.com/lambda/latest/dg/services-cloudwatchlogs.html + You can use a Lambda function to monitor and analyze logs from an Amazon CloudWatch Logs log stream. + + Documentation: + -------------- + - https://docs.aws.amazon.com/lambda/latest/dg/services-cloudwatchlogs.html """ @property - def aws_logs(self) -> CloudWatchLogsEventData: - return CloudWatchLogsEventData(self) + def aws_logs_data(self) -> str: + """The value of the `data` field is a Base64 encoded ZIP archive.""" + return self["awslogs"]["data"] def decode_cloud_watch_logs_data(self) -> CloudWatchLogsDecodedData: """Gzip and parse json data""" - payload = base64.b64decode(self.aws_logs.data) + payload = base64.b64decode(self.aws_logs_data) decoded: dict = json.loads(zlib.decompress(payload, zlib.MAX_WBITS | 32).decode("UTF-8")) return CloudWatchLogsDecodedData(decoded) diff --git a/aws_lambda_powertools/utilities/trigger/dynamo_db_stream_event.py b/aws_lambda_powertools/utilities/trigger/dynamo_db_stream_event.py index 5abbc450405..a5acf585538 100644 --- a/aws_lambda_powertools/utilities/trigger/dynamo_db_stream_event.py +++ b/aws_lambda_powertools/utilities/trigger/dynamo_db_stream_event.py @@ -2,12 +2,15 @@ from typing import Dict, Iterator, List, Optional -class AttributeValue(dict): +class AttributeValue: """Represents the data for an attribute Documentation: https://docs.aws.amazon.com/amazondynamodb/latest/APIReference/API_streams_AttributeValue.html """ + def __init__(self, attr_value: dict): + self._val = attr_value + @property def b_value(self) -> Optional[str]: """An attribute of type Base64-encoded binary data object @@ -15,7 +18,7 @@ def b_value(self) -> Optional[str]: Example: >>> {"B": "dGhpcyB0ZXh0IGlzIGJhc2U2NC1lbmNvZGVk"} """ - return self.get("B") + return self._val.get("B") @property def bs_value(self) -> Optional[List[str]]: @@ -24,7 +27,7 @@ def bs_value(self) -> Optional[List[str]]: Example: >>> {"BS": ["U3Vubnk=", "UmFpbnk=", "U25vd3k="]} """ - return self.get("BS") + return self._val.get("BS") @property def bool_value(self) -> Optional[bool]: @@ -33,7 +36,7 @@ def bool_value(self) -> Optional[bool]: Example: >>> {"BOOL": True} """ - item = self.get("bool") + item = self._val.get("bool") return None if item is None else bool(item) @property @@ -43,8 +46,8 @@ def list_value(self) -> Optional[List["AttributeValue"]]: Example: >>> {"L": [ {"S": "Cookies"} , {"S": "Coffee"}, {"N": "3.14159"}]} """ - item = self.get("L") - return None if item is None else [AttributeValue(i) for i in item] + item = self._val.get("L") + return None if item is None else [AttributeValue(v) for v in item] @property def map_value(self) -> Optional[Dict[str, "AttributeValue"]]: @@ -53,7 +56,7 @@ def map_value(self) -> Optional[Dict[str, "AttributeValue"]]: Example: >>> {"M": {"Name": {"S": "Joe"}, "Age": {"N": "35"}}} """ - return _attribute_value(self, "M") + return _attribute_value_dict(self._val, "M") @property def n_value(self) -> Optional[str]: @@ -65,7 +68,7 @@ def n_value(self) -> Optional[str]: Example: >>> {"N": "123.45"} """ - return self.get("N") + return self._val.get("N") @property def ns_value(self) -> Optional[List[str]]: @@ -74,7 +77,7 @@ def ns_value(self) -> Optional[List[str]]: Example: >>> {"NS": ["42.2", "-19", "7.5", "3.14"]} """ - return self.get("NS") + return self._val.get("NS") @property def null_value(self) -> Optional[bool]: @@ -83,7 +86,7 @@ def null_value(self) -> Optional[bool]: Example: >>> {"NULL": True} """ - item = self.get("NULL") + item = self._val.get("NULL") return None if item is None else bool(item) @property @@ -93,7 +96,7 @@ def s_value(self) -> Optional[str]: Example: >>> {"S": "Hello"} """ - return self.get("S") + return self._val.get("S") @property def ss_value(self) -> Optional[List[str]]: @@ -102,12 +105,17 @@ def ss_value(self) -> Optional[List[str]]: Example: >>> {"SS": ["Giraffe", "Hippo" ,"Zebra"]} """ - return self.get("SS") + return self._val.get("SS") + +def _attribute_value_dict(attr_values: Dict[str, dict], key: str) -> Optional[Dict[str, AttributeValue]]: + """A dict of type String to AttributeValue object map -def _attribute_value(values: dict, key: str) -> Optional[Dict[str, AttributeValue]]: - item: dict = values.get(key) - return None if item is None else {k: AttributeValue(v) for k, v in item.items()} + Example: + >>> {"NewImage": {"Id": {"S": "xxx-xxx"}, "Value": {"N": "35"}}} + """ + attr_values_dict = attr_values.get(key) + return None if attr_values_dict is None else {k: AttributeValue(v) for k, v in attr_values_dict.items()} class StreamViewType(Enum): @@ -119,43 +127,46 @@ class StreamViewType(Enum): NEW_AND_OLD_IMAGES = 3 # both the new and the old item images of the item. -class StreamRecord(dict): +class StreamRecord: + def __init__(self, stream_record: dict): + self._val = stream_record + @property def approximate_creation_date_time(self) -> Optional[int]: """The approximate date and time when the stream record was created, in UNIX epoch time format.""" - item = self.get("ApproximateCreationDateTime") + item = self._val.get("ApproximateCreationDateTime") return None if item is None else int(item) @property def keys(self) -> Optional[Dict[str, AttributeValue]]: """The primary key attribute(s) for the DynamoDB item that was modified.""" - return _attribute_value(self, "Keys") + return _attribute_value_dict(self._val, "Keys") @property def new_image(self) -> Optional[Dict[str, AttributeValue]]: """The item in the DynamoDB table as it appeared after it was modified.""" - return _attribute_value(self, "NewImage") + return _attribute_value_dict(self._val, "NewImage") @property def old_image(self) -> Optional[Dict[str, AttributeValue]]: """The item in the DynamoDB table as it appeared before it was modified.""" - return _attribute_value(self, "OldImage") + return _attribute_value_dict(self._val, "OldImage") @property def sequence_number(self) -> Optional[str]: """The sequence number of the stream record.""" - return self.get("SequenceNumber") + return self._val.get("SequenceNumber") @property def size_bytes(self) -> Optional[int]: """The size of the stream record, in bytes.""" - item = self.get("SizeBytes") + item = self._val.get("SizeBytes") return None if item is None else int(item) @property def stream_view_type(self) -> Optional[StreamViewType]: """The type of data from the modified DynamoDB item that was captured in this stream record""" - item = self.get("StreamViewType") + item = self._val.get("StreamViewType") return None if item is None else StreamViewType[str(item)] @@ -165,55 +176,60 @@ class DynamoDBRecordEventName(Enum): REMOVE = 2 # the item was deleted from the table -class DynamoDBRecord(dict): +class DynamoDBRecord: """A description of a unique event within a stream""" + def __init__(self, record: dict): + self._val = record + @property def aws_region(self) -> Optional[str]: """The region in which the GetRecords request was received""" - return self.get("awsRegion") + return self._val.get("awsRegion") @property def dynamodb(self) -> Optional[StreamRecord]: """The main body of the stream record, containing all of the DynamoDB-specific fields.""" - item = self.get("dynamodb") - return None if item is None else StreamRecord(item) + stream_record = self._val.get("dynamodb") + return None if stream_record is None else StreamRecord(stream_record) @property def event_id(self) -> Optional[str]: """A globally unique identifier for the event that was recorded in this stream record.""" - return self.get("eventID") + return self._val.get("eventID") @property def event_name(self) -> Optional[DynamoDBRecordEventName]: """The type of data modification that was performed on the DynamoDB table""" - item = self.get("eventName") + item = self._val.get("eventName") return None if item is None else DynamoDBRecordEventName[item] @property def event_source(self) -> Optional[str]: """The AWS service from which the stream record originated. For DynamoDB Streams, this is aws:dynamodb.""" - return self.get("eventSource") + return self._val.get("eventSource") @property def event_source_arn(self) -> Optional[str]: - return self.get("eventSourceARN") + return self._val.get("eventSourceARN") @property def event_version(self) -> Optional[str]: """The version number of the stream record format.""" - return self.get("eventVersion") + return self._val.get("eventVersion") @property def user_identity(self) -> Optional[dict]: """Contains details about the type of identity that made the request""" - return self.get("userIdentity") + return self._val.get("userIdentity") class DynamoDBStreamEvent(dict): """Dynamo DB Stream Event - Documentation: https://docs.aws.amazon.com/lambda/latest/dg/with-ddb.html + Documentation: + ------------- + - https://docs.aws.amazon.com/lambda/latest/dg/with-ddb.html """ @property diff --git a/aws_lambda_powertools/utilities/trigger/event_bridge_event.py b/aws_lambda_powertools/utilities/trigger/event_bridge_event.py index 74180bb2328..2c5ff03dd86 100644 --- a/aws_lambda_powertools/utilities/trigger/event_bridge_event.py +++ b/aws_lambda_powertools/utilities/trigger/event_bridge_event.py @@ -5,7 +5,8 @@ class EventBridgeEvent(dict): """Amazon EventBridge Event Documentation: - https://docs.aws.amazon.com/eventbridge/latest/userguide/aws-events.html + -------------- + - https://docs.aws.amazon.com/eventbridge/latest/userguide/aws-events.html """ @property diff --git a/aws_lambda_powertools/utilities/trigger/s3_event.py b/aws_lambda_powertools/utilities/trigger/s3_event.py index 0b53e926e91..9df916c3543 100644 --- a/aws_lambda_powertools/utilities/trigger/s3_event.py +++ b/aws_lambda_powertools/utilities/trigger/s3_event.py @@ -80,18 +80,18 @@ def get_object(self) -> S3Object: class S3EventRecordGlacierRestoreEventData(dict): @property def lifecycle_restoration_expiry_time(self) -> str: - return self["lifecycleRestorationExpiryTime"] + return self["restoreEventData"]["lifecycleRestorationExpiryTime"] @property def lifecycle_restore_storage_class(self) -> str: - return self["lifecycleRestoreStorageClass"] + return self["restoreEventData"]["lifecycleRestoreStorageClass"] class S3EventRecordGlacierEventData(dict): @property def restore_event_data(self) -> S3EventRecordGlacierRestoreEventData: """The restoreEventData key contains attributes related to your restore request.""" - return S3EventRecordGlacierRestoreEventData(self["restoreEventData"]) + return S3EventRecordGlacierRestoreEventData(self) class S3EventRecord(dict): @@ -107,6 +107,7 @@ def event_source(self) -> str: @property def aws_region(self) -> str: + """aws region eg: us-east-1""" return self["awsRegion"] @property @@ -153,9 +154,10 @@ class S3Event(dict): """S3 event notification Documentation: - https://docs.aws.amazon.com/lambda/latest/dg/with-s3.html - https://docs.aws.amazon.com/AmazonS3/latest/dev/NotificationHowTo.html - https://docs.aws.amazon.com/AmazonS3/latest/dev/notification-content-structure.html + ------------- + - https://docs.aws.amazon.com/lambda/latest/dg/with-s3.html + - https://docs.aws.amazon.com/AmazonS3/latest/dev/NotificationHowTo.html + - https://docs.aws.amazon.com/AmazonS3/latest/dev/notification-content-structure.html """ @property diff --git a/aws_lambda_powertools/utilities/trigger/ses_event.py b/aws_lambda_powertools/utilities/trigger/ses_event.py index b40032dd815..b35ac108f23 100644 --- a/aws_lambda_powertools/utilities/trigger/ses_event.py +++ b/aws_lambda_powertools/utilities/trigger/ses_event.py @@ -195,8 +195,9 @@ class SESEvent(dict): NOTE: There is a 30-second timeout on RequestResponse invocations. Documentation: - https://docs.aws.amazon.com/lambda/latest/dg/services-ses.html - https://docs.aws.amazon.com/ses/latest/DeveloperGuide/receiving-email-action-lambda.html + -------------- + - https://docs.aws.amazon.com/lambda/latest/dg/services-ses.html + - https://docs.aws.amazon.com/ses/latest/DeveloperGuide/receiving-email-action-lambda.html """ @property diff --git a/aws_lambda_powertools/utilities/trigger/sns_event.py b/aws_lambda_powertools/utilities/trigger/sns_event.py index 9f016ce0fbe..9b5315e7f30 100644 --- a/aws_lambda_powertools/utilities/trigger/sns_event.py +++ b/aws_lambda_powertools/utilities/trigger/sns_event.py @@ -4,38 +4,47 @@ class SNSMessageAttribute(dict): @property def get_type(self) -> str: - """Get the `type` property""" + """The supported message attribute data types are String, String.Array, Number, and Binary.""" # Note: this name conflicts with existing python builtins return self["Type"] @property def value(self) -> str: + """The user-specified message attribute value.""" return self["Value"] class SNSMessage(dict): @property def signature_version(self) -> str: + """Version of the Amazon SNS signature used.""" return self["Sns"]["SignatureVersion"] @property def timestamp(self) -> str: + """The time (GMT) when the subscription confirmation was sent.""" return self["Sns"]["Timestamp"] @property def signature(self) -> str: + """Base64-encoded "SHA1withRSA" signature of the Message, MessageId, Type, Timestamp, and TopicArn values.""" return self["Sns"]["Signature"] @property def signing_cert_url(self) -> str: + """The URL to the certificate that was used to sign the message.""" return self["Sns"]["SigningCertUrl"] @property def message_id(self) -> str: + """A Universally Unique Identifier, unique for each message published. + + For a message that Amazon SNS resends during a retry, the message ID of the original message is used.""" return self["Sns"]["MessageId"] @property def message(self) -> str: + """A string that describes the message. """ return self["Sns"]["Message"] @property @@ -44,26 +53,34 @@ def message_attributes(self) -> Dict[str, SNSMessageAttribute]: @property def get_type(self) -> str: - """Get the `type` property""" + """The type of message. + + For a subscription confirmation, the type is SubscriptionConfirmation.""" # Note: this name conflicts with existing python builtins return self["Sns"]["Type"] @property def unsubscribe_url(self) -> str: + """A URL that you can use to unsubscribe the endpoint from this topic. + + If you visit this URL, Amazon SNS unsubscribes the endpoint and stops sending notifications to this endpoint.""" return self["Sns"]["UnsubscribeUrl"] @property def topic_arn(self) -> str: + """The Amazon Resource Name (ARN) for the topic that this endpoint is subscribed to.""" return self["Sns"]["TopicArn"] @property def subject(self) -> str: + """The Subject parameter specified when the notification was published to the topic.""" return self["Sns"]["Subject"] class SNSEventRecord(dict): @property def event_version(self) -> str: + """Event version""" return self["EventVersion"] @property @@ -82,7 +99,9 @@ def sns(self) -> SNSMessage: class SNSEvent(dict): """SNS Event - Documentation: https://docs.aws.amazon.com/lambda/latest/dg/with-sns.html + Documentation: + ------------- + - https://docs.aws.amazon.com/lambda/latest/dg/with-sns.html """ @property diff --git a/aws_lambda_powertools/utilities/trigger/sqs_event.py b/aws_lambda_powertools/utilities/trigger/sqs_event.py index ad7a15a61a7..60604362676 100644 --- a/aws_lambda_powertools/utilities/trigger/sqs_event.py +++ b/aws_lambda_powertools/utilities/trigger/sqs_event.py @@ -1,102 +1,149 @@ from typing import Dict, Iterator, Optional -class SQSRecordAttributes(dict): +class SQSRecordAttributes: + def __init__(self, record_attributes: dict): + self._val = record_attributes + @property def aws_trace_header(self) -> Optional[str]: - return self["attributes"].get("AWSTraceHeader") + """Returns the AWS X-Ray trace header string.""" + return self._val.get("AWSTraceHeader") @property def approximate_receive_count(self) -> str: - return self["attributes"]["ApproximateReceiveCount"] + """Returns the number of times a message has been received across all queues but not deleted.""" + return self._val["ApproximateReceiveCount"] @property def sent_timestamp(self) -> str: - return self["attributes"]["SentTimestamp"] + """Returns the time the message was sent to the queue (epoch time in milliseconds).""" + return self._val["SentTimestamp"] @property def sender_id(self) -> str: - return self["attributes"]["SenderId"] + """For an IAM user, returns the IAM user ID, For an IAM role, returns the IAM role ID""" + return self._val["SenderId"] @property def approximate_first_receive_timestamp(self) -> str: - return self["attributes"]["ApproximateFirstReceiveTimestamp"] + """Returns the time the message was first received from the queue (epoch time in milliseconds).""" + return self._val["ApproximateFirstReceiveTimestamp"] @property def sequence_number(self) -> Optional[str]: - return self["attributes"].get("SequenceNumber") + """The large, non-consecutive number that Amazon SQS assigns to each message.""" + return self._val.get("SequenceNumber") @property def message_group_id(self) -> Optional[str]: - return self.get("MessageGroupId") + """The tag that specifies that a message belongs to a specific message group. + + Messages that belong to the same message group are always processed one by one, in a + strict order relative to the message group (however, messages that belong to different + message groups might be processed out of order).""" + return self._val.get("MessageGroupId") @property def message_deduplication_id(self) -> Optional[str]: - return self["attributes"].get("MessageDeduplicationId") + """The token used for deduplication of sent messages. + + If a message with a particular message deduplication ID is sent successfully, any messages sent + with the same message deduplication ID are accepted successfully but aren't delivered during + the 5-minute deduplication interval.""" + return self._val.get("MessageDeduplicationId") + + +class SQSMessageAttribute: + """The user-specified message attribute value.""" + def __init__(self, message_attribute: dict): + self._val = message_attribute -class SQSMessageAttribute(dict): @property def string_value(self) -> Optional[str]: - return self["stringValue"] + """Strings are Unicode with UTF-8 binary encoding.""" + return self._val["stringValue"] @property def binary_value(self) -> Optional[str]: - return self["binaryValue"] + """Binary type attributes can store any binary data, such as compressed data, encrypted data, or images. + + Base64-encoded binary data object""" + return self._val["binaryValue"] @property def data_type(self) -> str: - return self["dataType"] + """ The message attribute data type. Supported types include `String`, `Number`, and `Binary`.""" + return self._val["dataType"] class SQSMessageAttributes(Dict[str, SQSMessageAttribute]): - def __getitem__(self, item) -> Optional[SQSMessageAttribute]: - item = super(SQSMessageAttributes, self).get(item) + def __getitem__(self, key: str) -> Optional[SQSMessageAttribute]: + item = super(SQSMessageAttributes, self).get(key) return None if item is None else SQSMessageAttribute(item) -class SQSRecord(dict): +class SQSRecord: + """An Amazon SQS message""" + + def __init__(self, record: dict): + self._val = record + @property def message_id(self) -> str: - return self["messageId"] + """A unique identifier for the message. + + A messageId is considered unique across all AWS accounts for an extended period of time.""" + return self._val["messageId"] @property def receipt_handle(self) -> str: - return self["receiptHandle"] + """An identifier associated with the act of receiving the message. + + A new receipt handle is returned every time you receive a message. When deleting a message, + you provide the last received receipt handle to delete the message.""" + return self._val["receiptHandle"] @property def body(self) -> str: - return self["body"] + """The message's contents (not URL-encoded).""" + return self._val["body"] @property def attributes(self) -> SQSRecordAttributes: - return SQSRecordAttributes(self) + """A map of the attributes requested in ReceiveMessage to their respective values.""" + return SQSRecordAttributes(self._val["attributes"]) @property def message_attributes(self) -> SQSMessageAttributes: - return SQSMessageAttributes(self["messageAttributes"]) + """Each message attribute consists of a Name, Type, and Value.""" + return SQSMessageAttributes(self._val["messageAttributes"]) @property def md5_of_body(self) -> str: - return self["md5OfBody"] + """An MD5 digest of the non-URL-encoded message body string.""" + return self._val["md5OfBody"] @property def event_source(self) -> str: - return self["eventSource"] + return self._val["eventSource"] @property def event_source_arn(self) -> str: - return self["eventSourceARN"] + return self._val["eventSourceARN"] @property def aws_region(self) -> str: - return self["awsRegion"] + return self._val["awsRegion"] class SQSEvent(dict): """SQS Event - Documentation: https://docs.aws.amazon.com/lambda/latest/dg/with-sqs.html + Documentation: + -------------- + - https://docs.aws.amazon.com/lambda/latest/dg/with-sqs.html """ @property diff --git a/tests/functional/test_lambda_trigger_events.py b/tests/functional/test_lambda_trigger_events.py index 30adac3d537..ab8e0a697cf 100644 --- a/tests/functional/test_lambda_trigger_events.py +++ b/tests/functional/test_lambda_trigger_events.py @@ -197,6 +197,7 @@ def test_dynamo_db_stream_trigger_event(): event = DynamoDBStreamEvent(load_event("dynamoStreamEvent.json")) records = list(event.records) + record = records[0] assert record.aws_region == "us-west-2" dynamodb = record.dynamodb @@ -241,7 +242,9 @@ def test_dynamo_attribute_value_list_value(): def test_dynamo_attribute_value_map_value(): example_attribute_value = {"M": {"Name": {"S": "Joe"}, "Age": {"N": "35"}}} + attribute_value = AttributeValue(example_attribute_value) + map_value = attribute_value.map_value assert map_value is not None item = map_value["Name"] From f3914614763b7e6426ba6d53a8c33e2696304122 Mon Sep 17 00:00:00 2001 From: Michael Brewer Date: Wed, 9 Sep 2020 19:28:32 -0700 Subject: [PATCH 15/30] chore: consistent naming --- .../trigger/cloud_watch_logs_event.py | 15 +++-- .../trigger/cognito_user_pool_event.py | 59 ++++++++++--------- .../utilities/trigger/s3_event.py | 7 ++- ...nt.json => cognitoCustomMessageEvent.json} | 0 ...on => cognitoPostAuthenticationEvent.json} | 0 ...json => cognitoPostConfirmationEvent.json} | 0 ...son => cognitoPreAuthenticationEvent.json} | 0 ...rEvent.json => cognitoPreSignUpEvent.json} | 0 ...on => cognitoPreTokenGenerationEvent.json} | 0 ...nt.json => cognitoUserMigrationEvent.json} | 0 .../functional/test_lambda_trigger_events.py | 14 ++--- 11 files changed, 53 insertions(+), 42 deletions(-) rename tests/events/{cognitoCustomMessageTriggerEvent.json => cognitoCustomMessageEvent.json} (100%) rename tests/events/{cognitoPostAuthenticationTriggerEvent.json => cognitoPostAuthenticationEvent.json} (100%) rename tests/events/{cognitoPostConfirmationTriggerEvent.json => cognitoPostConfirmationEvent.json} (100%) rename tests/events/{cognitoPreAuthenticationTriggerEvent.json => cognitoPreAuthenticationEvent.json} (100%) rename tests/events/{cognitoPreSignUpTriggerEvent.json => cognitoPreSignUpEvent.json} (100%) rename tests/events/{cognitoPreTokenGenerationTriggerEvent.json => cognitoPreTokenGenerationEvent.json} (100%) rename tests/events/{cognitoUserMigrationTriggerEvent.json => cognitoUserMigrationEvent.json} (100%) diff --git a/aws_lambda_powertools/utilities/trigger/cloud_watch_logs_event.py b/aws_lambda_powertools/utilities/trigger/cloud_watch_logs_event.py index 191c02a2c19..f6bad46d865 100644 --- a/aws_lambda_powertools/utilities/trigger/cloud_watch_logs_event.py +++ b/aws_lambda_powertools/utilities/trigger/cloud_watch_logs_event.py @@ -4,27 +4,30 @@ from typing import Dict, List, Optional -class CloudWatchLogsLogEvent(dict): +class CloudWatchLogsLogEvent: + def __init__(self, log_event: dict): + self._val = log_event + @property def get_id(self) -> str: """The ID property is a unique identifier for every log event.""" # Note: this name conflicts with existing python builtins - return self["id"] + return self._val["id"] @property def timestamp(self) -> int: """Get the `timestamp` property""" - return self["timestamp"] + return self._val["timestamp"] @property def message(self) -> str: """Get the `message` property""" - return self["message"] + return self._val["message"] @property def extracted_fields(self) -> Optional[Dict[str, str]]: """Get the `extractedFields` property""" - return self.get("extractedFields") + return self._val.get("extractedFields") class CloudWatchLogsDecodedData(dict): @@ -82,7 +85,7 @@ def aws_logs_data(self) -> str: return self["awslogs"]["data"] def decode_cloud_watch_logs_data(self) -> CloudWatchLogsDecodedData: - """Gzip and parse json data""" + """Decode, unzip and parse json data""" payload = base64.b64decode(self.aws_logs_data) decoded: dict = json.loads(zlib.decompress(payload, zlib.MAX_WBITS | 32).decode("UTF-8")) return CloudWatchLogsDecodedData(decoded) diff --git a/aws_lambda_powertools/utilities/trigger/cognito_user_pool_event.py b/aws_lambda_powertools/utilities/trigger/cognito_user_pool_event.py index 249bdce25ce..57b807287d9 100644 --- a/aws_lambda_powertools/utilities/trigger/cognito_user_pool_event.py +++ b/aws_lambda_powertools/utilities/trigger/cognito_user_pool_event.py @@ -293,26 +293,26 @@ class CustomMessageTriggerEventResponse(dict): def sms_message(self) -> str: return self["response"]["smsMessage"] + @property + def email_message(self) -> str: + return self["response"]["emailMessage"] + + @property + def email_subject(self) -> str: + return self["response"]["emailSubject"] + @sms_message.setter def sms_message(self, value: str): """The custom SMS message to be sent to your users. Must include the codeParameter value received in the request.""" self["response"]["smsMessage"] = value - @property - def email_message(self) -> str: - return self["response"]["emailMessage"] - @email_message.setter def email_message(self, value: str): """The custom email message to be sent to your users. Must include the codeParameter value received in the request.""" self["response"]["emailMessage"] = value - @property - def email_subject(self) -> str: - return self["response"]["emailSubject"] - @email_subject.setter def email_subject(self, value: str): """The subject line for the custom message.""" @@ -448,51 +448,54 @@ def preferred_role(self) -> Optional[str]: return self.get("preferredRole") -class PreTokenGenerationTriggerEventRequest(dict): +class PreTokenGenerationTriggerEventRequest: + def __init__(self, event: dict): + self._val = event + @property def group_configuration(self) -> GroupOverrideDetails: """The input object containing the current group configuration""" - return GroupOverrideDetails(self["request"]["groupConfiguration"]) + return GroupOverrideDetails(self._val["request"]["groupConfiguration"]) @property def user_attributes(self) -> Dict[str, str]: """One or more name-value pairs representing user attributes.""" - return self["request"]["userAttributes"] + return self._val["request"]["userAttributes"] @property def client_metadata(self) -> Optional[Dict[str, str]]: """One or more key-value pairs that you can provide as custom input to the Lambda function that you specify for the pre token generation trigger.""" - return self["request"].get("clientMetadata") + return self._val["request"].get("clientMetadata") class ClaimsOverrideDetails: def __init__(self, event: dict): - self._claims_override_details = event["response"]["claimsOverrideDetails"] + self._val = event["response"]["claimsOverrideDetails"] @property def claims_to_add_or_override(self) -> Optional[Dict[str, str]]: - return self._claims_override_details.get("claimsToAddOrOverride") + return self._val.get("claimsToAddOrOverride") @property def claims_to_suppress(self) -> Optional[List[str]]: - return self._claims_override_details.get("claimsToSuppress") + return self._val.get("claimsToSuppress") @property def group_configuration(self) -> Optional[GroupOverrideDetails]: - group_override_details = self._claims_override_details.get("groupOverrideDetails") + group_override_details = self._val.get("groupOverrideDetails") return None if group_override_details is None else GroupOverrideDetails(group_override_details) @claims_to_add_or_override.setter def claims_to_add_or_override(self, value: Dict[str, str]): """A map of one or more key-value pairs of claims to add or override. For group related claims, use groupOverrideDetails instead.""" - self._claims_override_details["claimsToAddOrOverride"] = value + self._val["claimsToAddOrOverride"] = value @claims_to_suppress.setter def claims_to_suppress(self, value: List[str]): """A list that contains claims to be suppressed from the identity token.""" - self._claims_override_details["claimsToSuppress"] = value + self._val["claimsToSuppress"] = value @group_configuration.setter def group_configuration(self, value: Dict[str, Any]): @@ -505,33 +508,33 @@ def group_configuration(self, value: Dict[str, Any]): as is, copy the value of the request's groupConfiguration object to the groupOverrideDetails object in the response, and pass it back to the service. """ - self._claims_override_details["groupOverrideDetails"] = value + self._val["groupOverrideDetails"] = value def set_group_configuration_groups_to_override(self, value: List[str]): """A list of the group names that are associated with the user that the identity token is issued for.""" - self._claims_override_details.setdefault("groupOverrideDetails", {}) - self._claims_override_details["groupOverrideDetails"]["groupsToOverride"] = value + self._val.setdefault("groupOverrideDetails", {}) + self._val["groupOverrideDetails"]["groupsToOverride"] = value def set_group_configuration_iam_roles_to_override(self, value: List[str]): """A list of the current IAM roles associated with these groups.""" - self._claims_override_details.setdefault("groupOverrideDetails", {}) - self._claims_override_details["groupOverrideDetails"]["iamRolesToOverride"] = value + self._val.setdefault("groupOverrideDetails", {}) + self._val["groupOverrideDetails"]["iamRolesToOverride"] = value def set_group_configuration_preferred_role(self, value: str): """A string indicating the preferred IAM role.""" - self._claims_override_details.setdefault("groupOverrideDetails", {}) - self._claims_override_details["groupOverrideDetails"]["preferredRole"] = value + self._val.setdefault("groupOverrideDetails", {}) + self._val["groupOverrideDetails"]["preferredRole"] = value class PreTokenGenerationTriggerEventResponse: def __init__(self, event: dict): - self._event = event + self._val = event @property def claims_override_details(self) -> ClaimsOverrideDetails: # Ensure we have a `claimsOverrideDetails` element - self._event["response"].setdefault("claimsOverrideDetails", {}) - return ClaimsOverrideDetails(self._event) + self._val["response"].setdefault("claimsOverrideDetails", {}) + return ClaimsOverrideDetails(self._val) class PreTokenGenerationTriggerEvent(BaseTriggerEvent): diff --git a/aws_lambda_powertools/utilities/trigger/s3_event.py b/aws_lambda_powertools/utilities/trigger/s3_event.py index 9df916c3543..5f6409cd9d8 100644 --- a/aws_lambda_powertools/utilities/trigger/s3_event.py +++ b/aws_lambda_powertools/utilities/trigger/s3_event.py @@ -80,17 +80,22 @@ def get_object(self) -> S3Object: class S3EventRecordGlacierRestoreEventData(dict): @property def lifecycle_restoration_expiry_time(self) -> str: + """Time when the object restoration will be expired.""" return self["restoreEventData"]["lifecycleRestorationExpiryTime"] @property def lifecycle_restore_storage_class(self) -> str: + """Source storage class for restore""" return self["restoreEventData"]["lifecycleRestoreStorageClass"] class S3EventRecordGlacierEventData(dict): @property def restore_event_data(self) -> S3EventRecordGlacierRestoreEventData: - """The restoreEventData key contains attributes related to your restore request.""" + """The restoreEventData key contains attributes related to your restore request. + + The glacierEventData key is only visible for s3:ObjectRestore:Completed events + """ return S3EventRecordGlacierRestoreEventData(self) diff --git a/tests/events/cognitoCustomMessageTriggerEvent.json b/tests/events/cognitoCustomMessageEvent.json similarity index 100% rename from tests/events/cognitoCustomMessageTriggerEvent.json rename to tests/events/cognitoCustomMessageEvent.json diff --git a/tests/events/cognitoPostAuthenticationTriggerEvent.json b/tests/events/cognitoPostAuthenticationEvent.json similarity index 100% rename from tests/events/cognitoPostAuthenticationTriggerEvent.json rename to tests/events/cognitoPostAuthenticationEvent.json diff --git a/tests/events/cognitoPostConfirmationTriggerEvent.json b/tests/events/cognitoPostConfirmationEvent.json similarity index 100% rename from tests/events/cognitoPostConfirmationTriggerEvent.json rename to tests/events/cognitoPostConfirmationEvent.json diff --git a/tests/events/cognitoPreAuthenticationTriggerEvent.json b/tests/events/cognitoPreAuthenticationEvent.json similarity index 100% rename from tests/events/cognitoPreAuthenticationTriggerEvent.json rename to tests/events/cognitoPreAuthenticationEvent.json diff --git a/tests/events/cognitoPreSignUpTriggerEvent.json b/tests/events/cognitoPreSignUpEvent.json similarity index 100% rename from tests/events/cognitoPreSignUpTriggerEvent.json rename to tests/events/cognitoPreSignUpEvent.json diff --git a/tests/events/cognitoPreTokenGenerationTriggerEvent.json b/tests/events/cognitoPreTokenGenerationEvent.json similarity index 100% rename from tests/events/cognitoPreTokenGenerationTriggerEvent.json rename to tests/events/cognitoPreTokenGenerationEvent.json diff --git a/tests/events/cognitoUserMigrationTriggerEvent.json b/tests/events/cognitoUserMigrationEvent.json similarity index 100% rename from tests/events/cognitoUserMigrationTriggerEvent.json rename to tests/events/cognitoUserMigrationEvent.json diff --git a/tests/functional/test_lambda_trigger_events.py b/tests/functional/test_lambda_trigger_events.py index ab8e0a697cf..687ad24197c 100644 --- a/tests/functional/test_lambda_trigger_events.py +++ b/tests/functional/test_lambda_trigger_events.py @@ -53,7 +53,7 @@ def test_cloud_watch_trigger_event(): def test_cognito_pre_signup_trigger_event(): - event = PreSignUpTriggerEvent(load_event("cognitoPreSignUpTriggerEvent.json")) + event = PreSignUpTriggerEvent(load_event("cognitoPreSignUpEvent.json")) assert event.version == "string" assert event.trigger_source == "PreSignUp_SignUp" @@ -79,7 +79,7 @@ def test_cognito_pre_signup_trigger_event(): def test_cognito_post_confirmation_trigger_event(): - event = PostConfirmationTriggerEvent(load_event("cognitoPostConfirmationTriggerEvent.json")) + event = PostConfirmationTriggerEvent(load_event("cognitoPostConfirmationEvent.json")) user_attributes = event.request.user_attributes assert user_attributes["email"] == "user@example.com" @@ -87,7 +87,7 @@ def test_cognito_post_confirmation_trigger_event(): def test_cognito_user_migration_trigger_event(): - event = UserMigrationTriggerEvent(load_event("cognitoUserMigrationTriggerEvent.json")) + event = UserMigrationTriggerEvent(load_event("cognitoUserMigrationEvent.json")) assert compare_digest(event.request.password, event["request"]["password"]) assert event.request.validation_data is None @@ -112,7 +112,7 @@ def test_cognito_user_migration_trigger_event(): def test_cognito_custom_message_trigger_event(): - event = CustomMessageTriggerEvent(load_event("cognitoCustomMessageTriggerEvent.json")) + event = CustomMessageTriggerEvent(load_event("cognitoCustomMessageEvent.json")) assert event.request.code_parameter == "####" assert event.request.username_parameter == "username" @@ -128,7 +128,7 @@ def test_cognito_custom_message_trigger_event(): def test_cognito_pre_authentication_trigger_event(): - event = PreAuthenticationTriggerEvent(load_event("cognitoPreAuthenticationTriggerEvent.json")) + event = PreAuthenticationTriggerEvent(load_event("cognitoPreAuthenticationEvent.json")) assert event.request.user_not_found is None event["request"]["userNotFound"] = True @@ -138,7 +138,7 @@ def test_cognito_pre_authentication_trigger_event(): def test_cognito_post_authentication_trigger_event(): - event = PostAuthenticationTriggerEvent(load_event("cognitoPostAuthenticationTriggerEvent.json")) + event = PostAuthenticationTriggerEvent(load_event("cognitoPostAuthenticationEvent.json")) assert event.request.new_device_used is True assert event.request.user_attributes["email"] == "test@mail.com" @@ -146,7 +146,7 @@ def test_cognito_post_authentication_trigger_event(): def test_cognito_pre_token_generation_trigger_event(): - event = PreTokenGenerationTriggerEvent(load_event("cognitoPreTokenGenerationTriggerEvent.json")) + event = PreTokenGenerationTriggerEvent(load_event("cognitoPreTokenGenerationEvent.json")) group_configuration = event.request.group_configuration assert group_configuration.groups_to_override == [] From b396e8c78f7c363349f7350bc5281607efa85e24 Mon Sep 17 00:00:00 2001 From: Michael Brewer Date: Wed, 9 Sep 2020 23:33:01 -0700 Subject: [PATCH 16/30] feat(trigger): Add api gateway proxy events --- .../utilities/trigger/__init__.py | 3 + .../trigger/api_gateway_proxy_event.py | 430 ++++++++++++++++++ tests/events/apiGatewayProxyEvent.json | 70 +++ tests/events/apiGatewayProxyV2Event.json | 57 +++ .../functional/test_lambda_trigger_events.py | 107 +++++ 5 files changed, 667 insertions(+) create mode 100644 aws_lambda_powertools/utilities/trigger/api_gateway_proxy_event.py create mode 100644 tests/events/apiGatewayProxyEvent.json create mode 100644 tests/events/apiGatewayProxyV2Event.json diff --git a/aws_lambda_powertools/utilities/trigger/__init__.py b/aws_lambda_powertools/utilities/trigger/__init__.py index 1499b4a5367..3b63f1a723f 100644 --- a/aws_lambda_powertools/utilities/trigger/__init__.py +++ b/aws_lambda_powertools/utilities/trigger/__init__.py @@ -1,3 +1,4 @@ +from .api_gateway_proxy_event import APIGatewayProxyEvent, APIGatewayProxyEventV2 from .cloud_watch_logs_event import CloudWatchLogsEvent from .dynamo_db_stream_event import DynamoDBStreamEvent from .event_bridge_event import EventBridgeEvent @@ -7,6 +8,8 @@ from .sqs_event import SQSEvent __all__ = [ + "APIGatewayProxyEvent", + "APIGatewayProxyEventV2", "CloudWatchLogsEvent", "DynamoDBStreamEvent", "EventBridgeEvent", diff --git a/aws_lambda_powertools/utilities/trigger/api_gateway_proxy_event.py b/aws_lambda_powertools/utilities/trigger/api_gateway_proxy_event.py new file mode 100644 index 00000000000..b3bbcb5c2b0 --- /dev/null +++ b/aws_lambda_powertools/utilities/trigger/api_gateway_proxy_event.py @@ -0,0 +1,430 @@ +from typing import Any, Dict, List, Optional + + +class APIGatewayEventIdentity: + def __init__(self, event: dict): + self._val = event + + @property + def access_key(self) -> Optional[str]: + return self._val["requestContext"]["identity"].get("accessKey") + + @property + def account_id(self) -> Optional[str]: + """The AWS account ID associated with the request.""" + return self._val["requestContext"]["identity"].get("accountId") + + @property + def api_key(self) -> Optional[str]: + """For API methods that require an API key, this variable is the API key associated with the method request. + For methods that don't require an API key, this variable is null. """ + return self._val["requestContext"]["identity"].get("apiKey") + + @property + def api_key_id(self) -> Optional[str]: + """The API key ID associated with an API request that requires an API key.""" + return self._val["requestContext"]["identity"].get("apiKeyId") + + @property + def caller(self) -> Optional[str]: + """The principal identifier of the caller making the request.""" + return self._val["requestContext"]["identity"].get("caller") + + @property + def cognito_authentication_provider(self) -> Optional[str]: + """A comma-separated list of the Amazon Cognito authentication providers used by the caller + making the request. Available only if the request was signed with Amazon Cognito credentials.""" + return self._val["requestContext"]["identity"].get("cognitoAuthenticationProvider") + + @property + def cognito_authentication_type(self) -> Optional[str]: + """The Amazon Cognito authentication type of the caller making the request. + Available only if the request was signed with Amazon Cognito credentials.""" + return self._val["requestContext"]["identity"].get("cognitoAuthenticationType") + + @property + def cognito_identity_id(self) -> Optional[str]: + """The Amazon Cognito identity ID of the caller making the request. + Available only if the request was signed with Amazon Cognito credentials.""" + return self._val["requestContext"]["identity"].get("cognitoIdentityId") + + @property + def cognito_identity_pool_id(self) -> Optional[str]: + """The Amazon Cognito identity pool ID of the caller making the request. + Available only if the request was signed with Amazon Cognito credentials.""" + return self._val["requestContext"]["identity"].get("cognitoIdentityPoolId") + + @property + def principal_org_id(self) -> Optional[str]: + """The AWS organization ID.""" + return self._val["requestContext"]["identity"].get("principalOrgId") + + @property + def source_ip(self) -> str: + """The source IP address of the TCP connection making the request to API Gateway.""" + return self._val["requestContext"]["identity"]["sourceIp"] + + @property + def user(self) -> Optional[str]: + """The principal identifier of the user making the request.""" + return self._val["requestContext"]["identity"].get("user") + + @property + def user_agent(self) -> Optional[str]: + """The User Agent of the API caller.""" + return self._val["requestContext"]["identity"].get("userAgent") + + @property + def user_arn(self) -> Optional[str]: + """The Amazon Resource Name (ARN) of the effective user identified after authentication.""" + return self._val["requestContext"]["identity"].get("userArn") + + +class APIGatewayEventAuthorizer: + def __init__(self, event: Dict): + self._val = event + + @property + def claims(self) -> Optional[Dict[str, Any]]: + return self._val["requestContext"]["authorizer"].get("claims") + + @property + def scopes(self) -> Optional[List[str]]: + return self._val["requestContext"]["authorizer"].get("scopes") + + +class APIGatewayEventRequestContext: + def __init__(self, event: Dict[str, Any]): + self._val = event + + @property + def account_id(self) -> str: + """The AWS account ID associated with the request.""" + return self._val["requestContext"]["accountId"] + + @property + def api_id(self) -> str: + """The identifier API Gateway assigns to your API.""" + return self._val["requestContext"]["apiId"] + + @property + def authorizer(self) -> APIGatewayEventAuthorizer: + return APIGatewayEventAuthorizer(self._val) + + @property + def connected_at(self) -> Optional[int]: + """The Epoch-formatted connection time. (WebSocket API)""" + return self._val["requestContext"].get("connectedAt") + + @property + def connection_id(self) -> Optional[str]: + """A unique ID for the connection that can be used to make a callback to the client. (WebSocket API)""" + return self._val["requestContext"].get("connectionId") + + @property + def domain_name(self) -> Optional[str]: + """A domain name""" + return self._val["requestContext"].get("domainName") + + @property + def domain_prefix(self) -> Optional[str]: + return self._val["requestContext"].get("domainPrefix") + + @property + def event_type(self) -> Optional[str]: + """The event type: `CONNECT`, `MESSAGE`, or `DISCONNECT`. (WebSocket API)""" + return self._val["requestContext"].get("eventType") + + @property + def extended_request_id(self) -> Optional[str]: + """An automatically generated ID for the API call, which contains more useful information + for debugging/troubleshooting.""" + return self._val["requestContext"].get("extendedRequestId") + + @property + def protocol(self) -> str: + """The request protocol, for example, HTTP/1.1.""" + return self._val["requestContext"]["protocol"] + + @property + def http_method(self) -> str: + """The HTTP method used. Valid values include: DELETE, GET, HEAD, OPTIONS, PATCH, POST, and PUT.""" + return self._val["requestContext"]["httpMethod"] + + @property + def identity(self) -> APIGatewayEventIdentity: + return APIGatewayEventIdentity(self._val) + + @property + def message_direction(self) -> Optional[str]: + """Message direction (WebSocket API)""" + return self._val["requestContext"].get("messageDirection") + + @property + def message_id(self) -> Optional[str]: + """A unique server-side ID for a message. Available only when the `eventType` is `MESSAGE`.""" + return self._val["requestContext"].get("messageId") + + @property + def path(self) -> str: + return self._val["requestContext"]["path"] + + @property + def stage(self) -> str: + """The deployment stage of the API request """ + return self._val["requestContext"]["stage"] + + @property + def request_id(self) -> str: + """The ID that API Gateway assigns to the API request.""" + return self._val["requestContext"]["requestId"] + + @property + def request_time(self) -> Optional[str]: + """The CLF-formatted request time (dd/MMM/yyyy:HH:mm:ss +-hhmm)""" + return self._val["requestContext"].get("requestTime") + + @property + def request_time_epoch(self) -> int: + """The Epoch-formatted request time.""" + return self._val["requestContext"]["requestTimeEpoch"] + + @property + def resource_id(self) -> str: + return self._val["requestContext"]["resourceId"] + + @property + def resource_path(self) -> str: + return self._val["requestContext"]["resourcePath"] + + @property + def route_key(self) -> Optional[str]: + """The selected route key.""" + return self._val["requestContext"].get("routeKey") + + +class APIGatewayProxyEvent(dict): + """AWS Lambda proxy V1 + + Documentation: + -------------- + - https://docs.aws.amazon.com/apigateway/latest/developerguide/http-api-develop-integrations-lambda.html + """ + + @property + def version(self) -> str: + return self["version"] + + @property + def resource(self) -> str: + return self["resource"] + + @property + def path(self) -> str: + return self["path"] + + @property + def http_method(self) -> str: + """The HTTP method used. Valid values include: DELETE, GET, HEAD, OPTIONS, PATCH, POST, and PUT.""" + return self["httpMethod"] + + @property + def headers(self) -> Dict[str, str]: + return self["headers"] + + @property + def multi_value_headers(self) -> Dict[str, List[str]]: + return self["multiValueHeaders"] + + @property + def query_string_parameters(self) -> Optional[Dict[str, str]]: + return self.get("queryStringParameters") + + @property + def multi_value_query_string_parameters(self) -> Optional[Dict[str, List[str]]]: + return self.get("multiValueQueryStringParameters") + + @property + def request_context(self) -> APIGatewayEventRequestContext: + return APIGatewayEventRequestContext(self) + + @property + def path_parameters(self) -> Optional[Dict[str, str]]: + return self.get("pathParameters") + + @property + def stage_variables(self) -> Optional[Dict[str, str]]: + return self.get("stageVariables") + + @property + def body(self) -> Optional[str]: + return self.get("body") + + @property + def is_base64_encoded(self) -> bool: + return self["isBase64Encoded"] + + +class RequestContextV2Http: + def __init__(self, event: dict): + self._val = event + + @property + def method(self) -> str: + return self._val["requestContext"]["http"]["method"] + + @property + def path(self) -> str: + return self._val["requestContext"]["http"]["path"] + + @property + def protocol(self) -> str: + """The request protocol, for example, HTTP/1.1.""" + return self._val["requestContext"]["http"]["protocol"] + + @property + def source_ip(self) -> str: + """The source IP address of the TCP connection making the request to API Gateway.""" + return self._val["requestContext"]["http"]["sourceIp"] + + @property + def user_agent(self) -> str: + """The User Agent of the API caller.""" + return self._val["requestContext"]["http"]["userAgent"] + + +class RequestContextV2Authorizer: + def __init__(self, event: dict): + self._val = event + + @property + def jwt_claim(self) -> Dict[str, Any]: + return self._val["jwt"]["claims"] + + @property + def jwt_scopes(self) -> List[str]: + return self._val["jwt"]["scopes"] + + +class RequestContextV2: + def __init__(self, event: dict): + self._val = event + + @property + def account_id(self) -> str: + """The AWS account ID associated with the request.""" + return self._val["requestContext"]["accountId"] + + @property + def api_id(self) -> str: + """The identifier API Gateway assigns to your API.""" + return self._val["requestContext"]["apiId"] + + @property + def authorizer(self) -> Optional[RequestContextV2Authorizer]: + authorizer = self._val["requestContext"].get("authorizer") + return None if authorizer is None else RequestContextV2Authorizer(authorizer) + + @property + def domain_name(self) -> str: + """A domain name """ + return self._val["requestContext"]["domainName"] + + @property + def domain_prefix(self) -> str: + return self._val["requestContext"]["domainPrefix"] + + @property + def http(self) -> RequestContextV2Http: + return RequestContextV2Http(self._val) + + @property + def request_id(self) -> str: + """The ID that API Gateway assigns to the API request.""" + return self._val["requestContext"]["requestId"] + + @property + def route_key(self) -> str: + """The selected route key.""" + return self._val["requestContext"]["routeKey"] + + @property + def stage(self) -> str: + """The deployment stage of the API request """ + return self._val["requestContext"]["stage"] + + @property + def time(self) -> str: + """The CLF-formatted request time (dd/MMM/yyyy:HH:mm:ss +-hhmm).""" + return self._val["requestContext"]["time"] + + @property + def time_epoch(self) -> int: + """The Epoch-formatted request time.""" + return self._val["requestContext"]["timeEpoch"] + + +class APIGatewayProxyEventV2(dict): + """AWS Lambda proxy V2 event + + Notes: + ----- + Format 2.0 doesn't have multiValueHeaders or multiValueQueryStringParameters fields. Duplicate headers + are combined with commas and included in the headers field. Duplicate query strings are combined with + commas and included in the queryStringParameters field. + + Format 2.0 includes a new cookies field. All cookie headers in the request are combined with commas and + added to the cookies field. In the response to the client, each cookie becomes a set-cookie header. + + Documentation: + -------------- + - https://docs.aws.amazon.com/apigateway/latest/developerguide/http-api-develop-integrations-lambda.html + """ + + @property + def version(self) -> str: + return self["version"] + + @property + def route_key(self) -> str: + return self["routeKey"] + + @property + def raw_path(self) -> str: + return self["rawPath"] + + @property + def raw_query_string(self) -> str: + return self["rawQueryString"] + + @property + def cookies(self) -> Optional[List[str]]: + return self.get("cookies") + + @property + def headers(self) -> Dict[str, str]: + return self["headers"] + + @property + def query_string_parameters(self) -> Optional[Dict[str, str]]: + return self.get("queryStringParameters") + + @property + def request_context(self) -> RequestContextV2: + return RequestContextV2(self) + + @property + def body(self) -> Optional[str]: + return self.get("body") + + @property + def path_parameters(self) -> Optional[Dict[str, str]]: + return self.get("pathParameters") + + @property + def is_base64_encoded(self) -> bool: + return self["isBase64Encoded"] + + @property + def stage_variables(self) -> Optional[Dict[str, str]]: + return self.get("stageVariables") diff --git a/tests/events/apiGatewayProxyEvent.json b/tests/events/apiGatewayProxyEvent.json new file mode 100644 index 00000000000..1fed04a25bf --- /dev/null +++ b/tests/events/apiGatewayProxyEvent.json @@ -0,0 +1,70 @@ +{ + "version": "1.0", + "resource": "/my/path", + "path": "/my/path", + "httpMethod": "GET", + "headers": { + "Header1": "value1", + "Header2": "value2" + }, + "multiValueHeaders": { + "Header1": [ + "value1" + ], + "Header2": [ + "value1", + "value2" + ] + }, + "queryStringParameters": { + "parameter1": "value1", + "parameter2": "value" + }, + "multiValueQueryStringParameters": { + "parameter1": [ + "value1", + "value2" + ], + "parameter2": [ + "value" + ] + }, + "requestContext": { + "accountId": "123456789012", + "apiId": "id", + "authorizer": { + "claims": null, + "scopes": null + }, + "domainName": "id.execute-api.us-east-1.amazonaws.com", + "domainPrefix": "id", + "extendedRequestId": "request-id", + "httpMethod": "GET", + "identity": { + "accessKey": null, + "accountId": null, + "caller": null, + "cognitoAuthenticationProvider": null, + "cognitoAuthenticationType": null, + "cognitoIdentityId": null, + "cognitoIdentityPoolId": null, + "principalOrgId": null, + "sourceIp": "IP", + "user": null, + "userAgent": "user-agent", + "userArn": null + }, + "path": "/my/path", + "protocol": "HTTP/1.1", + "requestId": "id=", + "requestTime": "04/Mar/2020:19:15:17 +0000", + "requestTimeEpoch": 1583349317135, + "resourceId": null, + "resourcePath": "/my/path", + "stage": "$default" + }, + "pathParameters": null, + "stageVariables": null, + "body": "Hello from Lambda!", + "isBase64Encoded": true +} diff --git a/tests/events/apiGatewayProxyV2Event.json b/tests/events/apiGatewayProxyV2Event.json new file mode 100644 index 00000000000..9c310e6d52f --- /dev/null +++ b/tests/events/apiGatewayProxyV2Event.json @@ -0,0 +1,57 @@ +{ + "version": "2.0", + "routeKey": "$default", + "rawPath": "/my/path", + "rawQueryString": "parameter1=value1¶meter1=value2¶meter2=value", + "cookies": [ + "cookie1", + "cookie2" + ], + "headers": { + "Header1": "value1", + "Header2": "value1,value2" + }, + "queryStringParameters": { + "parameter1": "value1,value2", + "parameter2": "value" + }, + "requestContext": { + "accountId": "123456789012", + "apiId": "api-id", + "authorizer": { + "jwt": { + "claims": { + "claim1": "value1", + "claim2": "value2" + }, + "scopes": [ + "scope1", + "scope2" + ] + } + }, + "domainName": "id.execute-api.us-east-1.amazonaws.com", + "domainPrefix": "id", + "http": { + "method": "POST", + "path": "/my/path", + "protocol": "HTTP/1.1", + "sourceIp": "IP", + "userAgent": "agent" + }, + "requestId": "id", + "routeKey": "$default", + "stage": "$default", + "time": "12/Mar/2020:19:03:58 +0000", + "timeEpoch": 1583348638390 + }, + "body": "Hello from Lambda", + "pathParameters": { + "parameter1": "value1" + }, + "isBase64Encoded": false, + "stageVariables": { + "stageVariable1": "value1", + "stageVariable2": "value2" + } +} diff --git a/tests/functional/test_lambda_trigger_events.py b/tests/functional/test_lambda_trigger_events.py index 687ad24197c..13b3c087693 100644 --- a/tests/functional/test_lambda_trigger_events.py +++ b/tests/functional/test_lambda_trigger_events.py @@ -3,6 +3,8 @@ from secrets import compare_digest from aws_lambda_powertools.utilities.trigger import ( + APIGatewayProxyEvent, + APIGatewayProxyEventV2, CloudWatchLogsEvent, EventBridgeEvent, S3Event, @@ -410,3 +412,108 @@ def test_seq_trigger_event(): assert record.event_source == "aws:sqs" assert record.event_source_arn == "arn:aws:sqs:us-east-2:123456789012:my-queue" assert record.aws_region == "us-east-2" + + +def test_api_gateway_proxy_event(): + event = APIGatewayProxyEvent(load_event("apiGatewayProxyEvent.json")) + + assert event.version == event["version"] + assert event.resource == event["resource"] + assert event.path == event["path"] + assert event.http_method == event["httpMethod"] + assert event.headers == event["headers"] + assert event.multi_value_headers == event["multiValueHeaders"] + assert event.query_string_parameters == event["queryStringParameters"] + assert event.multi_value_query_string_parameters == event["multiValueQueryStringParameters"] + + request_context = event.request_context + assert request_context.account_id == event["requestContext"]["accountId"] + assert request_context.api_id == event["requestContext"]["apiId"] + + authorizer = request_context.authorizer + assert authorizer.claims is None + assert authorizer.scopes is None + + assert request_context.domain_name == event["requestContext"]["domainName"] + assert request_context.domain_prefix == event["requestContext"]["domainPrefix"] + assert request_context.extended_request_id == event["requestContext"]["extendedRequestId"] + assert request_context.http_method == event["requestContext"]["httpMethod"] + + identity = request_context.identity + assert identity.access_key == event["requestContext"]["identity"]["accessKey"] + assert identity.account_id == event["requestContext"]["identity"]["accountId"] + assert identity.caller == event["requestContext"]["identity"]["caller"] + assert ( + identity.cognito_authentication_provider == event["requestContext"]["identity"]["cognitoAuthenticationProvider"] + ) + assert identity.cognito_authentication_type == event["requestContext"]["identity"]["cognitoAuthenticationType"] + assert identity.cognito_identity_id == event["requestContext"]["identity"]["cognitoIdentityId"] + assert identity.cognito_identity_pool_id == event["requestContext"]["identity"]["cognitoIdentityPoolId"] + assert identity.principal_org_id == event["requestContext"]["identity"]["principalOrgId"] + assert identity.source_ip == event["requestContext"]["identity"]["sourceIp"] + assert identity.user == event["requestContext"]["identity"]["user"] + assert identity.user_agent == event["requestContext"]["identity"]["userAgent"] + assert identity.user_arn == event["requestContext"]["identity"]["userArn"] + + assert request_context.path == event["requestContext"]["path"] + assert request_context.protocol == event["requestContext"]["protocol"] + assert request_context.request_id == event["requestContext"]["requestId"] + assert request_context.request_time == event["requestContext"]["requestTime"] + assert request_context.request_time_epoch == event["requestContext"]["requestTimeEpoch"] + assert request_context.resource_id == event["requestContext"]["resourceId"] + assert request_context.resource_path == event["requestContext"]["resourcePath"] + assert request_context.stage == event["requestContext"]["stage"] + + assert event.path_parameters == event["pathParameters"] + assert event.stage_variables == event["stageVariables"] + assert event.body == event["body"] + assert event.is_base64_encoded == event["isBase64Encoded"] + + assert request_context.connected_at is None + assert request_context.connection_id is None + assert request_context.event_type is None + assert request_context.message_direction is None + assert request_context.message_id is None + assert request_context.route_key is None + assert identity.api_key is None + assert identity.api_key_id is None + + +def test_api_gateway_proxy_v2_event(): + event = APIGatewayProxyEventV2(load_event("apiGatewayProxyV2Event.json")) + + assert event.version == event["version"] + assert event.route_key == event["routeKey"] + assert event.raw_path == event["rawPath"] + assert event.raw_query_string == event["rawQueryString"] + assert event.cookies == event["cookies"] + assert event.cookies[0] == "cookie1" + assert event.headers == event["headers"] + assert event.query_string_parameters == event["queryStringParameters"] + assert event.query_string_parameters["parameter2"] == "value" + + request_context = event.request_context + assert request_context.account_id == event["requestContext"]["accountId"] + assert request_context.api_id == event["requestContext"]["apiId"] + assert request_context.authorizer.jwt_claim == event["requestContext"]["authorizer"]["jwt"]["claims"] + assert request_context.authorizer.jwt_scopes == event["requestContext"]["authorizer"]["jwt"]["scopes"] + assert request_context.domain_name == event["requestContext"]["domainName"] + assert request_context.domain_prefix == event["requestContext"]["domainPrefix"] + + http = request_context.http + assert http.method == "POST" + assert http.path == "/my/path" + assert http.protocol == "HTTP/1.1" + assert http.source_ip == "IP" + assert http.user_agent == "agent" + + assert request_context.request_id == event["requestContext"]["requestId"] + assert request_context.route_key == event["requestContext"]["routeKey"] + assert request_context.stage == event["requestContext"]["stage"] + assert request_context.time == event["requestContext"]["time"] + assert request_context.time_epoch == event["requestContext"]["timeEpoch"] + + assert event.body == event["body"] + assert event.path_parameters == event["pathParameters"] + assert event.is_base64_encoded == event["isBase64Encoded"] + assert event.stage_variables == event["stageVariables"] From 553c48a88bdff25fa1e4077f018053cc8e96ead0 Mon Sep 17 00:00:00 2001 From: Michael Brewer Date: Thu, 10 Sep 2020 08:39:07 -0700 Subject: [PATCH 17/30] chore: Add more docs --- .../utilities/trigger/dynamo_db_stream_event.py | 1 + aws_lambda_powertools/utilities/trigger/s3_event.py | 2 +- aws_lambda_powertools/utilities/trigger/ses_event.py | 2 +- aws_lambda_powertools/utilities/trigger/sns_event.py | 1 + aws_lambda_powertools/utilities/trigger/sqs_event.py | 3 +++ 5 files changed, 7 insertions(+), 2 deletions(-) diff --git a/aws_lambda_powertools/utilities/trigger/dynamo_db_stream_event.py b/aws_lambda_powertools/utilities/trigger/dynamo_db_stream_event.py index a5acf585538..4871b33312e 100644 --- a/aws_lambda_powertools/utilities/trigger/dynamo_db_stream_event.py +++ b/aws_lambda_powertools/utilities/trigger/dynamo_db_stream_event.py @@ -211,6 +211,7 @@ def event_source(self) -> Optional[str]: @property def event_source_arn(self) -> Optional[str]: + """The Amazon Resource Name (ARN) of the event source""" return self._val.get("eventSourceARN") @property diff --git a/aws_lambda_powertools/utilities/trigger/s3_event.py b/aws_lambda_powertools/utilities/trigger/s3_event.py index 5f6409cd9d8..93e3dd7835b 100644 --- a/aws_lambda_powertools/utilities/trigger/s3_event.py +++ b/aws_lambda_powertools/utilities/trigger/s3_event.py @@ -107,7 +107,7 @@ def event_version(self) -> str: @property def event_source(self) -> str: - """aws:s3""" + """The AWS service from which the S3 event originated. For S3, this is aws:s3""" return self["eventSource"] @property diff --git a/aws_lambda_powertools/utilities/trigger/ses_event.py b/aws_lambda_powertools/utilities/trigger/ses_event.py index b35ac108f23..e795128b1a4 100644 --- a/aws_lambda_powertools/utilities/trigger/ses_event.py +++ b/aws_lambda_powertools/utilities/trigger/ses_event.py @@ -177,7 +177,7 @@ def receipt(self) -> SESReceipt: class SESEventRecord(dict): @property def event_source(self) -> str: - """event source will be: aws:ses""" + """The AWS service from which the SES event record originated. For SES, this is aws:ses""" return self["eventSource"] @property diff --git a/aws_lambda_powertools/utilities/trigger/sns_event.py b/aws_lambda_powertools/utilities/trigger/sns_event.py index 9b5315e7f30..c33fa39f7d9 100644 --- a/aws_lambda_powertools/utilities/trigger/sns_event.py +++ b/aws_lambda_powertools/utilities/trigger/sns_event.py @@ -89,6 +89,7 @@ def event_subscription_arn(self) -> str: @property def event_source(self) -> str: + """The AWS service from which the SNS event record originated. For SNS, this is aws:sns""" return self["EventSource"] @property diff --git a/aws_lambda_powertools/utilities/trigger/sqs_event.py b/aws_lambda_powertools/utilities/trigger/sqs_event.py index 60604362676..22addce4925 100644 --- a/aws_lambda_powertools/utilities/trigger/sqs_event.py +++ b/aws_lambda_powertools/utilities/trigger/sqs_event.py @@ -127,14 +127,17 @@ def md5_of_body(self) -> str: @property def event_source(self) -> str: + """The AWS service from which the SQS record originated. For SQS, this is `aws:sqs` """ return self._val["eventSource"] @property def event_source_arn(self) -> str: + """The Amazon Resource Name (ARN) of the event source""" return self._val["eventSourceARN"] @property def aws_region(self) -> str: + """aws region eg: us-east-1""" return self._val["awsRegion"] From 7e6a3e5afbaff527056ddeda941029011d6dd457 Mon Sep 17 00:00:00 2001 From: Michael Brewer Date: Sat, 12 Sep 2020 11:34:21 -0700 Subject: [PATCH 18/30] fix(trigger): better type hinting --- .../trigger/api_gateway_proxy_event.py | 134 +++++++++--------- .../trigger/cloud_watch_logs_event.py | 14 +- .../trigger/cognito_user_pool_event.py | 54 +++---- .../trigger/dynamo_db_stream_event.py | 64 ++++----- .../utilities/trigger/s3_event.py | 96 ++++++++----- .../utilities/trigger/ses_event.py | 106 ++++++++------ .../utilities/trigger/sns_event.py | 51 ++++--- .../utilities/trigger/sqs_event.py | 54 +++---- .../functional/test_lambda_trigger_events.py | 8 +- 9 files changed, 319 insertions(+), 262 deletions(-) diff --git a/aws_lambda_powertools/utilities/trigger/api_gateway_proxy_event.py b/aws_lambda_powertools/utilities/trigger/api_gateway_proxy_event.py index b3bbcb5c2b0..5ca2ba22d10 100644 --- a/aws_lambda_powertools/utilities/trigger/api_gateway_proxy_event.py +++ b/aws_lambda_powertools/utilities/trigger/api_gateway_proxy_event.py @@ -2,205 +2,205 @@ class APIGatewayEventIdentity: - def __init__(self, event: dict): - self._val = event + def __init__(self, event: Dict[str, Any]): + self._v = event @property def access_key(self) -> Optional[str]: - return self._val["requestContext"]["identity"].get("accessKey") + return self._v["requestContext"]["identity"].get("accessKey") @property def account_id(self) -> Optional[str]: """The AWS account ID associated with the request.""" - return self._val["requestContext"]["identity"].get("accountId") + return self._v["requestContext"]["identity"].get("accountId") @property def api_key(self) -> Optional[str]: """For API methods that require an API key, this variable is the API key associated with the method request. For methods that don't require an API key, this variable is null. """ - return self._val["requestContext"]["identity"].get("apiKey") + return self._v["requestContext"]["identity"].get("apiKey") @property def api_key_id(self) -> Optional[str]: """The API key ID associated with an API request that requires an API key.""" - return self._val["requestContext"]["identity"].get("apiKeyId") + return self._v["requestContext"]["identity"].get("apiKeyId") @property def caller(self) -> Optional[str]: """The principal identifier of the caller making the request.""" - return self._val["requestContext"]["identity"].get("caller") + return self._v["requestContext"]["identity"].get("caller") @property def cognito_authentication_provider(self) -> Optional[str]: """A comma-separated list of the Amazon Cognito authentication providers used by the caller making the request. Available only if the request was signed with Amazon Cognito credentials.""" - return self._val["requestContext"]["identity"].get("cognitoAuthenticationProvider") + return self._v["requestContext"]["identity"].get("cognitoAuthenticationProvider") @property def cognito_authentication_type(self) -> Optional[str]: """The Amazon Cognito authentication type of the caller making the request. Available only if the request was signed with Amazon Cognito credentials.""" - return self._val["requestContext"]["identity"].get("cognitoAuthenticationType") + return self._v["requestContext"]["identity"].get("cognitoAuthenticationType") @property def cognito_identity_id(self) -> Optional[str]: """The Amazon Cognito identity ID of the caller making the request. Available only if the request was signed with Amazon Cognito credentials.""" - return self._val["requestContext"]["identity"].get("cognitoIdentityId") + return self._v["requestContext"]["identity"].get("cognitoIdentityId") @property def cognito_identity_pool_id(self) -> Optional[str]: """The Amazon Cognito identity pool ID of the caller making the request. Available only if the request was signed with Amazon Cognito credentials.""" - return self._val["requestContext"]["identity"].get("cognitoIdentityPoolId") + return self._v["requestContext"]["identity"].get("cognitoIdentityPoolId") @property def principal_org_id(self) -> Optional[str]: """The AWS organization ID.""" - return self._val["requestContext"]["identity"].get("principalOrgId") + return self._v["requestContext"]["identity"].get("principalOrgId") @property def source_ip(self) -> str: """The source IP address of the TCP connection making the request to API Gateway.""" - return self._val["requestContext"]["identity"]["sourceIp"] + return self._v["requestContext"]["identity"]["sourceIp"] @property def user(self) -> Optional[str]: """The principal identifier of the user making the request.""" - return self._val["requestContext"]["identity"].get("user") + return self._v["requestContext"]["identity"].get("user") @property def user_agent(self) -> Optional[str]: """The User Agent of the API caller.""" - return self._val["requestContext"]["identity"].get("userAgent") + return self._v["requestContext"]["identity"].get("userAgent") @property def user_arn(self) -> Optional[str]: """The Amazon Resource Name (ARN) of the effective user identified after authentication.""" - return self._val["requestContext"]["identity"].get("userArn") + return self._v["requestContext"]["identity"].get("userArn") class APIGatewayEventAuthorizer: - def __init__(self, event: Dict): - self._val = event + def __init__(self, event: Dict[str, Any]): + self._v = event @property def claims(self) -> Optional[Dict[str, Any]]: - return self._val["requestContext"]["authorizer"].get("claims") + return self._v["requestContext"]["authorizer"].get("claims") @property def scopes(self) -> Optional[List[str]]: - return self._val["requestContext"]["authorizer"].get("scopes") + return self._v["requestContext"]["authorizer"].get("scopes") class APIGatewayEventRequestContext: def __init__(self, event: Dict[str, Any]): - self._val = event + self._v = event @property def account_id(self) -> str: """The AWS account ID associated with the request.""" - return self._val["requestContext"]["accountId"] + return self._v["requestContext"]["accountId"] @property def api_id(self) -> str: """The identifier API Gateway assigns to your API.""" - return self._val["requestContext"]["apiId"] + return self._v["requestContext"]["apiId"] @property def authorizer(self) -> APIGatewayEventAuthorizer: - return APIGatewayEventAuthorizer(self._val) + return APIGatewayEventAuthorizer(self._v) @property def connected_at(self) -> Optional[int]: """The Epoch-formatted connection time. (WebSocket API)""" - return self._val["requestContext"].get("connectedAt") + return self._v["requestContext"].get("connectedAt") @property def connection_id(self) -> Optional[str]: """A unique ID for the connection that can be used to make a callback to the client. (WebSocket API)""" - return self._val["requestContext"].get("connectionId") + return self._v["requestContext"].get("connectionId") @property def domain_name(self) -> Optional[str]: """A domain name""" - return self._val["requestContext"].get("domainName") + return self._v["requestContext"].get("domainName") @property def domain_prefix(self) -> Optional[str]: - return self._val["requestContext"].get("domainPrefix") + return self._v["requestContext"].get("domainPrefix") @property def event_type(self) -> Optional[str]: """The event type: `CONNECT`, `MESSAGE`, or `DISCONNECT`. (WebSocket API)""" - return self._val["requestContext"].get("eventType") + return self._v["requestContext"].get("eventType") @property def extended_request_id(self) -> Optional[str]: """An automatically generated ID for the API call, which contains more useful information for debugging/troubleshooting.""" - return self._val["requestContext"].get("extendedRequestId") + return self._v["requestContext"].get("extendedRequestId") @property def protocol(self) -> str: """The request protocol, for example, HTTP/1.1.""" - return self._val["requestContext"]["protocol"] + return self._v["requestContext"]["protocol"] @property def http_method(self) -> str: """The HTTP method used. Valid values include: DELETE, GET, HEAD, OPTIONS, PATCH, POST, and PUT.""" - return self._val["requestContext"]["httpMethod"] + return self._v["requestContext"]["httpMethod"] @property def identity(self) -> APIGatewayEventIdentity: - return APIGatewayEventIdentity(self._val) + return APIGatewayEventIdentity(self._v) @property def message_direction(self) -> Optional[str]: """Message direction (WebSocket API)""" - return self._val["requestContext"].get("messageDirection") + return self._v["requestContext"].get("messageDirection") @property def message_id(self) -> Optional[str]: """A unique server-side ID for a message. Available only when the `eventType` is `MESSAGE`.""" - return self._val["requestContext"].get("messageId") + return self._v["requestContext"].get("messageId") @property def path(self) -> str: - return self._val["requestContext"]["path"] + return self._v["requestContext"]["path"] @property def stage(self) -> str: """The deployment stage of the API request """ - return self._val["requestContext"]["stage"] + return self._v["requestContext"]["stage"] @property def request_id(self) -> str: """The ID that API Gateway assigns to the API request.""" - return self._val["requestContext"]["requestId"] + return self._v["requestContext"]["requestId"] @property def request_time(self) -> Optional[str]: """The CLF-formatted request time (dd/MMM/yyyy:HH:mm:ss +-hhmm)""" - return self._val["requestContext"].get("requestTime") + return self._v["requestContext"].get("requestTime") @property def request_time_epoch(self) -> int: """The Epoch-formatted request time.""" - return self._val["requestContext"]["requestTimeEpoch"] + return self._v["requestContext"]["requestTimeEpoch"] @property def resource_id(self) -> str: - return self._val["requestContext"]["resourceId"] + return self._v["requestContext"]["resourceId"] @property def resource_path(self) -> str: - return self._val["requestContext"]["resourcePath"] + return self._v["requestContext"]["resourcePath"] @property def route_key(self) -> Optional[str]: """The selected route key.""" - return self._val["requestContext"].get("routeKey") + return self._v["requestContext"].get("routeKey") class APIGatewayProxyEvent(dict): @@ -266,102 +266,102 @@ def is_base64_encoded(self) -> bool: class RequestContextV2Http: - def __init__(self, event: dict): - self._val = event + def __init__(self, event: Dict[str, Any]): + self._v = event @property def method(self) -> str: - return self._val["requestContext"]["http"]["method"] + return self._v["requestContext"]["http"]["method"] @property def path(self) -> str: - return self._val["requestContext"]["http"]["path"] + return self._v["requestContext"]["http"]["path"] @property def protocol(self) -> str: """The request protocol, for example, HTTP/1.1.""" - return self._val["requestContext"]["http"]["protocol"] + return self._v["requestContext"]["http"]["protocol"] @property def source_ip(self) -> str: """The source IP address of the TCP connection making the request to API Gateway.""" - return self._val["requestContext"]["http"]["sourceIp"] + return self._v["requestContext"]["http"]["sourceIp"] @property def user_agent(self) -> str: """The User Agent of the API caller.""" - return self._val["requestContext"]["http"]["userAgent"] + return self._v["requestContext"]["http"]["userAgent"] class RequestContextV2Authorizer: - def __init__(self, event: dict): - self._val = event + def __init__(self, event: Dict[str, Any]): + self._v = event @property def jwt_claim(self) -> Dict[str, Any]: - return self._val["jwt"]["claims"] + return self._v["jwt"]["claims"] @property def jwt_scopes(self) -> List[str]: - return self._val["jwt"]["scopes"] + return self._v["jwt"]["scopes"] class RequestContextV2: - def __init__(self, event: dict): - self._val = event + def __init__(self, event: Dict[str, Any]): + self._v = event @property def account_id(self) -> str: """The AWS account ID associated with the request.""" - return self._val["requestContext"]["accountId"] + return self._v["requestContext"]["accountId"] @property def api_id(self) -> str: """The identifier API Gateway assigns to your API.""" - return self._val["requestContext"]["apiId"] + return self._v["requestContext"]["apiId"] @property def authorizer(self) -> Optional[RequestContextV2Authorizer]: - authorizer = self._val["requestContext"].get("authorizer") + authorizer = self._v["requestContext"].get("authorizer") return None if authorizer is None else RequestContextV2Authorizer(authorizer) @property def domain_name(self) -> str: """A domain name """ - return self._val["requestContext"]["domainName"] + return self._v["requestContext"]["domainName"] @property def domain_prefix(self) -> str: - return self._val["requestContext"]["domainPrefix"] + return self._v["requestContext"]["domainPrefix"] @property def http(self) -> RequestContextV2Http: - return RequestContextV2Http(self._val) + return RequestContextV2Http(self._v) @property def request_id(self) -> str: """The ID that API Gateway assigns to the API request.""" - return self._val["requestContext"]["requestId"] + return self._v["requestContext"]["requestId"] @property def route_key(self) -> str: """The selected route key.""" - return self._val["requestContext"]["routeKey"] + return self._v["requestContext"]["routeKey"] @property def stage(self) -> str: """The deployment stage of the API request """ - return self._val["requestContext"]["stage"] + return self._v["requestContext"]["stage"] @property def time(self) -> str: """The CLF-formatted request time (dd/MMM/yyyy:HH:mm:ss +-hhmm).""" - return self._val["requestContext"]["time"] + return self._v["requestContext"]["time"] @property def time_epoch(self) -> int: """The Epoch-formatted request time.""" - return self._val["requestContext"]["timeEpoch"] + return self._v["requestContext"]["timeEpoch"] class APIGatewayProxyEventV2(dict): diff --git a/aws_lambda_powertools/utilities/trigger/cloud_watch_logs_event.py b/aws_lambda_powertools/utilities/trigger/cloud_watch_logs_event.py index f6bad46d865..9f86b9f370d 100644 --- a/aws_lambda_powertools/utilities/trigger/cloud_watch_logs_event.py +++ b/aws_lambda_powertools/utilities/trigger/cloud_watch_logs_event.py @@ -1,33 +1,33 @@ import base64 import json import zlib -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional class CloudWatchLogsLogEvent: - def __init__(self, log_event: dict): - self._val = log_event + def __init__(self, log_event: Dict[str, Any]): + self._v = log_event @property def get_id(self) -> str: """The ID property is a unique identifier for every log event.""" # Note: this name conflicts with existing python builtins - return self._val["id"] + return self._v["id"] @property def timestamp(self) -> int: """Get the `timestamp` property""" - return self._val["timestamp"] + return self._v["timestamp"] @property def message(self) -> str: """Get the `message` property""" - return self._val["message"] + return self._v["message"] @property def extracted_fields(self) -> Optional[Dict[str, str]]: """Get the `extractedFields` property""" - return self._val.get("extractedFields") + return self._v.get("extractedFields") class CloudWatchLogsDecodedData(dict): diff --git a/aws_lambda_powertools/utilities/trigger/cognito_user_pool_event.py b/aws_lambda_powertools/utilities/trigger/cognito_user_pool_event.py index 57b807287d9..55674c794d0 100644 --- a/aws_lambda_powertools/utilities/trigger/cognito_user_pool_event.py +++ b/aws_lambda_powertools/utilities/trigger/cognito_user_pool_event.py @@ -2,18 +2,18 @@ class CallerContext: - def __init__(self, event: dict): - self._event = event + def __init__(self, event: Dict[str, Any]): + self._v = event @property def aws_sdk_version(self) -> str: """The AWS SDK version number.""" - return self._event["callerContext"]["awsSdkVersion"] + return self._v["callerContext"]["awsSdkVersion"] @property def client_id(self) -> str: """The ID of the client associated with the user pool.""" - return self._event["callerContext"]["clientId"] + return self._v["callerContext"]["clientId"] class BaseTriggerEvent(dict): @@ -449,53 +449,53 @@ def preferred_role(self) -> Optional[str]: class PreTokenGenerationTriggerEventRequest: - def __init__(self, event: dict): - self._val = event + def __init__(self, event: Dict[str, Any]): + self._v = event @property def group_configuration(self) -> GroupOverrideDetails: """The input object containing the current group configuration""" - return GroupOverrideDetails(self._val["request"]["groupConfiguration"]) + return GroupOverrideDetails(self._v["request"]["groupConfiguration"]) @property def user_attributes(self) -> Dict[str, str]: """One or more name-value pairs representing user attributes.""" - return self._val["request"]["userAttributes"] + return self._v["request"]["userAttributes"] @property def client_metadata(self) -> Optional[Dict[str, str]]: """One or more key-value pairs that you can provide as custom input to the Lambda function that you specify for the pre token generation trigger.""" - return self._val["request"].get("clientMetadata") + return self._v["request"].get("clientMetadata") class ClaimsOverrideDetails: - def __init__(self, event: dict): - self._val = event["response"]["claimsOverrideDetails"] + def __init__(self, event: Dict[str, Any]): + self._v = event["response"]["claimsOverrideDetails"] @property def claims_to_add_or_override(self) -> Optional[Dict[str, str]]: - return self._val.get("claimsToAddOrOverride") + return self._v.get("claimsToAddOrOverride") @property def claims_to_suppress(self) -> Optional[List[str]]: - return self._val.get("claimsToSuppress") + return self._v.get("claimsToSuppress") @property def group_configuration(self) -> Optional[GroupOverrideDetails]: - group_override_details = self._val.get("groupOverrideDetails") + group_override_details = self._v.get("groupOverrideDetails") return None if group_override_details is None else GroupOverrideDetails(group_override_details) @claims_to_add_or_override.setter def claims_to_add_or_override(self, value: Dict[str, str]): """A map of one or more key-value pairs of claims to add or override. For group related claims, use groupOverrideDetails instead.""" - self._val["claimsToAddOrOverride"] = value + self._v["claimsToAddOrOverride"] = value @claims_to_suppress.setter def claims_to_suppress(self, value: List[str]): """A list that contains claims to be suppressed from the identity token.""" - self._val["claimsToSuppress"] = value + self._v["claimsToSuppress"] = value @group_configuration.setter def group_configuration(self, value: Dict[str, Any]): @@ -508,33 +508,33 @@ def group_configuration(self, value: Dict[str, Any]): as is, copy the value of the request's groupConfiguration object to the groupOverrideDetails object in the response, and pass it back to the service. """ - self._val["groupOverrideDetails"] = value + self._v["groupOverrideDetails"] = value def set_group_configuration_groups_to_override(self, value: List[str]): """A list of the group names that are associated with the user that the identity token is issued for.""" - self._val.setdefault("groupOverrideDetails", {}) - self._val["groupOverrideDetails"]["groupsToOverride"] = value + self._v.setdefault("groupOverrideDetails", {}) + self._v["groupOverrideDetails"]["groupsToOverride"] = value def set_group_configuration_iam_roles_to_override(self, value: List[str]): """A list of the current IAM roles associated with these groups.""" - self._val.setdefault("groupOverrideDetails", {}) - self._val["groupOverrideDetails"]["iamRolesToOverride"] = value + self._v.setdefault("groupOverrideDetails", {}) + self._v["groupOverrideDetails"]["iamRolesToOverride"] = value def set_group_configuration_preferred_role(self, value: str): """A string indicating the preferred IAM role.""" - self._val.setdefault("groupOverrideDetails", {}) - self._val["groupOverrideDetails"]["preferredRole"] = value + self._v.setdefault("groupOverrideDetails", {}) + self._v["groupOverrideDetails"]["preferredRole"] = value class PreTokenGenerationTriggerEventResponse: - def __init__(self, event: dict): - self._val = event + def __init__(self, event: Dict[str, Any]): + self._v = event @property def claims_override_details(self) -> ClaimsOverrideDetails: # Ensure we have a `claimsOverrideDetails` element - self._val["response"].setdefault("claimsOverrideDetails", {}) - return ClaimsOverrideDetails(self._val) + self._v["response"].setdefault("claimsOverrideDetails", {}) + return ClaimsOverrideDetails(self._v) class PreTokenGenerationTriggerEvent(BaseTriggerEvent): diff --git a/aws_lambda_powertools/utilities/trigger/dynamo_db_stream_event.py b/aws_lambda_powertools/utilities/trigger/dynamo_db_stream_event.py index 4871b33312e..559d50526e1 100644 --- a/aws_lambda_powertools/utilities/trigger/dynamo_db_stream_event.py +++ b/aws_lambda_powertools/utilities/trigger/dynamo_db_stream_event.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Dict, Iterator, List, Optional +from typing import Any, Dict, Iterator, List, Optional class AttributeValue: @@ -8,8 +8,8 @@ class AttributeValue: Documentation: https://docs.aws.amazon.com/amazondynamodb/latest/APIReference/API_streams_AttributeValue.html """ - def __init__(self, attr_value: dict): - self._val = attr_value + def __init__(self, attr_value: Dict[str, Any]): + self._v = attr_value @property def b_value(self) -> Optional[str]: @@ -18,7 +18,7 @@ def b_value(self) -> Optional[str]: Example: >>> {"B": "dGhpcyB0ZXh0IGlzIGJhc2U2NC1lbmNvZGVk"} """ - return self._val.get("B") + return self._v.get("B") @property def bs_value(self) -> Optional[List[str]]: @@ -27,7 +27,7 @@ def bs_value(self) -> Optional[List[str]]: Example: >>> {"BS": ["U3Vubnk=", "UmFpbnk=", "U25vd3k="]} """ - return self._val.get("BS") + return self._v.get("BS") @property def bool_value(self) -> Optional[bool]: @@ -36,7 +36,7 @@ def bool_value(self) -> Optional[bool]: Example: >>> {"BOOL": True} """ - item = self._val.get("bool") + item = self._v.get("bool") return None if item is None else bool(item) @property @@ -46,7 +46,7 @@ def list_value(self) -> Optional[List["AttributeValue"]]: Example: >>> {"L": [ {"S": "Cookies"} , {"S": "Coffee"}, {"N": "3.14159"}]} """ - item = self._val.get("L") + item = self._v.get("L") return None if item is None else [AttributeValue(v) for v in item] @property @@ -56,7 +56,7 @@ def map_value(self) -> Optional[Dict[str, "AttributeValue"]]: Example: >>> {"M": {"Name": {"S": "Joe"}, "Age": {"N": "35"}}} """ - return _attribute_value_dict(self._val, "M") + return _attribute_value_dict(self._v, "M") @property def n_value(self) -> Optional[str]: @@ -68,7 +68,7 @@ def n_value(self) -> Optional[str]: Example: >>> {"N": "123.45"} """ - return self._val.get("N") + return self._v.get("N") @property def ns_value(self) -> Optional[List[str]]: @@ -77,7 +77,7 @@ def ns_value(self) -> Optional[List[str]]: Example: >>> {"NS": ["42.2", "-19", "7.5", "3.14"]} """ - return self._val.get("NS") + return self._v.get("NS") @property def null_value(self) -> Optional[bool]: @@ -86,7 +86,7 @@ def null_value(self) -> Optional[bool]: Example: >>> {"NULL": True} """ - item = self._val.get("NULL") + item = self._v.get("NULL") return None if item is None else bool(item) @property @@ -96,7 +96,7 @@ def s_value(self) -> Optional[str]: Example: >>> {"S": "Hello"} """ - return self._val.get("S") + return self._v.get("S") @property def ss_value(self) -> Optional[List[str]]: @@ -105,7 +105,7 @@ def ss_value(self) -> Optional[List[str]]: Example: >>> {"SS": ["Giraffe", "Hippo" ,"Zebra"]} """ - return self._val.get("SS") + return self._v.get("SS") def _attribute_value_dict(attr_values: Dict[str, dict], key: str) -> Optional[Dict[str, AttributeValue]]: @@ -128,45 +128,45 @@ class StreamViewType(Enum): class StreamRecord: - def __init__(self, stream_record: dict): - self._val = stream_record + def __init__(self, stream_record: Dict[str, Any]): + self._v = stream_record @property def approximate_creation_date_time(self) -> Optional[int]: """The approximate date and time when the stream record was created, in UNIX epoch time format.""" - item = self._val.get("ApproximateCreationDateTime") + item = self._v.get("ApproximateCreationDateTime") return None if item is None else int(item) @property def keys(self) -> Optional[Dict[str, AttributeValue]]: """The primary key attribute(s) for the DynamoDB item that was modified.""" - return _attribute_value_dict(self._val, "Keys") + return _attribute_value_dict(self._v, "Keys") @property def new_image(self) -> Optional[Dict[str, AttributeValue]]: """The item in the DynamoDB table as it appeared after it was modified.""" - return _attribute_value_dict(self._val, "NewImage") + return _attribute_value_dict(self._v, "NewImage") @property def old_image(self) -> Optional[Dict[str, AttributeValue]]: """The item in the DynamoDB table as it appeared before it was modified.""" - return _attribute_value_dict(self._val, "OldImage") + return _attribute_value_dict(self._v, "OldImage") @property def sequence_number(self) -> Optional[str]: """The sequence number of the stream record.""" - return self._val.get("SequenceNumber") + return self._v.get("SequenceNumber") @property def size_bytes(self) -> Optional[int]: """The size of the stream record, in bytes.""" - item = self._val.get("SizeBytes") + item = self._v.get("SizeBytes") return None if item is None else int(item) @property def stream_view_type(self) -> Optional[StreamViewType]: """The type of data from the modified DynamoDB item that was captured in this stream record""" - item = self._val.get("StreamViewType") + item = self._v.get("StreamViewType") return None if item is None else StreamViewType[str(item)] @@ -179,50 +179,50 @@ class DynamoDBRecordEventName(Enum): class DynamoDBRecord: """A description of a unique event within a stream""" - def __init__(self, record: dict): - self._val = record + def __init__(self, record: Dict[str, Any]): + self._v = record @property def aws_region(self) -> Optional[str]: """The region in which the GetRecords request was received""" - return self._val.get("awsRegion") + return self._v.get("awsRegion") @property def dynamodb(self) -> Optional[StreamRecord]: """The main body of the stream record, containing all of the DynamoDB-specific fields.""" - stream_record = self._val.get("dynamodb") + stream_record = self._v.get("dynamodb") return None if stream_record is None else StreamRecord(stream_record) @property def event_id(self) -> Optional[str]: """A globally unique identifier for the event that was recorded in this stream record.""" - return self._val.get("eventID") + return self._v.get("eventID") @property def event_name(self) -> Optional[DynamoDBRecordEventName]: """The type of data modification that was performed on the DynamoDB table""" - item = self._val.get("eventName") + item = self._v.get("eventName") return None if item is None else DynamoDBRecordEventName[item] @property def event_source(self) -> Optional[str]: """The AWS service from which the stream record originated. For DynamoDB Streams, this is aws:dynamodb.""" - return self._val.get("eventSource") + return self._v.get("eventSource") @property def event_source_arn(self) -> Optional[str]: """The Amazon Resource Name (ARN) of the event source""" - return self._val.get("eventSourceARN") + return self._v.get("eventSourceARN") @property def event_version(self) -> Optional[str]: """The version number of the stream record format.""" - return self._val.get("eventVersion") + return self._v.get("eventVersion") @property def user_identity(self) -> Optional[dict]: """Contains details about the type of identity that made the request""" - return self._val.get("userIdentity") + return self._v.get("userIdentity") class DynamoDBStreamEvent(dict): diff --git a/aws_lambda_powertools/utilities/trigger/s3_event.py b/aws_lambda_powertools/utilities/trigger/s3_event.py index 93e3dd7835b..eafc9e8b8d9 100644 --- a/aws_lambda_powertools/utilities/trigger/s3_event.py +++ b/aws_lambda_powertools/utilities/trigger/s3_event.py @@ -1,138 +1,162 @@ -from typing import Dict, Iterator, Optional +from typing import Any, Dict, Iterator, Optional -class S3Identity(dict): +class S3Identity: + def __init__(self, s3_identity: Dict[str, str]): + self._v = s3_identity + @property def principal_id(self) -> str: - return self["principalId"] + return self._v["principalId"] + +class S3RequestParameters: + def __init__(self, record: Dict[str, Any]): + self._v = record -class S3RequestParameters(dict): @property def source_ip_address(self) -> str: - return self["requestParameters"]["sourceIPAddress"] + return self._v["requestParameters"]["sourceIPAddress"] -class S3Bucket(dict): +class S3Bucket: + def __init__(self, record: Dict[str, Any]): + self._v = record + @property def name(self) -> str: - return self["s3"]["bucket"]["name"] + return self._v["s3"]["bucket"]["name"] @property def owner_identity(self) -> S3Identity: - return S3Identity(self["s3"]["bucket"]["ownerIdentity"]) + return S3Identity(self._v["s3"]["bucket"]["ownerIdentity"]) @property def arn(self) -> str: - return self["s3"]["bucket"]["arn"] + return self._v["s3"]["bucket"]["arn"] + +class S3Object: + def __init__(self, record: Dict[str, Any]): + self._v = record -class S3Object(dict): @property def key(self) -> str: """Object key""" - return self["s3"]["object"]["key"] + return self._v["s3"]["object"]["key"] @property def size(self) -> int: """Object byte size""" - return int(self["s3"]["object"]["size"]) + return int(self._v["s3"]["object"]["size"]) @property def etag(self) -> str: """object eTag""" - return self["s3"]["object"]["eTag"] + return self._v["s3"]["object"]["eTag"] @property def version_id(self) -> Optional[str]: """Object version if bucket is versioning-enabled, otherwise null""" - return self["s3"]["object"].get("versionId") + return self._v["s3"]["object"].get("versionId") @property def sequencer(self) -> str: """A string representation of a hexadecimal value used to determine event sequence, only used with PUTs and DELETEs """ - return self["s3"]["object"]["sequencer"] + return self._v["s3"]["object"]["sequencer"] -class S3Message(dict): +class S3Message: + def __init__(self, record: Dict[str, Any]): + self._v = record + @property def s3_schema_version(self) -> str: - return self["s3"]["s3SchemaVersion"] + return self._v["s3"]["s3SchemaVersion"] @property def configuration_id(self) -> str: """ID found in the bucket notification configuration""" - return self["s3"]["configurationId"] + return self._v["s3"]["configurationId"] @property def bucket(self) -> S3Bucket: - return S3Bucket(self) + return S3Bucket(self._v) @property def get_object(self) -> S3Object: """Get the `object` property as an S3Object""" # Note: this name conflicts with existing python builtins - return S3Object(self) + return S3Object(self._v) + +class S3EventRecordGlacierRestoreEventData: + def __init__(self, glacier_event_data: Dict[str, Any]): + self._v = glacier_event_data -class S3EventRecordGlacierRestoreEventData(dict): @property def lifecycle_restoration_expiry_time(self) -> str: """Time when the object restoration will be expired.""" - return self["restoreEventData"]["lifecycleRestorationExpiryTime"] + return self._v["restoreEventData"]["lifecycleRestorationExpiryTime"] @property def lifecycle_restore_storage_class(self) -> str: """Source storage class for restore""" - return self["restoreEventData"]["lifecycleRestoreStorageClass"] + return self._v["restoreEventData"]["lifecycleRestoreStorageClass"] -class S3EventRecordGlacierEventData(dict): +class S3EventRecordGlacierEventData: + def __init__(self, glacier_event_data: Dict[str, Any]): + self._v = glacier_event_data + @property def restore_event_data(self) -> S3EventRecordGlacierRestoreEventData: """The restoreEventData key contains attributes related to your restore request. The glacierEventData key is only visible for s3:ObjectRestore:Completed events """ - return S3EventRecordGlacierRestoreEventData(self) + return S3EventRecordGlacierRestoreEventData(self._v) + +class S3EventRecord: + def __init__(self, record: Dict[str, Any]): + self._v = record -class S3EventRecord(dict): @property def event_version(self) -> str: """The eventVersion key value contains a major and minor version in the form ..""" - return self["eventVersion"] + return self._v["eventVersion"] @property def event_source(self) -> str: """The AWS service from which the S3 event originated. For S3, this is aws:s3""" - return self["eventSource"] + return self._v["eventSource"] @property def aws_region(self) -> str: """aws region eg: us-east-1""" - return self["awsRegion"] + return self._v["awsRegion"] @property def event_time(self) -> str: """The time, in ISO-8601 format, for example, 1970-01-01T00:00:00.000Z, when S3 finished processing the request""" - return self["eventTime"] + return self._v["eventTime"] @property def event_name(self) -> str: """Event type""" - return self["eventName"] + return self._v["eventName"] @property def user_identity(self) -> S3Identity: - return S3Identity(self["userIdentity"]) + return S3Identity(self._v["userIdentity"]) @property def request_parameters(self) -> S3RequestParameters: - return S3RequestParameters(self) + return S3RequestParameters(self._v) @property def response_elements(self) -> Dict[str, str]: @@ -142,16 +166,16 @@ def response_elements(self) -> Dict[str, str]: as those that Amazon S3 returns in the response to the request that initiates the events, so they can be used to match the event to the request. """ - return self["responseElements"] + return self._v["responseElements"] @property def s3(self) -> S3Message: - return S3Message(self) + return S3Message(self._v) @property def glacier_event_data(self) -> Optional[S3EventRecordGlacierEventData]: """The glacierEventData key is only visible for s3:ObjectRestore:Completed events.""" - item = self.get("glacierEventData") + item = self._v.get("glacierEventData") return None if item is None else S3EventRecordGlacierEventData(item) diff --git a/aws_lambda_powertools/utilities/trigger/ses_event.py b/aws_lambda_powertools/utilities/trigger/ses_event.py index e795128b1a4..7cda0792287 100644 --- a/aws_lambda_powertools/utilities/trigger/ses_event.py +++ b/aws_lambda_powertools/utilities/trigger/ses_event.py @@ -1,60 +1,69 @@ -from typing import Iterator, List +from typing import Any, Dict, Iterator, List -class SESMailHeader(dict): +class SESMailHeader: + def __init__(self, header: Dict[str, str]): + self._v = header + @property def name(self) -> str: - return self["name"] + return self._v["name"] @property def value(self) -> str: - return self["value"] + return self._v["value"] + +class SESMailCommonHeaders: + def __init__(self, common_headers: Dict[str, Any]): + self._v = common_headers -class SESMailCommonHeaders(dict): @property def return_path(self) -> str: """The values in the Return-Path header of the email.""" - return self["returnPath"] + return self._v["returnPath"] @property def get_from(self) -> List[str]: """The values in the From header of the email.""" # Note: this name conflicts with existing python builtins - return self["from"] + return self._v["from"] @property def date(self) -> List[str]: """The date and time when Amazon SES received the message.""" - return self["date"] + return self._v["date"] @property def to(self) -> List[str]: """The values in the To header of the email.""" - return self["to"] + return self._v["to"] @property def message_id(self) -> str: """The ID of the original message.""" - return str(self["messageId"]) + return str(self._v["messageId"]) @property def subject(self) -> str: """The value of the Subject header for the email.""" - return str(self["subject"]) + return str(self._v["subject"]) -class SESMail(dict): +class SESMail: + def __init__(self, mail: Dict[str, Any]): + self._v = mail + @property def timestamp(self) -> str: """String that contains the time at which the email was received, in ISO8601 format.""" - return self["timestamp"] + return self._v["timestamp"] @property def source(self) -> str: """String that contains the email address (specifically, the envelope MAIL FROM address) that the email was sent from.""" - return self["source"] + return self._v["source"] @property def message_id(self) -> str: @@ -62,40 +71,46 @@ def message_id(self) -> str: If the email was delivered to Amazon S3, the message ID is also the Amazon S3 object key that was used to write the message to your Amazon S3 bucket.""" - return self["messageId"] + return self._v["messageId"] @property def destination(self) -> List[str]: """A complete list of all recipient addresses (including To: and CC: recipients) from the MIME headers of the incoming email.""" - return self["destination"] + return self._v["destination"] @property def headers_truncated(self) -> bool: """String that specifies whether the headers were truncated in the notification, which will happen if the headers are larger than 10 KB. Possible values are true and false.""" - return bool(self["headersTruncated"]) + return bool(self._v["headersTruncated"]) @property def headers(self) -> Iterator[SESMailHeader]: """A list of Amazon SES headers and your custom headers. Each header in the list has a name field and a value field""" - for header in self["headers"]: + for header in self._v["headers"]: yield SESMailHeader(header) @property def common_headers(self) -> SESMailCommonHeaders: """A list of headers common to all emails. Each header in the list is composed of a name and a value.""" - return SESMailCommonHeaders(self["commonHeaders"]) + return SESMailCommonHeaders(self._v["commonHeaders"]) + +class SESReceiptStatus: + def __init__(self, receipt_status: Dict[str, str]): + self._v = receipt_status -class SESReceiptStatus(dict): @property def status(self) -> str: - return str(self["status"]) + return str(self._v["status"]) -class SESReceiptAction(dict): +class SESReceiptAction: + def __init__(self, receipt_action: Dict[str, str]): + self._v = receipt_action + @property def get_type(self) -> str: """String that indicates the type of action that was executed. @@ -103,90 +118,99 @@ def get_type(self) -> str: Possible values are S3, SNS, Bounce, Lambda, Stop, and WorkMail """ # Note: this name conflicts with existing python builtins - return self["type"] + return self._v["type"] @property def function_arn(self) -> str: """String that contains the ARN of the Lambda function that was triggered. Present only for the Lambda action type.""" - return self["functionArn"] + return self._v["functionArn"] @property def invocation_type(self) -> str: """String that contains the invocation type of the Lambda function. Possible values are RequestResponse and Event. Present only for the Lambda action type.""" - return self["invocationType"] + return self._v["invocationType"] + +class SESReceipt: + def __init__(self, receipt: Dict[str, Any]): + self._v = receipt -class SESReceipt(dict): @property def timestamp(self) -> str: """String that specifies the date and time at which the action was triggered, in ISO 8601 format.""" - return self["timestamp"] + return self._v["timestamp"] @property def processing_time_millis(self) -> int: """String that specifies the period, in milliseconds, from the time Amazon SES received the message to the time it triggered the action.""" - return int(self["processingTimeMillis"]) + return int(self._v["processingTimeMillis"]) @property def recipients(self) -> List[str]: """A list of recipients (specifically, the envelope RCPT TO addresses) that were matched by the active receipt rule. The addresses listed here may differ from those listed by the destination field in the mail object.""" - return self["recipients"] + return self._v["recipients"] @property def spam_verdict(self) -> SESReceiptStatus: """Object that indicates whether the message is spam.""" - return SESReceiptStatus(self["spamVerdict"]) + return SESReceiptStatus(self._v["spamVerdict"]) @property def virus_verdict(self) -> SESReceiptStatus: """Object that indicates whether the message contains a virus.""" - return SESReceiptStatus(self["virusVerdict"]) + return SESReceiptStatus(self._v["virusVerdict"]) @property def spf_verdict(self) -> SESReceiptStatus: """Object that indicates whether the Sender Policy Framework (SPF) check passed.""" - return SESReceiptStatus(self["spfVerdict"]) + return SESReceiptStatus(self._v["spfVerdict"]) @property def dmarc_verdict(self) -> SESReceiptStatus: """Object that indicates whether the Domain-based Message Authentication, Reporting & Conformance (DMARC) check passed.""" - return SESReceiptStatus(self["dmarcVerdict"]) + return SESReceiptStatus(self._v["dmarcVerdict"]) @property def action(self) -> SESReceiptAction: """Object that encapsulates information about the action that was executed.""" - return SESReceiptAction(self["action"]) + return SESReceiptAction(self._v["action"]) -class SESMessage(dict): +class SESMessage: + def __init__(self, record: Dict[str, Any]): + self._v = record + @property def mail(self) -> SESMail: - return SESMail(self["ses"]["mail"]) + return SESMail(self._v["ses"]["mail"]) @property def receipt(self) -> SESReceipt: - return SESReceipt(self["ses"]["receipt"]) + return SESReceipt(self._v["ses"]["receipt"]) + +class SESEventRecord: + def __init__(self, record: Dict[str, Any]): + self._v = record -class SESEventRecord(dict): @property def event_source(self) -> str: """The AWS service from which the SES event record originated. For SES, this is aws:ses""" - return self["eventSource"] + return self._v["eventSource"] @property def event_version(self) -> str: - return self["eventVersion"] + return self._v["eventVersion"] @property def ses(self) -> SESMessage: - return SESMessage(self) + return SESMessage(self._v) class SESEvent(dict): diff --git a/aws_lambda_powertools/utilities/trigger/sns_event.py b/aws_lambda_powertools/utilities/trigger/sns_event.py index c33fa39f7d9..3c7663cb850 100644 --- a/aws_lambda_powertools/utilities/trigger/sns_event.py +++ b/aws_lambda_powertools/utilities/trigger/sns_event.py @@ -1,55 +1,61 @@ -from typing import Dict, Iterator +from typing import Any, Dict, Iterator -class SNSMessageAttribute(dict): +class SNSMessageAttribute: + def __init__(self, message_attribute: Dict[str, str]): + self._v = message_attribute + @property def get_type(self) -> str: """The supported message attribute data types are String, String.Array, Number, and Binary.""" # Note: this name conflicts with existing python builtins - return self["Type"] + return self._v["Type"] @property def value(self) -> str: """The user-specified message attribute value.""" - return self["Value"] + return self._v["Value"] + +class SNSMessage: + def __init__(self, message: Dict[str, Any]): + self._v = message -class SNSMessage(dict): @property def signature_version(self) -> str: """Version of the Amazon SNS signature used.""" - return self["Sns"]["SignatureVersion"] + return self._v["Sns"]["SignatureVersion"] @property def timestamp(self) -> str: """The time (GMT) when the subscription confirmation was sent.""" - return self["Sns"]["Timestamp"] + return self._v["Sns"]["Timestamp"] @property def signature(self) -> str: """Base64-encoded "SHA1withRSA" signature of the Message, MessageId, Type, Timestamp, and TopicArn values.""" - return self["Sns"]["Signature"] + return self._v["Sns"]["Signature"] @property def signing_cert_url(self) -> str: """The URL to the certificate that was used to sign the message.""" - return self["Sns"]["SigningCertUrl"] + return self._v["Sns"]["SigningCertUrl"] @property def message_id(self) -> str: """A Universally Unique Identifier, unique for each message published. For a message that Amazon SNS resends during a retry, the message ID of the original message is used.""" - return self["Sns"]["MessageId"] + return self._v["Sns"]["MessageId"] @property def message(self) -> str: """A string that describes the message. """ - return self["Sns"]["Message"] + return self._v["Sns"]["Message"] @property def message_attributes(self) -> Dict[str, SNSMessageAttribute]: - return {k: SNSMessageAttribute(v) for (k, v) in self["Sns"]["MessageAttributes"].items()} + return {k: SNSMessageAttribute(v) for (k, v) in self._v["Sns"]["MessageAttributes"].items()} @property def get_type(self) -> str: @@ -57,44 +63,47 @@ def get_type(self) -> str: For a subscription confirmation, the type is SubscriptionConfirmation.""" # Note: this name conflicts with existing python builtins - return self["Sns"]["Type"] + return self._v["Sns"]["Type"] @property def unsubscribe_url(self) -> str: """A URL that you can use to unsubscribe the endpoint from this topic. If you visit this URL, Amazon SNS unsubscribes the endpoint and stops sending notifications to this endpoint.""" - return self["Sns"]["UnsubscribeUrl"] + return self._v["Sns"]["UnsubscribeUrl"] @property def topic_arn(self) -> str: """The Amazon Resource Name (ARN) for the topic that this endpoint is subscribed to.""" - return self["Sns"]["TopicArn"] + return self._v["Sns"]["TopicArn"] @property def subject(self) -> str: """The Subject parameter specified when the notification was published to the topic.""" - return self["Sns"]["Subject"] + return self._v["Sns"]["Subject"] + +class SNSEventRecord: + def __init__(self, record: Dict[str, Any]): + self._v = record -class SNSEventRecord(dict): @property def event_version(self) -> str: """Event version""" - return self["EventVersion"] + return self._v["EventVersion"] @property def event_subscription_arn(self) -> str: - return self["EventSubscriptionArn"] + return self._v["EventSubscriptionArn"] @property def event_source(self) -> str: """The AWS service from which the SNS event record originated. For SNS, this is aws:sns""" - return self["EventSource"] + return self._v["EventSource"] @property def sns(self) -> SNSMessage: - return SNSMessage(self) + return SNSMessage(self._v) class SNSEvent(dict): diff --git a/aws_lambda_powertools/utilities/trigger/sqs_event.py b/aws_lambda_powertools/utilities/trigger/sqs_event.py index 22addce4925..f50c2afe3d6 100644 --- a/aws_lambda_powertools/utilities/trigger/sqs_event.py +++ b/aws_lambda_powertools/utilities/trigger/sqs_event.py @@ -1,39 +1,39 @@ -from typing import Dict, Iterator, Optional +from typing import Any, Dict, Iterator, Optional class SQSRecordAttributes: - def __init__(self, record_attributes: dict): - self._val = record_attributes + def __init__(self, record_attributes: Dict[str, str]): + self._v = record_attributes @property def aws_trace_header(self) -> Optional[str]: """Returns the AWS X-Ray trace header string.""" - return self._val.get("AWSTraceHeader") + return self._v.get("AWSTraceHeader") @property def approximate_receive_count(self) -> str: """Returns the number of times a message has been received across all queues but not deleted.""" - return self._val["ApproximateReceiveCount"] + return self._v["ApproximateReceiveCount"] @property def sent_timestamp(self) -> str: """Returns the time the message was sent to the queue (epoch time in milliseconds).""" - return self._val["SentTimestamp"] + return self._v["SentTimestamp"] @property def sender_id(self) -> str: """For an IAM user, returns the IAM user ID, For an IAM role, returns the IAM role ID""" - return self._val["SenderId"] + return self._v["SenderId"] @property def approximate_first_receive_timestamp(self) -> str: """Returns the time the message was first received from the queue (epoch time in milliseconds).""" - return self._val["ApproximateFirstReceiveTimestamp"] + return self._v["ApproximateFirstReceiveTimestamp"] @property def sequence_number(self) -> Optional[str]: """The large, non-consecutive number that Amazon SQS assigns to each message.""" - return self._val.get("SequenceNumber") + return self._v.get("SequenceNumber") @property def message_group_id(self) -> Optional[str]: @@ -42,7 +42,7 @@ def message_group_id(self) -> Optional[str]: Messages that belong to the same message group are always processed one by one, in a strict order relative to the message group (however, messages that belong to different message groups might be processed out of order).""" - return self._val.get("MessageGroupId") + return self._v.get("MessageGroupId") @property def message_deduplication_id(self) -> Optional[str]: @@ -51,31 +51,31 @@ def message_deduplication_id(self) -> Optional[str]: If a message with a particular message deduplication ID is sent successfully, any messages sent with the same message deduplication ID are accepted successfully but aren't delivered during the 5-minute deduplication interval.""" - return self._val.get("MessageDeduplicationId") + return self._v.get("MessageDeduplicationId") class SQSMessageAttribute: """The user-specified message attribute value.""" - def __init__(self, message_attribute: dict): - self._val = message_attribute + def __init__(self, message_attribute: Dict[str, str]): + self._v = message_attribute @property def string_value(self) -> Optional[str]: """Strings are Unicode with UTF-8 binary encoding.""" - return self._val["stringValue"] + return self._v["stringValue"] @property def binary_value(self) -> Optional[str]: """Binary type attributes can store any binary data, such as compressed data, encrypted data, or images. Base64-encoded binary data object""" - return self._val["binaryValue"] + return self._v["binaryValue"] @property def data_type(self) -> str: """ The message attribute data type. Supported types include `String`, `Number`, and `Binary`.""" - return self._val["dataType"] + return self._v["dataType"] class SQSMessageAttributes(Dict[str, SQSMessageAttribute]): @@ -87,15 +87,15 @@ def __getitem__(self, key: str) -> Optional[SQSMessageAttribute]: class SQSRecord: """An Amazon SQS message""" - def __init__(self, record: dict): - self._val = record + def __init__(self, record: Dict[str, Any]): + self._v = record @property def message_id(self) -> str: """A unique identifier for the message. A messageId is considered unique across all AWS accounts for an extended period of time.""" - return self._val["messageId"] + return self._v["messageId"] @property def receipt_handle(self) -> str: @@ -103,42 +103,42 @@ def receipt_handle(self) -> str: A new receipt handle is returned every time you receive a message. When deleting a message, you provide the last received receipt handle to delete the message.""" - return self._val["receiptHandle"] + return self._v["receiptHandle"] @property def body(self) -> str: """The message's contents (not URL-encoded).""" - return self._val["body"] + return self._v["body"] @property def attributes(self) -> SQSRecordAttributes: """A map of the attributes requested in ReceiveMessage to their respective values.""" - return SQSRecordAttributes(self._val["attributes"]) + return SQSRecordAttributes(self._v["attributes"]) @property def message_attributes(self) -> SQSMessageAttributes: """Each message attribute consists of a Name, Type, and Value.""" - return SQSMessageAttributes(self._val["messageAttributes"]) + return SQSMessageAttributes(self._v["messageAttributes"]) @property def md5_of_body(self) -> str: """An MD5 digest of the non-URL-encoded message body string.""" - return self._val["md5OfBody"] + return self._v["md5OfBody"] @property def event_source(self) -> str: """The AWS service from which the SQS record originated. For SQS, this is `aws:sqs` """ - return self._val["eventSource"] + return self._v["eventSource"] @property def event_source_arn(self) -> str: """The Amazon Resource Name (ARN) of the event source""" - return self._val["eventSourceARN"] + return self._v["eventSourceARN"] @property def aws_region(self) -> str: """aws region eg: us-east-1""" - return self._val["awsRegion"] + return self._v["awsRegion"] class SQSEvent(dict): diff --git a/tests/functional/test_lambda_trigger_events.py b/tests/functional/test_lambda_trigger_events.py index 13b3c087693..a845345a7a7 100644 --- a/tests/functional/test_lambda_trigger_events.py +++ b/tests/functional/test_lambda_trigger_events.py @@ -338,7 +338,7 @@ def test_ses_trigger_event(): assert headers[0].value == "" common_headers = mail.common_headers assert common_headers.return_path == "janedoe@example.com" - assert common_headers.get_from == common_headers["from"] + assert common_headers.get_from == common_headers._v["from"] assert common_headers.date == "Wed, 7 Oct 2015 12:34:56 -0700" assert common_headers.to == [expected_address] assert common_headers.message_id == "<0123456789example.com>" @@ -352,9 +352,9 @@ def test_ses_trigger_event(): assert receipt.spf_verdict.status == "PASS" assert receipt.dmarc_verdict.status == "PASS" action = receipt.action - assert action.get_type == action["type"] - assert action.function_arn == action["functionArn"] - assert action.invocation_type == action["invocationType"] + assert action.get_type == action._v["type"] + assert action.function_arn == action._v["functionArn"] + assert action.invocation_type == action._v["invocationType"] def test_sns_trigger_event(): From 7a7e8627924b069457fa4478b41fb9a1ad2bd075 Mon Sep 17 00:00:00 2001 From: Michael Brewer Date: Sat, 12 Sep 2020 12:16:30 -0700 Subject: [PATCH 19/30] feat(trigger): Add conveince methods For both SNS and S3 event notifications we actually include a single records, so we can have conveince metods to access popular resources --- .../utilities/trigger/s3_event.py | 15 +++++++++++++++ .../utilities/trigger/sns_event.py | 10 ++++++++++ tests/functional/test_lambda_trigger_events.py | 5 +++++ 3 files changed, 30 insertions(+) diff --git a/aws_lambda_powertools/utilities/trigger/s3_event.py b/aws_lambda_powertools/utilities/trigger/s3_event.py index eafc9e8b8d9..ee40f3b697f 100644 --- a/aws_lambda_powertools/utilities/trigger/s3_event.py +++ b/aws_lambda_powertools/utilities/trigger/s3_event.py @@ -193,3 +193,18 @@ class S3Event(dict): def records(self) -> Iterator[S3EventRecord]: for record in self["Records"]: yield S3EventRecord(record) + + @property + def record(self) -> S3EventRecord: + """Get the first s3 event record""" + return next(self.records) + + @property + def bucket_name(self) -> str: + """Get the bucket name for the first s3 event record""" + return self.record.s3.bucket.name + + @property + def object_key(self) -> str: + """Get the object key for the first s3 event record""" + return self.record.s3.get_object.key diff --git a/aws_lambda_powertools/utilities/trigger/sns_event.py b/aws_lambda_powertools/utilities/trigger/sns_event.py index 3c7663cb850..05c71739bbb 100644 --- a/aws_lambda_powertools/utilities/trigger/sns_event.py +++ b/aws_lambda_powertools/utilities/trigger/sns_event.py @@ -118,3 +118,13 @@ class SNSEvent(dict): def records(self) -> Iterator[SNSEventRecord]: for record in self["Records"]: yield SNSEventRecord(record) + + @property + def record(self) -> SNSEventRecord: + """Return the first SNS event record""" + return next(self.records) + + @property + def sns_message(self) -> str: + """Return the message for the first sns event record""" + return self.record.sns.message diff --git a/tests/functional/test_lambda_trigger_events.py b/tests/functional/test_lambda_trigger_events.py index a845345a7a7..cdbc0478573 100644 --- a/tests/functional/test_lambda_trigger_events.py +++ b/tests/functional/test_lambda_trigger_events.py @@ -295,6 +295,9 @@ def test_s3_trigger_event(): assert s3.get_object.version_id is None assert s3.get_object.sequencer == "0C0F6F405D6ED209E1" assert record.glacier_event_data is None + assert event.record._v == event["Records"][0] + assert event.bucket_name == "lambda-artifacts-deafc19498e3f2df" + assert event.object_key == "b21b84d653bb07b05b1e6b33684dc11b" def test_s3_glacier_event(): @@ -380,6 +383,8 @@ def test_sns_trigger_event(): assert sns.unsubscribe_url == "https://sns.us-east-2.amazonaws.com/?Action=Unsubscri ..." assert sns.topic_arn == "arn:aws:sns:us-east-2:123456789012:sns-lambda" assert sns.subject == "TestInvoke" + assert event.record._v == event["Records"][0] + assert event.sns_message == "Hello from SNS!" def test_seq_trigger_event(): From 379b11f2e3c352c2b7e57fa2ac2c752ff51913bd Mon Sep 17 00:00:00 2001 From: Michael Brewer Date: Sat, 12 Sep 2020 20:40:42 -0700 Subject: [PATCH 20/30] feat(trigger): Create DictWrapper Use new DictWrapper abstract class for the data classes that wrap an event Dict --- .../trigger/api_gateway_proxy_event.py | 142 ++++++++---------- .../trigger/cloud_watch_logs_event.py | 15 +- .../trigger/cognito_user_pool_event.py | 78 +++++----- .../utilities/trigger/common.py | 13 ++ .../trigger/dynamo_db_stream_event.py | 69 ++++----- .../utilities/trigger/s3_event.py | 96 +++++------- .../utilities/trigger/ses_event.py | 118 +++++++-------- .../utilities/trigger/sns_event.py | 51 +++---- .../utilities/trigger/sqs_event.py | 57 +++---- .../functional/test_lambda_trigger_events.py | 19 ++- 10 files changed, 298 insertions(+), 360 deletions(-) create mode 100644 aws_lambda_powertools/utilities/trigger/common.py diff --git a/aws_lambda_powertools/utilities/trigger/api_gateway_proxy_event.py b/aws_lambda_powertools/utilities/trigger/api_gateway_proxy_event.py index 5ca2ba22d10..2c293a2a60e 100644 --- a/aws_lambda_powertools/utilities/trigger/api_gateway_proxy_event.py +++ b/aws_lambda_powertools/utilities/trigger/api_gateway_proxy_event.py @@ -1,206 +1,199 @@ from typing import Any, Dict, List, Optional +from aws_lambda_powertools.utilities.trigger.common import DictWrapper -class APIGatewayEventIdentity: - def __init__(self, event: Dict[str, Any]): - self._v = event +class APIGatewayEventIdentity(DictWrapper): @property def access_key(self) -> Optional[str]: - return self._v["requestContext"]["identity"].get("accessKey") + return self["requestContext"]["identity"].get("accessKey") @property def account_id(self) -> Optional[str]: """The AWS account ID associated with the request.""" - return self._v["requestContext"]["identity"].get("accountId") + return self["requestContext"]["identity"].get("accountId") @property def api_key(self) -> Optional[str]: """For API methods that require an API key, this variable is the API key associated with the method request. For methods that don't require an API key, this variable is null. """ - return self._v["requestContext"]["identity"].get("apiKey") + return self["requestContext"]["identity"].get("apiKey") @property def api_key_id(self) -> Optional[str]: """The API key ID associated with an API request that requires an API key.""" - return self._v["requestContext"]["identity"].get("apiKeyId") + return self["requestContext"]["identity"].get("apiKeyId") @property def caller(self) -> Optional[str]: """The principal identifier of the caller making the request.""" - return self._v["requestContext"]["identity"].get("caller") + return self["requestContext"]["identity"].get("caller") @property def cognito_authentication_provider(self) -> Optional[str]: """A comma-separated list of the Amazon Cognito authentication providers used by the caller making the request. Available only if the request was signed with Amazon Cognito credentials.""" - return self._v["requestContext"]["identity"].get("cognitoAuthenticationProvider") + return self["requestContext"]["identity"].get("cognitoAuthenticationProvider") @property def cognito_authentication_type(self) -> Optional[str]: """The Amazon Cognito authentication type of the caller making the request. Available only if the request was signed with Amazon Cognito credentials.""" - return self._v["requestContext"]["identity"].get("cognitoAuthenticationType") + return self["requestContext"]["identity"].get("cognitoAuthenticationType") @property def cognito_identity_id(self) -> Optional[str]: """The Amazon Cognito identity ID of the caller making the request. Available only if the request was signed with Amazon Cognito credentials.""" - return self._v["requestContext"]["identity"].get("cognitoIdentityId") + return self["requestContext"]["identity"].get("cognitoIdentityId") @property def cognito_identity_pool_id(self) -> Optional[str]: """The Amazon Cognito identity pool ID of the caller making the request. Available only if the request was signed with Amazon Cognito credentials.""" - return self._v["requestContext"]["identity"].get("cognitoIdentityPoolId") + return self["requestContext"]["identity"].get("cognitoIdentityPoolId") @property def principal_org_id(self) -> Optional[str]: """The AWS organization ID.""" - return self._v["requestContext"]["identity"].get("principalOrgId") + return self["requestContext"]["identity"].get("principalOrgId") @property def source_ip(self) -> str: """The source IP address of the TCP connection making the request to API Gateway.""" - return self._v["requestContext"]["identity"]["sourceIp"] + return self["requestContext"]["identity"]["sourceIp"] @property def user(self) -> Optional[str]: """The principal identifier of the user making the request.""" - return self._v["requestContext"]["identity"].get("user") + return self["requestContext"]["identity"].get("user") @property def user_agent(self) -> Optional[str]: """The User Agent of the API caller.""" - return self._v["requestContext"]["identity"].get("userAgent") + return self["requestContext"]["identity"].get("userAgent") @property def user_arn(self) -> Optional[str]: """The Amazon Resource Name (ARN) of the effective user identified after authentication.""" - return self._v["requestContext"]["identity"].get("userArn") + return self["requestContext"]["identity"].get("userArn") -class APIGatewayEventAuthorizer: - def __init__(self, event: Dict[str, Any]): - self._v = event - +class APIGatewayEventAuthorizer(DictWrapper): @property def claims(self) -> Optional[Dict[str, Any]]: - return self._v["requestContext"]["authorizer"].get("claims") + return self["requestContext"]["authorizer"].get("claims") @property def scopes(self) -> Optional[List[str]]: - return self._v["requestContext"]["authorizer"].get("scopes") - + return self["requestContext"]["authorizer"].get("scopes") -class APIGatewayEventRequestContext: - def __init__(self, event: Dict[str, Any]): - self._v = event +class APIGatewayEventRequestContext(DictWrapper): @property def account_id(self) -> str: """The AWS account ID associated with the request.""" - return self._v["requestContext"]["accountId"] + return self["requestContext"]["accountId"] @property def api_id(self) -> str: """The identifier API Gateway assigns to your API.""" - return self._v["requestContext"]["apiId"] + return self["requestContext"]["apiId"] @property def authorizer(self) -> APIGatewayEventAuthorizer: - return APIGatewayEventAuthorizer(self._v) + return APIGatewayEventAuthorizer(self._data) @property def connected_at(self) -> Optional[int]: """The Epoch-formatted connection time. (WebSocket API)""" - return self._v["requestContext"].get("connectedAt") + return self["requestContext"].get("connectedAt") @property def connection_id(self) -> Optional[str]: """A unique ID for the connection that can be used to make a callback to the client. (WebSocket API)""" - return self._v["requestContext"].get("connectionId") + return self["requestContext"].get("connectionId") @property def domain_name(self) -> Optional[str]: """A domain name""" - return self._v["requestContext"].get("domainName") + return self["requestContext"].get("domainName") @property def domain_prefix(self) -> Optional[str]: - return self._v["requestContext"].get("domainPrefix") + return self["requestContext"].get("domainPrefix") @property def event_type(self) -> Optional[str]: """The event type: `CONNECT`, `MESSAGE`, or `DISCONNECT`. (WebSocket API)""" - return self._v["requestContext"].get("eventType") + return self["requestContext"].get("eventType") @property def extended_request_id(self) -> Optional[str]: """An automatically generated ID for the API call, which contains more useful information for debugging/troubleshooting.""" - return self._v["requestContext"].get("extendedRequestId") + return self["requestContext"].get("extendedRequestId") @property def protocol(self) -> str: """The request protocol, for example, HTTP/1.1.""" - return self._v["requestContext"]["protocol"] + return self["requestContext"]["protocol"] @property def http_method(self) -> str: """The HTTP method used. Valid values include: DELETE, GET, HEAD, OPTIONS, PATCH, POST, and PUT.""" - return self._v["requestContext"]["httpMethod"] + return self["requestContext"]["httpMethod"] @property def identity(self) -> APIGatewayEventIdentity: - return APIGatewayEventIdentity(self._v) + return APIGatewayEventIdentity(self._data) @property def message_direction(self) -> Optional[str]: """Message direction (WebSocket API)""" - return self._v["requestContext"].get("messageDirection") + return self["requestContext"].get("messageDirection") @property def message_id(self) -> Optional[str]: """A unique server-side ID for a message. Available only when the `eventType` is `MESSAGE`.""" - return self._v["requestContext"].get("messageId") + return self["requestContext"].get("messageId") @property def path(self) -> str: - return self._v["requestContext"]["path"] + return self["requestContext"]["path"] @property def stage(self) -> str: """The deployment stage of the API request """ - return self._v["requestContext"]["stage"] + return self["requestContext"]["stage"] @property def request_id(self) -> str: """The ID that API Gateway assigns to the API request.""" - return self._v["requestContext"]["requestId"] + return self["requestContext"]["requestId"] @property def request_time(self) -> Optional[str]: """The CLF-formatted request time (dd/MMM/yyyy:HH:mm:ss +-hhmm)""" - return self._v["requestContext"].get("requestTime") + return self["requestContext"].get("requestTime") @property def request_time_epoch(self) -> int: """The Epoch-formatted request time.""" - return self._v["requestContext"]["requestTimeEpoch"] + return self["requestContext"]["requestTimeEpoch"] @property def resource_id(self) -> str: - return self._v["requestContext"]["resourceId"] + return self["requestContext"]["resourceId"] @property def resource_path(self) -> str: - return self._v["requestContext"]["resourcePath"] + return self["requestContext"]["resourcePath"] @property def route_key(self) -> Optional[str]: """The selected route key.""" - return self._v["requestContext"].get("routeKey") + return self["requestContext"].get("routeKey") class APIGatewayProxyEvent(dict): @@ -265,103 +258,94 @@ def is_base64_encoded(self) -> bool: return self["isBase64Encoded"] -class RequestContextV2Http: - def __init__(self, event: Dict[str, Any]): - self._v = event - +class RequestContextV2Http(DictWrapper): @property def method(self) -> str: - return self._v["requestContext"]["http"]["method"] + return self["requestContext"]["http"]["method"] @property def path(self) -> str: - return self._v["requestContext"]["http"]["path"] + return self["requestContext"]["http"]["path"] @property def protocol(self) -> str: """The request protocol, for example, HTTP/1.1.""" - return self._v["requestContext"]["http"]["protocol"] + return self["requestContext"]["http"]["protocol"] @property def source_ip(self) -> str: """The source IP address of the TCP connection making the request to API Gateway.""" - return self._v["requestContext"]["http"]["sourceIp"] + return self["requestContext"]["http"]["sourceIp"] @property def user_agent(self) -> str: """The User Agent of the API caller.""" - return self._v["requestContext"]["http"]["userAgent"] + return self["requestContext"]["http"]["userAgent"] -class RequestContextV2Authorizer: - def __init__(self, event: Dict[str, Any]): - self._v = event - +class RequestContextV2Authorizer(DictWrapper): @property def jwt_claim(self) -> Dict[str, Any]: - return self._v["jwt"]["claims"] + return self["jwt"]["claims"] @property def jwt_scopes(self) -> List[str]: - return self._v["jwt"]["scopes"] - + return self["jwt"]["scopes"] -class RequestContextV2: - def __init__(self, event: Dict[str, Any]): - self._v = event +class RequestContextV2(DictWrapper): @property def account_id(self) -> str: """The AWS account ID associated with the request.""" - return self._v["requestContext"]["accountId"] + return self["requestContext"]["accountId"] @property def api_id(self) -> str: """The identifier API Gateway assigns to your API.""" - return self._v["requestContext"]["apiId"] + return self["requestContext"]["apiId"] @property def authorizer(self) -> Optional[RequestContextV2Authorizer]: - authorizer = self._v["requestContext"].get("authorizer") + authorizer = self["requestContext"].get("authorizer") return None if authorizer is None else RequestContextV2Authorizer(authorizer) @property def domain_name(self) -> str: """A domain name """ - return self._v["requestContext"]["domainName"] + return self["requestContext"]["domainName"] @property def domain_prefix(self) -> str: - return self._v["requestContext"]["domainPrefix"] + return self["requestContext"]["domainPrefix"] @property def http(self) -> RequestContextV2Http: - return RequestContextV2Http(self._v) + return RequestContextV2Http(self._data) @property def request_id(self) -> str: """The ID that API Gateway assigns to the API request.""" - return self._v["requestContext"]["requestId"] + return self["requestContext"]["requestId"] @property def route_key(self) -> str: """The selected route key.""" - return self._v["requestContext"]["routeKey"] + return self["requestContext"]["routeKey"] @property def stage(self) -> str: """The deployment stage of the API request """ - return self._v["requestContext"]["stage"] + return self["requestContext"]["stage"] @property def time(self) -> str: """The CLF-formatted request time (dd/MMM/yyyy:HH:mm:ss +-hhmm).""" - return self._v["requestContext"]["time"] + return self["requestContext"]["time"] @property def time_epoch(self) -> int: """The Epoch-formatted request time.""" - return self._v["requestContext"]["timeEpoch"] + return self["requestContext"]["timeEpoch"] class APIGatewayProxyEventV2(dict): diff --git a/aws_lambda_powertools/utilities/trigger/cloud_watch_logs_event.py b/aws_lambda_powertools/utilities/trigger/cloud_watch_logs_event.py index 9f86b9f370d..427392e06aa 100644 --- a/aws_lambda_powertools/utilities/trigger/cloud_watch_logs_event.py +++ b/aws_lambda_powertools/utilities/trigger/cloud_watch_logs_event.py @@ -1,33 +1,32 @@ import base64 import json import zlib -from typing import Any, Dict, List, Optional +from typing import Dict, List, Optional +from aws_lambda_powertools.utilities.trigger.common import DictWrapper -class CloudWatchLogsLogEvent: - def __init__(self, log_event: Dict[str, Any]): - self._v = log_event +class CloudWatchLogsLogEvent(DictWrapper): @property def get_id(self) -> str: """The ID property is a unique identifier for every log event.""" # Note: this name conflicts with existing python builtins - return self._v["id"] + return self["id"] @property def timestamp(self) -> int: """Get the `timestamp` property""" - return self._v["timestamp"] + return self["timestamp"] @property def message(self) -> str: """Get the `message` property""" - return self._v["message"] + return self["message"] @property def extracted_fields(self) -> Optional[Dict[str, str]]: """Get the `extractedFields` property""" - return self._v.get("extractedFields") + return self.get("extractedFields") class CloudWatchLogsDecodedData(dict): diff --git a/aws_lambda_powertools/utilities/trigger/cognito_user_pool_event.py b/aws_lambda_powertools/utilities/trigger/cognito_user_pool_event.py index 55674c794d0..036c04902eb 100644 --- a/aws_lambda_powertools/utilities/trigger/cognito_user_pool_event.py +++ b/aws_lambda_powertools/utilities/trigger/cognito_user_pool_event.py @@ -1,19 +1,18 @@ from typing import Any, Dict, List, Optional +from aws_lambda_powertools.utilities.trigger.common import DictWrapper -class CallerContext: - def __init__(self, event: Dict[str, Any]): - self._v = event +class CallerContext(DictWrapper): @property def aws_sdk_version(self) -> str: """The AWS SDK version number.""" - return self._v["callerContext"]["awsSdkVersion"] + return self["callerContext"]["awsSdkVersion"] @property def client_id(self) -> str: """The ID of the client associated with the user pool.""" - return self._v["callerContext"]["clientId"] + return self["callerContext"]["clientId"] class BaseTriggerEvent(dict): @@ -55,7 +54,7 @@ def caller_context(self) -> CallerContext: return CallerContext(self) -class PreSignUpTriggerEventRequest(dict): +class PreSignUpTriggerEventRequest(DictWrapper): @property def user_attributes(self) -> Dict[str, str]: """One or more name-value pairs representing user attributes. The attribute names are the keys.""" @@ -73,7 +72,7 @@ def client_metadata(self) -> Optional[Dict[str, str]]: return self["request"].get("clientMetadata") -class PreSignUpTriggerEventResponse(dict): +class PreSignUpTriggerEventResponse(DictWrapper): @property def auto_confirm_user(self) -> bool: return bool(self["response"]["autoConfirmUser"]) @@ -127,7 +126,7 @@ def response(self) -> PreSignUpTriggerEventResponse: return PreSignUpTriggerEventResponse(self) -class PostConfirmationTriggerEventRequest(dict): +class PostConfirmationTriggerEventRequest(DictWrapper): @property def user_attributes(self) -> Dict[str, str]: """One or more name-value pairs representing user attributes. The attribute names are the keys.""" @@ -160,7 +159,7 @@ def request(self) -> PostConfirmationTriggerEventRequest: return PostConfirmationTriggerEventRequest(self) -class UserMigrationTriggerEventRequest(dict): +class UserMigrationTriggerEventRequest(DictWrapper): @property def password(self) -> str: return self["request"]["password"] @@ -177,7 +176,7 @@ def client_metadata(self) -> Optional[Dict[str, str]]: return self["request"].get("clientMetadata") -class UserMigrationTriggerEventResponse(dict): +class UserMigrationTriggerEventResponse(DictWrapper): @property def user_attributes(self) -> Dict[str, str]: return self["response"]["userAttributes"] @@ -265,7 +264,7 @@ def response(self) -> UserMigrationTriggerEventResponse: return UserMigrationTriggerEventResponse(self) -class CustomMessageTriggerEventRequest(dict): +class CustomMessageTriggerEventRequest(DictWrapper): @property def code_parameter(self) -> str: """A string for you to use as the placeholder for the verification code in the custom message.""" @@ -288,7 +287,7 @@ def client_metadata(self) -> Optional[Dict[str, str]]: return self["request"].get("clientMetadata") -class CustomMessageTriggerEventResponse(dict): +class CustomMessageTriggerEventResponse(DictWrapper): @property def sms_message(self) -> str: return self["response"]["smsMessage"] @@ -350,7 +349,7 @@ def response(self) -> CustomMessageTriggerEventResponse: return CustomMessageTriggerEventResponse(self) -class PreAuthenticationTriggerEventRequest(dict): +class PreAuthenticationTriggerEventRequest(DictWrapper): @property def user_not_found(self) -> Optional[bool]: """This boolean is populated when PreventUserExistenceErrors is set to ENABLED for your User Pool client.""" @@ -390,7 +389,7 @@ def request(self) -> PreAuthenticationTriggerEventRequest: return PreAuthenticationTriggerEventRequest(self) -class PostAuthenticationTriggerEventRequest(dict): +class PostAuthenticationTriggerEventRequest(DictWrapper): @property def new_device_used(self) -> bool: """This flag indicates if the user has signed in on a new device. @@ -431,7 +430,7 @@ def request(self) -> PostAuthenticationTriggerEventRequest: return PostAuthenticationTriggerEventRequest(self) -class GroupOverrideDetails(dict): +class GroupOverrideDetails(DictWrapper): @property def groups_to_override(self) -> Optional[List[str]]: """A list of the group names that are associated with the user that the identity token is issued for.""" @@ -448,54 +447,48 @@ def preferred_role(self) -> Optional[str]: return self.get("preferredRole") -class PreTokenGenerationTriggerEventRequest: - def __init__(self, event: Dict[str, Any]): - self._v = event - +class PreTokenGenerationTriggerEventRequest(DictWrapper): @property def group_configuration(self) -> GroupOverrideDetails: """The input object containing the current group configuration""" - return GroupOverrideDetails(self._v["request"]["groupConfiguration"]) + return GroupOverrideDetails(self["request"]["groupConfiguration"]) @property def user_attributes(self) -> Dict[str, str]: """One or more name-value pairs representing user attributes.""" - return self._v["request"]["userAttributes"] + return self["request"]["userAttributes"] @property def client_metadata(self) -> Optional[Dict[str, str]]: """One or more key-value pairs that you can provide as custom input to the Lambda function that you specify for the pre token generation trigger.""" - return self._v["request"].get("clientMetadata") - + return self["request"].get("clientMetadata") -class ClaimsOverrideDetails: - def __init__(self, event: Dict[str, Any]): - self._v = event["response"]["claimsOverrideDetails"] +class ClaimsOverrideDetails(DictWrapper): @property def claims_to_add_or_override(self) -> Optional[Dict[str, str]]: - return self._v.get("claimsToAddOrOverride") + return self.get("claimsToAddOrOverride") @property def claims_to_suppress(self) -> Optional[List[str]]: - return self._v.get("claimsToSuppress") + return self.get("claimsToSuppress") @property def group_configuration(self) -> Optional[GroupOverrideDetails]: - group_override_details = self._v.get("groupOverrideDetails") + group_override_details = self.get("groupOverrideDetails") return None if group_override_details is None else GroupOverrideDetails(group_override_details) @claims_to_add_or_override.setter def claims_to_add_or_override(self, value: Dict[str, str]): """A map of one or more key-value pairs of claims to add or override. For group related claims, use groupOverrideDetails instead.""" - self._v["claimsToAddOrOverride"] = value + self._data["claimsToAddOrOverride"] = value @claims_to_suppress.setter def claims_to_suppress(self, value: List[str]): """A list that contains claims to be suppressed from the identity token.""" - self._v["claimsToSuppress"] = value + self._data["claimsToSuppress"] = value @group_configuration.setter def group_configuration(self, value: Dict[str, Any]): @@ -508,33 +501,30 @@ def group_configuration(self, value: Dict[str, Any]): as is, copy the value of the request's groupConfiguration object to the groupOverrideDetails object in the response, and pass it back to the service. """ - self._v["groupOverrideDetails"] = value + self._data["groupOverrideDetails"] = value def set_group_configuration_groups_to_override(self, value: List[str]): """A list of the group names that are associated with the user that the identity token is issued for.""" - self._v.setdefault("groupOverrideDetails", {}) - self._v["groupOverrideDetails"]["groupsToOverride"] = value + self._data.setdefault("groupOverrideDetails", {}) + self["groupOverrideDetails"]["groupsToOverride"] = value def set_group_configuration_iam_roles_to_override(self, value: List[str]): """A list of the current IAM roles associated with these groups.""" - self._v.setdefault("groupOverrideDetails", {}) - self._v["groupOverrideDetails"]["iamRolesToOverride"] = value + self._data.setdefault("groupOverrideDetails", {}) + self["groupOverrideDetails"]["iamRolesToOverride"] = value def set_group_configuration_preferred_role(self, value: str): """A string indicating the preferred IAM role.""" - self._v.setdefault("groupOverrideDetails", {}) - self._v["groupOverrideDetails"]["preferredRole"] = value - + self._data.setdefault("groupOverrideDetails", {}) + self["groupOverrideDetails"]["preferredRole"] = value -class PreTokenGenerationTriggerEventResponse: - def __init__(self, event: Dict[str, Any]): - self._v = event +class PreTokenGenerationTriggerEventResponse(DictWrapper): @property def claims_override_details(self) -> ClaimsOverrideDetails: # Ensure we have a `claimsOverrideDetails` element - self._v["response"].setdefault("claimsOverrideDetails", {}) - return ClaimsOverrideDetails(self._v) + self._data["response"].setdefault("claimsOverrideDetails", {}) + return ClaimsOverrideDetails(self._data["response"]["claimsOverrideDetails"]) class PreTokenGenerationTriggerEvent(BaseTriggerEvent): diff --git a/aws_lambda_powertools/utilities/trigger/common.py b/aws_lambda_powertools/utilities/trigger/common.py new file mode 100644 index 00000000000..8feb38e837f --- /dev/null +++ b/aws_lambda_powertools/utilities/trigger/common.py @@ -0,0 +1,13 @@ +from abc import ABC +from typing import Any, Dict, Optional + + +class DictWrapper(ABC): + def __init__(self, data: Dict[str, Any]): + self._data = data + + def __getitem__(self, key: str) -> Any: + return self._data[key] + + def get(self, key: str) -> Optional[Any]: + return self._data.get(key) diff --git a/aws_lambda_powertools/utilities/trigger/dynamo_db_stream_event.py b/aws_lambda_powertools/utilities/trigger/dynamo_db_stream_event.py index 559d50526e1..8623d733a0e 100644 --- a/aws_lambda_powertools/utilities/trigger/dynamo_db_stream_event.py +++ b/aws_lambda_powertools/utilities/trigger/dynamo_db_stream_event.py @@ -1,16 +1,15 @@ from enum import Enum -from typing import Any, Dict, Iterator, List, Optional +from typing import Dict, Iterator, List, Optional +from aws_lambda_powertools.utilities.trigger.common import DictWrapper -class AttributeValue: + +class AttributeValue(DictWrapper): """Represents the data for an attribute Documentation: https://docs.aws.amazon.com/amazondynamodb/latest/APIReference/API_streams_AttributeValue.html """ - def __init__(self, attr_value: Dict[str, Any]): - self._v = attr_value - @property def b_value(self) -> Optional[str]: """An attribute of type Base64-encoded binary data object @@ -18,7 +17,7 @@ def b_value(self) -> Optional[str]: Example: >>> {"B": "dGhpcyB0ZXh0IGlzIGJhc2U2NC1lbmNvZGVk"} """ - return self._v.get("B") + return self.get("B") @property def bs_value(self) -> Optional[List[str]]: @@ -27,7 +26,7 @@ def bs_value(self) -> Optional[List[str]]: Example: >>> {"BS": ["U3Vubnk=", "UmFpbnk=", "U25vd3k="]} """ - return self._v.get("BS") + return self.get("BS") @property def bool_value(self) -> Optional[bool]: @@ -36,7 +35,7 @@ def bool_value(self) -> Optional[bool]: Example: >>> {"BOOL": True} """ - item = self._v.get("bool") + item = self.get("bool") return None if item is None else bool(item) @property @@ -46,7 +45,7 @@ def list_value(self) -> Optional[List["AttributeValue"]]: Example: >>> {"L": [ {"S": "Cookies"} , {"S": "Coffee"}, {"N": "3.14159"}]} """ - item = self._v.get("L") + item = self.get("L") return None if item is None else [AttributeValue(v) for v in item] @property @@ -56,7 +55,7 @@ def map_value(self) -> Optional[Dict[str, "AttributeValue"]]: Example: >>> {"M": {"Name": {"S": "Joe"}, "Age": {"N": "35"}}} """ - return _attribute_value_dict(self._v, "M") + return _attribute_value_dict(self._data, "M") @property def n_value(self) -> Optional[str]: @@ -68,7 +67,7 @@ def n_value(self) -> Optional[str]: Example: >>> {"N": "123.45"} """ - return self._v.get("N") + return self.get("N") @property def ns_value(self) -> Optional[List[str]]: @@ -77,7 +76,7 @@ def ns_value(self) -> Optional[List[str]]: Example: >>> {"NS": ["42.2", "-19", "7.5", "3.14"]} """ - return self._v.get("NS") + return self.get("NS") @property def null_value(self) -> Optional[bool]: @@ -86,7 +85,7 @@ def null_value(self) -> Optional[bool]: Example: >>> {"NULL": True} """ - item = self._v.get("NULL") + item = self.get("NULL") return None if item is None else bool(item) @property @@ -96,7 +95,7 @@ def s_value(self) -> Optional[str]: Example: >>> {"S": "Hello"} """ - return self._v.get("S") + return self.get("S") @property def ss_value(self) -> Optional[List[str]]: @@ -105,7 +104,7 @@ def ss_value(self) -> Optional[List[str]]: Example: >>> {"SS": ["Giraffe", "Hippo" ,"Zebra"]} """ - return self._v.get("SS") + return self.get("SS") def _attribute_value_dict(attr_values: Dict[str, dict], key: str) -> Optional[Dict[str, AttributeValue]]: @@ -127,46 +126,43 @@ class StreamViewType(Enum): NEW_AND_OLD_IMAGES = 3 # both the new and the old item images of the item. -class StreamRecord: - def __init__(self, stream_record: Dict[str, Any]): - self._v = stream_record - +class StreamRecord(DictWrapper): @property def approximate_creation_date_time(self) -> Optional[int]: """The approximate date and time when the stream record was created, in UNIX epoch time format.""" - item = self._v.get("ApproximateCreationDateTime") + item = self.get("ApproximateCreationDateTime") return None if item is None else int(item) @property def keys(self) -> Optional[Dict[str, AttributeValue]]: """The primary key attribute(s) for the DynamoDB item that was modified.""" - return _attribute_value_dict(self._v, "Keys") + return _attribute_value_dict(self._data, "Keys") @property def new_image(self) -> Optional[Dict[str, AttributeValue]]: """The item in the DynamoDB table as it appeared after it was modified.""" - return _attribute_value_dict(self._v, "NewImage") + return _attribute_value_dict(self._data, "NewImage") @property def old_image(self) -> Optional[Dict[str, AttributeValue]]: """The item in the DynamoDB table as it appeared before it was modified.""" - return _attribute_value_dict(self._v, "OldImage") + return _attribute_value_dict(self._data, "OldImage") @property def sequence_number(self) -> Optional[str]: """The sequence number of the stream record.""" - return self._v.get("SequenceNumber") + return self.get("SequenceNumber") @property def size_bytes(self) -> Optional[int]: """The size of the stream record, in bytes.""" - item = self._v.get("SizeBytes") + item = self.get("SizeBytes") return None if item is None else int(item) @property def stream_view_type(self) -> Optional[StreamViewType]: """The type of data from the modified DynamoDB item that was captured in this stream record""" - item = self._v.get("StreamViewType") + item = self.get("StreamViewType") return None if item is None else StreamViewType[str(item)] @@ -176,53 +172,50 @@ class DynamoDBRecordEventName(Enum): REMOVE = 2 # the item was deleted from the table -class DynamoDBRecord: +class DynamoDBRecord(DictWrapper): """A description of a unique event within a stream""" - def __init__(self, record: Dict[str, Any]): - self._v = record - @property def aws_region(self) -> Optional[str]: """The region in which the GetRecords request was received""" - return self._v.get("awsRegion") + return self.get("awsRegion") @property def dynamodb(self) -> Optional[StreamRecord]: """The main body of the stream record, containing all of the DynamoDB-specific fields.""" - stream_record = self._v.get("dynamodb") + stream_record = self.get("dynamodb") return None if stream_record is None else StreamRecord(stream_record) @property def event_id(self) -> Optional[str]: """A globally unique identifier for the event that was recorded in this stream record.""" - return self._v.get("eventID") + return self.get("eventID") @property def event_name(self) -> Optional[DynamoDBRecordEventName]: """The type of data modification that was performed on the DynamoDB table""" - item = self._v.get("eventName") + item = self.get("eventName") return None if item is None else DynamoDBRecordEventName[item] @property def event_source(self) -> Optional[str]: """The AWS service from which the stream record originated. For DynamoDB Streams, this is aws:dynamodb.""" - return self._v.get("eventSource") + return self.get("eventSource") @property def event_source_arn(self) -> Optional[str]: """The Amazon Resource Name (ARN) of the event source""" - return self._v.get("eventSourceARN") + return self.get("eventSourceARN") @property def event_version(self) -> Optional[str]: """The version number of the stream record format.""" - return self._v.get("eventVersion") + return self.get("eventVersion") @property def user_identity(self) -> Optional[dict]: """Contains details about the type of identity that made the request""" - return self._v.get("userIdentity") + return self.get("userIdentity") class DynamoDBStreamEvent(dict): diff --git a/aws_lambda_powertools/utilities/trigger/s3_event.py b/aws_lambda_powertools/utilities/trigger/s3_event.py index ee40f3b697f..229ccebc729 100644 --- a/aws_lambda_powertools/utilities/trigger/s3_event.py +++ b/aws_lambda_powertools/utilities/trigger/s3_event.py @@ -1,162 +1,140 @@ -from typing import Any, Dict, Iterator, Optional +from typing import Dict, Iterator, Optional +from aws_lambda_powertools.utilities.trigger.common import DictWrapper -class S3Identity: - def __init__(self, s3_identity: Dict[str, str]): - self._v = s3_identity +class S3Identity(DictWrapper): @property def principal_id(self) -> str: - return self._v["principalId"] + return self["principalId"] -class S3RequestParameters: - def __init__(self, record: Dict[str, Any]): - self._v = record - +class S3RequestParameters(DictWrapper): @property def source_ip_address(self) -> str: - return self._v["requestParameters"]["sourceIPAddress"] - + return self["requestParameters"]["sourceIPAddress"] -class S3Bucket: - def __init__(self, record: Dict[str, Any]): - self._v = record +class S3Bucket(DictWrapper): @property def name(self) -> str: - return self._v["s3"]["bucket"]["name"] + return self["s3"]["bucket"]["name"] @property def owner_identity(self) -> S3Identity: - return S3Identity(self._v["s3"]["bucket"]["ownerIdentity"]) + return S3Identity(self["s3"]["bucket"]["ownerIdentity"]) @property def arn(self) -> str: - return self._v["s3"]["bucket"]["arn"] + return self["s3"]["bucket"]["arn"] -class S3Object: - def __init__(self, record: Dict[str, Any]): - self._v = record - +class S3Object(DictWrapper): @property def key(self) -> str: """Object key""" - return self._v["s3"]["object"]["key"] + return self["s3"]["object"]["key"] @property def size(self) -> int: """Object byte size""" - return int(self._v["s3"]["object"]["size"]) + return int(self["s3"]["object"]["size"]) @property def etag(self) -> str: """object eTag""" - return self._v["s3"]["object"]["eTag"] + return self["s3"]["object"]["eTag"] @property def version_id(self) -> Optional[str]: """Object version if bucket is versioning-enabled, otherwise null""" - return self._v["s3"]["object"].get("versionId") + return self["s3"]["object"].get("versionId") @property def sequencer(self) -> str: """A string representation of a hexadecimal value used to determine event sequence, only used with PUTs and DELETEs """ - return self._v["s3"]["object"]["sequencer"] - + return self["s3"]["object"]["sequencer"] -class S3Message: - def __init__(self, record: Dict[str, Any]): - self._v = record +class S3Message(DictWrapper): @property def s3_schema_version(self) -> str: - return self._v["s3"]["s3SchemaVersion"] + return self["s3"]["s3SchemaVersion"] @property def configuration_id(self) -> str: """ID found in the bucket notification configuration""" - return self._v["s3"]["configurationId"] + return self["s3"]["configurationId"] @property def bucket(self) -> S3Bucket: - return S3Bucket(self._v) + return S3Bucket(self._data) @property def get_object(self) -> S3Object: """Get the `object` property as an S3Object""" # Note: this name conflicts with existing python builtins - return S3Object(self._v) + return S3Object(self._data) -class S3EventRecordGlacierRestoreEventData: - def __init__(self, glacier_event_data: Dict[str, Any]): - self._v = glacier_event_data - +class S3EventRecordGlacierRestoreEventData(DictWrapper): @property def lifecycle_restoration_expiry_time(self) -> str: """Time when the object restoration will be expired.""" - return self._v["restoreEventData"]["lifecycleRestorationExpiryTime"] + return self["restoreEventData"]["lifecycleRestorationExpiryTime"] @property def lifecycle_restore_storage_class(self) -> str: """Source storage class for restore""" - return self._v["restoreEventData"]["lifecycleRestoreStorageClass"] - + return self["restoreEventData"]["lifecycleRestoreStorageClass"] -class S3EventRecordGlacierEventData: - def __init__(self, glacier_event_data: Dict[str, Any]): - self._v = glacier_event_data +class S3EventRecordGlacierEventData(DictWrapper): @property def restore_event_data(self) -> S3EventRecordGlacierRestoreEventData: """The restoreEventData key contains attributes related to your restore request. The glacierEventData key is only visible for s3:ObjectRestore:Completed events """ - return S3EventRecordGlacierRestoreEventData(self._v) - + return S3EventRecordGlacierRestoreEventData(self._data) -class S3EventRecord: - def __init__(self, record: Dict[str, Any]): - self._v = record +class S3EventRecord(DictWrapper): @property def event_version(self) -> str: """The eventVersion key value contains a major and minor version in the form ..""" - return self._v["eventVersion"] + return self["eventVersion"] @property def event_source(self) -> str: """The AWS service from which the S3 event originated. For S3, this is aws:s3""" - return self._v["eventSource"] + return self["eventSource"] @property def aws_region(self) -> str: """aws region eg: us-east-1""" - return self._v["awsRegion"] + return self["awsRegion"] @property def event_time(self) -> str: """The time, in ISO-8601 format, for example, 1970-01-01T00:00:00.000Z, when S3 finished processing the request""" - return self._v["eventTime"] + return self["eventTime"] @property def event_name(self) -> str: """Event type""" - return self._v["eventName"] + return self["eventName"] @property def user_identity(self) -> S3Identity: - return S3Identity(self._v["userIdentity"]) + return S3Identity(self["userIdentity"]) @property def request_parameters(self) -> S3RequestParameters: - return S3RequestParameters(self._v) + return S3RequestParameters(self._data) @property def response_elements(self) -> Dict[str, str]: @@ -166,16 +144,16 @@ def response_elements(self) -> Dict[str, str]: as those that Amazon S3 returns in the response to the request that initiates the events, so they can be used to match the event to the request. """ - return self._v["responseElements"] + return self["responseElements"] @property def s3(self) -> S3Message: - return S3Message(self._v) + return S3Message(self._data) @property def glacier_event_data(self) -> Optional[S3EventRecordGlacierEventData]: """The glacierEventData key is only visible for s3:ObjectRestore:Completed events.""" - item = self._v.get("glacierEventData") + item = self.get("glacierEventData") return None if item is None else S3EventRecordGlacierEventData(item) diff --git a/aws_lambda_powertools/utilities/trigger/ses_event.py b/aws_lambda_powertools/utilities/trigger/ses_event.py index 7cda0792287..adfd612c4f8 100644 --- a/aws_lambda_powertools/utilities/trigger/ses_event.py +++ b/aws_lambda_powertools/utilities/trigger/ses_event.py @@ -1,69 +1,62 @@ -from typing import Any, Dict, Iterator, List +from typing import Iterator, List +from aws_lambda_powertools.utilities.trigger.common import DictWrapper -class SESMailHeader: - def __init__(self, header: Dict[str, str]): - self._v = header +class SESMailHeader(DictWrapper): @property def name(self) -> str: - return self._v["name"] + return self["name"] @property def value(self) -> str: - return self._v["value"] + return self["value"] -class SESMailCommonHeaders: - def __init__(self, common_headers: Dict[str, Any]): - self._v = common_headers - +class SESMailCommonHeaders(DictWrapper): @property def return_path(self) -> str: """The values in the Return-Path header of the email.""" - return self._v["returnPath"] + return self["returnPath"] @property def get_from(self) -> List[str]: """The values in the From header of the email.""" # Note: this name conflicts with existing python builtins - return self._v["from"] + return self["from"] @property def date(self) -> List[str]: """The date and time when Amazon SES received the message.""" - return self._v["date"] + return self["date"] @property def to(self) -> List[str]: """The values in the To header of the email.""" - return self._v["to"] + return self["to"] @property def message_id(self) -> str: """The ID of the original message.""" - return str(self._v["messageId"]) + return str(self["messageId"]) @property def subject(self) -> str: """The value of the Subject header for the email.""" - return str(self._v["subject"]) - + return str(self["subject"]) -class SESMail: - def __init__(self, mail: Dict[str, Any]): - self._v = mail +class SESMail(DictWrapper): @property def timestamp(self) -> str: """String that contains the time at which the email was received, in ISO8601 format.""" - return self._v["timestamp"] + return self["timestamp"] @property def source(self) -> str: """String that contains the email address (specifically, the envelope MAIL FROM address) that the email was sent from.""" - return self._v["source"] + return self["source"] @property def message_id(self) -> str: @@ -71,46 +64,40 @@ def message_id(self) -> str: If the email was delivered to Amazon S3, the message ID is also the Amazon S3 object key that was used to write the message to your Amazon S3 bucket.""" - return self._v["messageId"] + return self["messageId"] @property def destination(self) -> List[str]: """A complete list of all recipient addresses (including To: and CC: recipients) from the MIME headers of the incoming email.""" - return self._v["destination"] + return self["destination"] @property def headers_truncated(self) -> bool: """String that specifies whether the headers were truncated in the notification, which will happen if the headers are larger than 10 KB. Possible values are true and false.""" - return bool(self._v["headersTruncated"]) + return bool(self["headersTruncated"]) @property def headers(self) -> Iterator[SESMailHeader]: """A list of Amazon SES headers and your custom headers. Each header in the list has a name field and a value field""" - for header in self._v["headers"]: + for header in self["headers"]: yield SESMailHeader(header) @property def common_headers(self) -> SESMailCommonHeaders: """A list of headers common to all emails. Each header in the list is composed of a name and a value.""" - return SESMailCommonHeaders(self._v["commonHeaders"]) + return SESMailCommonHeaders(self["commonHeaders"]) -class SESReceiptStatus: - def __init__(self, receipt_status: Dict[str, str]): - self._v = receipt_status - +class SESReceiptStatus(DictWrapper): @property def status(self) -> str: - return str(self._v["status"]) - + return str(self["status"]) -class SESReceiptAction: - def __init__(self, receipt_action: Dict[str, str]): - self._v = receipt_action +class SESReceiptAction(DictWrapper): @property def get_type(self) -> str: """String that indicates the type of action that was executed. @@ -118,99 +105,90 @@ def get_type(self) -> str: Possible values are S3, SNS, Bounce, Lambda, Stop, and WorkMail """ # Note: this name conflicts with existing python builtins - return self._v["type"] + return self["type"] @property def function_arn(self) -> str: """String that contains the ARN of the Lambda function that was triggered. Present only for the Lambda action type.""" - return self._v["functionArn"] + return self["functionArn"] @property def invocation_type(self) -> str: """String that contains the invocation type of the Lambda function. Possible values are RequestResponse and Event. Present only for the Lambda action type.""" - return self._v["invocationType"] + return self["invocationType"] -class SESReceipt: - def __init__(self, receipt: Dict[str, Any]): - self._v = receipt - +class SESReceipt(DictWrapper): @property def timestamp(self) -> str: """String that specifies the date and time at which the action was triggered, in ISO 8601 format.""" - return self._v["timestamp"] + return self["timestamp"] @property def processing_time_millis(self) -> int: """String that specifies the period, in milliseconds, from the time Amazon SES received the message to the time it triggered the action.""" - return int(self._v["processingTimeMillis"]) + return int(self["processingTimeMillis"]) @property def recipients(self) -> List[str]: """A list of recipients (specifically, the envelope RCPT TO addresses) that were matched by the active receipt rule. The addresses listed here may differ from those listed by the destination field in the mail object.""" - return self._v["recipients"] + return self["recipients"] @property def spam_verdict(self) -> SESReceiptStatus: """Object that indicates whether the message is spam.""" - return SESReceiptStatus(self._v["spamVerdict"]) + return SESReceiptStatus(self["spamVerdict"]) @property def virus_verdict(self) -> SESReceiptStatus: """Object that indicates whether the message contains a virus.""" - return SESReceiptStatus(self._v["virusVerdict"]) + return SESReceiptStatus(self["virusVerdict"]) @property def spf_verdict(self) -> SESReceiptStatus: """Object that indicates whether the Sender Policy Framework (SPF) check passed.""" - return SESReceiptStatus(self._v["spfVerdict"]) + return SESReceiptStatus(self["spfVerdict"]) @property def dmarc_verdict(self) -> SESReceiptStatus: """Object that indicates whether the Domain-based Message Authentication, Reporting & Conformance (DMARC) check passed.""" - return SESReceiptStatus(self._v["dmarcVerdict"]) + return SESReceiptStatus(self["dmarcVerdict"]) @property def action(self) -> SESReceiptAction: """Object that encapsulates information about the action that was executed.""" - return SESReceiptAction(self._v["action"]) - + return SESReceiptAction(self["action"]) -class SESMessage: - def __init__(self, record: Dict[str, Any]): - self._v = record +class SESMessage(DictWrapper): @property def mail(self) -> SESMail: - return SESMail(self._v["ses"]["mail"]) + return SESMail(self["ses"]["mail"]) @property def receipt(self) -> SESReceipt: - return SESReceipt(self._v["ses"]["receipt"]) + return SESReceipt(self["ses"]["receipt"]) -class SESEventRecord: - def __init__(self, record: Dict[str, Any]): - self._v = record - +class SESEventRecord(DictWrapper): @property def event_source(self) -> str: """The AWS service from which the SES event record originated. For SES, this is aws:ses""" - return self._v["eventSource"] + return self["eventSource"] @property def event_version(self) -> str: - return self._v["eventVersion"] + return self["eventVersion"] @property def ses(self) -> SESMessage: - return SESMessage(self._v) + return SESMessage(self._data) class SESEvent(dict): @@ -228,3 +206,15 @@ class SESEvent(dict): def records(self) -> Iterator[SESEventRecord]: for record in self["Records"]: yield SESEventRecord(record) + + @property + def record(self) -> SESEventRecord: + return next(self.records) + + @property + def mail(self) -> SESMail: + return self.record.ses.mail + + @property + def receipt(self) -> SESReceipt: + return self.record.ses.receipt diff --git a/aws_lambda_powertools/utilities/trigger/sns_event.py b/aws_lambda_powertools/utilities/trigger/sns_event.py index 05c71739bbb..7f6c70f8a6d 100644 --- a/aws_lambda_powertools/utilities/trigger/sns_event.py +++ b/aws_lambda_powertools/utilities/trigger/sns_event.py @@ -1,61 +1,57 @@ -from typing import Any, Dict, Iterator +from typing import Dict, Iterator +from aws_lambda_powertools.utilities.trigger.common import DictWrapper -class SNSMessageAttribute: - def __init__(self, message_attribute: Dict[str, str]): - self._v = message_attribute +class SNSMessageAttribute(DictWrapper): @property def get_type(self) -> str: """The supported message attribute data types are String, String.Array, Number, and Binary.""" # Note: this name conflicts with existing python builtins - return self._v["Type"] + return self["Type"] @property def value(self) -> str: """The user-specified message attribute value.""" - return self._v["Value"] + return self["Value"] -class SNSMessage: - def __init__(self, message: Dict[str, Any]): - self._v = message - +class SNSMessage(DictWrapper): @property def signature_version(self) -> str: """Version of the Amazon SNS signature used.""" - return self._v["Sns"]["SignatureVersion"] + return self["Sns"]["SignatureVersion"] @property def timestamp(self) -> str: """The time (GMT) when the subscription confirmation was sent.""" - return self._v["Sns"]["Timestamp"] + return self["Sns"]["Timestamp"] @property def signature(self) -> str: """Base64-encoded "SHA1withRSA" signature of the Message, MessageId, Type, Timestamp, and TopicArn values.""" - return self._v["Sns"]["Signature"] + return self["Sns"]["Signature"] @property def signing_cert_url(self) -> str: """The URL to the certificate that was used to sign the message.""" - return self._v["Sns"]["SigningCertUrl"] + return self["Sns"]["SigningCertUrl"] @property def message_id(self) -> str: """A Universally Unique Identifier, unique for each message published. For a message that Amazon SNS resends during a retry, the message ID of the original message is used.""" - return self._v["Sns"]["MessageId"] + return self["Sns"]["MessageId"] @property def message(self) -> str: """A string that describes the message. """ - return self._v["Sns"]["Message"] + return self["Sns"]["Message"] @property def message_attributes(self) -> Dict[str, SNSMessageAttribute]: - return {k: SNSMessageAttribute(v) for (k, v) in self._v["Sns"]["MessageAttributes"].items()} + return {k: SNSMessageAttribute(v) for (k, v) in self["Sns"]["MessageAttributes"].items()} @property def get_type(self) -> str: @@ -63,47 +59,44 @@ def get_type(self) -> str: For a subscription confirmation, the type is SubscriptionConfirmation.""" # Note: this name conflicts with existing python builtins - return self._v["Sns"]["Type"] + return self["Sns"]["Type"] @property def unsubscribe_url(self) -> str: """A URL that you can use to unsubscribe the endpoint from this topic. If you visit this URL, Amazon SNS unsubscribes the endpoint and stops sending notifications to this endpoint.""" - return self._v["Sns"]["UnsubscribeUrl"] + return self["Sns"]["UnsubscribeUrl"] @property def topic_arn(self) -> str: """The Amazon Resource Name (ARN) for the topic that this endpoint is subscribed to.""" - return self._v["Sns"]["TopicArn"] + return self["Sns"]["TopicArn"] @property def subject(self) -> str: """The Subject parameter specified when the notification was published to the topic.""" - return self._v["Sns"]["Subject"] - + return self["Sns"]["Subject"] -class SNSEventRecord: - def __init__(self, record: Dict[str, Any]): - self._v = record +class SNSEventRecord(DictWrapper): @property def event_version(self) -> str: """Event version""" - return self._v["EventVersion"] + return self["EventVersion"] @property def event_subscription_arn(self) -> str: - return self._v["EventSubscriptionArn"] + return self["EventSubscriptionArn"] @property def event_source(self) -> str: """The AWS service from which the SNS event record originated. For SNS, this is aws:sns""" - return self._v["EventSource"] + return self["EventSource"] @property def sns(self) -> SNSMessage: - return SNSMessage(self._v) + return SNSMessage(self._data) class SNSEvent(dict): diff --git a/aws_lambda_powertools/utilities/trigger/sqs_event.py b/aws_lambda_powertools/utilities/trigger/sqs_event.py index f50c2afe3d6..34d6289a8bb 100644 --- a/aws_lambda_powertools/utilities/trigger/sqs_event.py +++ b/aws_lambda_powertools/utilities/trigger/sqs_event.py @@ -1,39 +1,38 @@ -from typing import Any, Dict, Iterator, Optional +from typing import Dict, Iterator, Optional +from aws_lambda_powertools.utilities.trigger.common import DictWrapper -class SQSRecordAttributes: - def __init__(self, record_attributes: Dict[str, str]): - self._v = record_attributes +class SQSRecordAttributes(DictWrapper): @property def aws_trace_header(self) -> Optional[str]: """Returns the AWS X-Ray trace header string.""" - return self._v.get("AWSTraceHeader") + return self.get("AWSTraceHeader") @property def approximate_receive_count(self) -> str: """Returns the number of times a message has been received across all queues but not deleted.""" - return self._v["ApproximateReceiveCount"] + return self["ApproximateReceiveCount"] @property def sent_timestamp(self) -> str: """Returns the time the message was sent to the queue (epoch time in milliseconds).""" - return self._v["SentTimestamp"] + return self["SentTimestamp"] @property def sender_id(self) -> str: """For an IAM user, returns the IAM user ID, For an IAM role, returns the IAM role ID""" - return self._v["SenderId"] + return self["SenderId"] @property def approximate_first_receive_timestamp(self) -> str: """Returns the time the message was first received from the queue (epoch time in milliseconds).""" - return self._v["ApproximateFirstReceiveTimestamp"] + return self["ApproximateFirstReceiveTimestamp"] @property def sequence_number(self) -> Optional[str]: """The large, non-consecutive number that Amazon SQS assigns to each message.""" - return self._v.get("SequenceNumber") + return self.get("SequenceNumber") @property def message_group_id(self) -> Optional[str]: @@ -42,7 +41,7 @@ def message_group_id(self) -> Optional[str]: Messages that belong to the same message group are always processed one by one, in a strict order relative to the message group (however, messages that belong to different message groups might be processed out of order).""" - return self._v.get("MessageGroupId") + return self.get("MessageGroupId") @property def message_deduplication_id(self) -> Optional[str]: @@ -51,31 +50,28 @@ def message_deduplication_id(self) -> Optional[str]: If a message with a particular message deduplication ID is sent successfully, any messages sent with the same message deduplication ID are accepted successfully but aren't delivered during the 5-minute deduplication interval.""" - return self._v.get("MessageDeduplicationId") + return self.get("MessageDeduplicationId") -class SQSMessageAttribute: +class SQSMessageAttribute(DictWrapper): """The user-specified message attribute value.""" - def __init__(self, message_attribute: Dict[str, str]): - self._v = message_attribute - @property def string_value(self) -> Optional[str]: """Strings are Unicode with UTF-8 binary encoding.""" - return self._v["stringValue"] + return self["stringValue"] @property def binary_value(self) -> Optional[str]: """Binary type attributes can store any binary data, such as compressed data, encrypted data, or images. Base64-encoded binary data object""" - return self._v["binaryValue"] + return self["binaryValue"] @property def data_type(self) -> str: """ The message attribute data type. Supported types include `String`, `Number`, and `Binary`.""" - return self._v["dataType"] + return self["dataType"] class SQSMessageAttributes(Dict[str, SQSMessageAttribute]): @@ -84,18 +80,15 @@ def __getitem__(self, key: str) -> Optional[SQSMessageAttribute]: return None if item is None else SQSMessageAttribute(item) -class SQSRecord: +class SQSRecord(DictWrapper): """An Amazon SQS message""" - def __init__(self, record: Dict[str, Any]): - self._v = record - @property def message_id(self) -> str: """A unique identifier for the message. A messageId is considered unique across all AWS accounts for an extended period of time.""" - return self._v["messageId"] + return self["messageId"] @property def receipt_handle(self) -> str: @@ -103,42 +96,42 @@ def receipt_handle(self) -> str: A new receipt handle is returned every time you receive a message. When deleting a message, you provide the last received receipt handle to delete the message.""" - return self._v["receiptHandle"] + return self["receiptHandle"] @property def body(self) -> str: """The message's contents (not URL-encoded).""" - return self._v["body"] + return self["body"] @property def attributes(self) -> SQSRecordAttributes: """A map of the attributes requested in ReceiveMessage to their respective values.""" - return SQSRecordAttributes(self._v["attributes"]) + return SQSRecordAttributes(self["attributes"]) @property def message_attributes(self) -> SQSMessageAttributes: """Each message attribute consists of a Name, Type, and Value.""" - return SQSMessageAttributes(self._v["messageAttributes"]) + return SQSMessageAttributes(self["messageAttributes"]) @property def md5_of_body(self) -> str: """An MD5 digest of the non-URL-encoded message body string.""" - return self._v["md5OfBody"] + return self["md5OfBody"] @property def event_source(self) -> str: """The AWS service from which the SQS record originated. For SQS, this is `aws:sqs` """ - return self._v["eventSource"] + return self["eventSource"] @property def event_source_arn(self) -> str: """The Amazon Resource Name (ARN) of the event source""" - return self._v["eventSourceARN"] + return self["eventSourceARN"] @property def aws_region(self) -> str: """aws region eg: us-east-1""" - return self._v["awsRegion"] + return self["awsRegion"] class SQSEvent(dict): diff --git a/tests/functional/test_lambda_trigger_events.py b/tests/functional/test_lambda_trigger_events.py index cdbc0478573..5f68bbbd7ff 100644 --- a/tests/functional/test_lambda_trigger_events.py +++ b/tests/functional/test_lambda_trigger_events.py @@ -78,6 +78,7 @@ def test_cognito_pre_signup_trigger_event(): assert event.response.auto_verify_phone is True event.response.auto_verify_email = True assert event.response.auto_verify_email is True + assert event["response"]["autoVerifyEmail"] is True def test_cognito_post_confirmation_trigger_event(): @@ -170,7 +171,8 @@ def test_cognito_pre_token_generation_trigger_event(): assert claims_override_details.group_configuration is None claims_override_details.group_configuration = {} - assert claims_override_details.group_configuration == {} + assert claims_override_details.group_configuration._data == {} + assert event["response"]["claimsOverrideDetails"]["groupOverrideDetails"] == {} expected_claims = {"test": "value"} claims_override_details.claims_to_add_or_override = expected_claims @@ -295,7 +297,7 @@ def test_s3_trigger_event(): assert s3.get_object.version_id is None assert s3.get_object.sequencer == "0C0F6F405D6ED209E1" assert record.glacier_event_data is None - assert event.record._v == event["Records"][0] + assert event.record._data == event["Records"][0] assert event.bucket_name == "lambda-artifacts-deafc19498e3f2df" assert event.object_key == "b21b84d653bb07b05b1e6b33684dc11b" @@ -341,7 +343,7 @@ def test_ses_trigger_event(): assert headers[0].value == "" common_headers = mail.common_headers assert common_headers.return_path == "janedoe@example.com" - assert common_headers.get_from == common_headers._v["from"] + assert common_headers.get_from == common_headers._data["from"] assert common_headers.date == "Wed, 7 Oct 2015 12:34:56 -0700" assert common_headers.to == [expected_address] assert common_headers.message_id == "<0123456789example.com>" @@ -355,9 +357,12 @@ def test_ses_trigger_event(): assert receipt.spf_verdict.status == "PASS" assert receipt.dmarc_verdict.status == "PASS" action = receipt.action - assert action.get_type == action._v["type"] - assert action.function_arn == action._v["functionArn"] - assert action.invocation_type == action._v["invocationType"] + assert action.get_type == action._data["type"] + assert action.function_arn == action._data["functionArn"] + assert action.invocation_type == action._data["invocationType"] + assert event.record._data == event["Records"][0] + assert event.mail._data == event["Records"][0]["ses"]["mail"] + assert event.receipt._data == event["Records"][0]["ses"]["receipt"] def test_sns_trigger_event(): @@ -383,7 +388,7 @@ def test_sns_trigger_event(): assert sns.unsubscribe_url == "https://sns.us-east-2.amazonaws.com/?Action=Unsubscri ..." assert sns.topic_arn == "arn:aws:sns:us-east-2:123456789012:sns-lambda" assert sns.subject == "TestInvoke" - assert event.record._v == event["Records"][0] + assert event.record._data == event["Records"][0] assert event.sns_message == "Hello from SNS!" From cc180d3ddb1b5ff3f758289ba5c79d296cb7acfd Mon Sep 17 00:00:00 2001 From: Michael Brewer Date: Sat, 12 Sep 2020 23:23:29 -0700 Subject: [PATCH 21/30] feat(trigger): API gateway helper methods --- .../trigger/api_gateway_proxy_event.py | 66 ++++++++++++++++ .../functional/test_lambda_trigger_events.py | 76 +++++++++++++++++++ 2 files changed, 142 insertions(+) diff --git a/aws_lambda_powertools/utilities/trigger/api_gateway_proxy_event.py b/aws_lambda_powertools/utilities/trigger/api_gateway_proxy_event.py index 2c293a2a60e..3cb6cb30d38 100644 --- a/aws_lambda_powertools/utilities/trigger/api_gateway_proxy_event.py +++ b/aws_lambda_powertools/utilities/trigger/api_gateway_proxy_event.py @@ -257,6 +257,39 @@ def body(self) -> Optional[str]: def is_base64_encoded(self) -> bool: return self["isBase64Encoded"] + def get_query_string_value(self, name: str, default_value: Optional[str] = None) -> Optional[str]: + """Get query string value by name + + Parameters + ---------- + name: str + Query string parameter name + default_value: str, optional + Default value if no value was found by name + Returns + ------- + str, optional + Query string parameter value + """ + params = self.query_string_parameters + return default_value if params is None else params.get(name, default_value) + + def get_header_value(self, name: str, default_value: Optional[str] = None) -> Optional[str]: + """Get header value by name + + Parameters + ---------- + name: str + Header name + default_value: str, optional + Default value if no value was found by name + Returns + ------- + str, optional + Header value + """ + return self.headers.get(name, default_value) + class RequestContextV2Http(DictWrapper): @property @@ -412,3 +445,36 @@ def is_base64_encoded(self) -> bool: @property def stage_variables(self) -> Optional[Dict[str, str]]: return self.get("stageVariables") + + def get_query_string_value(self, name: str, default_value: Optional[str] = None) -> Optional[str]: + """Get query string value by name + + Parameters + ---------- + name: str + Query string parameter name + default_value: str, optional + Default value if no value was found by name + Returns + ------- + str, optional + Query string parameter value + """ + params = self.query_string_parameters + return default_value if params is None else params.get(name, default_value) + + def get_header_value(self, name: str, default_value: Optional[str] = None) -> Optional[str]: + """Get header value by name + + Parameters + ---------- + name: str + Header name + default_value: str, optional + Default value if no value was found by name + Returns + ------- + str, optional + Header value + """ + return self.headers.get(name, default_value) diff --git a/tests/functional/test_lambda_trigger_events.py b/tests/functional/test_lambda_trigger_events.py index 5f68bbbd7ff..f2da181e247 100644 --- a/tests/functional/test_lambda_trigger_events.py +++ b/tests/functional/test_lambda_trigger_events.py @@ -527,3 +527,79 @@ def test_api_gateway_proxy_v2_event(): assert event.path_parameters == event["pathParameters"] assert event.is_base64_encoded == event["isBase64Encoded"] assert event.stage_variables == event["stageVariables"] + + +def test_api_gateway_proxy_get_query_string_value(): + default_value = "default" + set_value = "value" + + event = APIGatewayProxyEvent({}) + value = event.get_query_string_value("test", default_value) + assert value == default_value + + event["queryStringParameters"] = {"test": set_value} + value = event.get_query_string_value("test", default_value) + assert value == set_value + + value = event.get_query_string_value("unknown", default_value) + assert value == default_value + + value = event.get_query_string_value("unknown") + assert value is None + + +def test_api_gateway_proxy_v2_get_query_string_value(): + default_value = "default" + set_value = "value" + + event = APIGatewayProxyEventV2({}) + value = event.get_query_string_value("test", default_value) + assert value == default_value + + event["queryStringParameters"] = {"test": set_value} + value = event.get_query_string_value("test", default_value) + assert value == set_value + + value = event.get_query_string_value("unknown", default_value) + assert value == default_value + + value = event.get_query_string_value("unknown") + assert value is None + + +def test_api_gateway_proxy_get_header_value(): + default_value = "default" + set_value = "value" + + event = APIGatewayProxyEvent({"headers": {}}) + value = event.get_header_value("test", default_value) + assert value == default_value + + event["headers"] = {"test": set_value} + value = event.get_header_value("test", default_value) + assert value == set_value + + value = event.get_header_value("unknown", default_value) + assert value == default_value + + value = event.get_header_value("unknown") + assert value is None + + +def test_api_gateway_proxy_v2_get_header_value(): + default_value = "default" + set_value = "value" + + event = APIGatewayProxyEventV2({"headers": {}}) + value = event.get_header_value("test", default_value) + assert value == default_value + + event["headers"] = {"test": set_value} + value = event.get_header_value("test", default_value) + assert value == set_value + + value = event.get_header_value("unknown", default_value) + assert value == default_value + + value = event.get_header_value("unknown") + assert value is None From 38113563aa71b1bacf64480a49cd2d360a057e63 Mon Sep 17 00:00:00 2001 From: Michael Brewer Date: Sun, 13 Sep 2020 01:43:38 -0700 Subject: [PATCH 22/30] feat(trigger): Kinesis stream event Add support for Kinesis stream events with a helper method to decode data --- .../utilities/trigger/__init__.py | 2 + .../utilities/trigger/kinesis_stream_event.py | 96 +++++++++++++++++++ .../utilities/trigger/ses_event.py | 1 + tests/events/kinesisStreamEvent.json | 36 +++++++ .../functional/test_lambda_trigger_events.py | 36 +++++++ 5 files changed, 171 insertions(+) create mode 100644 aws_lambda_powertools/utilities/trigger/kinesis_stream_event.py create mode 100644 tests/events/kinesisStreamEvent.json diff --git a/aws_lambda_powertools/utilities/trigger/__init__.py b/aws_lambda_powertools/utilities/trigger/__init__.py index 3b63f1a723f..e2e8d57136f 100644 --- a/aws_lambda_powertools/utilities/trigger/__init__.py +++ b/aws_lambda_powertools/utilities/trigger/__init__.py @@ -2,6 +2,7 @@ from .cloud_watch_logs_event import CloudWatchLogsEvent from .dynamo_db_stream_event import DynamoDBStreamEvent from .event_bridge_event import EventBridgeEvent +from .kinesis_stream_event import KinesisStreamEvent from .s3_event import S3Event from .ses_event import SESEvent from .sns_event import SNSEvent @@ -13,6 +14,7 @@ "CloudWatchLogsEvent", "DynamoDBStreamEvent", "EventBridgeEvent", + "KinesisStreamEvent", "S3Event", "SESEvent", "SNSEvent", diff --git a/aws_lambda_powertools/utilities/trigger/kinesis_stream_event.py b/aws_lambda_powertools/utilities/trigger/kinesis_stream_event.py new file mode 100644 index 00000000000..14be80a5681 --- /dev/null +++ b/aws_lambda_powertools/utilities/trigger/kinesis_stream_event.py @@ -0,0 +1,96 @@ +import base64 +import json +from typing import Iterator + +from aws_lambda_powertools.utilities.trigger.common import DictWrapper + + +class KinesisStreamRecordPayload(DictWrapper): + @property + def approximate_arrival_timestamp(self) -> float: + """The approximate time that the record was inserted into the stream""" + return float(self["kinesis"]["approximateArrivalTimestamp"]) + + @property + def data(self) -> str: + """The data blob""" + return self["kinesis"]["data"] + + @property + def kinesis_schema_version(self) -> str: + """Schema version for the record""" + return self["kinesis"]["kinesisSchemaVersion"] + + @property + def partition_key(self) -> str: + """Identifies which shard in the stream the data record is assigned to""" + return self["kinesis"]["partitionKey"] + + @property + def sequence_number(self) -> str: + """The unique identifier of the record within its shard""" + return self["kinesis"]["sequenceNumber"] + + def data_as_text(self) -> str: + """Decode binary encoded data as text""" + return base64.b64decode(self.data).decode("utf-8") + + def data_as_json(self) -> dict: + """Decode binary encoded data as json""" + return json.loads(self.data_as_text()) + + +class KinesisStreamRecord(DictWrapper): + @property + def aws_region(self) -> str: + """AWS region where the event originated eg: us-east-1""" + return self["awsRegion"] + + @property + def event_id(self) -> str: + """A globally unique identifier for the event that was recorded in this stream record.""" + return self["eventID"] + + @property + def event_name(self) -> str: + """Event type eg: aws:kinesis:record""" + return self["eventName"] + + @property + def event_source(self) -> str: + """The AWS service from which the Kinesis event originated. For Kinesis, this is aws:kinesis""" + return self["eventSource"] + + @property + def event_source_arn(self) -> str: + """The Amazon Resource Name (ARN) of the event source""" + return self["eventSourceARN"] + + @property + def event_version(self) -> str: + """The eventVersion key value contains a major and minor version in the form ..""" + return self["eventVersion"] + + @property + def invoke_identity_arn(self) -> str: + """The ARN for the identity used to invoke the Lambda Function""" + return self["invokeIdentityArn"] + + @property + def kinesis(self) -> KinesisStreamRecordPayload: + """Underlying Kinesis record associated with the event""" + return KinesisStreamRecordPayload(self._data) + + +class KinesisStreamEvent(dict): + """Kinesis stream event + + Documentation: + -------------- + - https://docs.aws.amazon.com/lambda/latest/dg/with-kinesis.html + """ + + @property + def records(self) -> Iterator[KinesisStreamRecord]: + for record in self["Records"]: + yield KinesisStreamRecord(record) diff --git a/aws_lambda_powertools/utilities/trigger/ses_event.py b/aws_lambda_powertools/utilities/trigger/ses_event.py index adfd612c4f8..3a5ad3e2e40 100644 --- a/aws_lambda_powertools/utilities/trigger/ses_event.py +++ b/aws_lambda_powertools/utilities/trigger/ses_event.py @@ -184,6 +184,7 @@ def event_source(self) -> str: @property def event_version(self) -> str: + """The eventVersion key value contains a major and minor version in the form ..""" return self["eventVersion"] @property diff --git a/tests/events/kinesisStreamEvent.json b/tests/events/kinesisStreamEvent.json new file mode 100644 index 00000000000..ef8e2096388 --- /dev/null +++ b/tests/events/kinesisStreamEvent.json @@ -0,0 +1,36 @@ +{ + "Records": [ + { + "kinesis": { + "kinesisSchemaVersion": "1.0", + "partitionKey": "1", + "sequenceNumber": "49590338271490256608559692538361571095921575989136588898", + "data": "SGVsbG8sIHRoaXMgaXMgYSB0ZXN0Lg==", + "approximateArrivalTimestamp": 1545084650.987 + }, + "eventSource": "aws:kinesis", + "eventVersion": "1.0", + "eventID": "shardId-000000000006:49590338271490256608559692538361571095921575989136588898", + "eventName": "aws:kinesis:record", + "invokeIdentityArn": "arn:aws:iam::123456789012:role/lambda-role", + "awsRegion": "us-east-2", + "eventSourceARN": "arn:aws:kinesis:us-east-2:123456789012:stream/lambda-stream" + }, + { + "kinesis": { + "kinesisSchemaVersion": "1.0", + "partitionKey": "1", + "sequenceNumber": "49590338271490256608559692540925702759324208523137515618", + "data": "VGhpcyBpcyBvbmx5IGEgdGVzdC4=", + "approximateArrivalTimestamp": 1545084711.166 + }, + "eventSource": "aws:kinesis", + "eventVersion": "1.0", + "eventID": "shardId-000000000006:49590338271490256608559692540925702759324208523137515618", + "eventName": "aws:kinesis:record", + "invokeIdentityArn": "arn:aws:iam::123456789012:role/lambda-role", + "awsRegion": "us-east-2", + "eventSourceARN": "arn:aws:kinesis:us-east-2:123456789012:stream/lambda-stream" + } + ] +} diff --git a/tests/functional/test_lambda_trigger_events.py b/tests/functional/test_lambda_trigger_events.py index f2da181e247..ea3faff6745 100644 --- a/tests/functional/test_lambda_trigger_events.py +++ b/tests/functional/test_lambda_trigger_events.py @@ -1,3 +1,4 @@ +import base64 import json import os from secrets import compare_digest @@ -7,6 +8,7 @@ APIGatewayProxyEventV2, CloudWatchLogsEvent, EventBridgeEvent, + KinesisStreamEvent, S3Event, SESEvent, SNSEvent, @@ -603,3 +605,37 @@ def test_api_gateway_proxy_v2_get_header_value(): value = event.get_header_value("unknown") assert value is None + + +def test_kinesis_stream_event(): + event = KinesisStreamEvent(load_event("kinesisStreamEvent.json")) + + records = list(event.records) + assert len(records) == 2 + record = records[0] + + assert record.aws_region == "us-east-2" + assert record.event_id == "shardId-000000000006:49590338271490256608559692538361571095921575989136588898" + assert record.event_name == "aws:kinesis:record" + assert record.event_source == "aws:kinesis" + assert record.event_source_arn == "arn:aws:kinesis:us-east-2:123456789012:stream/lambda-stream" + assert record.event_version == "1.0" + assert record.invoke_identity_arn == "arn:aws:iam::123456789012:role/lambda-role" + + kinesis = record.kinesis + assert kinesis._data["kinesis"] == event["Records"][0]["kinesis"] + + assert kinesis.approximate_arrival_timestamp == 1545084650.987 + assert kinesis.data == event["Records"][0]["kinesis"]["data"] + assert kinesis.kinesis_schema_version == "1.0" + assert kinesis.partition_key == "1" + assert kinesis.sequence_number == "49590338271490256608559692538361571095921575989136588898" + + assert kinesis.data_as_text() == "Hello, this is a test." + + +def test_kinesis_stream_event_json_data(): + json_value = {"test": "value"} + data = base64.b64encode(bytes(json.dumps(json_value), "utf-8")).decode("utf-8") + event = KinesisStreamEvent({"Records": [{"kinesis": {"data": data}}]}) + assert next(event.records).kinesis.data_as_json() == json_value From 3c41f36ef9c9c1eca512c699638702e2abb14cff Mon Sep 17 00:00:00 2001 From: Michael Brewer Date: Sun, 13 Sep 2020 21:19:52 -0700 Subject: [PATCH 23/30] feat(trigger): Application load balancer event Create BaseProxyEvent for some reusable code for the http proxy events Add support for ALB events --- .../utilities/trigger/__init__.py | 2 + .../utilities/trigger/alb_event.py | 38 +++++++ .../trigger/api_gateway_proxy_event.py | 104 +----------------- .../utilities/trigger/common.py | 51 +++++++++ tests/events/albEvent.json | 28 +++++ .../functional/test_lambda_trigger_events.py | 61 ++++------ 6 files changed, 141 insertions(+), 143 deletions(-) create mode 100644 aws_lambda_powertools/utilities/trigger/alb_event.py create mode 100644 tests/events/albEvent.json diff --git a/aws_lambda_powertools/utilities/trigger/__init__.py b/aws_lambda_powertools/utilities/trigger/__init__.py index e2e8d57136f..47ca29c2148 100644 --- a/aws_lambda_powertools/utilities/trigger/__init__.py +++ b/aws_lambda_powertools/utilities/trigger/__init__.py @@ -1,3 +1,4 @@ +from .alb_event import ALBEvent from .api_gateway_proxy_event import APIGatewayProxyEvent, APIGatewayProxyEventV2 from .cloud_watch_logs_event import CloudWatchLogsEvent from .dynamo_db_stream_event import DynamoDBStreamEvent @@ -11,6 +12,7 @@ __all__ = [ "APIGatewayProxyEvent", "APIGatewayProxyEventV2", + "ALBEvent", "CloudWatchLogsEvent", "DynamoDBStreamEvent", "EventBridgeEvent", diff --git a/aws_lambda_powertools/utilities/trigger/alb_event.py b/aws_lambda_powertools/utilities/trigger/alb_event.py new file mode 100644 index 00000000000..359796451dc --- /dev/null +++ b/aws_lambda_powertools/utilities/trigger/alb_event.py @@ -0,0 +1,38 @@ +from typing import Dict, List, Optional + +from aws_lambda_powertools.utilities.trigger.common import BaseProxyEvent, DictWrapper + + +class ALBEventRequestContext(DictWrapper): + @property + def elb_target_group_arn(self) -> str: + return self["requestContext"]["elb"]["targetGroupArn"] + + +class ALBEvent(BaseProxyEvent): + """Application load balancer event + + Documentation: + -------------- + - https://docs.aws.amazon.com/lambda/latest/dg/services-alb.html + """ + + @property + def request_context(self) -> ALBEventRequestContext: + return ALBEventRequestContext(self) + + @property + def http_method(self) -> str: + return self["httpMethod"] + + @property + def path(self) -> str: + return self["path"] + + @property + def multi_value_query_string_parameters(self) -> Optional[Dict[str, List[str]]]: + return self.get("multiValueQueryStringParameters") + + @property + def multi_value_headers(self) -> Optional[Dict[str, List[str]]]: + return self.get("multiValueHeaders") diff --git a/aws_lambda_powertools/utilities/trigger/api_gateway_proxy_event.py b/aws_lambda_powertools/utilities/trigger/api_gateway_proxy_event.py index 3cb6cb30d38..aeec67de19c 100644 --- a/aws_lambda_powertools/utilities/trigger/api_gateway_proxy_event.py +++ b/aws_lambda_powertools/utilities/trigger/api_gateway_proxy_event.py @@ -1,6 +1,6 @@ from typing import Any, Dict, List, Optional -from aws_lambda_powertools.utilities.trigger.common import DictWrapper +from aws_lambda_powertools.utilities.trigger.common import BaseProxyEvent, DictWrapper class APIGatewayEventIdentity(DictWrapper): @@ -196,7 +196,7 @@ def route_key(self) -> Optional[str]: return self["requestContext"].get("routeKey") -class APIGatewayProxyEvent(dict): +class APIGatewayProxyEvent(BaseProxyEvent): """AWS Lambda proxy V1 Documentation: @@ -221,18 +221,10 @@ def http_method(self) -> str: """The HTTP method used. Valid values include: DELETE, GET, HEAD, OPTIONS, PATCH, POST, and PUT.""" return self["httpMethod"] - @property - def headers(self) -> Dict[str, str]: - return self["headers"] - @property def multi_value_headers(self) -> Dict[str, List[str]]: return self["multiValueHeaders"] - @property - def query_string_parameters(self) -> Optional[Dict[str, str]]: - return self.get("queryStringParameters") - @property def multi_value_query_string_parameters(self) -> Optional[Dict[str, List[str]]]: return self.get("multiValueQueryStringParameters") @@ -249,47 +241,6 @@ def path_parameters(self) -> Optional[Dict[str, str]]: def stage_variables(self) -> Optional[Dict[str, str]]: return self.get("stageVariables") - @property - def body(self) -> Optional[str]: - return self.get("body") - - @property - def is_base64_encoded(self) -> bool: - return self["isBase64Encoded"] - - def get_query_string_value(self, name: str, default_value: Optional[str] = None) -> Optional[str]: - """Get query string value by name - - Parameters - ---------- - name: str - Query string parameter name - default_value: str, optional - Default value if no value was found by name - Returns - ------- - str, optional - Query string parameter value - """ - params = self.query_string_parameters - return default_value if params is None else params.get(name, default_value) - - def get_header_value(self, name: str, default_value: Optional[str] = None) -> Optional[str]: - """Get header value by name - - Parameters - ---------- - name: str - Header name - default_value: str, optional - Default value if no value was found by name - Returns - ------- - str, optional - Header value - """ - return self.headers.get(name, default_value) - class RequestContextV2Http(DictWrapper): @property @@ -381,7 +332,7 @@ def time_epoch(self) -> int: return self["requestContext"]["timeEpoch"] -class APIGatewayProxyEventV2(dict): +class APIGatewayProxyEventV2(BaseProxyEvent): """AWS Lambda proxy V2 event Notes: @@ -418,63 +369,14 @@ def raw_query_string(self) -> str: def cookies(self) -> Optional[List[str]]: return self.get("cookies") - @property - def headers(self) -> Dict[str, str]: - return self["headers"] - - @property - def query_string_parameters(self) -> Optional[Dict[str, str]]: - return self.get("queryStringParameters") - @property def request_context(self) -> RequestContextV2: return RequestContextV2(self) - @property - def body(self) -> Optional[str]: - return self.get("body") - @property def path_parameters(self) -> Optional[Dict[str, str]]: return self.get("pathParameters") - @property - def is_base64_encoded(self) -> bool: - return self["isBase64Encoded"] - @property def stage_variables(self) -> Optional[Dict[str, str]]: return self.get("stageVariables") - - def get_query_string_value(self, name: str, default_value: Optional[str] = None) -> Optional[str]: - """Get query string value by name - - Parameters - ---------- - name: str - Query string parameter name - default_value: str, optional - Default value if no value was found by name - Returns - ------- - str, optional - Query string parameter value - """ - params = self.query_string_parameters - return default_value if params is None else params.get(name, default_value) - - def get_header_value(self, name: str, default_value: Optional[str] = None) -> Optional[str]: - """Get header value by name - - Parameters - ---------- - name: str - Header name - default_value: str, optional - Default value if no value was found by name - Returns - ------- - str, optional - Header value - """ - return self.headers.get(name, default_value) diff --git a/aws_lambda_powertools/utilities/trigger/common.py b/aws_lambda_powertools/utilities/trigger/common.py index 8feb38e837f..33803856da9 100644 --- a/aws_lambda_powertools/utilities/trigger/common.py +++ b/aws_lambda_powertools/utilities/trigger/common.py @@ -11,3 +11,54 @@ def __getitem__(self, key: str) -> Any: def get(self, key: str) -> Optional[Any]: return self._data.get(key) + + +class BaseProxyEvent(dict): + @property + def headers(self) -> Dict[str, str]: + return self["headers"] + + @property + def query_string_parameters(self) -> Optional[Dict[str, str]]: + return self.get("queryStringParameters") + + @property + def is_base64_encoded(self) -> bool: + return self.get("isBase64Encoded") + + @property + def body(self) -> Optional[str]: + return self.get("body") + + def get_query_string_value(self, name: str, default_value: Optional[str] = None) -> Optional[str]: + """Get query string value by name + + Parameters + ---------- + name: str + Query string parameter name + default_value: str, optional + Default value if no value was found by name + Returns + ------- + str, optional + Query string parameter value + """ + params = self.query_string_parameters + return default_value if params is None else params.get(name, default_value) + + def get_header_value(self, name: str, default_value: Optional[str] = None) -> Optional[str]: + """Get header value by name + + Parameters + ---------- + name: str + Header name + default_value: str, optional + Default value if no value was found by name + Returns + ------- + str, optional + Header value + """ + return self.headers.get(name, default_value) diff --git a/tests/events/albEvent.json b/tests/events/albEvent.json new file mode 100644 index 00000000000..9328cb39e12 --- /dev/null +++ b/tests/events/albEvent.json @@ -0,0 +1,28 @@ +{ + "requestContext": { + "elb": { + "targetGroupArn": "arn:aws:elasticloadbalancing:us-east-2:123456789012:targetgroup/lambda-279XGJDqGZ5rsrHC2Fjr/49e9d65c45c6791a" + } + }, + "httpMethod": "GET", + "path": "/lambda", + "queryStringParameters": { + "query": "1234ABCD" + }, + "headers": { + "accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,image/apng,*/*;q=0.8", + "accept-encoding": "gzip", + "accept-language": "en-US,en;q=0.9", + "connection": "keep-alive", + "host": "lambda-alb-123578498.us-east-2.elb.amazonaws.com", + "upgrade-insecure-requests": "1", + "user-agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/71.0.3578.98 Safari/537.36", + "x-amzn-trace-id": "Root=1-5c536348-3d683b8b04734faae651f476", + "x-forwarded-for": "72.12.164.125", + "x-forwarded-port": "80", + "x-forwarded-proto": "http", + "x-imforwards": "20" + }, + "body": "Test", + "isBase64Encoded": false +} diff --git a/tests/functional/test_lambda_trigger_events.py b/tests/functional/test_lambda_trigger_events.py index ea3faff6745..54f203979e3 100644 --- a/tests/functional/test_lambda_trigger_events.py +++ b/tests/functional/test_lambda_trigger_events.py @@ -4,6 +4,7 @@ from secrets import compare_digest from aws_lambda_powertools.utilities.trigger import ( + ALBEvent, APIGatewayProxyEvent, APIGatewayProxyEventV2, CloudWatchLogsEvent, @@ -23,6 +24,7 @@ PreTokenGenerationTriggerEvent, UserMigrationTriggerEvent, ) +from aws_lambda_powertools.utilities.trigger.common import BaseProxyEvent from aws_lambda_powertools.utilities.trigger.dynamo_db_stream_event import ( AttributeValue, DynamoDBRecordEventName, @@ -531,11 +533,11 @@ def test_api_gateway_proxy_v2_event(): assert event.stage_variables == event["stageVariables"] -def test_api_gateway_proxy_get_query_string_value(): +def test_base_proxy_event_get_query_string_value(): default_value = "default" set_value = "value" - event = APIGatewayProxyEvent({}) + event = BaseProxyEvent({}) value = event.get_query_string_value("test", default_value) assert value == default_value @@ -550,49 +552,11 @@ def test_api_gateway_proxy_get_query_string_value(): assert value is None -def test_api_gateway_proxy_v2_get_query_string_value(): +def test_base_proxy_event_get_header_value(): default_value = "default" set_value = "value" - event = APIGatewayProxyEventV2({}) - value = event.get_query_string_value("test", default_value) - assert value == default_value - - event["queryStringParameters"] = {"test": set_value} - value = event.get_query_string_value("test", default_value) - assert value == set_value - - value = event.get_query_string_value("unknown", default_value) - assert value == default_value - - value = event.get_query_string_value("unknown") - assert value is None - - -def test_api_gateway_proxy_get_header_value(): - default_value = "default" - set_value = "value" - - event = APIGatewayProxyEvent({"headers": {}}) - value = event.get_header_value("test", default_value) - assert value == default_value - - event["headers"] = {"test": set_value} - value = event.get_header_value("test", default_value) - assert value == set_value - - value = event.get_header_value("unknown", default_value) - assert value == default_value - - value = event.get_header_value("unknown") - assert value is None - - -def test_api_gateway_proxy_v2_get_header_value(): - default_value = "default" - set_value = "value" - - event = APIGatewayProxyEventV2({"headers": {}}) + event = BaseProxyEvent({"headers": {}}) value = event.get_header_value("test", default_value) assert value == default_value @@ -639,3 +603,16 @@ def test_kinesis_stream_event_json_data(): data = base64.b64encode(bytes(json.dumps(json_value), "utf-8")).decode("utf-8") event = KinesisStreamEvent({"Records": [{"kinesis": {"data": data}}]}) assert next(event.records).kinesis.data_as_json() == json_value + + +def test_alb_event(): + event = ALBEvent(load_event("albEvent.json")) + assert event.request_context.elb_target_group_arn == event["requestContext"]["elb"]["targetGroupArn"] + assert event.http_method == event["httpMethod"] + assert event.path == event["path"] + assert event.query_string_parameters == event["queryStringParameters"] + assert event.headers == event["headers"] + assert event.multi_value_query_string_parameters == event.get("multiValueQueryStringParameters") + assert event.multi_value_headers == event.get("multiValueHeaders") + assert event.body == event["body"] + assert event.is_base64_encoded == event["isBase64Encoded"] From 35358eda8d825c69f7b5fe64a502b5c831c72e8b Mon Sep 17 00:00:00 2001 From: Michael Brewer Date: Sun, 13 Sep 2020 21:32:54 -0700 Subject: [PATCH 24/30] refactor(trigger): Split decompress and parse For handling CloudWatchLogsEvent log data split the Decode and decompress from the parse as CloudWatchLogsDecodedData --- .../utilities/trigger/cloud_watch_logs_event.py | 12 ++++++++---- tests/functional/test_lambda_trigger_events.py | 2 +- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/aws_lambda_powertools/utilities/trigger/cloud_watch_logs_event.py b/aws_lambda_powertools/utilities/trigger/cloud_watch_logs_event.py index 427392e06aa..7ee42ce45cc 100644 --- a/aws_lambda_powertools/utilities/trigger/cloud_watch_logs_event.py +++ b/aws_lambda_powertools/utilities/trigger/cloud_watch_logs_event.py @@ -83,8 +83,12 @@ def aws_logs_data(self) -> str: """The value of the `data` field is a Base64 encoded ZIP archive.""" return self["awslogs"]["data"] - def decode_cloud_watch_logs_data(self) -> CloudWatchLogsDecodedData: - """Decode, unzip and parse json data""" + @property + def decompress_logs_data(self) -> bytes: + """Decode and decompress log data""" payload = base64.b64decode(self.aws_logs_data) - decoded: dict = json.loads(zlib.decompress(payload, zlib.MAX_WBITS | 32).decode("UTF-8")) - return CloudWatchLogsDecodedData(decoded) + return zlib.decompress(payload, zlib.MAX_WBITS | 32) + + def parse_logs_data(self) -> CloudWatchLogsDecodedData: + """Decode, decompress and parse json data as CloudWatchLogsDecodedData""" + return CloudWatchLogsDecodedData(json.loads(self.decompress_logs_data.decode("UTF-8"))) diff --git a/tests/functional/test_lambda_trigger_events.py b/tests/functional/test_lambda_trigger_events.py index 54f203979e3..42549db4c89 100644 --- a/tests/functional/test_lambda_trigger_events.py +++ b/tests/functional/test_lambda_trigger_events.py @@ -42,7 +42,7 @@ def load_event(file_name: str) -> dict: def test_cloud_watch_trigger_event(): event = CloudWatchLogsEvent(load_event("cloudWatchLogEvent.json")) - decoded_data = event.decode_cloud_watch_logs_data() + decoded_data = event.parse_logs_data() log_events = decoded_data.log_events log_event = log_events[0] From a746ddc03a94cc65beece0cd2a6a1c23853fb024 Mon Sep 17 00:00:00 2001 From: Michael Brewer Date: Sun, 13 Sep 2020 23:04:24 -0700 Subject: [PATCH 25/30] feat(trigger): Some attribte caching --- .../trigger/cloud_watch_logs_event.py | 19 ++++++++++++++----- .../functional/test_lambda_trigger_events.py | 18 +++++++++++------- 2 files changed, 25 insertions(+), 12 deletions(-) diff --git a/aws_lambda_powertools/utilities/trigger/cloud_watch_logs_event.py b/aws_lambda_powertools/utilities/trigger/cloud_watch_logs_event.py index 7ee42ce45cc..c11363811c1 100644 --- a/aws_lambda_powertools/utilities/trigger/cloud_watch_logs_event.py +++ b/aws_lambda_powertools/utilities/trigger/cloud_watch_logs_event.py @@ -1,7 +1,7 @@ import base64 import json import zlib -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional from aws_lambda_powertools.utilities.trigger.common import DictWrapper @@ -78,17 +78,26 @@ class CloudWatchLogsEvent(dict): - https://docs.aws.amazon.com/lambda/latest/dg/services-cloudwatchlogs.html """ + def __init__(self, event: Dict[str, Any]): + super().__init__(event) + self._decompressed_logs_data = None + self._json_logs_data = None + @property - def aws_logs_data(self) -> str: + def raw_logs_data(self) -> str: """The value of the `data` field is a Base64 encoded ZIP archive.""" return self["awslogs"]["data"] @property def decompress_logs_data(self) -> bytes: """Decode and decompress log data""" - payload = base64.b64decode(self.aws_logs_data) - return zlib.decompress(payload, zlib.MAX_WBITS | 32) + if self._decompressed_logs_data is None: + payload = base64.b64decode(self.raw_logs_data) + self._decompressed_logs_data = zlib.decompress(payload, zlib.MAX_WBITS | 32) + return self._decompressed_logs_data def parse_logs_data(self) -> CloudWatchLogsDecodedData: """Decode, decompress and parse json data as CloudWatchLogsDecodedData""" - return CloudWatchLogsDecodedData(json.loads(self.decompress_logs_data.decode("UTF-8"))) + if self._json_logs_data is None: + self._json_logs_data = json.loads(self.decompress_logs_data.decode("UTF-8")) + return CloudWatchLogsDecodedData(self._json_logs_data) diff --git a/tests/functional/test_lambda_trigger_events.py b/tests/functional/test_lambda_trigger_events.py index 42549db4c89..4fde880fa53 100644 --- a/tests/functional/test_lambda_trigger_events.py +++ b/tests/functional/test_lambda_trigger_events.py @@ -42,15 +42,19 @@ def load_event(file_name: str) -> dict: def test_cloud_watch_trigger_event(): event = CloudWatchLogsEvent(load_event("cloudWatchLogEvent.json")) - decoded_data = event.parse_logs_data() - log_events = decoded_data.log_events + decompressed_logs_data = event.decompress_logs_data + assert event.decompress_logs_data == decompressed_logs_data + + json_logs_data = event.parse_logs_data() + assert event.parse_logs_data() == json_logs_data + log_events = json_logs_data.log_events log_event = log_events[0] - assert decoded_data.owner == "123456789123" - assert decoded_data.log_group == "testLogGroup" - assert decoded_data.log_stream == "testLogStream" - assert decoded_data.subscription_filters == ["testFilter"] - assert decoded_data.message_type == "DATA_MESSAGE" + assert json_logs_data.owner == "123456789123" + assert json_logs_data.log_group == "testLogGroup" + assert json_logs_data.log_stream == "testLogStream" + assert json_logs_data.subscription_filters == ["testFilter"] + assert json_logs_data.message_type == "DATA_MESSAGE" assert log_event.get_id == "eventId1" assert log_event.timestamp == 1440442987000 From 7b4ec44c4789b0b9715f17a8e91abb85ce735511 Mon Sep 17 00:00:00 2001 From: Michael Brewer Date: Sun, 13 Sep 2020 23:25:58 -0700 Subject: [PATCH 26/30] fix(trigger): Init with __init__ --- .../utilities/trigger/cloud_watch_logs_event.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/aws_lambda_powertools/utilities/trigger/cloud_watch_logs_event.py b/aws_lambda_powertools/utilities/trigger/cloud_watch_logs_event.py index c11363811c1..86c975895c8 100644 --- a/aws_lambda_powertools/utilities/trigger/cloud_watch_logs_event.py +++ b/aws_lambda_powertools/utilities/trigger/cloud_watch_logs_event.py @@ -1,7 +1,7 @@ import base64 import json import zlib -from typing import Any, Dict, List, Optional +from typing import Dict, List, Optional from aws_lambda_powertools.utilities.trigger.common import DictWrapper @@ -78,10 +78,8 @@ class CloudWatchLogsEvent(dict): - https://docs.aws.amazon.com/lambda/latest/dg/services-cloudwatchlogs.html """ - def __init__(self, event: Dict[str, Any]): - super().__init__(event) - self._decompressed_logs_data = None - self._json_logs_data = None + _decompressed_logs_data = None + _json_logs_data = None @property def raw_logs_data(self) -> str: From 6306436b2da2951b7e69421a7e3c9c83ce399c44 Mon Sep 17 00:00:00 2001 From: Michael Brewer Date: Sun, 13 Sep 2020 23:52:41 -0700 Subject: [PATCH 27/30] fix(trigger): Add missing __eq__ We don't need to include `_decompressed_logs_data` and `_json_logs_data` in equality tests --- .../utilities/trigger/cloud_watch_logs_event.py | 3 +++ tests/functional/test_lambda_trigger_events.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/aws_lambda_powertools/utilities/trigger/cloud_watch_logs_event.py b/aws_lambda_powertools/utilities/trigger/cloud_watch_logs_event.py index 86c975895c8..0c62545b470 100644 --- a/aws_lambda_powertools/utilities/trigger/cloud_watch_logs_event.py +++ b/aws_lambda_powertools/utilities/trigger/cloud_watch_logs_event.py @@ -81,6 +81,9 @@ class CloudWatchLogsEvent(dict): _decompressed_logs_data = None _json_logs_data = None + def __eq__(self, other): + return super(CloudWatchLogsEvent, self).__eq__(other) + @property def raw_logs_data(self) -> str: """The value of the `data` field is a Base64 encoded ZIP archive.""" diff --git a/tests/functional/test_lambda_trigger_events.py b/tests/functional/test_lambda_trigger_events.py index 4fde880fa53..667de110897 100644 --- a/tests/functional/test_lambda_trigger_events.py +++ b/tests/functional/test_lambda_trigger_events.py @@ -61,6 +61,9 @@ def test_cloud_watch_trigger_event(): assert log_event.message == "[ERROR] First test message" assert log_event.extracted_fields is None + event2 = CloudWatchLogsEvent(load_event("cloudWatchLogEvent.json")) + assert event == event2 + def test_cognito_pre_signup_trigger_event(): event = PreSignUpTriggerEvent(load_event("cognitoPreSignUpEvent.json")) From 0ffb1be8fe4dd54c7626615076ee929ea65449c0 Mon Sep 17 00:00:00 2001 From: Michael Brewer Date: Mon, 14 Sep 2020 09:46:45 -0700 Subject: [PATCH 28/30] feat(trigger): unquote_plus s3 object key --- aws_lambda_powertools/utilities/trigger/s3_event.py | 7 ++++--- tests/functional/test_lambda_trigger_events.py | 8 ++++++++ 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/aws_lambda_powertools/utilities/trigger/s3_event.py b/aws_lambda_powertools/utilities/trigger/s3_event.py index 229ccebc729..20b3489a709 100644 --- a/aws_lambda_powertools/utilities/trigger/s3_event.py +++ b/aws_lambda_powertools/utilities/trigger/s3_event.py @@ -1,4 +1,5 @@ from typing import Dict, Iterator, Optional +from urllib.parse import unquote_plus from aws_lambda_powertools.utilities.trigger.common import DictWrapper @@ -180,9 +181,9 @@ def record(self) -> S3EventRecord: @property def bucket_name(self) -> str: """Get the bucket name for the first s3 event record""" - return self.record.s3.bucket.name + return self["Records"][0]["s3"]["bucket"]["name"] @property def object_key(self) -> str: - """Get the object key for the first s3 event record""" - return self.record.s3.get_object.key + """Get the object key for the first s3 event record and unquote plus""" + return unquote_plus(self["Records"][0]["s3"]["object"]["key"]) diff --git a/tests/functional/test_lambda_trigger_events.py b/tests/functional/test_lambda_trigger_events.py index 667de110897..54fe15a928d 100644 --- a/tests/functional/test_lambda_trigger_events.py +++ b/tests/functional/test_lambda_trigger_events.py @@ -2,6 +2,7 @@ import json import os from secrets import compare_digest +from urllib.parse import quote_plus from aws_lambda_powertools.utilities.trigger import ( ALBEvent, @@ -313,6 +314,13 @@ def test_s3_trigger_event(): assert event.object_key == "b21b84d653bb07b05b1e6b33684dc11b" +def test_s3_key_unquote_plus(): + tricky_name = "foo name+value" + event_dict = {"Records": [{"s3": {"object": {"key": quote_plus(tricky_name)}}}]} + event = S3Event(event_dict) + assert event.object_key == tricky_name + + def test_s3_glacier_event(): example_event = { "Records": [ From d0c9e5b638bd9a725111c613befb0706d8dcba0a Mon Sep 17 00:00:00 2001 From: Michael Brewer Date: Thu, 17 Sep 2020 13:42:06 -0700 Subject: [PATCH 29/30] refactor(data_classes): rename package from trigger --- .../{trigger => data_classes}/__init__.py | 0 .../{trigger => data_classes}/alb_event.py | 2 +- .../api_gateway_proxy_event.py | 2 +- .../cloud_watch_logs_event.py | 2 +- .../cognito_user_pool_event.py | 25 ++++++++++--------- .../{trigger => data_classes}/common.py | 3 +-- .../dynamo_db_stream_event.py | 2 +- .../event_bridge_event.py | 0 .../kinesis_stream_event.py | 2 +- .../{trigger => data_classes}/s3_event.py | 2 +- .../{trigger => data_classes}/ses_event.py | 4 +-- .../{trigger => data_classes}/sns_event.py | 2 +- .../{trigger => data_classes}/sqs_event.py | 2 +- .../functional/test_lambda_trigger_events.py | 8 +++--- 14 files changed, 28 insertions(+), 28 deletions(-) rename aws_lambda_powertools/utilities/{trigger => data_classes}/__init__.py (100%) rename aws_lambda_powertools/utilities/{trigger => data_classes}/alb_event.py (91%) rename aws_lambda_powertools/utilities/{trigger => data_classes}/api_gateway_proxy_event.py (99%) rename aws_lambda_powertools/utilities/{trigger => data_classes}/cloud_watch_logs_event.py (97%) rename aws_lambda_powertools/utilities/{trigger => data_classes}/cognito_user_pool_event.py (95%) rename aws_lambda_powertools/utilities/{trigger => data_classes}/common.py (97%) rename aws_lambda_powertools/utilities/{trigger => data_classes}/dynamo_db_stream_event.py (99%) rename aws_lambda_powertools/utilities/{trigger => data_classes}/event_bridge_event.py (100%) rename aws_lambda_powertools/utilities/{trigger => data_classes}/kinesis_stream_event.py (97%) rename aws_lambda_powertools/utilities/{trigger => data_classes}/s3_event.py (98%) rename aws_lambda_powertools/utilities/{trigger => data_classes}/ses_event.py (98%) rename aws_lambda_powertools/utilities/{trigger => data_classes}/sns_event.py (98%) rename aws_lambda_powertools/utilities/{trigger => data_classes}/sqs_event.py (98%) diff --git a/aws_lambda_powertools/utilities/trigger/__init__.py b/aws_lambda_powertools/utilities/data_classes/__init__.py similarity index 100% rename from aws_lambda_powertools/utilities/trigger/__init__.py rename to aws_lambda_powertools/utilities/data_classes/__init__.py diff --git a/aws_lambda_powertools/utilities/trigger/alb_event.py b/aws_lambda_powertools/utilities/data_classes/alb_event.py similarity index 91% rename from aws_lambda_powertools/utilities/trigger/alb_event.py rename to aws_lambda_powertools/utilities/data_classes/alb_event.py index 359796451dc..5de23dc3ab0 100644 --- a/aws_lambda_powertools/utilities/trigger/alb_event.py +++ b/aws_lambda_powertools/utilities/data_classes/alb_event.py @@ -1,6 +1,6 @@ from typing import Dict, List, Optional -from aws_lambda_powertools.utilities.trigger.common import BaseProxyEvent, DictWrapper +from aws_lambda_powertools.utilities.data_classes.common import BaseProxyEvent, DictWrapper class ALBEventRequestContext(DictWrapper): diff --git a/aws_lambda_powertools/utilities/trigger/api_gateway_proxy_event.py b/aws_lambda_powertools/utilities/data_classes/api_gateway_proxy_event.py similarity index 99% rename from aws_lambda_powertools/utilities/trigger/api_gateway_proxy_event.py rename to aws_lambda_powertools/utilities/data_classes/api_gateway_proxy_event.py index aeec67de19c..a253348fac4 100644 --- a/aws_lambda_powertools/utilities/trigger/api_gateway_proxy_event.py +++ b/aws_lambda_powertools/utilities/data_classes/api_gateway_proxy_event.py @@ -1,6 +1,6 @@ from typing import Any, Dict, List, Optional -from aws_lambda_powertools.utilities.trigger.common import BaseProxyEvent, DictWrapper +from aws_lambda_powertools.utilities.data_classes.common import BaseProxyEvent, DictWrapper class APIGatewayEventIdentity(DictWrapper): diff --git a/aws_lambda_powertools/utilities/trigger/cloud_watch_logs_event.py b/aws_lambda_powertools/utilities/data_classes/cloud_watch_logs_event.py similarity index 97% rename from aws_lambda_powertools/utilities/trigger/cloud_watch_logs_event.py rename to aws_lambda_powertools/utilities/data_classes/cloud_watch_logs_event.py index 0c62545b470..7c063b0a7a7 100644 --- a/aws_lambda_powertools/utilities/trigger/cloud_watch_logs_event.py +++ b/aws_lambda_powertools/utilities/data_classes/cloud_watch_logs_event.py @@ -3,7 +3,7 @@ import zlib from typing import Dict, List, Optional -from aws_lambda_powertools.utilities.trigger.common import DictWrapper +from aws_lambda_powertools.utilities.data_classes.common import DictWrapper class CloudWatchLogsLogEvent(DictWrapper): diff --git a/aws_lambda_powertools/utilities/trigger/cognito_user_pool_event.py b/aws_lambda_powertools/utilities/data_classes/cognito_user_pool_event.py similarity index 95% rename from aws_lambda_powertools/utilities/trigger/cognito_user_pool_event.py rename to aws_lambda_powertools/utilities/data_classes/cognito_user_pool_event.py index 036c04902eb..cc533fb4d13 100644 --- a/aws_lambda_powertools/utilities/trigger/cognito_user_pool_event.py +++ b/aws_lambda_powertools/utilities/data_classes/cognito_user_pool_event.py @@ -1,6 +1,6 @@ from typing import Any, Dict, List, Optional -from aws_lambda_powertools.utilities.trigger.common import DictWrapper +from aws_lambda_powertools.utilities.data_classes.common import DictWrapper class CallerContext(DictWrapper): @@ -68,7 +68,7 @@ def validation_data(self) -> Optional[Dict[str, str]]: @property def client_metadata(self) -> Optional[Dict[str, str]]: """One or more key-value pairs that you can provide as custom input to the Lambda function - that you specify for the pre sign-up trigger.""" + that you specify for the pre sign-up data_classes.""" return self["request"].get("clientMetadata") @@ -135,7 +135,7 @@ def user_attributes(self) -> Dict[str, str]: @property def client_metadata(self) -> Optional[Dict[str, str]]: """One or more key-value pairs that you can provide as custom input to the Lambda function - that you specify for the post confirmation trigger.""" + that you specify for the post confirmation data_classes.""" return self["request"].get("clientMetadata") @@ -172,7 +172,7 @@ def validation_data(self) -> Optional[Dict[str, str]]: @property def client_metadata(self) -> Optional[Dict[str, str]]: """One or more key-value pairs that you can provide as custom input to the Lambda function - that you specify for the pre sign-up trigger.""" + that you specify for the pre sign-up data_classes.""" return self["request"].get("clientMetadata") @@ -283,7 +283,7 @@ def user_attributes(self) -> Dict[str, str]: @property def client_metadata(self) -> Optional[Dict[str, str]]: """One or more key-value pairs that you can provide as custom input to the Lambda function - that you specify for the pre sign-up trigger.""" + that you specify for the pre sign-up data_classes.""" return self["request"].get("clientMetadata") @@ -329,9 +329,9 @@ class CustomMessageTriggerEvent(BaseTriggerEvent): - `CustomMessage_AdminCreateUser` To send the temporary password to a new user. - `CustomMessage_ResendCode` To resend the confirmation code to an existing user. - `CustomMessage_ForgotPassword` To send the confirmation code for Forgot Password request. - - `CustomMessage_UpdateUserAttribute` When a user's email or phone number is changed, this trigger sends a + - `CustomMessage_UpdateUserAttribute` When a user's email or phone number is changed, this data_classes sends a verification code automatically to the user. Cannot be used for other attributes. - - `CustomMessage_VerifyUserAttribute` This trigger sends a verification code to the user when they manually + - `CustomMessage_VerifyUserAttribute` This data_classes sends a verification code to the user when they manually request it for a new email or phone number. - `CustomMessage_Authentication` To send MFA code during authentication. @@ -369,7 +369,7 @@ def validation_data(self) -> Optional[Dict[str, str]]: class PreAuthenticationTriggerEvent(BaseTriggerEvent): """Pre Authentication Lambda Trigger - Amazon Cognito invokes this trigger when a user attempts to sign in, allowing custom validation + Amazon Cognito invokes this data_classes when a user attempts to sign in, allowing custom validation to accept or deny the authentication request. Notes: @@ -404,14 +404,15 @@ def user_attributes(self) -> Dict[str, str]: @property def client_metadata(self) -> Optional[Dict[str, str]]: """One or more key-value pairs that you can provide as custom input to the Lambda function - that you specify for the post authentication trigger.""" + that you specify for the post authentication data_classes.""" return self["request"].get("clientMetadata") class PostAuthenticationTriggerEvent(BaseTriggerEvent): """Post Authentication Lambda Trigger - Amazon Cognito invokes this trigger after signing in a user, allowing you to add custom logic after authentication. + Amazon Cognito invokes this data_classes after signing in a user, allowing you to add custom logic + after authentication. Notes: ---- @@ -461,7 +462,7 @@ def user_attributes(self) -> Dict[str, str]: @property def client_metadata(self) -> Optional[Dict[str, str]]: """One or more key-value pairs that you can provide as custom input to the Lambda function - that you specify for the pre token generation trigger.""" + that you specify for the pre token generation data_classes.""" return self["request"].get("clientMetadata") @@ -530,7 +531,7 @@ def claims_override_details(self) -> ClaimsOverrideDetails: class PreTokenGenerationTriggerEvent(BaseTriggerEvent): """Pre Token Generation Lambda Trigger - Amazon Cognito invokes this trigger before token generation allowing you to customize identity token claims. + Amazon Cognito invokes this data_classes before token generation allowing you to customize identity token claims. Notes: ---- diff --git a/aws_lambda_powertools/utilities/trigger/common.py b/aws_lambda_powertools/utilities/data_classes/common.py similarity index 97% rename from aws_lambda_powertools/utilities/trigger/common.py rename to aws_lambda_powertools/utilities/data_classes/common.py index 33803856da9..851a5e5bea4 100644 --- a/aws_lambda_powertools/utilities/trigger/common.py +++ b/aws_lambda_powertools/utilities/data_classes/common.py @@ -1,8 +1,7 @@ -from abc import ABC from typing import Any, Dict, Optional -class DictWrapper(ABC): +class DictWrapper: def __init__(self, data: Dict[str, Any]): self._data = data diff --git a/aws_lambda_powertools/utilities/trigger/dynamo_db_stream_event.py b/aws_lambda_powertools/utilities/data_classes/dynamo_db_stream_event.py similarity index 99% rename from aws_lambda_powertools/utilities/trigger/dynamo_db_stream_event.py rename to aws_lambda_powertools/utilities/data_classes/dynamo_db_stream_event.py index 8623d733a0e..e37b6d1f23e 100644 --- a/aws_lambda_powertools/utilities/trigger/dynamo_db_stream_event.py +++ b/aws_lambda_powertools/utilities/data_classes/dynamo_db_stream_event.py @@ -1,7 +1,7 @@ from enum import Enum from typing import Dict, Iterator, List, Optional -from aws_lambda_powertools.utilities.trigger.common import DictWrapper +from aws_lambda_powertools.utilities.data_classes.common import DictWrapper class AttributeValue(DictWrapper): diff --git a/aws_lambda_powertools/utilities/trigger/event_bridge_event.py b/aws_lambda_powertools/utilities/data_classes/event_bridge_event.py similarity index 100% rename from aws_lambda_powertools/utilities/trigger/event_bridge_event.py rename to aws_lambda_powertools/utilities/data_classes/event_bridge_event.py diff --git a/aws_lambda_powertools/utilities/trigger/kinesis_stream_event.py b/aws_lambda_powertools/utilities/data_classes/kinesis_stream_event.py similarity index 97% rename from aws_lambda_powertools/utilities/trigger/kinesis_stream_event.py rename to aws_lambda_powertools/utilities/data_classes/kinesis_stream_event.py index 14be80a5681..b8b1606f8e9 100644 --- a/aws_lambda_powertools/utilities/trigger/kinesis_stream_event.py +++ b/aws_lambda_powertools/utilities/data_classes/kinesis_stream_event.py @@ -2,7 +2,7 @@ import json from typing import Iterator -from aws_lambda_powertools.utilities.trigger.common import DictWrapper +from aws_lambda_powertools.utilities.data_classes.common import DictWrapper class KinesisStreamRecordPayload(DictWrapper): diff --git a/aws_lambda_powertools/utilities/trigger/s3_event.py b/aws_lambda_powertools/utilities/data_classes/s3_event.py similarity index 98% rename from aws_lambda_powertools/utilities/trigger/s3_event.py rename to aws_lambda_powertools/utilities/data_classes/s3_event.py index 20b3489a709..f6cf97e9e5d 100644 --- a/aws_lambda_powertools/utilities/trigger/s3_event.py +++ b/aws_lambda_powertools/utilities/data_classes/s3_event.py @@ -1,7 +1,7 @@ from typing import Dict, Iterator, Optional from urllib.parse import unquote_plus -from aws_lambda_powertools.utilities.trigger.common import DictWrapper +from aws_lambda_powertools.utilities.data_classes.common import DictWrapper class S3Identity(DictWrapper): diff --git a/aws_lambda_powertools/utilities/trigger/ses_event.py b/aws_lambda_powertools/utilities/data_classes/ses_event.py similarity index 98% rename from aws_lambda_powertools/utilities/trigger/ses_event.py rename to aws_lambda_powertools/utilities/data_classes/ses_event.py index 3a5ad3e2e40..fb738e43945 100644 --- a/aws_lambda_powertools/utilities/trigger/ses_event.py +++ b/aws_lambda_powertools/utilities/data_classes/ses_event.py @@ -1,6 +1,6 @@ from typing import Iterator, List -from aws_lambda_powertools.utilities.trigger.common import DictWrapper +from aws_lambda_powertools.utilities.data_classes.common import DictWrapper class SESMailHeader(DictWrapper): @@ -193,7 +193,7 @@ def ses(self) -> SESMessage: class SESEvent(dict): - """Amazon SES to receive message event trigger + """Amazon SES to receive message event data_classes NOTE: There is a 30-second timeout on RequestResponse invocations. diff --git a/aws_lambda_powertools/utilities/trigger/sns_event.py b/aws_lambda_powertools/utilities/data_classes/sns_event.py similarity index 98% rename from aws_lambda_powertools/utilities/trigger/sns_event.py rename to aws_lambda_powertools/utilities/data_classes/sns_event.py index 7f6c70f8a6d..191c2f43425 100644 --- a/aws_lambda_powertools/utilities/trigger/sns_event.py +++ b/aws_lambda_powertools/utilities/data_classes/sns_event.py @@ -1,6 +1,6 @@ from typing import Dict, Iterator -from aws_lambda_powertools.utilities.trigger.common import DictWrapper +from aws_lambda_powertools.utilities.data_classes.common import DictWrapper class SNSMessageAttribute(DictWrapper): diff --git a/aws_lambda_powertools/utilities/trigger/sqs_event.py b/aws_lambda_powertools/utilities/data_classes/sqs_event.py similarity index 98% rename from aws_lambda_powertools/utilities/trigger/sqs_event.py rename to aws_lambda_powertools/utilities/data_classes/sqs_event.py index 34d6289a8bb..095526258ee 100644 --- a/aws_lambda_powertools/utilities/trigger/sqs_event.py +++ b/aws_lambda_powertools/utilities/data_classes/sqs_event.py @@ -1,6 +1,6 @@ from typing import Dict, Iterator, Optional -from aws_lambda_powertools.utilities.trigger.common import DictWrapper +from aws_lambda_powertools.utilities.data_classes.common import DictWrapper class SQSRecordAttributes(DictWrapper): diff --git a/tests/functional/test_lambda_trigger_events.py b/tests/functional/test_lambda_trigger_events.py index 54fe15a928d..f64a319eed1 100644 --- a/tests/functional/test_lambda_trigger_events.py +++ b/tests/functional/test_lambda_trigger_events.py @@ -4,7 +4,7 @@ from secrets import compare_digest from urllib.parse import quote_plus -from aws_lambda_powertools.utilities.trigger import ( +from aws_lambda_powertools.utilities.data_classes import ( ALBEvent, APIGatewayProxyEvent, APIGatewayProxyEventV2, @@ -16,7 +16,7 @@ SNSEvent, SQSEvent, ) -from aws_lambda_powertools.utilities.trigger.cognito_user_pool_event import ( +from aws_lambda_powertools.utilities.data_classes.cognito_user_pool_event import ( CustomMessageTriggerEvent, PostAuthenticationTriggerEvent, PostConfirmationTriggerEvent, @@ -25,8 +25,8 @@ PreTokenGenerationTriggerEvent, UserMigrationTriggerEvent, ) -from aws_lambda_powertools.utilities.trigger.common import BaseProxyEvent -from aws_lambda_powertools.utilities.trigger.dynamo_db_stream_event import ( +from aws_lambda_powertools.utilities.data_classes.common import BaseProxyEvent +from aws_lambda_powertools.utilities.data_classes.dynamo_db_stream_event import ( AttributeValue, DynamoDBRecordEventName, DynamoDBStreamEvent, From e860731f2cf04dea5453e237b3658b9cb0ab34e0 Mon Sep 17 00:00:00 2001 From: Michael Brewer Date: Thu, 17 Sep 2020 14:03:36 -0700 Subject: [PATCH 30/30] refactor(data_classes): Use DictWrapper consistently --- .../utilities/data_classes/cloud_watch_logs_event.py | 7 ++----- .../utilities/data_classes/cognito_user_pool_event.py | 2 +- aws_lambda_powertools/utilities/data_classes/common.py | 4 +++- .../utilities/data_classes/dynamo_db_stream_event.py | 2 +- .../utilities/data_classes/event_bridge_event.py | 4 +++- .../utilities/data_classes/kinesis_stream_event.py | 2 +- aws_lambda_powertools/utilities/data_classes/s3_event.py | 2 +- aws_lambda_powertools/utilities/data_classes/ses_event.py | 2 +- aws_lambda_powertools/utilities/data_classes/sns_event.py | 2 +- aws_lambda_powertools/utilities/data_classes/sqs_event.py | 2 +- tests/functional/test_lambda_trigger_events.py | 8 ++++---- 11 files changed, 19 insertions(+), 18 deletions(-) diff --git a/aws_lambda_powertools/utilities/data_classes/cloud_watch_logs_event.py b/aws_lambda_powertools/utilities/data_classes/cloud_watch_logs_event.py index 7c063b0a7a7..978f6956fc2 100644 --- a/aws_lambda_powertools/utilities/data_classes/cloud_watch_logs_event.py +++ b/aws_lambda_powertools/utilities/data_classes/cloud_watch_logs_event.py @@ -29,7 +29,7 @@ def extracted_fields(self) -> Optional[Dict[str, str]]: return self.get("extractedFields") -class CloudWatchLogsDecodedData(dict): +class CloudWatchLogsDecodedData(DictWrapper): @property def owner(self) -> str: """The AWS Account ID of the originating log data.""" @@ -68,7 +68,7 @@ def log_events(self) -> List[CloudWatchLogsLogEvent]: return [CloudWatchLogsLogEvent(i) for i in self["logEvents"]] -class CloudWatchLogsEvent(dict): +class CloudWatchLogsEvent(DictWrapper): """CloudWatch Logs log stream event You can use a Lambda function to monitor and analyze logs from an Amazon CloudWatch Logs log stream. @@ -81,9 +81,6 @@ class CloudWatchLogsEvent(dict): _decompressed_logs_data = None _json_logs_data = None - def __eq__(self, other): - return super(CloudWatchLogsEvent, self).__eq__(other) - @property def raw_logs_data(self) -> str: """The value of the `data` field is a Base64 encoded ZIP archive.""" diff --git a/aws_lambda_powertools/utilities/data_classes/cognito_user_pool_event.py b/aws_lambda_powertools/utilities/data_classes/cognito_user_pool_event.py index cc533fb4d13..7bf38715006 100644 --- a/aws_lambda_powertools/utilities/data_classes/cognito_user_pool_event.py +++ b/aws_lambda_powertools/utilities/data_classes/cognito_user_pool_event.py @@ -15,7 +15,7 @@ def client_id(self) -> str: return self["callerContext"]["clientId"] -class BaseTriggerEvent(dict): +class BaseTriggerEvent(DictWrapper): """Common attributes shared by all User Pool Lambda Trigger Events Documentation: diff --git a/aws_lambda_powertools/utilities/data_classes/common.py b/aws_lambda_powertools/utilities/data_classes/common.py index 851a5e5bea4..73cf1b339ff 100644 --- a/aws_lambda_powertools/utilities/data_classes/common.py +++ b/aws_lambda_powertools/utilities/data_classes/common.py @@ -2,6 +2,8 @@ class DictWrapper: + """Provides a single read only access to a wrapper dict""" + def __init__(self, data: Dict[str, Any]): self._data = data @@ -12,7 +14,7 @@ def get(self, key: str) -> Optional[Any]: return self._data.get(key) -class BaseProxyEvent(dict): +class BaseProxyEvent(DictWrapper): @property def headers(self) -> Dict[str, str]: return self["headers"] diff --git a/aws_lambda_powertools/utilities/data_classes/dynamo_db_stream_event.py b/aws_lambda_powertools/utilities/data_classes/dynamo_db_stream_event.py index e37b6d1f23e..db581ceaf7d 100644 --- a/aws_lambda_powertools/utilities/data_classes/dynamo_db_stream_event.py +++ b/aws_lambda_powertools/utilities/data_classes/dynamo_db_stream_event.py @@ -218,7 +218,7 @@ def user_identity(self) -> Optional[dict]: return self.get("userIdentity") -class DynamoDBStreamEvent(dict): +class DynamoDBStreamEvent(DictWrapper): """Dynamo DB Stream Event Documentation: diff --git a/aws_lambda_powertools/utilities/data_classes/event_bridge_event.py b/aws_lambda_powertools/utilities/data_classes/event_bridge_event.py index 2c5ff03dd86..cb299309a69 100644 --- a/aws_lambda_powertools/utilities/data_classes/event_bridge_event.py +++ b/aws_lambda_powertools/utilities/data_classes/event_bridge_event.py @@ -1,7 +1,9 @@ from typing import Any, Dict, List +from aws_lambda_powertools.utilities.data_classes.common import DictWrapper -class EventBridgeEvent(dict): + +class EventBridgeEvent(DictWrapper): """Amazon EventBridge Event Documentation: diff --git a/aws_lambda_powertools/utilities/data_classes/kinesis_stream_event.py b/aws_lambda_powertools/utilities/data_classes/kinesis_stream_event.py index b8b1606f8e9..6af1484f155 100644 --- a/aws_lambda_powertools/utilities/data_classes/kinesis_stream_event.py +++ b/aws_lambda_powertools/utilities/data_classes/kinesis_stream_event.py @@ -82,7 +82,7 @@ def kinesis(self) -> KinesisStreamRecordPayload: return KinesisStreamRecordPayload(self._data) -class KinesisStreamEvent(dict): +class KinesisStreamEvent(DictWrapper): """Kinesis stream event Documentation: diff --git a/aws_lambda_powertools/utilities/data_classes/s3_event.py b/aws_lambda_powertools/utilities/data_classes/s3_event.py index f6cf97e9e5d..2670142d575 100644 --- a/aws_lambda_powertools/utilities/data_classes/s3_event.py +++ b/aws_lambda_powertools/utilities/data_classes/s3_event.py @@ -158,7 +158,7 @@ def glacier_event_data(self) -> Optional[S3EventRecordGlacierEventData]: return None if item is None else S3EventRecordGlacierEventData(item) -class S3Event(dict): +class S3Event(DictWrapper): """S3 event notification Documentation: diff --git a/aws_lambda_powertools/utilities/data_classes/ses_event.py b/aws_lambda_powertools/utilities/data_classes/ses_event.py index fb738e43945..518981618dc 100644 --- a/aws_lambda_powertools/utilities/data_classes/ses_event.py +++ b/aws_lambda_powertools/utilities/data_classes/ses_event.py @@ -192,7 +192,7 @@ def ses(self) -> SESMessage: return SESMessage(self._data) -class SESEvent(dict): +class SESEvent(DictWrapper): """Amazon SES to receive message event data_classes NOTE: There is a 30-second timeout on RequestResponse invocations. diff --git a/aws_lambda_powertools/utilities/data_classes/sns_event.py b/aws_lambda_powertools/utilities/data_classes/sns_event.py index 191c2f43425..e96b096fe6b 100644 --- a/aws_lambda_powertools/utilities/data_classes/sns_event.py +++ b/aws_lambda_powertools/utilities/data_classes/sns_event.py @@ -99,7 +99,7 @@ def sns(self) -> SNSMessage: return SNSMessage(self._data) -class SNSEvent(dict): +class SNSEvent(DictWrapper): """SNS Event Documentation: diff --git a/aws_lambda_powertools/utilities/data_classes/sqs_event.py b/aws_lambda_powertools/utilities/data_classes/sqs_event.py index 095526258ee..778b8f56f36 100644 --- a/aws_lambda_powertools/utilities/data_classes/sqs_event.py +++ b/aws_lambda_powertools/utilities/data_classes/sqs_event.py @@ -134,7 +134,7 @@ def aws_region(self) -> str: return self["awsRegion"] -class SQSEvent(dict): +class SQSEvent(DictWrapper): """SQS Event Documentation: diff --git a/tests/functional/test_lambda_trigger_events.py b/tests/functional/test_lambda_trigger_events.py index f64a319eed1..21e775b7a5f 100644 --- a/tests/functional/test_lambda_trigger_events.py +++ b/tests/functional/test_lambda_trigger_events.py @@ -47,7 +47,7 @@ def test_cloud_watch_trigger_event(): assert event.decompress_logs_data == decompressed_logs_data json_logs_data = event.parse_logs_data() - assert event.parse_logs_data() == json_logs_data + assert event.parse_logs_data()._data == json_logs_data._data log_events = json_logs_data.log_events log_event = log_events[0] @@ -63,7 +63,7 @@ def test_cloud_watch_trigger_event(): assert log_event.extracted_fields is None event2 = CloudWatchLogsEvent(load_event("cloudWatchLogEvent.json")) - assert event == event2 + assert event._data == event2._data def test_cognito_pre_signup_trigger_event(): @@ -556,7 +556,7 @@ def test_base_proxy_event_get_query_string_value(): value = event.get_query_string_value("test", default_value) assert value == default_value - event["queryStringParameters"] = {"test": set_value} + event._data["queryStringParameters"] = {"test": set_value} value = event.get_query_string_value("test", default_value) assert value == set_value @@ -575,7 +575,7 @@ def test_base_proxy_event_get_header_value(): value = event.get_header_value("test", default_value) assert value == default_value - event["headers"] = {"test": set_value} + event._data["headers"] = {"test": set_value} value = event.get_header_value("test", default_value) assert value == set_value