diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index b3d77df24b4..5bd3bc0b70e 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -10,10 +10,10 @@ from enum import Enum from functools import partial from http import HTTPStatus -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union from aws_lambda_powertools.event_handler import content_types -from aws_lambda_powertools.event_handler.exceptions import ServiceError +from aws_lambda_powertools.event_handler.exceptions import NotFoundError, ServiceError from aws_lambda_powertools.shared import constants from aws_lambda_powertools.shared.functions import resolve_truthy_env_var_choice from aws_lambda_powertools.shared.json_encoder import Encoder @@ -27,7 +27,6 @@ _SAFE_URI = "-._~()'!*:@,;" # https://www.ietf.org/rfc/rfc3986.txt # API GW/ALB decode non-safe URI chars; we must support them too _UNSAFE_URI = "%<>\[\]{}|^" # noqa: W605 - _NAMED_GROUP_BOUNDARY_PATTERN = fr"(?P\1[{_SAFE_URI}{_UNSAFE_URI}\\w]+)" @@ -435,6 +434,7 @@ def __init__( self._proxy_type = proxy_type self._routes: List[Route] = [] self._route_keys: List[str] = [] + self._exception_handlers: Dict[Type, Callable] = {} self._cors = cors self._cors_enabled: bool = cors is not None self._cors_methods: Set[str] = {"OPTIONS"} @@ -596,6 +596,10 @@ def _not_found(self, method: str) -> ResponseBuilder: headers["Access-Control-Allow-Methods"] = ",".join(sorted(self._cors_methods)) return ResponseBuilder(Response(status_code=204, content_type=None, headers=headers, body=None)) + handler = self._lookup_exception_handler(NotFoundError) + if handler: + return ResponseBuilder(handler(NotFoundError())) + return ResponseBuilder( Response( status_code=HTTPStatus.NOT_FOUND.value, @@ -609,16 +613,11 @@ def _call_route(self, route: Route, args: Dict[str, str]) -> ResponseBuilder: """Actually call the matching route with any provided keyword arguments.""" try: return ResponseBuilder(self._to_response(route.func(**args)), route) - except ServiceError as e: - return ResponseBuilder( - Response( - status_code=e.status_code, - content_type=content_types.APPLICATION_JSON, - body=self._json_dump({"statusCode": e.status_code, "message": e.msg}), - ), - route, - ) - except Exception: + except Exception as exc: + response_builder = self._call_exception_handler(exc, route) + if response_builder: + return response_builder + if self._debug: # If the user has turned on debug mode, # we'll let the original exception propagate so @@ -628,10 +627,46 @@ def _call_route(self, route: Route, args: Dict[str, str]) -> ResponseBuilder: status_code=500, content_type=content_types.TEXT_PLAIN, body="".join(traceback.format_exc()), - ) + ), + route, ) + raise + def not_found(self, func: Callable): + return self.exception_handler(NotFoundError)(func) + + def exception_handler(self, exc_class: Type[Exception]): + def register_exception_handler(func: Callable): + self._exception_handlers[exc_class] = func + + return register_exception_handler + + def _lookup_exception_handler(self, exp_type: Type) -> Optional[Callable]: + # Use "Method Resolution Order" to allow for matching against a base class + # of an exception + for cls in exp_type.__mro__: + if cls in self._exception_handlers: + return self._exception_handlers[cls] + return None + + def _call_exception_handler(self, exp: Exception, route: Route) -> Optional[ResponseBuilder]: + handler = self._lookup_exception_handler(type(exp)) + if handler: + return ResponseBuilder(handler(exp), route) + + if isinstance(exp, ServiceError): + return ResponseBuilder( + Response( + status_code=exp.status_code, + content_type=content_types.APPLICATION_JSON, + body=self._json_dump({"statusCode": exp.status_code, "message": exp.msg}), + ), + route, + ) + + return None + def _to_response(self, result: Union[Dict, Response]) -> Response: """Convert the route's result to a Response diff --git a/tests/functional/event_handler/test_api_gateway.py b/tests/functional/event_handler/test_api_gateway.py index f28752e6de6..45b1e3f41a4 100644 --- a/tests/functional/event_handler/test_api_gateway.py +++ b/tests/functional/event_handler/test_api_gateway.py @@ -163,7 +163,7 @@ def patch_func(): def handler(event, context): return app.resolve(event, context) - # Also check check the route configurations + # Also check the route configurations routes = app._routes assert len(routes) == 5 for route in routes: @@ -1076,3 +1076,76 @@ def foo(): assert result["statusCode"] == 200 assert result["headers"]["Content-Type"] == content_types.APPLICATION_JSON + + +def test_exception_handler(): + # GIVEN a resolver with an exception handler defined for ValueError + app = ApiGatewayResolver() + + @app.exception_handler(ValueError) + def handle_value_error(ex: ValueError): + print(f"request path is '{app.current_event.path}'") + return Response( + status_code=418, + content_type=content_types.TEXT_HTML, + body=str(ex), + ) + + @app.get("/my/path") + def get_lambda() -> Response: + raise ValueError("Foo!") + + # WHEN calling the event handler + # AND a ValueError is raised + result = app(LOAD_GW_EVENT, {}) + + # THEN call the exception_handler + assert result["statusCode"] == 418 + assert result["headers"]["Content-Type"] == content_types.TEXT_HTML + assert result["body"] == "Foo!" + + +def test_exception_handler_service_error(): + # GIVEN + app = ApiGatewayResolver() + + @app.exception_handler(ServiceError) + def service_error(ex: ServiceError): + print(ex.msg) + return Response( + status_code=ex.status_code, + content_type=content_types.APPLICATION_JSON, + body="CUSTOM ERROR FORMAT", + ) + + @app.get("/my/path") + def get_lambda() -> Response: + raise InternalServerError("Something sensitive") + + # WHEN calling the event handler + # AND a ServiceError is raised + result = app(LOAD_GW_EVENT, {}) + + # THEN call the exception_handler + assert result["statusCode"] == 500 + assert result["headers"]["Content-Type"] == content_types.APPLICATION_JSON + assert result["body"] == "CUSTOM ERROR FORMAT" + + +def test_exception_handler_not_found(): + # GIVEN a resolver with an exception handler defined for a 404 not found + app = ApiGatewayResolver() + + @app.not_found + def handle_not_found(exc: NotFoundError) -> Response: + assert isinstance(exc, NotFoundError) + return Response(status_code=404, content_type=content_types.TEXT_PLAIN, body="I am a teapot!") + + # WHEN calling the event handler + # AND not route is found + result = app(LOAD_GW_EVENT, {}) + + # THEN call the exception_handler + assert result["statusCode"] == 404 + assert result["headers"]["Content-Type"] == content_types.TEXT_PLAIN + assert result["body"] == "I am a teapot!"