Skip to content

Commit 1409499

Browse files
feat(event_handler): add support for additional response models (#3591)
* feat(event_handler): add support for additional response models * fix: I hate sonarcube * fix: pydantic 2 * fix: refactor * fix: increase coverage * chore: update docs --------- Co-authored-by: Leandro Damascena <lcdama@amazon.pt>
1 parent 76dc016 commit 1409499

File tree

6 files changed

+215
-27
lines changed

6 files changed

+215
-27
lines changed

aws_lambda_powertools/event_handler/api_gateway.py

Lines changed: 64 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@
3737
from aws_lambda_powertools.event_handler.openapi.types import (
3838
COMPONENT_REF_PREFIX,
3939
METHODS_WITH_BODY,
40+
OpenAPIResponse,
41+
OpenAPIResponseContentModel,
42+
OpenAPIResponseContentSchema,
4043
validation_error_definition,
4144
validation_error_response_definition,
4245
)
@@ -273,7 +276,7 @@ def __init__(
273276
cache_control: Optional[str],
274277
summary: Optional[str],
275278
description: Optional[str],
276-
responses: Optional[Dict[int, Dict[str, Any]]],
279+
responses: Optional[Dict[int, OpenAPIResponse]],
277280
response_description: Optional[str],
278281
tags: Optional[List[str]],
279282
operation_id: Optional[str],
@@ -303,7 +306,7 @@ def __init__(
303306
The OpenAPI summary for this route
304307
description: Optional[str]
305308
The OpenAPI description for this route
306-
responses: Optional[Dict[int, Dict[str, Any]]]
309+
responses: Optional[Dict[int, OpenAPIResponse]]
307310
The OpenAPI responses for this route
308311
response_description: Optional[str]
309312
The OpenAPI response description for this route
@@ -442,7 +445,7 @@ def dependant(self) -> "Dependant":
442445
if self._dependant is None:
443446
from aws_lambda_powertools.event_handler.openapi.dependant import get_dependant
444447

445-
self._dependant = get_dependant(path=self.openapi_path, call=self.func)
448+
self._dependant = get_dependant(path=self.openapi_path, call=self.func, responses=self.responses)
446449

447450
return self._dependant
448451

@@ -501,11 +504,54 @@ def _get_openapi_path(
501504

502505
# Add the response to the OpenAPI operation
503506
if self.responses:
504-
# If the user supplied responses, we use them and don't set a default 200 response
507+
for status_code in list(self.responses):
508+
response = self.responses[status_code]
509+
510+
# Case 1: there is not 'content' key
511+
if "content" not in response:
512+
response["content"] = {
513+
"application/json": self._openapi_operation_return(
514+
param=dependant.return_param,
515+
model_name_map=model_name_map,
516+
field_mapping=field_mapping,
517+
),
518+
}
519+
520+
# Case 2: there is a 'content' key
521+
else:
522+
# Need to iterate to transform any 'model' into a 'schema'
523+
for content_type, payload in response["content"].items():
524+
new_payload: OpenAPIResponseContentSchema
525+
526+
# Case 2.1: the 'content' has a model
527+
if "model" in payload:
528+
# Find the model in the dependant's extra models
529+
return_field = next(
530+
filter(
531+
lambda model: model.type_ is cast(OpenAPIResponseContentModel, payload)["model"],
532+
self.dependant.response_extra_models,
533+
),
534+
)
535+
if not return_field:
536+
raise AssertionError("Model declared in custom responses was not found")
537+
538+
new_payload = self._openapi_operation_return(
539+
param=return_field,
540+
model_name_map=model_name_map,
541+
field_mapping=field_mapping,
542+
)
543+
544+
# Case 2.2: the 'content' has a schema
545+
else:
546+
# Do nothing! We already have what we need!
547+
new_payload = payload
548+
549+
response["content"][content_type] = new_payload
550+
505551
operation["responses"] = self.responses
506552
else:
507553
# Set the default 200 response
508-
responses = operation.setdefault("responses", self.responses or {})
554+
responses = operation.setdefault("responses", {})
509555
success_response = responses.setdefault(200, {})
510556
success_response["description"] = self.response_description or _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION
511557
success_response["content"] = {"application/json": {"schema": {}}}
@@ -682,7 +728,7 @@ def _openapi_operation_return(
682728
Tuple["ModelField", Literal["validation", "serialization"]],
683729
"JsonSchemaValue",
684730
],
685-
) -> Dict[str, Any]:
731+
) -> OpenAPIResponseContentSchema:
686732
"""
687733
Returns the OpenAPI operation return.
688734
"""
@@ -832,7 +878,7 @@ def route(
832878
cache_control: Optional[str] = None,
833879
summary: Optional[str] = None,
834880
description: Optional[str] = None,
835-
responses: Optional[Dict[int, Dict[str, Any]]] = None,
881+
responses: Optional[Dict[int, OpenAPIResponse]] = None,
836882
response_description: str = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION,
837883
tags: Optional[List[str]] = None,
838884
operation_id: Optional[str] = None,
@@ -890,7 +936,7 @@ def get(
890936
cache_control: Optional[str] = None,
891937
summary: Optional[str] = None,
892938
description: Optional[str] = None,
893-
responses: Optional[Dict[int, Dict[str, Any]]] = None,
939+
responses: Optional[Dict[int, OpenAPIResponse]] = None,
894940
response_description: str = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION,
895941
tags: Optional[List[str]] = None,
896942
operation_id: Optional[str] = None,
@@ -943,7 +989,7 @@ def post(
943989
cache_control: Optional[str] = None,
944990
summary: Optional[str] = None,
945991
description: Optional[str] = None,
946-
responses: Optional[Dict[int, Dict[str, Any]]] = None,
992+
responses: Optional[Dict[int, OpenAPIResponse]] = None,
947993
response_description: str = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION,
948994
tags: Optional[List[str]] = None,
949995
operation_id: Optional[str] = None,
@@ -997,7 +1043,7 @@ def put(
9971043
cache_control: Optional[str] = None,
9981044
summary: Optional[str] = None,
9991045
description: Optional[str] = None,
1000-
responses: Optional[Dict[int, Dict[str, Any]]] = None,
1046+
responses: Optional[Dict[int, OpenAPIResponse]] = None,
10011047
response_description: str = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION,
10021048
tags: Optional[List[str]] = None,
10031049
operation_id: Optional[str] = None,
@@ -1051,7 +1097,7 @@ def delete(
10511097
cache_control: Optional[str] = None,
10521098
summary: Optional[str] = None,
10531099
description: Optional[str] = None,
1054-
responses: Optional[Dict[int, Dict[str, Any]]] = None,
1100+
responses: Optional[Dict[int, OpenAPIResponse]] = None,
10551101
response_description: str = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION,
10561102
tags: Optional[List[str]] = None,
10571103
operation_id: Optional[str] = None,
@@ -1104,7 +1150,7 @@ def patch(
11041150
cache_control: Optional[str] = None,
11051151
summary: Optional[str] = None,
11061152
description: Optional[str] = None,
1107-
responses: Optional[Dict[int, Dict[str, Any]]] = None,
1153+
responses: Optional[Dict[int, OpenAPIResponse]] = None,
11081154
response_description: str = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION,
11091155
tags: Optional[List[str]] = None,
11101156
operation_id: Optional[str] = None,
@@ -1662,7 +1708,7 @@ def route(
16621708
cache_control: Optional[str] = None,
16631709
summary: Optional[str] = None,
16641710
description: Optional[str] = None,
1665-
responses: Optional[Dict[int, Dict[str, Any]]] = None,
1711+
responses: Optional[Dict[int, OpenAPIResponse]] = None,
16661712
response_description: str = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION,
16671713
tags: Optional[List[str]] = None,
16681714
operation_id: Optional[str] = None,
@@ -2110,6 +2156,9 @@ def _get_fields_from_routes(routes: Sequence[Route]) -> List["ModelField"]:
21102156
if route.dependant.return_param:
21112157
responses_from_routes.append(route.dependant.return_param)
21122158

2159+
if route.dependant.response_extra_models:
2160+
responses_from_routes.extend(route.dependant.response_extra_models)
2161+
21132162
flat_models = list(responses_from_routes + request_fields_from_routes + body_fields_from_routes)
21142163
return flat_models
21152164

@@ -2132,7 +2181,7 @@ def route(
21322181
cache_control: Optional[str] = None,
21332182
summary: Optional[str] = None,
21342183
description: Optional[str] = None,
2135-
responses: Optional[Dict[int, Dict[str, Any]]] = None,
2184+
responses: Optional[Dict[int, OpenAPIResponse]] = None,
21362185
response_description: Optional[str] = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION,
21372186
tags: Optional[List[str]] = None,
21382187
operation_id: Optional[str] = None,
@@ -2221,7 +2270,7 @@ def route(
22212270
cache_control: Optional[str] = None,
22222271
summary: Optional[str] = None,
22232272
description: Optional[str] = None,
2224-
responses: Optional[Dict[int, Dict[str, Any]]] = None,
2273+
responses: Optional[Dict[int, OpenAPIResponse]] = None,
22252274
response_description: str = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION,
22262275
tags: Optional[List[str]] = None,
22272276
operation_id: Optional[str] = None,

aws_lambda_powertools/event_handler/openapi/dependant.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
create_response_field,
2525
get_flat_dependant,
2626
)
27+
from aws_lambda_powertools.event_handler.openapi.types import OpenAPIResponse, OpenAPIResponseContentModel
2728

2829
"""
2930
This turns the opaque function signature into typed, validated models.
@@ -145,6 +146,7 @@ def get_dependant(
145146
path: str,
146147
call: Callable[..., Any],
147148
name: Optional[str] = None,
149+
responses: Optional[Dict[int, OpenAPIResponse]] = None,
148150
) -> Dependant:
149151
"""
150152
Returns a dependant model for a handler function. A dependant model is a model that contains
@@ -158,6 +160,8 @@ def get_dependant(
158160
The handler function
159161
name: str, optional
160162
The name of the handler function
163+
responses: List[Dict[int, OpenAPIResponse]], optional
164+
The list of extra responses for the handler function
161165
162166
Returns
163167
-------
@@ -195,6 +199,34 @@ def get_dependant(
195199
else:
196200
add_param_to_fields(field=param_field, dependant=dependant)
197201

202+
_add_return_annotation(dependant, endpoint_signature)
203+
_add_extra_responses(dependant, responses)
204+
205+
return dependant
206+
207+
208+
def _add_extra_responses(dependant: Dependant, responses: Optional[Dict[int, OpenAPIResponse]]):
209+
# Also add the optional extra responses to the dependant model.
210+
if not responses:
211+
return
212+
213+
for response in responses.values():
214+
for schema in response.get("content", {}).values():
215+
if "model" in schema:
216+
response_field = analyze_param(
217+
param_name="return",
218+
annotation=cast(OpenAPIResponseContentModel, schema)["model"],
219+
value=None,
220+
is_path_param=False,
221+
is_response_param=True,
222+
)
223+
if response_field is None:
224+
raise AssertionError("Response field is None for response model")
225+
226+
dependant.response_extra_models.append(response_field)
227+
228+
229+
def _add_return_annotation(dependant: Dependant, endpoint_signature: inspect.Signature):
198230
# If the return annotation is not empty, add it to the dependant model.
199231
return_annotation = endpoint_signature.return_annotation
200232
if return_annotation is not inspect.Signature.empty:
@@ -210,8 +242,6 @@ def get_dependant(
210242

211243
dependant.return_param = param_field
212244

213-
return dependant
214-
215245

216246
def is_body_param(*, param_field: ModelField, is_path_param: bool) -> bool:
217247
"""

aws_lambda_powertools/event_handler/openapi/params.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def __init__(
4949
cookie_params: Optional[List[ModelField]] = None,
5050
body_params: Optional[List[ModelField]] = None,
5151
return_param: Optional[ModelField] = None,
52+
response_extra_models: Optional[List[ModelField]] = None,
5253
name: Optional[str] = None,
5354
call: Optional[Callable[..., Any]] = None,
5455
request_param_name: Optional[str] = None,
@@ -64,6 +65,7 @@ def __init__(
6465
self.cookie_params = cookie_params or []
6566
self.body_params = body_params or []
6667
self.return_param = return_param or None
68+
self.response_extra_models = response_extra_models or []
6769
self.request_param_name = request_param_name
6870
self.websocket_param_name = websocket_param_name
6971
self.http_connection_param_name = http_connection_param_name

aws_lambda_powertools/event_handler/openapi/types.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
from enum import Enum
33
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Set, Type, Union
44

5+
from aws_lambda_powertools.shared.types import NotRequired, TypedDict
6+
57
if TYPE_CHECKING:
68
from pydantic import BaseModel # noqa: F401
79

@@ -43,3 +45,16 @@
4345
},
4446
},
4547
}
48+
49+
50+
class OpenAPIResponseContentSchema(TypedDict, total=False):
51+
schema: Dict
52+
53+
54+
class OpenAPIResponseContentModel(TypedDict):
55+
model: Any
56+
57+
58+
class OpenAPIResponse(TypedDict):
59+
description: str
60+
content: NotRequired[Dict[str, Union[OpenAPIResponseContentSchema, OpenAPIResponseContentModel]]]

docs/core/event_handler/api_gateway.md

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -955,15 +955,15 @@ Customize your API endpoints by adding metadata to endpoint definitions. This pr
955955

956956
Here's a breakdown of various customizable fields:
957957

958-
| Field Name | Type | Description |
959-
| ---------------------- | --------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
960-
| `summary` | `str` | A concise overview of the main functionality of the endpoint. This brief introduction is usually displayed in autogenerated API documentation and helps consumers quickly understand what the endpoint does. |
961-
| `description` | `str` | A more detailed explanation of the endpoint, which can include information about the operation's behavior, including side effects, error states, and other operational guidelines. |
962-
| `responses` | `Dict[int, Dict[str, Any]]` | A dictionary that maps each HTTP status code to a Response Object as defined by the [OpenAPI Specification](https://swagger.io/specification/#response-object). This allows you to describe expected responses, including default or error messages, and their corresponding schemas for different status codes. |
963-
| `response_description` | `str` | Provides the default textual description of the response sent by the endpoint when the operation is successful. It is intended to give a human-readable understanding of the result. |
964-
| `tags` | `List[str]` | Tags are a way to categorize and group endpoints within the API documentation. They can help organize the operations by resources or other heuristic. |
965-
| `operation_id` | `str` | A unique identifier for the operation, which can be used for referencing this operation in documentation or code. This ID must be unique across all operations described in the API. |
966-
| `include_in_schema` | `bool` | A boolean value that determines whether or not this operation should be included in the OpenAPI schema. Setting it to `False` can hide the endpoint from generated documentation and schema exports, which might be useful for private or experimental endpoints. |
958+
| Field Name | Type | Description |
959+
| ---------------------- |-----------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
960+
| `summary` | `str` | A concise overview of the main functionality of the endpoint. This brief introduction is usually displayed in autogenerated API documentation and helps consumers quickly understand what the endpoint does. |
961+
| `description` | `str` | A more detailed explanation of the endpoint, which can include information about the operation's behavior, including side effects, error states, and other operational guidelines. |
962+
| `responses` | `Dict[int, Dict[str, OpenAPIResponse]]` | A dictionary that maps each HTTP status code to a Response Object as defined by the [OpenAPI Specification](https://swagger.io/specification/#response-object). This allows you to describe expected responses, including default or error messages, and their corresponding schemas or models for different status codes. |
963+
| `response_description` | `str` | Provides the default textual description of the response sent by the endpoint when the operation is successful. It is intended to give a human-readable understanding of the result. |
964+
| `tags` | `List[str]` | Tags are a way to categorize and group endpoints within the API documentation. They can help organize the operations by resources or other heuristic. |
965+
| `operation_id` | `str` | A unique identifier for the operation, which can be used for referencing this operation in documentation or code. This ID must be unique across all operations described in the API. |
966+
| `include_in_schema` | `bool` | A boolean value that determines whether or not this operation should be included in the OpenAPI schema. Setting it to `False` can hide the endpoint from generated documentation and schema exports, which might be useful for private or experimental endpoints. |
967967

968968
To implement these customizations, include extra parameters when defining your routes:
969969

0 commit comments

Comments
 (0)