From ae1a9c8b2b005abd04f5d970cbcf12235d50c0b9 Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Thu, 4 Jan 2024 15:25:09 +0100 Subject: [PATCH 1/6] feat(event_handler): add support for additional response models --- .../event_handler/api_gateway.py | 74 +++++++++++++++---- .../event_handler/openapi/types.py | 15 ++++ .../event_handler/test_openapi_responses.py | 73 +++++++++++++++++- 3 files changed, 147 insertions(+), 15 deletions(-) diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 79e194e3719..cc59f44b158 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -37,6 +37,9 @@ from aws_lambda_powertools.event_handler.openapi.types import ( COMPONENT_REF_PREFIX, METHODS_WITH_BODY, + OpenAPIResponse, + OpenAPIResponseContentModel, + OpenAPIResponseContentSchema, validation_error_definition, validation_error_response_definition, ) @@ -273,7 +276,7 @@ def __init__( cache_control: Optional[str], summary: Optional[str], description: Optional[str], - responses: Optional[Dict[int, Dict[str, Any]]], + responses: Optional[Dict[int, OpenAPIResponse]], response_description: Optional[str], tags: Optional[List[str]], operation_id: Optional[str], @@ -303,7 +306,7 @@ def __init__( The OpenAPI summary for this route description: Optional[str] The OpenAPI description for this route - responses: Optional[Dict[int, Dict[str, Any]]] + responses: Optional[Dict[int, OpenAPIResponse]] The OpenAPI responses for this route response_description: Optional[str] The OpenAPI response description for this route @@ -501,11 +504,54 @@ def _get_openapi_path( # Add the response to the OpenAPI operation if self.responses: - # If the user supplied responses, we use them and don't set a default 200 response + for status_code in list(self.responses): + response = self.responses[status_code] + + # Case 1: there is not 'content' key + if "content" not in response: + response["content"] = { + "application/json": self._openapi_operation_return( + param=dependant.return_param, + model_name_map=model_name_map, + field_mapping=field_mapping, + ), + } + + # Case 2: there is a 'content' key + else: + # Need to iterate to transform any 'model' into a 'schema' + for content_type, payload in response["content"].items(): + new_payload: OpenAPIResponseContentSchema + + # Case 2.1: the 'content' has a model + if "model" in payload: + from aws_lambda_powertools.event_handler.openapi.params import analyze_param + + return_field = analyze_param( + param_name="return", + annotation=cast(OpenAPIResponseContentModel, payload)["model"], + value=None, + is_path_param=False, + is_response_param=True, + ) + + new_payload = self._openapi_operation_return( + param=return_field, + model_name_map=model_name_map, + field_mapping=field_mapping, + ) + + # Case 2.2: the 'content' has a schema + else: + # Do nothing! We already have what we need! + new_payload = payload + + response["content"][content_type] = new_payload + operation["responses"] = self.responses else: # Set the default 200 response - responses = operation.setdefault("responses", self.responses or {}) + responses = operation.setdefault("responses", {}) success_response = responses.setdefault(200, {}) success_response["description"] = self.response_description or _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION success_response["content"] = {"application/json": {"schema": {}}} @@ -682,7 +728,7 @@ def _openapi_operation_return( Tuple["ModelField", Literal["validation", "serialization"]], "JsonSchemaValue", ], - ) -> Dict[str, Any]: + ) -> OpenAPIResponseContentSchema: """ Returns the OpenAPI operation return. """ @@ -832,7 +878,7 @@ def route( cache_control: Optional[str] = None, summary: Optional[str] = None, description: Optional[str] = None, - responses: Optional[Dict[int, Dict[str, Any]]] = None, + responses: Optional[Dict[int, OpenAPIResponse]] = None, response_description: str = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, tags: Optional[List[str]] = None, operation_id: Optional[str] = None, @@ -890,7 +936,7 @@ def get( cache_control: Optional[str] = None, summary: Optional[str] = None, description: Optional[str] = None, - responses: Optional[Dict[int, Dict[str, Any]]] = None, + responses: Optional[Dict[int, OpenAPIResponse]] = None, response_description: str = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, tags: Optional[List[str]] = None, operation_id: Optional[str] = None, @@ -943,7 +989,7 @@ def post( cache_control: Optional[str] = None, summary: Optional[str] = None, description: Optional[str] = None, - responses: Optional[Dict[int, Dict[str, Any]]] = None, + responses: Optional[Dict[int, OpenAPIResponse]] = None, response_description: str = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, tags: Optional[List[str]] = None, operation_id: Optional[str] = None, @@ -997,7 +1043,7 @@ def put( cache_control: Optional[str] = None, summary: Optional[str] = None, description: Optional[str] = None, - responses: Optional[Dict[int, Dict[str, Any]]] = None, + responses: Optional[Dict[int, OpenAPIResponse]] = None, response_description: str = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, tags: Optional[List[str]] = None, operation_id: Optional[str] = None, @@ -1051,7 +1097,7 @@ def delete( cache_control: Optional[str] = None, summary: Optional[str] = None, description: Optional[str] = None, - responses: Optional[Dict[int, Dict[str, Any]]] = None, + responses: Optional[Dict[int, OpenAPIResponse]] = None, response_description: str = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, tags: Optional[List[str]] = None, operation_id: Optional[str] = None, @@ -1104,7 +1150,7 @@ def patch( cache_control: Optional[str] = None, summary: Optional[str] = None, description: Optional[str] = None, - responses: Optional[Dict[int, Dict[str, Any]]] = None, + responses: Optional[Dict[int, OpenAPIResponse]] = None, response_description: str = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, tags: Optional[List[str]] = None, operation_id: Optional[str] = None, @@ -1657,7 +1703,7 @@ def route( cache_control: Optional[str] = None, summary: Optional[str] = None, description: Optional[str] = None, - responses: Optional[Dict[int, Dict[str, Any]]] = None, + responses: Optional[Dict[int, OpenAPIResponse]] = None, response_description: str = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, tags: Optional[List[str]] = None, operation_id: Optional[str] = None, @@ -2127,7 +2173,7 @@ def route( cache_control: Optional[str] = None, summary: Optional[str] = None, description: Optional[str] = None, - responses: Optional[Dict[int, Dict[str, Any]]] = None, + responses: Optional[Dict[int, OpenAPIResponse]] = None, response_description: Optional[str] = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, tags: Optional[List[str]] = None, operation_id: Optional[str] = None, @@ -2216,7 +2262,7 @@ def route( cache_control: Optional[str] = None, summary: Optional[str] = None, description: Optional[str] = None, - responses: Optional[Dict[int, Dict[str, Any]]] = None, + responses: Optional[Dict[int, OpenAPIResponse]] = None, response_description: str = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, tags: Optional[List[str]] = None, operation_id: Optional[str] = None, diff --git a/aws_lambda_powertools/event_handler/openapi/types.py b/aws_lambda_powertools/event_handler/openapi/types.py index 0d166de1131..beafa0e566c 100644 --- a/aws_lambda_powertools/event_handler/openapi/types.py +++ b/aws_lambda_powertools/event_handler/openapi/types.py @@ -2,6 +2,8 @@ from enum import Enum from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Set, Type, Union +from aws_lambda_powertools.shared.types import NotRequired, TypedDict + if TYPE_CHECKING: from pydantic import BaseModel # noqa: F401 @@ -43,3 +45,16 @@ }, }, } + + +class OpenAPIResponseContentSchema(TypedDict, total=False): + schema: Dict + + +class OpenAPIResponseContentModel(TypedDict): + model: Any + + +class OpenAPIResponse(TypedDict): + description: str + content: NotRequired[Dict[str, Union[OpenAPIResponseContentSchema, OpenAPIResponseContentModel]]] diff --git a/tests/functional/event_handler/test_openapi_responses.py b/tests/functional/event_handler/test_openapi_responses.py index bd470867428..819ab2e1041 100644 --- a/tests/functional/event_handler/test_openapi_responses.py +++ b/tests/functional/event_handler/test_openapi_responses.py @@ -1,4 +1,9 @@ -from aws_lambda_powertools.event_handler import APIGatewayRestResolver +from random import random +from typing import Union + +from pydantic import BaseModel + +from aws_lambda_powertools.event_handler import APIGatewayRestResolver, Response def test_openapi_default_response(): @@ -47,3 +52,69 @@ def handler(): assert 200 not in responses.keys() assert 422 not in responses.keys() + + +def test_openapi_union_response(): + app = APIGatewayRestResolver(enable_validation=True) + + class User(BaseModel): + pass + + class Order(BaseModel): + pass + + @app.get( + "/", + responses={ + 200: {"description": "200 Response", "content": {"application/json": {"model": User}}}, + 202: {"description": "202 Response", "content": {"application/json": {"model": Order}}}, + }, + ) + def handler() -> Response[Union[User, Order]]: + if random() > 0.5: + return Response(status_code=200, body=User()) + else: + return Response(status_code=202, body=Order()) + + schema = app.get_openapi_schema() + responses = schema.paths["/"].get.responses + assert 200 in responses.keys() + assert responses[200].description == "200 Response" + assert responses[200].content["application/json"].schema_.ref == "#/components/schemas/User" + + assert 202 in responses.keys() + assert responses[202].description == "202 Response" + assert responses[202].content["application/json"].schema_.ref == "#/components/schemas/Order" + + +def test_openapi_union_partial_response(): + app = APIGatewayRestResolver(enable_validation=True) + + class User(BaseModel): + pass + + class Order(BaseModel): + pass + + @app.get( + "/", + responses={ + 200: {"description": "200 Response"}, + 202: {"description": "202 Response", "content": {"application/json": {"model": Order}}}, + }, + ) + def handler() -> Response[Union[User, Order]]: + if random() > 0.5: + return Response(status_code=200, body=User()) + else: + return Response(status_code=202, body=Order()) + + schema = app.get_openapi_schema() + responses = schema.paths["/"].get.responses + assert 200 in responses.keys() + assert responses[200].description == "200 Response" + assert responses[200].content["application/json"].schema_.anyOf is not None + + assert 202 in responses.keys() + assert responses[202].description == "202 Response" + assert responses[202].content["application/json"].schema_.ref == "#/components/schemas/Order" From 568bcc825c853136edb63cc79ff12b31a93101e6 Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Thu, 4 Jan 2024 15:29:19 +0100 Subject: [PATCH 2/6] fix: I hate sonarcube --- tests/functional/event_handler/test_openapi_responses.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/functional/event_handler/test_openapi_responses.py b/tests/functional/event_handler/test_openapi_responses.py index 819ab2e1041..cec6b0b5b1a 100644 --- a/tests/functional/event_handler/test_openapi_responses.py +++ b/tests/functional/event_handler/test_openapi_responses.py @@ -1,4 +1,4 @@ -from random import random +from secrets import randbelow from typing import Union from pydantic import BaseModel @@ -71,7 +71,7 @@ class Order(BaseModel): }, ) def handler() -> Response[Union[User, Order]]: - if random() > 0.5: + if randbelow(2) > 0: return Response(status_code=200, body=User()) else: return Response(status_code=202, body=Order()) @@ -104,7 +104,7 @@ class Order(BaseModel): }, ) def handler() -> Response[Union[User, Order]]: - if random() > 0.5: + if randbelow(2) > 0: return Response(status_code=200, body=User()) else: return Response(status_code=202, body=Order()) From d7792fd5ef55ca6b436cbbdc80904f3df8d8bc06 Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Thu, 4 Jan 2024 17:02:25 +0100 Subject: [PATCH 3/6] fix: pydantic 2 --- .../event_handler/api_gateway.py | 21 ++++++++++-------- .../event_handler/openapi/dependant.py | 22 +++++++++++++++++++ .../event_handler/openapi/params.py | 2 ++ 3 files changed, 36 insertions(+), 9 deletions(-) diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index cc59f44b158..821b2dcbed2 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -445,7 +445,7 @@ def dependant(self) -> "Dependant": if self._dependant is None: from aws_lambda_powertools.event_handler.openapi.dependant import get_dependant - self._dependant = get_dependant(path=self.openapi_path, call=self.func) + self._dependant = get_dependant(path=self.openapi_path, call=self.func, responses=self.responses) return self._dependant @@ -525,15 +525,15 @@ def _get_openapi_path( # Case 2.1: the 'content' has a model if "model" in payload: - from aws_lambda_powertools.event_handler.openapi.params import analyze_param - - return_field = analyze_param( - param_name="return", - annotation=cast(OpenAPIResponseContentModel, payload)["model"], - value=None, - is_path_param=False, - is_response_param=True, + # Find the model in the dependant's extra models + return_field = next( + filter( + lambda model: model.type_ is cast(OpenAPIResponseContentModel, payload)["model"], + self.dependant.response_extra_models, + ), ) + if not return_field: + raise AssertionError("Model declared in custom responses was not found") new_payload = self._openapi_operation_return( param=return_field, @@ -2151,6 +2151,9 @@ def _get_fields_from_routes(routes: Sequence[Route]) -> List["ModelField"]: if route.dependant.return_param: responses_from_routes.append(route.dependant.return_param) + if route.dependant.response_extra_models: + responses_from_routes.extend(route.dependant.response_extra_models) + flat_models = list(responses_from_routes + request_fields_from_routes + body_fields_from_routes) return flat_models diff --git a/aws_lambda_powertools/event_handler/openapi/dependant.py b/aws_lambda_powertools/event_handler/openapi/dependant.py index e22eb535a7e..a81107348cf 100644 --- a/aws_lambda_powertools/event_handler/openapi/dependant.py +++ b/aws_lambda_powertools/event_handler/openapi/dependant.py @@ -24,6 +24,7 @@ create_response_field, get_flat_dependant, ) +from aws_lambda_powertools.event_handler.openapi.types import OpenAPIResponse, OpenAPIResponseContentModel """ This turns the opaque function signature into typed, validated models. @@ -145,6 +146,7 @@ def get_dependant( path: str, call: Callable[..., Any], name: Optional[str] = None, + responses: Optional[Dict[int, OpenAPIResponse]] = None, ) -> Dependant: """ Returns a dependant model for a handler function. A dependant model is a model that contains @@ -158,6 +160,8 @@ def get_dependant( The handler function name: str, optional The name of the handler function + responses: List[Dict[int, OpenAPIResponse]], optional + The list of extra responses for the handler function Returns ------- @@ -210,6 +214,24 @@ def get_dependant( dependant.return_param = param_field + # Also add the optional extra responses to the dependant model. + if responses: + for response in responses.values(): + if "content" in response: + for schema in response["content"].values(): + if "model" in schema: + response_field = analyze_param( + param_name="return", + annotation=cast(OpenAPIResponseContentModel, schema)["model"], + value=None, + is_path_param=False, + is_response_param=True, + ) + if response_field is None: + raise AssertionError("Response field is None for response model") + + dependant.response_extra_models.append(response_field) + return dependant diff --git a/aws_lambda_powertools/event_handler/openapi/params.py b/aws_lambda_powertools/event_handler/openapi/params.py index bd542ba7932..78426cbc7c9 100644 --- a/aws_lambda_powertools/event_handler/openapi/params.py +++ b/aws_lambda_powertools/event_handler/openapi/params.py @@ -49,6 +49,7 @@ def __init__( cookie_params: Optional[List[ModelField]] = None, body_params: Optional[List[ModelField]] = None, return_param: Optional[ModelField] = None, + response_extra_models: Optional[List[ModelField]] = None, name: Optional[str] = None, call: Optional[Callable[..., Any]] = None, request_param_name: Optional[str] = None, @@ -64,6 +65,7 @@ def __init__( self.cookie_params = cookie_params or [] self.body_params = body_params or [] self.return_param = return_param or None + self.response_extra_models = response_extra_models or [] self.request_param_name = request_param_name self.websocket_param_name = websocket_param_name self.http_connection_param_name = http_connection_param_name From e9d1c31a50fbf7f02cbcef2da2980490e9d1e203 Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Thu, 4 Jan 2024 17:08:57 +0100 Subject: [PATCH 4/6] fix: refactor --- .../event_handler/openapi/dependant.py | 48 +++++++++++-------- 1 file changed, 28 insertions(+), 20 deletions(-) diff --git a/aws_lambda_powertools/event_handler/openapi/dependant.py b/aws_lambda_powertools/event_handler/openapi/dependant.py index a81107348cf..418a86e083c 100644 --- a/aws_lambda_powertools/event_handler/openapi/dependant.py +++ b/aws_lambda_powertools/event_handler/openapi/dependant.py @@ -199,6 +199,34 @@ def get_dependant( else: add_param_to_fields(field=param_field, dependant=dependant) + _add_return_annotation(dependant, endpoint_signature) + _add_extra_responses(dependant, responses) + + return dependant + + +def _add_extra_responses(dependant: Dependant, responses: Optional[Dict[int, OpenAPIResponse]]): + # Also add the optional extra responses to the dependant model. + if not responses: + return + + for response in responses.values(): + for schema in response.get("content", {}).values(): + if "model" in schema: + response_field = analyze_param( + param_name="return", + annotation=cast(OpenAPIResponseContentModel, schema)["model"], + value=None, + is_path_param=False, + is_response_param=True, + ) + if response_field is None: + raise AssertionError("Response field is None for response model") + + dependant.response_extra_models.append(response_field) + + +def _add_return_annotation(dependant: Dependant, endpoint_signature: inspect.Signature): # If the return annotation is not empty, add it to the dependant model. return_annotation = endpoint_signature.return_annotation if return_annotation is not inspect.Signature.empty: @@ -214,26 +242,6 @@ def get_dependant( dependant.return_param = param_field - # Also add the optional extra responses to the dependant model. - if responses: - for response in responses.values(): - if "content" in response: - for schema in response["content"].values(): - if "model" in schema: - response_field = analyze_param( - param_name="return", - annotation=cast(OpenAPIResponseContentModel, schema)["model"], - value=None, - is_path_param=False, - is_response_param=True, - ) - if response_field is None: - raise AssertionError("Response field is None for response model") - - dependant.response_extra_models.append(response_field) - - return dependant - def is_body_param(*, param_field: ModelField, is_path_param: bool) -> bool: """ From 04adc769cb161f626c9f2894e944e3ed3fdb2c50 Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Fri, 5 Jan 2024 13:20:38 +0100 Subject: [PATCH 5/6] fix: increase coverage --- .../event_handler/test_openapi_responses.py | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/tests/functional/event_handler/test_openapi_responses.py b/tests/functional/event_handler/test_openapi_responses.py index cec6b0b5b1a..be5d9bca288 100644 --- a/tests/functional/event_handler/test_openapi_responses.py +++ b/tests/functional/event_handler/test_openapi_responses.py @@ -54,6 +54,27 @@ def handler(): assert 422 not in responses.keys() +def test_openapi_200_custom_schema(): + app = APIGatewayRestResolver(enable_validation=True) + + class User(BaseModel): + pass + + @app.get( + "/", + responses={200: {"description": "Custom response", "content": {"application/json": {"schema": User.schema()}}}}, + ) + def handler(): + return {"message": "hello world"} + + schema = app.get_openapi_schema() + responses = schema.paths["/"].get.responses + assert 200 in responses.keys() + + assert responses[200].description == "Custom response" + assert responses[200].content["application/json"].schema_.title == "User" + + def test_openapi_union_response(): app = APIGatewayRestResolver(enable_validation=True) From bb89ed985e19b5b7ff4b5db9f4f227a7f004882b Mon Sep 17 00:00:00 2001 From: Ruben Fonseca Date: Tue, 16 Jan 2024 15:57:37 +0100 Subject: [PATCH 6/6] chore: update docs --- docs/core/event_handler/api_gateway.md | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/docs/core/event_handler/api_gateway.md b/docs/core/event_handler/api_gateway.md index 4602231a63e..a34a94975bc 100644 --- a/docs/core/event_handler/api_gateway.md +++ b/docs/core/event_handler/api_gateway.md @@ -955,15 +955,15 @@ Customize your API endpoints by adding metadata to endpoint definitions. This pr Here's a breakdown of various customizable fields: -| Field Name | Type | Description | -| ---------------------- | --------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `summary` | `str` | A concise overview of the main functionality of the endpoint. This brief introduction is usually displayed in autogenerated API documentation and helps consumers quickly understand what the endpoint does. | -| `description` | `str` | A more detailed explanation of the endpoint, which can include information about the operation's behavior, including side effects, error states, and other operational guidelines. | -| `responses` | `Dict[int, Dict[str, Any]]` | A dictionary that maps each HTTP status code to a Response Object as defined by the [OpenAPI Specification](https://swagger.io/specification/#response-object). This allows you to describe expected responses, including default or error messages, and their corresponding schemas for different status codes. | -| `response_description` | `str` | Provides the default textual description of the response sent by the endpoint when the operation is successful. It is intended to give a human-readable understanding of the result. | -| `tags` | `List[str]` | Tags are a way to categorize and group endpoints within the API documentation. They can help organize the operations by resources or other heuristic. | -| `operation_id` | `str` | A unique identifier for the operation, which can be used for referencing this operation in documentation or code. This ID must be unique across all operations described in the API. | -| `include_in_schema` | `bool` | A boolean value that determines whether or not this operation should be included in the OpenAPI schema. Setting it to `False` can hide the endpoint from generated documentation and schema exports, which might be useful for private or experimental endpoints. | +| Field Name | Type | Description | +| ---------------------- |-----------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| `summary` | `str` | A concise overview of the main functionality of the endpoint. This brief introduction is usually displayed in autogenerated API documentation and helps consumers quickly understand what the endpoint does. | +| `description` | `str` | A more detailed explanation of the endpoint, which can include information about the operation's behavior, including side effects, error states, and other operational guidelines. | +| `responses` | `Dict[int, Dict[str, OpenAPIResponse]]` | A dictionary that maps each HTTP status code to a Response Object as defined by the [OpenAPI Specification](https://swagger.io/specification/#response-object). This allows you to describe expected responses, including default or error messages, and their corresponding schemas or models for different status codes. | +| `response_description` | `str` | Provides the default textual description of the response sent by the endpoint when the operation is successful. It is intended to give a human-readable understanding of the result. | +| `tags` | `List[str]` | Tags are a way to categorize and group endpoints within the API documentation. They can help organize the operations by resources or other heuristic. | +| `operation_id` | `str` | A unique identifier for the operation, which can be used for referencing this operation in documentation or code. This ID must be unique across all operations described in the API. | +| `include_in_schema` | `bool` | A boolean value that determines whether or not this operation should be included in the OpenAPI schema. Setting it to `False` can hide the endpoint from generated documentation and schema exports, which might be useful for private or experimental endpoints. | To implement these customizations, include extra parameters when defining your routes: