From f78c0d6bb696bd8cacd5f52166fd405ae05e9fcb Mon Sep 17 00:00:00 2001 From: Michael Brewer Date: Tue, 14 Dec 2021 21:53:57 -0800 Subject: [PATCH 1/5] feat(event-handler): Add exception_handler support --- .../event_handler/api_gateway.py | 21 ++++++++++++-- .../event_handler/test_api_gateway.py | 29 ++++++++++++++++++- 2 files changed, 47 insertions(+), 3 deletions(-) diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index b3d77df24b4..f331d538307 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -10,7 +10,7 @@ from enum import Enum from functools import partial from http import HTTPStatus -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union from aws_lambda_powertools.event_handler import content_types from aws_lambda_powertools.event_handler.exceptions import ServiceError @@ -435,6 +435,7 @@ def __init__( self._proxy_type = proxy_type self._routes: List[Route] = [] self._route_keys: List[str] = [] + self._exception_handlers: Dict[Type, Callable] = {} self._cors = cors self._cors_enabled: bool = cors is not None self._cors_methods: Set[str] = {"OPTIONS"} @@ -618,7 +619,11 @@ def _call_route(self, route: Route, args: Dict[str, str]) -> ResponseBuilder: ), route, ) - except Exception: + except Exception as exc: + handler = self._lookup_exception_handler(exc) + if handler: + return ResponseBuilder(handler(exc)) + if self._debug: # If the user has turned on debug mode, # we'll let the original exception propagate so @@ -676,6 +681,18 @@ def include_router(self, router: "Router", prefix: Optional[str] = None) -> None self.route(*route)(func) + def exception_handler(self, exception): + def register_exception_handler(func: Callable): + self._exception_handlers[exception] = func + + return register_exception_handler + + def _lookup_exception_handler(self, exc: Exception) -> Optional[Callable]: + for cls in type(exc).__mro__: + if cls in self._exception_handlers: + return self._exception_handlers[cls] + return None + class Router(BaseRouter): """Router helper class to allow splitting ApiGatewayResolver into multiple files""" diff --git a/tests/functional/event_handler/test_api_gateway.py b/tests/functional/event_handler/test_api_gateway.py index f28752e6de6..5dee15793a1 100644 --- a/tests/functional/event_handler/test_api_gateway.py +++ b/tests/functional/event_handler/test_api_gateway.py @@ -163,7 +163,7 @@ def patch_func(): def handler(event, context): return app.resolve(event, context) - # Also check check the route configurations + # Also check the route configurations routes = app._routes assert len(routes) == 5 for route in routes: @@ -1076,3 +1076,30 @@ def foo(): assert result["statusCode"] == 200 assert result["headers"]["Content-Type"] == content_types.APPLICATION_JSON + + +def test_exception_handler(): + # GIVEN a resolver with an exception handler defined for ValueError + app = ApiGatewayResolver(proxy_type=ProxyEventType.APIGatewayProxyEvent) + + @app.exception_handler(ValueError) + def handle_value_error(ex: ValueError): + print(f"request path is '{app.current_event.path}'") + return Response( + status_code=418, + content_type=content_types.TEXT_HTML, + body=str(ex), + ) + + @app.get("/my/path") + def get_lambda() -> Response: + raise ValueError("Foo!") + + # WHEN calling the event handler + # AND a ValueError is raised + result = app(LOAD_GW_EVENT, {}) + + # THEN call the exception_handler + assert result["statusCode"] == 418 + assert result["headers"]["Content-Type"] == content_types.TEXT_HTML + assert result["body"] == "Foo!" From 5337dc35fa155cf07e6d1b9798eedfaa8f692636 Mon Sep 17 00:00:00 2001 From: Michael Brewer Date: Wed, 15 Dec 2021 08:21:49 -0800 Subject: [PATCH 2/5] feat: allow for overriding ServiceErrors --- .../event_handler/api_gateway.py | 33 ++++++++++--------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index f331d538307..1b3026d1cca 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -610,19 +610,10 @@ def _call_route(self, route: Route, args: Dict[str, str]) -> ResponseBuilder: """Actually call the matching route with any provided keyword arguments.""" try: return ResponseBuilder(self._to_response(route.func(**args)), route) - except ServiceError as e: - return ResponseBuilder( - Response( - status_code=e.status_code, - content_type=content_types.APPLICATION_JSON, - body=self._json_dump({"statusCode": e.status_code, "message": e.msg}), - ), - route, - ) except Exception as exc: - handler = self._lookup_exception_handler(exc) - if handler: - return ResponseBuilder(handler(exc)) + response = self._call_exception_handler(exc, route) + if response: + return response if self._debug: # If the user has turned on debug mode, @@ -687,10 +678,22 @@ def register_exception_handler(func: Callable): return register_exception_handler - def _lookup_exception_handler(self, exc: Exception) -> Optional[Callable]: - for cls in type(exc).__mro__: + def _call_exception_handler(self, exp: Exception, route: Route) -> Optional[ResponseBuilder]: + for cls in type(exp).__mro__: if cls in self._exception_handlers: - return self._exception_handlers[cls] + handler = self._exception_handlers[cls] + return ResponseBuilder(handler(exp), route) + + if isinstance(exp, ServiceError): + return ResponseBuilder( + Response( + status_code=exp.status_code, + content_type=content_types.APPLICATION_JSON, + body=self._json_dump({"statusCode": exp.status_code, "message": exp.msg}), + ), + route, + ) + return None From f0bae3cd77cd6cbc7d672fba127078cf98826625 Mon Sep 17 00:00:00 2001 From: Michael Brewer Date: Wed, 15 Dec 2021 10:38:59 -0800 Subject: [PATCH 3/5] feat: add not_found handler --- .../event_handler/api_gateway.py | 37 +++++++++++++------ .../event_handler/test_api_gateway.py | 19 ++++++++++ 2 files changed, 45 insertions(+), 11 deletions(-) diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 1b3026d1cca..2f4169ac1fc 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -13,7 +13,7 @@ from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union from aws_lambda_powertools.event_handler import content_types -from aws_lambda_powertools.event_handler.exceptions import ServiceError +from aws_lambda_powertools.event_handler.exceptions import NotFoundError, ServiceError from aws_lambda_powertools.shared import constants from aws_lambda_powertools.shared.functions import resolve_truthy_env_var_choice from aws_lambda_powertools.shared.json_encoder import Encoder @@ -435,7 +435,7 @@ def __init__( self._proxy_type = proxy_type self._routes: List[Route] = [] self._route_keys: List[str] = [] - self._exception_handlers: Dict[Type, Callable] = {} + self._exception_handlers: Dict[Union[int, Type], Callable] = {} self._cors = cors self._cors_enabled: bool = cors is not None self._cors_methods: Set[str] = {"OPTIONS"} @@ -597,6 +597,11 @@ def _not_found(self, method: str) -> ResponseBuilder: headers["Access-Control-Allow-Methods"] = ",".join(sorted(self._cors_methods)) return ResponseBuilder(Response(status_code=204, content_type=None, headers=headers, body=None)) + # Allow for custom exception handlers + handler = self._exception_handlers.get(404) + if handler: + return ResponseBuilder(handler(NotFoundError())) + return ResponseBuilder( Response( status_code=HTTPStatus.NOT_FOUND.value, @@ -611,9 +616,9 @@ def _call_route(self, route: Route, args: Dict[str, str]) -> ResponseBuilder: try: return ResponseBuilder(self._to_response(route.func(**args)), route) except Exception as exc: - response = self._call_exception_handler(exc, route) - if response: - return response + response_builder = self._call_exception_handler(exc, route) + if response_builder: + return response_builder if self._debug: # If the user has turned on debug mode, @@ -624,8 +629,10 @@ def _call_route(self, route: Route, args: Dict[str, str]) -> ResponseBuilder: status_code=500, content_type=content_types.TEXT_PLAIN, body="".join(traceback.format_exc()), - ) + ), + route, ) + raise def _to_response(self, result: Union[Dict, Response]) -> Response: @@ -672,17 +679,25 @@ def include_router(self, router: "Router", prefix: Optional[str] = None) -> None self.route(*route)(func) - def exception_handler(self, exception): + def not_found(self): + return self.exception_handler(404) + + def exception_handler(self, exc_class_or_status_code: Union[int, Type[Exception]]): def register_exception_handler(func: Callable): - self._exception_handlers[exception] = func + self._exception_handlers[exc_class_or_status_code] = func return register_exception_handler - def _call_exception_handler(self, exp: Exception, route: Route) -> Optional[ResponseBuilder]: + def _lookup_exception_handler(self, exp: Exception) -> Optional[Callable]: for cls in type(exp).__mro__: if cls in self._exception_handlers: - handler = self._exception_handlers[cls] - return ResponseBuilder(handler(exp), route) + return self._exception_handlers[cls] + return None + + def _call_exception_handler(self, exp: Exception, route: Route) -> Optional[ResponseBuilder]: + handler = self._lookup_exception_handler(exp) + if handler: + return ResponseBuilder(handler(exp), route) if isinstance(exp, ServiceError): return ResponseBuilder( diff --git a/tests/functional/event_handler/test_api_gateway.py b/tests/functional/event_handler/test_api_gateway.py index 5dee15793a1..7fc316034f0 100644 --- a/tests/functional/event_handler/test_api_gateway.py +++ b/tests/functional/event_handler/test_api_gateway.py @@ -1103,3 +1103,22 @@ def get_lambda() -> Response: assert result["statusCode"] == 418 assert result["headers"]["Content-Type"] == content_types.TEXT_HTML assert result["body"] == "Foo!" + + +def test_exception_handler_not_found(): + # GIVEN a resolver with an exception handler defined for ValueError + app = ApiGatewayResolver(proxy_type=ProxyEventType.APIGatewayProxyEvent) + + @app.not_found() + def handle_not_found(exc: NotFoundError): + assert isinstance(exc, NotFoundError) + return Response(status_code=404, content_type=content_types.TEXT_PLAIN, body="I am a teapot!") + + # WHEN calling the event handler + # AND a ValueError is raised + result = app(LOAD_GW_EVENT, {}) + + # THEN call the exception_handler + assert result["statusCode"] == 404 + assert result["headers"]["Content-Type"] == content_types.TEXT_PLAIN + assert result["body"] == "I am a teapot!" From 019e099352b2a952d1f4d9c5da3a8b2c6753edf9 Mon Sep 17 00:00:00 2001 From: Michael Brewer Date: Wed, 15 Dec 2021 11:17:22 -0800 Subject: [PATCH 4/5] feat: add example for custom service errors --- .../event_handler/api_gateway.py | 4 +- .../event_handler/test_api_gateway.py | 39 ++++++++++++++++--- 2 files changed, 35 insertions(+), 8 deletions(-) diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 2f4169ac1fc..8327054bda0 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -679,8 +679,8 @@ def include_router(self, router: "Router", prefix: Optional[str] = None) -> None self.route(*route)(func) - def not_found(self): - return self.exception_handler(404) + def not_found(self, func: Callable): + return self.exception_handler(404)(func) def exception_handler(self, exc_class_or_status_code: Union[int, Type[Exception]]): def register_exception_handler(func: Callable): diff --git a/tests/functional/event_handler/test_api_gateway.py b/tests/functional/event_handler/test_api_gateway.py index 7fc316034f0..45b1e3f41a4 100644 --- a/tests/functional/event_handler/test_api_gateway.py +++ b/tests/functional/event_handler/test_api_gateway.py @@ -1080,7 +1080,7 @@ def foo(): def test_exception_handler(): # GIVEN a resolver with an exception handler defined for ValueError - app = ApiGatewayResolver(proxy_type=ProxyEventType.APIGatewayProxyEvent) + app = ApiGatewayResolver() @app.exception_handler(ValueError) def handle_value_error(ex: ValueError): @@ -1105,17 +1105,44 @@ def get_lambda() -> Response: assert result["body"] == "Foo!" +def test_exception_handler_service_error(): + # GIVEN + app = ApiGatewayResolver() + + @app.exception_handler(ServiceError) + def service_error(ex: ServiceError): + print(ex.msg) + return Response( + status_code=ex.status_code, + content_type=content_types.APPLICATION_JSON, + body="CUSTOM ERROR FORMAT", + ) + + @app.get("/my/path") + def get_lambda() -> Response: + raise InternalServerError("Something sensitive") + + # WHEN calling the event handler + # AND a ServiceError is raised + result = app(LOAD_GW_EVENT, {}) + + # THEN call the exception_handler + assert result["statusCode"] == 500 + assert result["headers"]["Content-Type"] == content_types.APPLICATION_JSON + assert result["body"] == "CUSTOM ERROR FORMAT" + + def test_exception_handler_not_found(): - # GIVEN a resolver with an exception handler defined for ValueError - app = ApiGatewayResolver(proxy_type=ProxyEventType.APIGatewayProxyEvent) + # GIVEN a resolver with an exception handler defined for a 404 not found + app = ApiGatewayResolver() - @app.not_found() - def handle_not_found(exc: NotFoundError): + @app.not_found + def handle_not_found(exc: NotFoundError) -> Response: assert isinstance(exc, NotFoundError) return Response(status_code=404, content_type=content_types.TEXT_PLAIN, body="I am a teapot!") # WHEN calling the event handler - # AND a ValueError is raised + # AND not route is found result = app(LOAD_GW_EVENT, {}) # THEN call the exception_handler From eb1d7a89520c33c3ff01786846875c95fc1d162a Mon Sep 17 00:00:00 2001 From: Michael Brewer Date: Thu, 16 Dec 2021 09:23:28 +0200 Subject: [PATCH 5/5] chore: some refactoring --- .../event_handler/api_gateway.py | 72 +++++++++---------- 1 file changed, 36 insertions(+), 36 deletions(-) diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 8327054bda0..5bd3bc0b70e 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -27,7 +27,6 @@ _SAFE_URI = "-._~()'!*:@,;" # https://www.ietf.org/rfc/rfc3986.txt # API GW/ALB decode non-safe URI chars; we must support them too _UNSAFE_URI = "%<>\[\]{}|^" # noqa: W605 - _NAMED_GROUP_BOUNDARY_PATTERN = fr"(?P\1[{_SAFE_URI}{_UNSAFE_URI}\\w]+)" @@ -435,7 +434,7 @@ def __init__( self._proxy_type = proxy_type self._routes: List[Route] = [] self._route_keys: List[str] = [] - self._exception_handlers: Dict[Union[int, Type], Callable] = {} + self._exception_handlers: Dict[Type, Callable] = {} self._cors = cors self._cors_enabled: bool = cors is not None self._cors_methods: Set[str] = {"OPTIONS"} @@ -597,8 +596,7 @@ def _not_found(self, method: str) -> ResponseBuilder: headers["Access-Control-Allow-Methods"] = ",".join(sorted(self._cors_methods)) return ResponseBuilder(Response(status_code=204, content_type=None, headers=headers, body=None)) - # Allow for custom exception handlers - handler = self._exception_handlers.get(404) + handler = self._lookup_exception_handler(NotFoundError) if handler: return ResponseBuilder(handler(NotFoundError())) @@ -635,6 +633,40 @@ def _call_route(self, route: Route, args: Dict[str, str]) -> ResponseBuilder: raise + def not_found(self, func: Callable): + return self.exception_handler(NotFoundError)(func) + + def exception_handler(self, exc_class: Type[Exception]): + def register_exception_handler(func: Callable): + self._exception_handlers[exc_class] = func + + return register_exception_handler + + def _lookup_exception_handler(self, exp_type: Type) -> Optional[Callable]: + # Use "Method Resolution Order" to allow for matching against a base class + # of an exception + for cls in exp_type.__mro__: + if cls in self._exception_handlers: + return self._exception_handlers[cls] + return None + + def _call_exception_handler(self, exp: Exception, route: Route) -> Optional[ResponseBuilder]: + handler = self._lookup_exception_handler(type(exp)) + if handler: + return ResponseBuilder(handler(exp), route) + + if isinstance(exp, ServiceError): + return ResponseBuilder( + Response( + status_code=exp.status_code, + content_type=content_types.APPLICATION_JSON, + body=self._json_dump({"statusCode": exp.status_code, "message": exp.msg}), + ), + route, + ) + + return None + def _to_response(self, result: Union[Dict, Response]) -> Response: """Convert the route's result to a Response @@ -679,38 +711,6 @@ def include_router(self, router: "Router", prefix: Optional[str] = None) -> None self.route(*route)(func) - def not_found(self, func: Callable): - return self.exception_handler(404)(func) - - def exception_handler(self, exc_class_or_status_code: Union[int, Type[Exception]]): - def register_exception_handler(func: Callable): - self._exception_handlers[exc_class_or_status_code] = func - - return register_exception_handler - - def _lookup_exception_handler(self, exp: Exception) -> Optional[Callable]: - for cls in type(exp).__mro__: - if cls in self._exception_handlers: - return self._exception_handlers[cls] - return None - - def _call_exception_handler(self, exp: Exception, route: Route) -> Optional[ResponseBuilder]: - handler = self._lookup_exception_handler(exp) - if handler: - return ResponseBuilder(handler(exp), route) - - if isinstance(exp, ServiceError): - return ResponseBuilder( - Response( - status_code=exp.status_code, - content_type=content_types.APPLICATION_JSON, - body=self._json_dump({"statusCode": exp.status_code, "message": exp.msg}), - ), - route, - ) - - return None - class Router(BaseRouter): """Router helper class to allow splitting ApiGatewayResolver into multiple files"""