diff --git a/aws_lambda_powertools/event_handler/appsync.py b/aws_lambda_powertools/event_handler/appsync.py index 7f4cce5c8bf..69b90c4cbb6 100644 --- a/aws_lambda_powertools/event_handler/appsync.py +++ b/aws_lambda_powertools/event_handler/appsync.py @@ -1,11 +1,13 @@ import logging -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, Type, TypeVar from aws_lambda_powertools.utilities.data_classes import AppSyncResolverEvent from aws_lambda_powertools.utilities.typing import LambdaContext logger = logging.getLogger(__name__) +AppSyncResolverEventT = TypeVar("AppSyncResolverEventT", bound=AppSyncResolverEvent) + class AppSyncResolver: """ @@ -38,7 +40,7 @@ def common_field() -> str: return str(uuid.uuid4()) """ - current_event: AppSyncResolverEvent + current_event: AppSyncResolverEventT # type: ignore[valid-type] lambda_context: LambdaContext def __init__(self): @@ -62,7 +64,9 @@ def register_resolver(func): return register_resolver - def resolve(self, event: dict, context: LambdaContext) -> Any: + def resolve( + self, event: dict, context: LambdaContext, data_model: Type[AppSyncResolverEvent] = AppSyncResolverEvent + ) -> Any: """Resolve field_name Parameters @@ -71,6 +75,56 @@ def resolve(self, event: dict, context: LambdaContext) -> Any: Lambda event context : LambdaContext Lambda context + data_model: + Your data data_model to decode AppSync event, by default AppSyncResolverEvent + + Example + ------- + + ```python + from aws_lambda_powertools.event_handler import AppSyncResolver + from aws_lambda_powertools.utilities.typing import LambdaContext + + @app.resolver(field_name="createSomething") + def create_something(id: str): # noqa AA03 VNE003 + return id + + def handler(event, context: LambdaContext): + return app.resolve(event, context) + ``` + + **Bringing custom models** + + ```python + from aws_lambda_powertools import Logger, Tracer + + from aws_lambda_powertools.logging import correlation_paths + from aws_lambda_powertools.event_handler import AppSyncResolver + + tracer = Tracer(service="sample_resolver") + logger = Logger(service="sample_resolver") + app = AppSyncResolver() + + + class MyCustomModel(AppSyncResolverEvent): + @property + def country_viewer(self) -> str: + return self.request_headers.get("cloudfront-viewer-country") + + + @app.resolver(field_name="listLocations") + @app.resolver(field_name="locations") + def get_locations(name: str, description: str = ""): + if app.current_event.country_viewer == "US": + ... + return name + description + + + @logger.inject_lambda_context(correlation_id_path=correlation_paths.APPSYNC_RESOLVER) + @tracer.capture_lambda_handler + def lambda_handler(event, context): + return app.resolve(event, context, data_model=MyCustomModel) + ``` Returns ------- @@ -82,7 +136,7 @@ def resolve(self, event: dict, context: LambdaContext) -> Any: ValueError If we could not find a field resolver """ - self.current_event = AppSyncResolverEvent(event) + self.current_event = data_model(event) self.lambda_context = context resolver = self._get_resolver(self.current_event.type_name, self.current_event.field_name) return resolver(**self.current_event.arguments) @@ -108,6 +162,8 @@ def _get_resolver(self, type_name: str, field_name: str) -> Callable: raise ValueError(f"No resolver found for '{full_name}'") return resolver["func"] - def __call__(self, event, context) -> Any: + def __call__( + self, event: dict, context: LambdaContext, data_model: Type[AppSyncResolverEvent] = AppSyncResolverEvent + ) -> Any: """Implicit lambda handler which internally calls `resolve`""" - return self.resolve(event, context) + return self.resolve(event, context, data_model) diff --git a/docs/core/event_handler/appsync.md b/docs/core/event_handler/appsync.md index 67ad1999285..a47b8a4c641 100644 --- a/docs/core/event_handler/appsync.md +++ b/docs/core/event_handler/appsync.md @@ -598,6 +598,118 @@ Use the following code for `merchantInfo` and `searchMerchant` functions respect } ``` +### Custom data models + +You can subclass `AppSyncResolverEvent` to bring your own set of methods to handle incoming events, by using `data_model` param in the `resolve` method. + + +=== "custom_model.py" + + ```python hl_lines="11-14 19 26" + from aws_lambda_powertools import Logger, Tracer + + from aws_lambda_powertools.logging import correlation_paths + from aws_lambda_powertools.event_handler import AppSyncResolver + + tracer = Tracer(service="sample_resolver") + logger = Logger(service="sample_resolver") + app = AppSyncResolver() + + + class MyCustomModel(AppSyncResolverEvent): + @property + def country_viewer(self) -> str: + return self.request_headers.get("cloudfront-viewer-country") + + @app.resolver(field_name="listLocations") + @app.resolver(field_name="locations") + def get_locations(name: str, description: str = ""): + if app.current_event.country_viewer == "US": + ... + return name + description + + @logger.inject_lambda_context(correlation_id_path=correlation_paths.APPSYNC_RESOLVER) + @tracer.capture_lambda_handler + def lambda_handler(event, context): + return app.resolve(event, context, data_model=MyCustomModel) + ``` + +=== "schema.graphql" + + ```typescript hl_lines="6 20" + schema { + query: Query + } + + type Query { + listLocations: [Location] + } + + type Location { + id: ID! + name: String! + description: String + address: String + } + + type Merchant { + id: String! + name: String! + description: String + locations: [Location] + } + ``` + +=== "listLocations_event.json" + + ```json + { + "arguments": {}, + "identity": null, + "source": null, + "request": { + "headers": { + "x-forwarded-for": "1.2.3.4, 5.6.7.8", + "accept-encoding": "gzip, deflate, br", + "cloudfront-viewer-country": "NL", + "cloudfront-is-tablet-viewer": "false", + "referer": "https://eu-west-1.console.aws.amazon.com/appsync/home?region=eu-west-1", + "via": "2.0 9fce949f3749407c8e6a75087e168b47.cloudfront.net (CloudFront)", + "cloudfront-forwarded-proto": "https", + "origin": "https://eu-west-1.console.aws.amazon.com", + "x-api-key": "da1-c33ullkbkze3jg5hf5ddgcs4fq", + "content-type": "application/json", + "x-amzn-trace-id": "Root=1-606eb2f2-1babc433453a332c43fb4494", + "x-amz-cf-id": "SJw16ZOPuMZMINx5Xcxa9pB84oMPSGCzNOfrbJLvd80sPa0waCXzYQ==", + "content-length": "114", + "x-amz-user-agent": "AWS-Console-AppSync/", + "x-forwarded-proto": "https", + "host": "ldcvmkdnd5az3lm3gnf5ixvcyy.appsync-api.eu-west-1.amazonaws.com", + "accept-language": "en-US,en;q=0.5", + "user-agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10.15; rv:78.0) Gecko/20100101 Firefox/78.0", + "cloudfront-is-desktop-viewer": "true", + "cloudfront-is-mobile-viewer": "false", + "accept": "*/*", + "x-forwarded-port": "443", + "cloudfront-is-smarttv-viewer": "false" + } + }, + "prev": null, + "info": { + "parentTypeName": "Query", + "selectionSetList": [ + "id", + "name", + "description" + ], + "selectionSetGraphQL": "{\n id\n name\n description\n}", + "fieldName": "listLocations", + "variables": {} + }, + "stash": {} + } + ``` + ## Testing your code You can test your resolvers by passing a mocked or actual AppSync Lambda event that you're expecting. diff --git a/tests/functional/event_handler/test_appsync.py b/tests/functional/event_handler/test_appsync.py index e260fef89ab..26a3ffdcb1f 100644 --- a/tests/functional/event_handler/test_appsync.py +++ b/tests/functional/event_handler/test_appsync.py @@ -138,3 +138,26 @@ async def get_async(): # THEN assert asyncio.run(result) == "value" + + +def test_resolve_custom_data_model(): + # Check whether we can handle an example appsync direct resolver + mock_event = load_event("appSyncDirectResolver.json") + + class MyCustomModel(AppSyncResolverEvent): + @property + def country_viewer(self): + return self.request_headers.get("cloudfront-viewer-country") + + app = AppSyncResolver() + + @app.resolver(field_name="createSomething") + def create_something(id: str): # noqa AA03 VNE003 + return id + + # Call the implicit handler + result = app(event=mock_event, context=LambdaContext(), data_model=MyCustomModel) + + assert result == "my identifier" + + assert app.current_event.country_viewer == "US"