Skip to content

Commit ae1a9c8

Browse files
committed
feat(event_handler): add support for additional response models
1 parent aab88f8 commit ae1a9c8

File tree

3 files changed

+147
-15
lines changed

3 files changed

+147
-15
lines changed

aws_lambda_powertools/event_handler/api_gateway.py

Lines changed: 60 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@
3737
from aws_lambda_powertools.event_handler.openapi.types import (
3838
COMPONENT_REF_PREFIX,
3939
METHODS_WITH_BODY,
40+
OpenAPIResponse,
41+
OpenAPIResponseContentModel,
42+
OpenAPIResponseContentSchema,
4043
validation_error_definition,
4144
validation_error_response_definition,
4245
)
@@ -273,7 +276,7 @@ def __init__(
273276
cache_control: Optional[str],
274277
summary: Optional[str],
275278
description: Optional[str],
276-
responses: Optional[Dict[int, Dict[str, Any]]],
279+
responses: Optional[Dict[int, OpenAPIResponse]],
277280
response_description: Optional[str],
278281
tags: Optional[List[str]],
279282
operation_id: Optional[str],
@@ -303,7 +306,7 @@ def __init__(
303306
The OpenAPI summary for this route
304307
description: Optional[str]
305308
The OpenAPI description for this route
306-
responses: Optional[Dict[int, Dict[str, Any]]]
309+
responses: Optional[Dict[int, OpenAPIResponse]]
307310
The OpenAPI responses for this route
308311
response_description: Optional[str]
309312
The OpenAPI response description for this route
@@ -501,11 +504,54 @@ def _get_openapi_path(
501504

502505
# Add the response to the OpenAPI operation
503506
if self.responses:
504-
# If the user supplied responses, we use them and don't set a default 200 response
507+
for status_code in list(self.responses):
508+
response = self.responses[status_code]
509+
510+
# Case 1: there is not 'content' key
511+
if "content" not in response:
512+
response["content"] = {
513+
"application/json": self._openapi_operation_return(
514+
param=dependant.return_param,
515+
model_name_map=model_name_map,
516+
field_mapping=field_mapping,
517+
),
518+
}
519+
520+
# Case 2: there is a 'content' key
521+
else:
522+
# Need to iterate to transform any 'model' into a 'schema'
523+
for content_type, payload in response["content"].items():
524+
new_payload: OpenAPIResponseContentSchema
525+
526+
# Case 2.1: the 'content' has a model
527+
if "model" in payload:
528+
from aws_lambda_powertools.event_handler.openapi.params import analyze_param
529+
530+
return_field = analyze_param(
531+
param_name="return",
532+
annotation=cast(OpenAPIResponseContentModel, payload)["model"],
533+
value=None,
534+
is_path_param=False,
535+
is_response_param=True,
536+
)
537+
538+
new_payload = self._openapi_operation_return(
539+
param=return_field,
540+
model_name_map=model_name_map,
541+
field_mapping=field_mapping,
542+
)
543+
544+
# Case 2.2: the 'content' has a schema
545+
else:
546+
# Do nothing! We already have what we need!
547+
new_payload = payload
548+
549+
response["content"][content_type] = new_payload
550+
505551
operation["responses"] = self.responses
506552
else:
507553
# Set the default 200 response
508-
responses = operation.setdefault("responses", self.responses or {})
554+
responses = operation.setdefault("responses", {})
509555
success_response = responses.setdefault(200, {})
510556
success_response["description"] = self.response_description or _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION
511557
success_response["content"] = {"application/json": {"schema": {}}}
@@ -682,7 +728,7 @@ def _openapi_operation_return(
682728
Tuple["ModelField", Literal["validation", "serialization"]],
683729
"JsonSchemaValue",
684730
],
685-
) -> Dict[str, Any]:
731+
) -> OpenAPIResponseContentSchema:
686732
"""
687733
Returns the OpenAPI operation return.
688734
"""
@@ -832,7 +878,7 @@ def route(
832878
cache_control: Optional[str] = None,
833879
summary: Optional[str] = None,
834880
description: Optional[str] = None,
835-
responses: Optional[Dict[int, Dict[str, Any]]] = None,
881+
responses: Optional[Dict[int, OpenAPIResponse]] = None,
836882
response_description: str = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION,
837883
tags: Optional[List[str]] = None,
838884
operation_id: Optional[str] = None,
@@ -890,7 +936,7 @@ def get(
890936
cache_control: Optional[str] = None,
891937
summary: Optional[str] = None,
892938
description: Optional[str] = None,
893-
responses: Optional[Dict[int, Dict[str, Any]]] = None,
939+
responses: Optional[Dict[int, OpenAPIResponse]] = None,
894940
response_description: str = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION,
895941
tags: Optional[List[str]] = None,
896942
operation_id: Optional[str] = None,
@@ -943,7 +989,7 @@ def post(
943989
cache_control: Optional[str] = None,
944990
summary: Optional[str] = None,
945991
description: Optional[str] = None,
946-
responses: Optional[Dict[int, Dict[str, Any]]] = None,
992+
responses: Optional[Dict[int, OpenAPIResponse]] = None,
947993
response_description: str = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION,
948994
tags: Optional[List[str]] = None,
949995
operation_id: Optional[str] = None,
@@ -997,7 +1043,7 @@ def put(
9971043
cache_control: Optional[str] = None,
9981044
summary: Optional[str] = None,
9991045
description: Optional[str] = None,
1000-
responses: Optional[Dict[int, Dict[str, Any]]] = None,
1046+
responses: Optional[Dict[int, OpenAPIResponse]] = None,
10011047
response_description: str = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION,
10021048
tags: Optional[List[str]] = None,
10031049
operation_id: Optional[str] = None,
@@ -1051,7 +1097,7 @@ def delete(
10511097
cache_control: Optional[str] = None,
10521098
summary: Optional[str] = None,
10531099
description: Optional[str] = None,
1054-
responses: Optional[Dict[int, Dict[str, Any]]] = None,
1100+
responses: Optional[Dict[int, OpenAPIResponse]] = None,
10551101
response_description: str = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION,
10561102
tags: Optional[List[str]] = None,
10571103
operation_id: Optional[str] = None,
@@ -1104,7 +1150,7 @@ def patch(
11041150
cache_control: Optional[str] = None,
11051151
summary: Optional[str] = None,
11061152
description: Optional[str] = None,
1107-
responses: Optional[Dict[int, Dict[str, Any]]] = None,
1153+
responses: Optional[Dict[int, OpenAPIResponse]] = None,
11081154
response_description: str = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION,
11091155
tags: Optional[List[str]] = None,
11101156
operation_id: Optional[str] = None,
@@ -1657,7 +1703,7 @@ def route(
16571703
cache_control: Optional[str] = None,
16581704
summary: Optional[str] = None,
16591705
description: Optional[str] = None,
1660-
responses: Optional[Dict[int, Dict[str, Any]]] = None,
1706+
responses: Optional[Dict[int, OpenAPIResponse]] = None,
16611707
response_description: str = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION,
16621708
tags: Optional[List[str]] = None,
16631709
operation_id: Optional[str] = None,
@@ -2127,7 +2173,7 @@ def route(
21272173
cache_control: Optional[str] = None,
21282174
summary: Optional[str] = None,
21292175
description: Optional[str] = None,
2130-
responses: Optional[Dict[int, Dict[str, Any]]] = None,
2176+
responses: Optional[Dict[int, OpenAPIResponse]] = None,
21312177
response_description: Optional[str] = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION,
21322178
tags: Optional[List[str]] = None,
21332179
operation_id: Optional[str] = None,
@@ -2216,7 +2262,7 @@ def route(
22162262
cache_control: Optional[str] = None,
22172263
summary: Optional[str] = None,
22182264
description: Optional[str] = None,
2219-
responses: Optional[Dict[int, Dict[str, Any]]] = None,
2265+
responses: Optional[Dict[int, OpenAPIResponse]] = None,
22202266
response_description: str = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION,
22212267
tags: Optional[List[str]] = None,
22222268
operation_id: Optional[str] = None,

aws_lambda_powertools/event_handler/openapi/types.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
from enum import Enum
33
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Set, Type, Union
44

5+
from aws_lambda_powertools.shared.types import NotRequired, TypedDict
6+
57
if TYPE_CHECKING:
68
from pydantic import BaseModel # noqa: F401
79

@@ -43,3 +45,16 @@
4345
},
4446
},
4547
}
48+
49+
50+
class OpenAPIResponseContentSchema(TypedDict, total=False):
51+
schema: Dict
52+
53+
54+
class OpenAPIResponseContentModel(TypedDict):
55+
model: Any
56+
57+
58+
class OpenAPIResponse(TypedDict):
59+
description: str
60+
content: NotRequired[Dict[str, Union[OpenAPIResponseContentSchema, OpenAPIResponseContentModel]]]

tests/functional/event_handler/test_openapi_responses.py

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
1-
from aws_lambda_powertools.event_handler import APIGatewayRestResolver
1+
from random import random
2+
from typing import Union
3+
4+
from pydantic import BaseModel
5+
6+
from aws_lambda_powertools.event_handler import APIGatewayRestResolver, Response
27

38

49
def test_openapi_default_response():
@@ -47,3 +52,69 @@ def handler():
4752

4853
assert 200 not in responses.keys()
4954
assert 422 not in responses.keys()
55+
56+
57+
def test_openapi_union_response():
58+
app = APIGatewayRestResolver(enable_validation=True)
59+
60+
class User(BaseModel):
61+
pass
62+
63+
class Order(BaseModel):
64+
pass
65+
66+
@app.get(
67+
"/",
68+
responses={
69+
200: {"description": "200 Response", "content": {"application/json": {"model": User}}},
70+
202: {"description": "202 Response", "content": {"application/json": {"model": Order}}},
71+
},
72+
)
73+
def handler() -> Response[Union[User, Order]]:
74+
if random() > 0.5:
75+
return Response(status_code=200, body=User())
76+
else:
77+
return Response(status_code=202, body=Order())
78+
79+
schema = app.get_openapi_schema()
80+
responses = schema.paths["/"].get.responses
81+
assert 200 in responses.keys()
82+
assert responses[200].description == "200 Response"
83+
assert responses[200].content["application/json"].schema_.ref == "#/components/schemas/User"
84+
85+
assert 202 in responses.keys()
86+
assert responses[202].description == "202 Response"
87+
assert responses[202].content["application/json"].schema_.ref == "#/components/schemas/Order"
88+
89+
90+
def test_openapi_union_partial_response():
91+
app = APIGatewayRestResolver(enable_validation=True)
92+
93+
class User(BaseModel):
94+
pass
95+
96+
class Order(BaseModel):
97+
pass
98+
99+
@app.get(
100+
"/",
101+
responses={
102+
200: {"description": "200 Response"},
103+
202: {"description": "202 Response", "content": {"application/json": {"model": Order}}},
104+
},
105+
)
106+
def handler() -> Response[Union[User, Order]]:
107+
if random() > 0.5:
108+
return Response(status_code=200, body=User())
109+
else:
110+
return Response(status_code=202, body=Order())
111+
112+
schema = app.get_openapi_schema()
113+
responses = schema.paths["/"].get.responses
114+
assert 200 in responses.keys()
115+
assert responses[200].description == "200 Response"
116+
assert responses[200].content["application/json"].schema_.anyOf is not None
117+
118+
assert 202 in responses.keys()
119+
assert responses[202].description == "202 Response"
120+
assert responses[202].content["application/json"].schema_.ref == "#/components/schemas/Order"

0 commit comments

Comments
 (0)