diff --git a/aws_lambda_powertools/logging/correlation_paths.py b/aws_lambda_powertools/logging/correlation_paths.py index 73227754363..cbccd85637f 100644 --- a/aws_lambda_powertools/logging/correlation_paths.py +++ b/aws_lambda_powertools/logging/correlation_paths.py @@ -2,5 +2,6 @@ API_GATEWAY_REST = "requestContext.requestId" API_GATEWAY_HTTP = API_GATEWAY_REST -APPLICATION_LOAD_BALANCER = "headers.x-amzn-trace-id" +APPSYNC_RESOLVER = 'request.headers."x-amzn-trace-id"' +APPLICATION_LOAD_BALANCER = 'headers."x-amzn-trace-id"' EVENT_BRIDGE = "id" diff --git a/aws_lambda_powertools/utilities/data_classes/__init__.py b/aws_lambda_powertools/utilities/data_classes/__init__.py index 9c74983f3a9..28179bfd291 100644 --- a/aws_lambda_powertools/utilities/data_classes/__init__.py +++ b/aws_lambda_powertools/utilities/data_classes/__init__.py @@ -1,3 +1,5 @@ +from aws_lambda_powertools.utilities.data_classes.appsync_resolver_event import AppSyncResolverEvent + from .alb_event import ALBEvent from .api_gateway_proxy_event import APIGatewayProxyEvent, APIGatewayProxyEventV2 from .cloud_watch_logs_event import CloudWatchLogsEvent @@ -13,6 +15,7 @@ __all__ = [ "APIGatewayProxyEvent", "APIGatewayProxyEventV2", + "AppSyncResolverEvent", "ALBEvent", "CloudWatchLogsEvent", "ConnectContactFlowEvent", diff --git a/aws_lambda_powertools/utilities/data_classes/appsync/__init__.py b/aws_lambda_powertools/utilities/data_classes/appsync/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/aws_lambda_powertools/utilities/data_classes/appsync/resolver_utils.py b/aws_lambda_powertools/utilities/data_classes/appsync/resolver_utils.py new file mode 100644 index 00000000000..848a24619d4 --- /dev/null +++ b/aws_lambda_powertools/utilities/data_classes/appsync/resolver_utils.py @@ -0,0 +1,76 @@ +import datetime +import time +import uuid +from typing import Any, Dict + +from aws_lambda_powertools.utilities.data_classes import AppSyncResolverEvent +from aws_lambda_powertools.utilities.typing import LambdaContext + + +def make_id(): + return str(uuid.uuid4()) + + +def aws_date(): + now = datetime.datetime.utcnow().date() + return now.strftime("%Y-%m-%d") + + +def aws_time(): + now = datetime.datetime.utcnow().time() + return now.strftime("%H:%M:%S") + + +def aws_datetime(): + now = datetime.datetime.utcnow() + return now.strftime("%Y-%m-%dT%H:%M:%SZ") + + +def aws_timestamp(): + return int(time.time()) + + +class AppSyncResolver: + def __init__(self): + self._resolvers: dict = {} + + def resolver( + self, + type_name: str = "*", + field_name: str = None, + include_event: bool = False, + include_context: bool = False, + **kwargs, + ): + def register_resolver(func): + kwargs["include_event"] = include_event + kwargs["include_context"] = include_context + self._resolvers[f"{type_name}.{field_name}"] = { + "func": func, + "config": kwargs, + } + return func + + return register_resolver + + def resolve(self, event: dict, context: LambdaContext) -> Any: + event = AppSyncResolverEvent(event) + resolver, config = self._resolver(event.type_name, event.field_name) + kwargs = self._kwargs(event, context, config) + return resolver(**kwargs) + + def _resolver(self, type_name: str, field_name: str) -> tuple: + full_name = f"{type_name}.{field_name}" + resolver = self._resolvers.get(full_name, self._resolvers.get(f"*.{field_name}")) + if not resolver: + raise ValueError(f"No resolver found for '{full_name}'") + return resolver["func"], resolver["config"] + + @staticmethod + def _kwargs(event: AppSyncResolverEvent, context: LambdaContext, config: dict) -> Dict[str, Any]: + kwargs = {**event.arguments} + if config.get("include_event", False): + kwargs["event"] = event + if config.get("include_context", False): + kwargs["context"] = context + return kwargs diff --git a/aws_lambda_powertools/utilities/data_classes/appsync_resolver_event.py b/aws_lambda_powertools/utilities/data_classes/appsync_resolver_event.py new file mode 100644 index 00000000000..a6f9e6c58a4 --- /dev/null +++ b/aws_lambda_powertools/utilities/data_classes/appsync_resolver_event.py @@ -0,0 +1,232 @@ +from typing import Any, Dict, List, Optional, Union + +from aws_lambda_powertools.utilities.data_classes.common import DictWrapper, get_header_value + + +def get_identity_object(identity: Optional[dict]) -> Any: + """Get the identity object based on the best detected type""" + # API_KEY authorization + if identity is None: + return None + + # AMAZON_COGNITO_USER_POOLS authorization + if "sub" in identity: + return AppSyncIdentityCognito(identity) + + # AWS_IAM authorization + return AppSyncIdentityIAM(identity) + + +class AppSyncIdentityIAM(DictWrapper): + """AWS_IAM authorization""" + + @property + def source_ip(self) -> List[str]: + """The source IP address of the caller received by AWS AppSync. """ + return self["sourceIp"] + + @property + def username(self) -> str: + """The user name of the authenticated user. IAM user principal""" + return self["username"] + + @property + def account_id(self) -> str: + """The AWS account ID of the caller.""" + return self["accountId"] + + @property + def cognito_identity_pool_id(self) -> str: + """The Amazon Cognito identity pool ID associated with the caller.""" + return self["cognitoIdentityPoolId"] + + @property + def cognito_identity_id(self) -> str: + """The Amazon Cognito identity ID of the caller.""" + return self["cognitoIdentityId"] + + @property + def user_arn(self) -> str: + """The ARN of the IAM user.""" + return self["userArn"] + + @property + def cognito_identity_auth_type(self) -> str: + """Either authenticated or unauthenticated based on the identity type.""" + return self["cognitoIdentityAuthType"] + + @property + def cognito_identity_auth_provider(self) -> str: + """A comma separated list of external identity provider information used in obtaining the + credentials used to sign the request.""" + return self["cognitoIdentityAuthProvider"] + + +class AppSyncIdentityCognito(DictWrapper): + """AMAZON_COGNITO_USER_POOLS authorization""" + + @property + def source_ip(self) -> List[str]: + """The source IP address of the caller received by AWS AppSync. """ + return self["sourceIp"] + + @property + def username(self) -> str: + """The user name of the authenticated user.""" + return self["username"] + + @property + def sub(self) -> str: + """The UUID of the authenticated user.""" + return self["sub"] + + @property + def claims(self) -> Dict[str, str]: + """The claims that the user has.""" + return self["claims"] + + @property + def default_auth_strategy(self) -> str: + """The default authorization strategy for this caller (ALLOW or DENY).""" + return self["defaultAuthStrategy"] + + @property + def groups(self) -> List[str]: + """List of OIDC groups""" + return self["groups"] + + @property + def issuer(self) -> str: + """The token issuer.""" + return self["issuer"] + + +class AppSyncResolverEventInfo(DictWrapper): + """The info section contains information about the GraphQL request""" + + @property + def field_name(self) -> str: + """The name of the field that is currently being resolved.""" + return self["fieldName"] + + @property + def parent_type_name(self) -> str: + """The name of the parent type for the field that is currently being resolved.""" + return self["parentTypeName"] + + @property + def variables(self) -> Dict[str, str]: + """A map which holds all variables that are passed into the GraphQL request.""" + return self.get("variables") + + @property + def selection_set_list(self) -> List[str]: + """A list representation of the fields in the GraphQL selection set. Fields that are aliased will + only be referenced by the alias name, not the field name.""" + return self.get("selectionSetList") + + @property + def selection_set_graphql(self) -> Optional[str]: + """A string representation of the selection set, formatted as GraphQL schema definition language (SDL). + Although fragments are not be merged into the selection set, inline fragments are preserved.""" + return self.get("selectionSetGraphQL") + + +class AppSyncResolverEvent(DictWrapper): + """AppSync resolver event + + **NOTE:** AppSync Resolver Events can come in various shapes this data class + supports both Amplify GraphQL directive @function and Direct Lambda Resolver + + Documentation: + ------------- + - https://docs.aws.amazon.com/appsync/latest/devguide/resolver-context-reference.html + - https://docs.amplify.aws/cli/graphql-transformer/function#structure-of-the-function-event + """ + + def __init__(self, data: dict): + super().__init__(data) + + info: dict = data.get("info") + if not info: + info = {"fieldName": self.get("fieldName"), "parentTypeName": self.get("typeName")} + + self._info = AppSyncResolverEventInfo(info) + + @property + def type_name(self) -> str: + """The name of the parent type for the field that is currently being resolved.""" + return self.info.parent_type_name + + @property + def field_name(self) -> str: + """The name of the field that is currently being resolved.""" + return self.info.field_name + + @property + def arguments(self) -> Dict[str, any]: + """A map that contains all GraphQL arguments for this field.""" + return self["arguments"] + + @property + def identity(self) -> Union[None, AppSyncIdentityIAM, AppSyncIdentityCognito]: + """An object that contains information about the caller. + + Depending of the type of identify found: + + - API_KEY authorization - returns None + - AWS_IAM authorization - returns AppSyncIdentityIAM + - AMAZON_COGNITO_USER_POOLS authorization - returns AppSyncIdentityCognito + """ + return get_identity_object(self.get("identity")) + + @property + def source(self) -> Dict[str, any]: + """A map that contains the resolution of the parent field.""" + return self.get("source") + + @property + def request_headers(self) -> Dict[str, str]: + """Request headers""" + return self["request"]["headers"] + + @property + def prev_result(self) -> Optional[Dict[str, any]]: + """It represents the result of whatever previous operation was executed in a pipeline resolver.""" + prev = self.get("prev") + if not prev: + return None + return prev.get("result") + + @property + def info(self) -> AppSyncResolverEventInfo: + """The info section contains information about the GraphQL request.""" + return self._info + + @property + def stash(self) -> Optional[dict]: + """The stash is a map that is made available inside each resolver and function mapping template. + The same stash instance lives through a single resolver execution. This means that you can use the + stash to pass arbitrary data across request and response mapping templates, and across functions in + a pipeline resolver.""" + return self.get("stash") + + def get_header_value( + self, name: str, default_value: Optional[str] = None, case_sensitive: Optional[bool] = False + ) -> Optional[str]: + """Get header value by name + + Parameters + ---------- + name: str + Header name + default_value: str, optional + Default value if no value was found by name + case_sensitive: bool + Whether to use a case sensitive look up + Returns + ------- + str, optional + Header value + """ + return get_header_value(self.request_headers, name, default_value, case_sensitive) diff --git a/aws_lambda_powertools/utilities/data_classes/common.py b/aws_lambda_powertools/utilities/data_classes/common.py index 94a357b3180..da51d8e0d5b 100644 --- a/aws_lambda_powertools/utilities/data_classes/common.py +++ b/aws_lambda_powertools/utilities/data_classes/common.py @@ -20,6 +20,21 @@ def get(self, key: str) -> Optional[Any]: return self._data.get(key) +def get_header_value(headers: Dict[str, str], name: str, default_value: str, case_sensitive: bool) -> Optional[str]: + """Get header value by name""" + if case_sensitive: + return headers.get(name, default_value) + + name_lower = name.lower() + + return next( + # Iterate over the dict and do a case insensitive key comparison + (value for key, value in headers.items() if key.lower() == name_lower), + # Default value is returned if no matches was found + default_value, + ) + + class BaseProxyEvent(DictWrapper): @property def headers(self) -> Dict[str, str]: @@ -72,7 +87,4 @@ def get_header_value( str, optional Header value """ - if case_sensitive: - return self.headers.get(name, default_value) - - return next((value for key, value in self.headers.items() if name.lower() == key.lower()), default_value) + return get_header_value(self.headers, name, default_value, case_sensitive) diff --git a/docs/core/logger.md b/docs/core/logger.md index 27cbd725f80..dc5e0c0d647 100644 --- a/docs/core/logger.md +++ b/docs/core/logger.md @@ -516,7 +516,7 @@ When logging exceptions, Logger will add new keys named `exception_name` and `ex "timestamp": "2020-08-28 18:11:38,886", "service": "service_undefined", "sampling_rate": 0.0, - "exception_name":"ValueError", + "exception_name": "ValueError", "exception": "Traceback (most recent call last):\n File \"\", line 2, in \nValueError: something went wrong" } ``` diff --git a/docs/utilities/data_classes.md b/docs/utilities/data_classes.md index ca71a928a41..92d4c8b0f70 100644 --- a/docs/utilities/data_classes.md +++ b/docs/utilities/data_classes.md @@ -3,7 +3,8 @@ title: Event Source Data Classes description: Utility --- -The event source data classes utility provides classes describing the schema of common Lambda events triggers. +Event Source Data Classes utility provides classes self-describing Lambda event sources, including API decorators when +applicable. ## Key Features @@ -50,8 +51,10 @@ Event Source | Data_class ------------------------------------------------- | --------------------------------------------------------------------------------- [API Gateway Proxy](#api-gateway-proxy) | `APIGatewayProxyEvent` [API Gateway Proxy event v2](#api-gateway-proxy-v2) | `APIGatewayProxyEventV2` +[AppSync Resolver](#appsync-resolver) | `AppSyncResolverEvent` [CloudWatch Logs](#cloudwatch-logs) | `CloudWatchLogsEvent` [Cognito User Pool](#cognito-user-pool) | Multiple available under `cognito_user_pool_event` +[Connect Contact Flow](#connect-contact-flow) | `ConnectContactFlowEvent` [DynamoDB streams](#dynamodb-streams) | `DynamoDBStreamEvent`, `DynamoDBRecordEventName` [EventBridge](#eventbridge) | `EventBridgeEvent` [Kinesis Data Stream](#kinesis-streams) | `KinesisStreamEvent` @@ -68,9 +71,9 @@ Event Source | Data_class ### API Gateway Proxy -Typically used for API Gateway REST API or HTTP API using v1 proxy event. +It is used for either API Gateway REST API or HTTP API using v1 proxy event. -=== "lambda_app.py" +=== "app.py" ```python from aws_lambda_powertools.utilities.data_classes import APIGatewayProxyEvent @@ -87,7 +90,7 @@ Typically used for API Gateway REST API or HTTP API using v1 proxy event. ### API Gateway Proxy v2 -=== "lambda_app.py" +=== "app.py" ```python from aws_lambda_powertools.utilities.data_classes import APIGatewayProxyEventV2 @@ -101,12 +104,106 @@ Typically used for API Gateway REST API or HTTP API using v1 proxy event. do_something_with(event.body, query_string_parameters) ``` +### AppSync Resolver + +Used when building a Lambda GraphQL Resolvers with [Amplify GraphQL Transform Library](https://docs.amplify.aws/cli/graphql-transformer/function){target="_blank"} +and can also be used for [AppSync Direct Lambda Resolvers](https://aws.amazon.com/blogs/mobile/appsync-direct-lambda/){target="_blank"}. + +=== "app.py" + + ```python hl_lines="2-5 12 14 19 21 29-30" + from aws_lambda_powertools.logging import Logger, correlation_paths + from aws_lambda_powertools.utilities.data_classes.appsync_resolver_event import ( + AppSyncResolverEvent, + AppSyncIdentityCognito + ) + + logger = Logger() + + def get_locations(name: str = None, size: int = 0, page: int = 0): + """Your resolver logic here""" + + @logger.inject_lambda_context(correlation_id_path=correlation_paths.APPSYNC_RESOLVER) + def lambda_handler(event, context): + event: AppSyncResolverEvent = AppSyncResolverEvent(event) + + # Case insensitive look up of request headers + x_forwarded_for = event.get_header_value("x-forwarded-for") + + # Support for AppSyncIdentityCognito or AppSyncIdentityIAM identity types + assert isinstance(event.identity, AppSyncIdentityCognito) + identity: AppSyncIdentityCognito = event.identity + + # Logging with correlation_id + logger.debug({ + "x-forwarded-for": x_forwarded_for, + "username": identity.username + }) + + if event.type_name == "Merchant" and event.field_name == "locations": + return get_locations(**event.arguments) + + raise ValueError(f"Unsupported field resolver: {event.field_name}") + + ``` +=== "Example AppSync Event" + + ```json hl_lines="2-8 14 19 20" + { + "typeName": "Merchant", + "fieldName": "locations", + "arguments": { + "page": 2, + "size": 1, + "name": "value" + }, + "identity": { + "claims": { + "iat": 1615366261 + ... + }, + "username": "mike", + ... + }, + "request": { + "headers": { + "x-amzn-trace-id": "Root=1-60488877-0b0c4e6727ab2a1c545babd0", + "x-forwarded-for": "127.0.0.1" + ... + } + }, + ... + } + ``` + +=== "Example CloudWatch Log" + + ```json hl_lines="5 6 16" + { + "level":"DEBUG", + "location":"lambda_handler:22", + "message":{ + "x-forwarded-for":"127.0.0.1", + "username":"mike" + }, + "timestamp":"2021-03-10 12:38:40,062", + "service":"service_undefined", + "sampling_rate":0.0, + "cold_start":true, + "function_name":"func_name", + "function_memory_size":512, + "function_arn":"func_arn", + "function_request_id":"6735a29c-c000-4ae3-94e6-1f1c934f7f94", + "correlation_id":"Root=1-60488877-0b0c4e6727ab2a1c545babd0" + } + ``` + ### CloudWatch Logs CloudWatch Logs events by default are compressed and base64 encoded. You can use the helper function provided to decode, decompress and parse json data from the event. -=== "lambda_app.py" +=== "app.py" ```python from aws_lambda_powertools.utilities.data_classes import CloudWatchLogsEvent @@ -139,7 +236,7 @@ Define Auth Challenge | `data_classes.cognito_user_pool_event.DefineAuthChalleng Create Auth Challenge | `data_classes.cognito_user_pool_event.CreateAuthChallengeTriggerEvent` Verify Auth Challenge | `data_classes.cognito_user_pool_event.VerifyAuthChallengeResponseTriggerEvent` -=== "lambda_app.py" +=== "app.py" ```python from aws_lambda_powertools.utilities.data_classes.cognito_user_pool_event import PostConfirmationTriggerEvent @@ -151,13 +248,33 @@ Verify Auth Challenge | `data_classes.cognito_user_pool_event.VerifyAuthChalleng do_something_with(user_attributes) ``` +### Connect Contact Flow + +=== "app.py" + + ```python + from aws_lambda_powertools.utilities.data_classes.connect_contact_flow_event import ( + ConnectContactFlowChannel, + ConnectContactFlowEndpointType, + ConnectContactFlowEvent, + ConnectContactFlowInitiationMethod, + ) + + def lambda_handler(event, context): + event: ConnectContactFlowEvent = ConnectContactFlowEvent(event) + assert event.contact_data.attributes == {"Language": "en-US"} + assert event.contact_data.channel == ConnectContactFlowChannel.VOICE + assert event.contact_data.customer_endpoint.endpoint_type == ConnectContactFlowEndpointType.TELEPHONE_NUMBER + assert event.contact_data.initiation_method == ConnectContactFlowInitiationMethod.API + ``` + ### DynamoDB Streams The DynamoDB data class utility provides the base class for `DynamoDBStreamEvent`, a typed class for attributes values (`AttributeValue`), as well as enums for stream view type (`StreamViewType`) and event type (`DynamoDBRecordEventName`). -=== "lambda_app.py" +=== "app.py" ```python from aws_lambda_powertools.utilities.data_classes.dynamo_db_stream_event import ( @@ -177,7 +294,7 @@ attributes values (`AttributeValue`), as well as enums for stream view type (`St ### EventBridge -=== "lambda_app.py" +=== "app.py" ```python from aws_lambda_powertools.utilities.data_classes import EventBridgeEvent @@ -193,12 +310,11 @@ attributes values (`AttributeValue`), as well as enums for stream view type (`St Kinesis events by default contain base64 encoded data. You can use the helper function to access the data either as json or plain text, depending on the original payload. -=== "lambda_app.py" +=== "app.py" ```python from aws_lambda_powertools.utilities.data_classes import KinesisStreamEvent - def lambda_handler(event, context): event: KinesisStreamEvent = KinesisStreamEvent(event) kinesis_record = next(event.records).kinesis @@ -214,9 +330,10 @@ or plain text, depending on the original payload. ### S3 -=== "lambda_app.py" +=== "app.py" ```python + from urllib.parse import unquote_plus from aws_lambda_powertools.utilities.data_classes import S3Event def lambda_handler(event, context): @@ -225,14 +342,14 @@ or plain text, depending on the original payload. # Multiple records can be delivered in a single event for record in event.records: - object_key = record.s3.get_object.key + object_key = unquote_plus(record.s3.get_object.key) do_something_with(f'{bucket_name}/{object_key}') ``` ### SES -=== "lambda_app.py" +=== "app.py" ```python from aws_lambda_powertools.utilities.data_classes import SESEvent @@ -250,7 +367,7 @@ or plain text, depending on the original payload. ### SNS -=== "lambda_app.py" +=== "app.py" ```python from aws_lambda_powertools.utilities.data_classes import SNSEvent @@ -268,7 +385,7 @@ or plain text, depending on the original payload. ### SQS -=== "lambda_app.py" +=== "app.py" ```python from aws_lambda_powertools.utilities.data_classes import SQSEvent @@ -280,25 +397,3 @@ or plain text, depending on the original payload. for record in event.records: do_something_with(record.body) ``` - -### Connect - -**Connect Contact Flow** - -=== "lambda_app.py" - - ```python - from aws_lambda_powertools.utilities.data_classes.connect_contact_flow_event import ( - ConnectContactFlowChannel, - ConnectContactFlowEndpointType, - ConnectContactFlowEvent, - ConnectContactFlowInitiationMethod, - ) - - def lambda_handler(event, context): - event: ConnectContactFlowEvent = ConnectContactFlowEvent(event) - assert event.contact_data.attributes == {"Language": "en-US"} - assert event.contact_data.channel == ConnectContactFlowChannel.VOICE - assert event.contact_data.customer_endpoint.endpoint_type == ConnectContactFlowEndpointType.TELEPHONE_NUMBER - assert event.contact_data.initiation_method == ConnectContactFlowInitiationMethod.API - ``` diff --git a/tests/events/appSyncDirectResolver.json b/tests/events/appSyncDirectResolver.json new file mode 100644 index 00000000000..08c3d00b203 --- /dev/null +++ b/tests/events/appSyncDirectResolver.json @@ -0,0 +1,74 @@ +{ + "arguments": { + "id": "my identifier" + }, + "identity": { + "claims": { + "sub": "192879fc-a240-4bf1-ab5a-d6a00f3063f9", + "email_verified": true, + "iss": "https://cognito-idp.us-west-2.amazonaws.com/us-west-xxxxxxxxxxx", + "phone_number_verified": false, + "cognito:username": "jdoe", + "aud": "7471s60os7h0uu77i1tk27sp9n", + "event_id": "bc334ed8-a938-4474-b644-9547e304e606", + "token_use": "id", + "auth_time": 1599154213, + "phone_number": "+19999999999", + "exp": 1599157813, + "iat": 1599154213, + "email": "jdoe@email.com" + }, + "defaultAuthStrategy": "ALLOW", + "groups": null, + "issuer": "https://cognito-idp.us-west-2.amazonaws.com/us-west-xxxxxxxxxxx", + "sourceIp": [ + "1.1.1.1" + ], + "sub": "192879fc-a240-4bf1-ab5a-d6a00f3063f9", + "username": "jdoe" + }, + "source": null, + "request": { + "headers": { + "x-forwarded-for": "1.1.1.1, 2.2.2.2", + "cloudfront-viewer-country": "US", + "cloudfront-is-tablet-viewer": "false", + "via": "2.0 xxxxxxxxxxxxxxxx.cloudfront.net (CloudFront)", + "cloudfront-forwarded-proto": "https", + "origin": "https://us-west-1.console.aws.amazon.com", + "content-length": "217", + "accept-language": "en-US,en;q=0.9", + "host": "xxxxxxxxxxxxxxxx.appsync-api.us-west-1.amazonaws.com", + "x-forwarded-proto": "https", + "user-agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_6) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/85.0.4183.83 Safari/537.36", + "accept": "*/*", + "cloudfront-is-mobile-viewer": "false", + "cloudfront-is-smarttv-viewer": "false", + "accept-encoding": "gzip, deflate, br", + "referer": "https://us-west-1.console.aws.amazon.com/appsync/home?region=us-west-1", + "content-type": "application/json", + "sec-fetch-mode": "cors", + "x-amz-cf-id": "3aykhqlUwQeANU-HGY7E_guV5EkNeMMtwyOgiA==", + "x-amzn-trace-id": "Root=1-5f512f51-fac632066c5e848ae714", + "authorization": "eyJraWQiOiJScWFCSlJqYVJlM0hrSnBTUFpIcVRXazNOW...", + "sec-fetch-dest": "empty", + "x-amz-user-agent": "AWS-Console-AppSync/", + "cloudfront-is-desktop-viewer": "true", + "sec-fetch-site": "cross-site", + "x-forwarded-port": "443" + } + }, + "prev": null, + "info": { + "selectionSetList": [ + "id", + "field1", + "field2" + ], + "selectionSetGraphQL": "{\n id\n field1\n field2\n}", + "parentTypeName": "Mutation", + "fieldName": "createSomething", + "variables": {} + }, + "stash": {} +} diff --git a/tests/events/appSyncResolverEvent.json b/tests/events/appSyncResolverEvent.json new file mode 100644 index 00000000000..84ac71951c6 --- /dev/null +++ b/tests/events/appSyncResolverEvent.json @@ -0,0 +1,71 @@ +{ + "typeName": "Merchant", + "fieldName": "locations", + "arguments": { + "page": 2, + "size": 1, + "name": "value" + }, + "identity": { + "claims": { + "sub": "07920713-4526-4642-9c88-2953512de441", + "iss": "https://cognito-idp.us-east-1.amazonaws.com/us-east-1_POOL_ID", + "aud": "58rc9bf5kkti90ctmvioppukm9", + "event_id": "7f4c9383-abf6-48b7-b821-91643968b755", + "token_use": "id", + "auth_time": 1615366261, + "name": "Michael Brewer", + "exp": 1615369861, + "iat": 1615366261 + }, + "defaultAuthStrategy": "ALLOW", + "groups": null, + "issuer": "https://cognito-idp.us-east-1.amazonaws.com/us-east-1_POOL_ID", + "sourceIp": [ + "11.215.2.22" + ], + "sub": "07920713-4526-4642-9c88-2953512de441", + "username": "mike" + }, + "source": { + "name": "Value", + "nested": { + "name": "value", + "list": [] + } + }, + "request": { + "headers": { + "x-forwarded-for": "11.215.2.22, 64.44.173.11", + "cloudfront-viewer-country": "US", + "cloudfront-is-tablet-viewer": "false", + "via": "2.0 SOMETHING.cloudfront.net (CloudFront)", + "cloudfront-forwarded-proto": "https", + "origin": "https://console.aws.amazon.com", + "content-length": "156", + "accept-language": "en-US,en;q=0.9", + "host": "SOMETHING.appsync-api.us-east-1.amazonaws.com", + "x-forwarded-proto": "https", + "sec-gpc": "1", + "user-agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) etc.", + "accept": "*/*", + "cloudfront-is-mobile-viewer": "false", + "cloudfront-is-smarttv-viewer": "false", + "accept-encoding": "gzip, deflate, br", + "referer": "https://console.aws.amazon.com/", + "content-type": "application/json", + "sec-fetch-mode": "cors", + "x-amz-cf-id": "Fo5VIuvP6V6anIEt62WzFDCK45mzM4yEdpt5BYxOl9OFqafd-WR0cA==", + "x-amzn-trace-id": "Root=1-60488877-0b0c4e6727ab2a1c545babd0", + "authorization": "AUTH-HEADER", + "sec-fetch-dest": "empty", + "x-amz-user-agent": "AWS-Console-AppSync/", + "cloudfront-is-desktop-viewer": "true", + "sec-fetch-site": "cross-site", + "x-forwarded-port": "443" + } + }, + "prev": { + "result": {} + } +} diff --git a/tests/functional/appsync/test_appsync_resolver_utils.py b/tests/functional/appsync/test_appsync_resolver_utils.py new file mode 100644 index 00000000000..b3ec85c7205 --- /dev/null +++ b/tests/functional/appsync/test_appsync_resolver_utils.py @@ -0,0 +1,219 @@ +import asyncio +import datetime +import json +import os +import sys + +import pytest + +from aws_lambda_powertools.utilities.data_classes import AppSyncResolverEvent +from aws_lambda_powertools.utilities.data_classes.appsync.resolver_utils import ( + AppSyncResolver, + aws_date, + aws_datetime, + aws_time, + aws_timestamp, + make_id, +) +from aws_lambda_powertools.utilities.typing import LambdaContext + + +def load_event(file_name: str) -> dict: + full_file_name = os.path.dirname(os.path.realpath(__file__)) + "/../../events/" + file_name + with open(full_file_name) as fp: + return json.load(fp) + + +def test_direct_resolver(): + # Check whether we can handle an example appsync direct resolver + mock_event = load_event("appSyncDirectResolver.json") + + app = AppSyncResolver() + + @app.resolver(field_name="createSomething", include_context=True) + def create_something(context, id: str): # noqa AA03 VNE003 + assert context == {} + return id + + def handler(event, context): + return app.resolve(event, context) + + result = handler(mock_event, {}) + assert result == "my identifier" + + +def test_amplify_resolver(): + # Check whether we can handle an example appsync resolver + mock_event = load_event("appSyncResolverEvent.json") + + app = AppSyncResolver() + + @app.resolver(type_name="Merchant", field_name="locations", include_event=True) + def get_location(event: AppSyncResolverEvent, page: int, size: int, name: str): + assert event is not None + assert page == 2 + assert size == 1 + return name + + def handler(event, context): + return app.resolve(event, context) + + result = handler(mock_event, {}) + assert result == "value" + + +def test_resolver_no_params(): + # GIVEN + app = AppSyncResolver() + + @app.resolver(type_name="Query", field_name="noParams") + def no_params(): + return "no_params has no params" + + event = {"typeName": "Query", "fieldName": "noParams", "arguments": {}} + + # WHEN + result = app.resolve(event, LambdaContext()) + + # THEN + assert result == "no_params has no params" + + +def test_resolver_include_event(): + # GIVEN + app = AppSyncResolver() + + mock_event = {"typeName": "Query", "fieldName": "field", "arguments": {}} + + @app.resolver(field_name="field", include_event=True) + def get_value(event: AppSyncResolverEvent): + return event + + # WHEN + result = app.resolve(mock_event, LambdaContext()) + + # THEN + assert result._data == mock_event + assert isinstance(result, AppSyncResolverEvent) + + +def test_resolver_include_context(): + # GIVEN + app = AppSyncResolver() + + mock_event = {"typeName": "Query", "fieldName": "field", "arguments": {}} + + @app.resolver(field_name="field", include_context=True) + def get_value(context: LambdaContext): + return context + + # WHEN + mock_context = LambdaContext() + result = app.resolve(mock_event, mock_context) + + # THEN + assert result == mock_context + + +def test_resolver_value_error(): + # GIVEN no defined field resolver + app = AppSyncResolver() + + # WHEN + with pytest.raises(ValueError) as exp: + event = {"typeName": "type", "fieldName": "field", "arguments": {}} + app.resolve(event, LambdaContext()) + + # THEN + assert exp.value.args[0] == "No resolver found for 'type.field'" + + +def test_resolver_yield(): + # GIVEN + app = AppSyncResolver() + + mock_event = {"typeName": "Customer", "fieldName": "field", "arguments": {}} + + @app.resolver(field_name="field") + def func_yield(): + yield "value" + + # WHEN + mock_context = LambdaContext() + result = app.resolve(mock_event, mock_context) + + # THEN + assert next(result) == "value" + + +def test_resolver_multiple_mappings(): + # GIVEN + app = AppSyncResolver() + + @app.resolver(field_name="listLocations") + @app.resolver(field_name="locations") + def get_locations(name: str, description: str = ""): + return name + description + + # WHEN + mock_event1 = {"typeName": "Query", "fieldName": "listLocations", "arguments": {"name": "value"}} + mock_event2 = { + "typeName": "Merchant", + "fieldName": "locations", + "arguments": {"name": "value2", "description": "description"}, + } + result1 = app.resolve(mock_event1, LambdaContext()) + result2 = app.resolve(mock_event2, LambdaContext()) + + # THEN + assert result1 == "value" + assert result2 == "value2description" + + +@pytest.mark.skipif(sys.version_info < (3, 8), reason="only for python versions that support asyncio.run") +def test_resolver_async(): + # GIVEN + app = AppSyncResolver() + + mock_event = {"typeName": "Customer", "fieldName": "field", "arguments": {}} + + @app.resolver(field_name="field") + async def get_async(): + await asyncio.sleep(0.0001) + return "value" + + # WHEN + mock_context = LambdaContext() + result = app.resolve(mock_event, mock_context) + + # THEN + assert asyncio.run(result) == "value" + + +def test_make_id(): + uuid: str = make_id() + assert isinstance(uuid, str) + assert len(uuid) == 36 + + +def test_aws_date(): + date_str = aws_date() + assert isinstance(date_str, str) + assert datetime.datetime.strptime(date_str, "%Y-%m-%d") + + +def test_aws_time(): + time_str = aws_time() + assert isinstance(time_str, str) + assert datetime.datetime.strptime(time_str, "%H:%M:%S") + + +def test_aws_datetime(): + datetime_str = aws_datetime() + assert isinstance(datetime_str, str) + assert datetime.datetime.strptime(datetime_str, "%Y-%m-%dT%H:%M:%SZ") + + +def test_aws_timestamp(): + timestamp = aws_timestamp() + assert isinstance(timestamp, int) diff --git a/tests/functional/test_lambda_trigger_events.py b/tests/functional/test_lambda_trigger_events.py index a6fb82970fc..40dd374960c 100644 --- a/tests/functional/test_lambda_trigger_events.py +++ b/tests/functional/test_lambda_trigger_events.py @@ -8,6 +8,7 @@ ALBEvent, APIGatewayProxyEvent, APIGatewayProxyEventV2, + AppSyncResolverEvent, CloudWatchLogsEvent, EventBridgeEvent, KinesisStreamEvent, @@ -16,6 +17,12 @@ SNSEvent, SQSEvent, ) +from aws_lambda_powertools.utilities.data_classes.appsync_resolver_event import ( + AppSyncIdentityCognito, + AppSyncIdentityIAM, + AppSyncResolverEventInfo, + get_identity_object, +) from aws_lambda_powertools.utilities.data_classes.cognito_user_pool_event import ( CreateAuthChallengeTriggerEvent, CustomMessageTriggerEvent, @@ -874,3 +881,125 @@ def test_alb_event(): assert event.multi_value_headers == event.get("multiValueHeaders") assert event.body == event["body"] assert event.is_base64_encoded == event["isBase64Encoded"] + + +def test_appsync_resolver_event(): + event = AppSyncResolverEvent(load_event("appSyncResolverEvent.json")) + + assert event.type_name == "Merchant" + assert event.field_name == "locations" + assert event.arguments["name"] == "value" + assert event.identity["claims"]["token_use"] == "id" + assert event.source["name"] == "Value" + assert event.get_header_value("X-amzn-trace-id") == "Root=1-60488877-0b0c4e6727ab2a1c545babd0" + assert event.get_header_value("X-amzn-trace-id", case_sensitive=True) is None + assert event.get_header_value("missing", default_value="Foo") == "Foo" + assert event.prev_result == {} + assert event.stash is None + + info = event.info + assert info is not None + assert isinstance(info, AppSyncResolverEventInfo) + assert info.field_name == event["fieldName"] + assert info.parent_type_name == event["typeName"] + assert info.variables is None + assert info.selection_set_list is None + assert info.selection_set_graphql is None + + assert isinstance(event.identity, AppSyncIdentityCognito) + identity: AppSyncIdentityCognito = event.identity + assert identity.claims is not None + assert identity.sub == "07920713-4526-4642-9c88-2953512de441" + assert len(identity.source_ip) == 1 + assert identity.username == "mike" + assert identity.default_auth_strategy == "ALLOW" + assert identity.groups is None + assert identity.issuer == identity["issuer"] + + +def test_get_identity_object_is_none(): + assert get_identity_object(None) is None + + event = AppSyncResolverEvent({}) + assert event.identity is None + + +def test_get_identity_object_iam(): + identity = { + "accountId": "string", + "cognitoIdentityPoolId": "string", + "cognitoIdentityId": "string", + "sourceIp": ["string"], + "username": "string", + "userArn": "string", + "cognitoIdentityAuthType": "string", + "cognitoIdentityAuthProvider": "string", + } + + identity_object = get_identity_object(identity) + + assert isinstance(identity_object, AppSyncIdentityIAM) + assert identity_object.account_id == identity["accountId"] + assert identity_object.cognito_identity_pool_id == identity["cognitoIdentityPoolId"] + assert identity_object.cognito_identity_id == identity["cognitoIdentityId"] + assert identity_object.source_ip == identity["sourceIp"] + assert identity_object.username == identity["username"] + assert identity_object.user_arn == identity["userArn"] + assert identity_object.cognito_identity_auth_type == identity["cognitoIdentityAuthType"] + assert identity_object.cognito_identity_auth_provider == identity["cognitoIdentityAuthProvider"] + + +def test_appsync_resolver_direct(): + event = AppSyncResolverEvent(load_event("appSyncDirectResolver.json")) + + assert event.source is None + assert event.arguments["id"] == "my identifier" + assert event.stash == {} + assert event.prev_result is None + assert isinstance(event.identity, AppSyncIdentityCognito) + + info = event.info + assert info is not None + assert isinstance(info, AppSyncResolverEventInfo) + assert info.selection_set_list is not None + assert info.selection_set_list == info["selectionSetList"] + assert info.selection_set_graphql == info["selectionSetGraphQL"] + assert info.parent_type_name == info["parentTypeName"] + assert info.field_name == info["fieldName"] + + assert event.type_name == info.parent_type_name + assert event.field_name == info.field_name + + +def test_appsync_resolver_event_info(): + info_dict = { + "fieldName": "getPost", + "parentTypeName": "Query", + "variables": {"postId": "123", "authorId": "456"}, + "selectionSetList": ["postId", "title"], + "selectionSetGraphQL": "{\n getPost(id: $postId) {\n postId\n etc..", + } + event = {"info": info_dict} + + event = AppSyncResolverEvent(event) + + assert event.source is None + assert event.identity is None + assert event.info is not None + assert isinstance(event.info, AppSyncResolverEventInfo) + info: AppSyncResolverEventInfo = event.info + assert info.field_name == info_dict["fieldName"] + assert event.field_name == info.field_name + assert info.parent_type_name == info_dict["parentTypeName"] + assert event.type_name == info.parent_type_name + assert info.variables == info_dict["variables"] + assert info.variables["postId"] == "123" + assert info.selection_set_list == info_dict["selectionSetList"] + assert info.selection_set_graphql == info_dict["selectionSetGraphQL"] + + +def test_appsync_resolver_event_empty(): + event = AppSyncResolverEvent({}) + + assert event.info.field_name is None + assert event.info.parent_type_name is None