Skip to content

Commit 481905e

Browse files
authored
feat(event_handler): allow customers to catch request validation errors (#3396)
1 parent 270060d commit 481905e

File tree

4 files changed

+93
-42
lines changed

4 files changed

+93
-42
lines changed

aws_lambda_powertools/event_handler/api_gateway.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from aws_lambda_powertools.event_handler import content_types
3333
from aws_lambda_powertools.event_handler.exceptions import NotFoundError, ServiceError
3434
from aws_lambda_powertools.event_handler.openapi.constants import DEFAULT_API_VERSION, DEFAULT_OPENAPI_VERSION
35+
from aws_lambda_powertools.event_handler.openapi.exceptions import RequestValidationError
3536
from aws_lambda_powertools.event_handler.openapi.swagger_ui.html import generate_swagger_html
3637
from aws_lambda_powertools.event_handler.openapi.types import (
3738
COMPONENT_REF_PREFIX,
@@ -1972,6 +1973,17 @@ def _call_exception_handler(self, exp: Exception, route: Route) -> Optional[Resp
19721973
except ServiceError as service_error:
19731974
exp = service_error
19741975

1976+
if isinstance(exp, RequestValidationError):
1977+
return self._response_builder_class(
1978+
response=Response(
1979+
status_code=HTTPStatus.UNPROCESSABLE_ENTITY,
1980+
content_type=content_types.APPLICATION_JSON,
1981+
body={"statusCode": HTTPStatus.UNPROCESSABLE_ENTITY, "message": exp.errors()},
1982+
),
1983+
serializer=self._serializer,
1984+
route=route,
1985+
)
1986+
19751987
if isinstance(exp, ServiceError):
19761988
return self._response_builder_class(
19771989
response=Response(

aws_lambda_powertools/event_handler/middlewares/openapi_validation.py

Lines changed: 33 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -62,50 +62,43 @@ def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) ->
6262
values: Dict[str, Any] = {}
6363
errors: List[Any] = []
6464

65-
try:
66-
# Process path values, which can be found on the route_args
67-
path_values, path_errors = _request_params_to_args(
68-
route.dependant.path_params,
69-
app.context["_route_args"],
65+
# Process path values, which can be found on the route_args
66+
path_values, path_errors = _request_params_to_args(
67+
route.dependant.path_params,
68+
app.context["_route_args"],
69+
)
70+
71+
# Process query values
72+
query_values, query_errors = _request_params_to_args(
73+
route.dependant.query_params,
74+
app.current_event.query_string_parameters or {},
75+
)
76+
77+
values.update(path_values)
78+
values.update(query_values)
79+
errors += path_errors + query_errors
80+
81+
# Process the request body, if it exists
82+
if route.dependant.body_params:
83+
(body_values, body_errors) = _request_body_to_args(
84+
required_params=route.dependant.body_params,
85+
received_body=self._get_body(app),
7086
)
87+
values.update(body_values)
88+
errors.extend(body_errors)
7189

72-
# Process query values
73-
query_values, query_errors = _request_params_to_args(
74-
route.dependant.query_params,
75-
app.current_event.query_string_parameters or {},
76-
)
77-
78-
values.update(path_values)
79-
values.update(query_values)
80-
errors += path_errors + query_errors
90+
if errors:
91+
# Raise the validation errors
92+
raise RequestValidationError(_normalize_errors(errors))
93+
else:
94+
# Re-write the route_args with the validated values, and call the next middleware
95+
app.context["_route_args"] = values
8196

82-
# Process the request body, if it exists
83-
if route.dependant.body_params:
84-
(body_values, body_errors) = _request_body_to_args(
85-
required_params=route.dependant.body_params,
86-
received_body=self._get_body(app),
87-
)
88-
values.update(body_values)
89-
errors.extend(body_errors)
97+
# Call the handler by calling the next middleware
98+
response = next_middleware(app)
9099

91-
if errors:
92-
# Raise the validation errors
93-
raise RequestValidationError(_normalize_errors(errors))
94-
else:
95-
# Re-write the route_args with the validated values, and call the next middleware
96-
app.context["_route_args"] = values
97-
98-
# Call the handler by calling the next middleware
99-
response = next_middleware(app)
100-
101-
# Process the response
102-
return self._handle_response(route=route, response=response)
103-
except RequestValidationError as e:
104-
return Response(
105-
status_code=422,
106-
content_type="application/json",
107-
body=json.dumps({"detail": e.errors()}),
108-
)
100+
# Process the response
101+
return self._handle_response(route=route, response=response)
109102

110103
def _handle_response(self, *, route: Route, response: Response):
111104
# Process the response body if it exists

tests/functional/event_handler/test_api_gateway.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
ServiceError,
3131
UnauthorizedError,
3232
)
33+
from aws_lambda_powertools.event_handler.openapi.exceptions import RequestValidationError
3334
from aws_lambda_powertools.shared import constants
3435
from aws_lambda_powertools.shared.cookies import Cookie
3536
from aws_lambda_powertools.shared.json_encoder import Encoder
@@ -1458,6 +1459,51 @@ def get_lambda() -> Response:
14581459
assert result["body"] == "Foo!"
14591460

14601461

1462+
def test_exception_handler_with_data_validation():
1463+
# GIVEN a resolver with an exception handler defined for RequestValidationError
1464+
app = ApiGatewayResolver(enable_validation=True)
1465+
1466+
@app.exception_handler(RequestValidationError)
1467+
def handle_validation_error(ex: RequestValidationError):
1468+
print(f"request path is '{app.current_event.path}'")
1469+
return Response(
1470+
status_code=422,
1471+
content_type=content_types.TEXT_PLAIN,
1472+
body=f"Invalid data. Number of errors: {len(ex.errors())}",
1473+
)
1474+
1475+
@app.get("/my/path")
1476+
def get_lambda(param: int):
1477+
...
1478+
1479+
# WHEN calling the event handler
1480+
# AND a RequestValidationError is raised
1481+
result = app(LOAD_GW_EVENT, {})
1482+
1483+
# THEN call the exception_handler
1484+
assert result["statusCode"] == 422
1485+
assert result["multiValueHeaders"]["Content-Type"] == [content_types.TEXT_PLAIN]
1486+
assert result["body"] == "Invalid data. Number of errors: 1"
1487+
1488+
1489+
def test_data_validation_error():
1490+
# GIVEN a resolver without an exception handler
1491+
app = ApiGatewayResolver(enable_validation=True)
1492+
1493+
@app.get("/my/path")
1494+
def get_lambda(param: int):
1495+
...
1496+
1497+
# WHEN calling the event handler
1498+
# AND a RequestValidationError is raised
1499+
result = app(LOAD_GW_EVENT, {})
1500+
1501+
# THEN call the exception_handler
1502+
assert result["statusCode"] == 422
1503+
assert result["multiValueHeaders"]["Content-Type"] == [content_types.APPLICATION_JSON]
1504+
assert "missing" in result["body"]
1505+
1506+
14611507
def test_exception_handler_service_error():
14621508
# GIVEN
14631509
app = ApiGatewayResolver()

tests/functional/event_handler/test_openapi_validation_middleware.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ class Model(BaseModel):
343343
# WHEN a handler is defined with a body parameter
344344
@app.post("/")
345345
def handler(user: Model) -> Response[Model]:
346-
return Response(body=user, status_code=200)
346+
return Response(body=user, status_code=200, content_type="application/json")
347347

348348
LOAD_GW_EVENT["httpMethod"] = "POST"
349349
LOAD_GW_EVENT["path"] = "/"
@@ -353,7 +353,7 @@ def handler(user: Model) -> Response[Model]:
353353
# THEN the body must be a dict
354354
result = app(LOAD_GW_EVENT, {})
355355
assert result["statusCode"] == 200
356-
assert result["body"] == {"name": "John", "age": 30}
356+
assert json.loads(result["body"]) == {"name": "John", "age": 30}
357357

358358

359359
def test_validate_response_invalid_return():

0 commit comments

Comments
 (0)