diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 46cb5587135..1383b74ada0 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -9,13 +9,35 @@ from enum import Enum from functools import partial from http import HTTPStatus -from typing import Any, Callable, Dict, List, Match, Optional, Pattern, Set, Tuple, Type, Union +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + List, + Match, + Optional, + Pattern, + Sequence, + Set, + Tuple, + Type, + Union, + cast, +) from aws_lambda_powertools.event_handler import content_types from aws_lambda_powertools.event_handler.exceptions import NotFoundError, ServiceError +from aws_lambda_powertools.event_handler.openapi.types import ( + COMPONENT_REF_PREFIX, + METHODS_WITH_BODY, + validation_error_definition, + validation_error_response_definition, +) from aws_lambda_powertools.shared.cookies import Cookie from aws_lambda_powertools.shared.functions import powertools_dev_is_set from aws_lambda_powertools.shared.json_encoder import Encoder +from aws_lambda_powertools.shared.types import Literal from aws_lambda_powertools.utilities.data_classes import ( ALBEvent, APIGatewayProxyEvent, @@ -34,8 +56,26 @@ # API GW/ALB decode non-safe URI chars; we must support them too _UNSAFE_URI = r"%<> \[\]{}|^" _NAMED_GROUP_BOUNDARY_PATTERN = rf"(?P\1[{_SAFE_URI}{_UNSAFE_URI}\\w]+)" +_DEFAULT_OPENAPI_RESPONSE_DESCRIPTION = "Successful Response" _ROUTE_REGEX = "^{}$" +if TYPE_CHECKING: + from aws_lambda_powertools.event_handler.openapi.compat import ( + JsonSchemaValue, + ModelField, + ) + from aws_lambda_powertools.event_handler.openapi.models import ( + Contact, + License, + OpenAPI, + Server, + Tag, + ) + from aws_lambda_powertools.event_handler.openapi.params import Dependant + from aws_lambda_powertools.event_handler.openapi.types import ( + TypeModelOrEnum, + ) + class ProxyEventType(Enum): """An enumerations of the supported proxy event types.""" @@ -203,11 +243,18 @@ class Route: def __init__( self, method: str, + path: str, rule: Pattern, func: Callable, cors: bool, compress: bool, cache_control: Optional[str], + summary: Optional[str], + description: Optional[str], + responses: Optional[Dict[int, Dict[str, Any]]], + response_description: Optional[str], + tags: Optional[List["Tag"]], + operation_id: Optional[str], middlewares: Optional[List[Callable[..., Response]]], ): """ @@ -217,6 +264,8 @@ def __init__( method: str The HTTP method, example "GET" + path: str + The path of the route rule: Pattern The route rule, example "/my/path" func: Callable @@ -227,21 +276,46 @@ def __init__( Whether or not to enable gzip compression for this route cache_control: Optional[str] The cache control header value, example "max-age=3600" + summary: Optional[str] + The OpenAPI summary for this route + description: Optional[str] + The OpenAPI description for this route + responses: Optional[Dict[int, Dict[str, Any]]] + The OpenAPI responses for this route + response_description: Optional[str] + The OpenAPI response description for this route + tags: Optional[List[Tag]] + The list of OpenAPI tags to be used for this route + operation_id: Optional[str] + The OpenAPI operationId for this route middlewares: Optional[List[Callable[..., Response]]] The list of route middlewares to be called in order. """ self.method = method.upper() + self.path = "/" if path.strip() == "" else path self.rule = rule self.func = func self._middleware_stack = func self.cors = cors self.compress = compress self.cache_control = cache_control + self.summary = summary + self.description = description + self.responses = responses + self.response_description = response_description + self.tags = tags or [] self.middlewares = middlewares or [] + self.operation_id = operation_id or self._generate_operation_id() # _middleware_stack_built is used to ensure the middleware stack is only built once. self._middleware_stack_built = False + # _dependant is used to cache the dependant model for the handler function + self._dependant: Optional["Dependant"] = None + + # _body_field is used to cache the dependant model for the body field + self._body_field: Optional["ModelField"] = None + def __call__( self, router_middlewares: List[Callable], @@ -332,6 +406,275 @@ def _build_middleware_stack(self, router_middlewares: List[Callable[..., Any]]) self._middleware_stack_built = True + @property + def dependant(self) -> "Dependant": + if self._dependant is None: + from aws_lambda_powertools.event_handler.openapi.dependant import get_dependant + + self._dependant = get_dependant(path=self.path, call=self.func) + + return self._dependant + + @property + def body_field(self) -> Optional["ModelField"]: + if self._body_field is None: + from aws_lambda_powertools.event_handler.openapi.dependant import get_body_field + + self._body_field = get_body_field(dependant=self.dependant, name=self.operation_id) + + return self._body_field + + def _get_openapi_path( + self, + *, + dependant: "Dependant", + operation_ids: Set[str], + model_name_map: Dict["TypeModelOrEnum", str], + field_mapping: Dict[Tuple["ModelField", Literal["validation", "serialization"]], "JsonSchemaValue"], + ) -> Tuple[Dict[str, Any], Dict[str, Any]]: + """ + Returns the OpenAPI path and definitions for the route. + """ + from aws_lambda_powertools.event_handler.openapi.dependant import get_flat_params + + path = {} + definitions: Dict[str, Any] = {} + + # Gather all the route parameters + operation = self._openapi_operation_metadata(operation_ids=operation_ids) + parameters: List[Dict[str, Any]] = [] + all_route_params = get_flat_params(dependant) + operation_params = self._openapi_operation_parameters( + all_route_params=all_route_params, + model_name_map=model_name_map, + field_mapping=field_mapping, + ) + parameters.extend(operation_params) + + # Add the parameters to the OpenAPI operation + if parameters: + all_parameters = {(param["in"], param["name"]): param for param in parameters} + required_parameters = {(param["in"], param["name"]): param for param in parameters if param.get("required")} + all_parameters.update(required_parameters) + operation["parameters"] = list(all_parameters.values()) + + # Add the request body to the OpenAPI operation, if applicable + if self.method.upper() in METHODS_WITH_BODY: + request_body_oai = self._openapi_operation_request_body( + body_field=self.body_field, + model_name_map=model_name_map, + field_mapping=field_mapping, + ) + if request_body_oai: + operation["requestBody"] = request_body_oai + + # Add the response to the OpenAPI operation + if self.responses: + # If the user supplied responses, we use them and don't set a default 200 response + operation["responses"] = self.responses + else: + # Set the default 200 response + responses = operation.setdefault("responses", self.responses or {}) + success_response = responses.setdefault(200, {}) + success_response["description"] = self.response_description or _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION + success_response["content"] = {"application/json": {"schema": {}}} + json_response = success_response["content"].setdefault("application/json", {}) + + # Add the response schema to the OpenAPI 200 response + json_response.update( + self._openapi_operation_return( + operation_id=self.operation_id, + param=dependant.return_param, + model_name_map=model_name_map, + field_mapping=field_mapping, + ), + ) + + # Add validation failure response (422) + operation["responses"][422] = { + "description": "Validation Error", + "content": { + "application/json": { + "schema": {"$ref": COMPONENT_REF_PREFIX + "HTTPValidationError"}, + }, + }, + } + + # Add the validation error schema to the definitions, but only if it hasn't been added yet + if "ValidationError" not in definitions: + definitions.update( + { + "ValidationError": validation_error_definition, + "HTTPValidationError": validation_error_response_definition, + }, + ) + + path[self.method.lower()] = operation + + # Generate the response schema + return path, definitions + + def _openapi_operation_summary(self) -> str: + """ + Returns the OpenAPI operation summary. If the user has not provided a summary, we + generate one based on the route path and method. + """ + return self.summary or f"{self.method.upper()} {self.path}" + + def _openapi_operation_metadata(self, operation_ids: Set[str]) -> Dict[str, Any]: + """ + Returns the OpenAPI operation metadata. If the user has not provided a description, we + generate one based on the route path and method. + """ + operation: Dict[str, Any] = {} + + # Ensure tags is added to the operation + if self.tags: + operation["tags"] = self.tags + + # Ensure summary is added to the operation + operation["summary"] = self._openapi_operation_summary() + + # Ensure description is added to the operation + if self.description: + operation["description"] = self.description + + # Ensure operationId is unique + if self.operation_id in operation_ids: + message = f"Duplicate Operation ID {self.operation_id} for function {self.func.__name__}" + file_name = getattr(self.func, "__globals__", {}).get("__file__") + if file_name: + message += f" in {file_name}" + warnings.warn(message, stacklevel=1) + + # Adds the operation + operation_ids.add(self.operation_id) + operation["operationId"] = self.operation_id + + return operation + + @staticmethod + def _openapi_operation_request_body( + *, + body_field: Optional["ModelField"], + model_name_map: Dict["TypeModelOrEnum", str], + field_mapping: Dict[Tuple["ModelField", Literal["validation", "serialization"]], "JsonSchemaValue"], + ) -> Optional[Dict[str, Any]]: + """ + Returns the OpenAPI operation request body. + """ + from aws_lambda_powertools.event_handler.openapi.compat import ModelField, get_schema_from_model_field + from aws_lambda_powertools.event_handler.openapi.params import Body + + # Check that there is a body field and it's a Pydantic's model field + if not body_field: + return None + + if not isinstance(body_field, ModelField): + raise AssertionError(f"Expected ModelField, got {body_field}") + + # Generate the request body schema + body_schema = get_schema_from_model_field( + field=body_field, + model_name_map=model_name_map, + field_mapping=field_mapping, + ) + + field_info = cast(Body, body_field.field_info) + request_media_type = field_info.media_type + required = body_field.required + request_body_oai: Dict[str, Any] = {} + if required: + request_body_oai["required"] = required + + # Generate the request body media type + request_media_content: Dict[str, Any] = {"schema": body_schema} + request_body_oai["content"] = {request_media_type: request_media_content} + return request_body_oai + + @staticmethod + def _openapi_operation_parameters( + *, + all_route_params: Sequence["ModelField"], + model_name_map: Dict["TypeModelOrEnum", str], + field_mapping: Dict[ + Tuple["ModelField", Literal["validation", "serialization"]], + "JsonSchemaValue", + ], + ) -> List[Dict[str, Any]]: + """ + Returns the OpenAPI operation parameters. + """ + from aws_lambda_powertools.event_handler.openapi.compat import ( + get_schema_from_model_field, + ) + from aws_lambda_powertools.event_handler.openapi.params import Param + + parameters = [] + for param in all_route_params: + field_info = param.field_info + field_info = cast(Param, field_info) + if not field_info.include_in_schema: + continue + + param_schema = get_schema_from_model_field( + field=param, + model_name_map=model_name_map, + field_mapping=field_mapping, + ) + + parameter = { + "name": param.alias, + "in": field_info.in_.value, + "required": param.required, + "schema": param_schema, + } + + if field_info.description: + parameter["description"] = field_info.description + + if field_info.deprecated: + parameter["deprecated"] = field_info.deprecated + + parameters.append(parameter) + + return parameters + + @staticmethod + def _openapi_operation_return( + *, + operation_id: str, + param: Optional["ModelField"], + model_name_map: Dict["TypeModelOrEnum", str], + field_mapping: Dict[ + Tuple["ModelField", Literal["validation", "serialization"]], + "JsonSchemaValue", + ], + ) -> Dict[str, Any]: + """ + Returns the OpenAPI operation return. + """ + if param is None: + return {} + + from aws_lambda_powertools.event_handler.openapi.compat import ( + get_schema_from_model_field, + ) + + return_schema = get_schema_from_model_field( + field=param, + model_name_map=model_name_map, + field_mapping=field_mapping, + ) + + return {"name": f"Return {operation_id}", "schema": return_schema} + + def _generate_operation_id(self) -> str: + operation_id = self.func.__name__ + self.path + operation_id = re.sub(r"\W", "_", operation_id) + operation_id = operation_id + "_" + self.method.lower() + return operation_id + class ResponseBuilder: """Internally used Response builder""" @@ -443,6 +786,12 @@ def route( cors: Optional[bool] = None, compress: bool = False, cache_control: Optional[str] = None, + summary: Optional[str] = None, + description: Optional[str] = None, + responses: Optional[Dict[int, Dict[str, Any]]] = None, + response_description: str = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, + tags: Optional[List["Tag"]] = None, + operation_id: Optional[str] = None, middlewares: Optional[List[Callable[..., Any]]] = None, ): raise NotImplementedError() @@ -494,6 +843,12 @@ def get( cors: Optional[bool] = None, compress: bool = False, cache_control: Optional[str] = None, + summary: Optional[str] = None, + description: Optional[str] = None, + responses: Optional[Dict[int, Dict[str, Any]]] = None, + response_description: str = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, + tags: Optional[List["Tag"]] = None, + operation_id: Optional[str] = None, middlewares: Optional[List[Callable[..., Any]]] = None, ): """Get route decorator with GET `method` @@ -518,7 +873,20 @@ def lambda_handler(event, context): return app.resolve(event, context) ``` """ - return self.route(rule, "GET", cors, compress, cache_control, middlewares) + return self.route( + rule, + "GET", + cors, + compress, + cache_control, + summary, + description, + responses, + response_description, + tags, + operation_id, + middlewares, + ) def post( self, @@ -526,6 +894,12 @@ def post( cors: Optional[bool] = None, compress: bool = False, cache_control: Optional[str] = None, + summary: Optional[str] = None, + description: Optional[str] = None, + responses: Optional[Dict[int, Dict[str, Any]]] = None, + response_description: str = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, + tags: Optional[List["Tag"]] = None, + operation_id: Optional[str] = None, middlewares: Optional[List[Callable[..., Any]]] = None, ): """Post route decorator with POST `method` @@ -551,7 +925,20 @@ def lambda_handler(event, context): return app.resolve(event, context) ``` """ - return self.route(rule, "POST", cors, compress, cache_control, middlewares) + return self.route( + rule, + "POST", + cors, + compress, + cache_control, + summary, + description, + responses, + response_description, + tags, + operation_id, + middlewares, + ) def put( self, @@ -559,6 +946,12 @@ def put( cors: Optional[bool] = None, compress: bool = False, cache_control: Optional[str] = None, + summary: Optional[str] = None, + description: Optional[str] = None, + responses: Optional[Dict[int, Dict[str, Any]]] = None, + response_description: str = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, + tags: Optional[List["Tag"]] = None, + operation_id: Optional[str] = None, middlewares: Optional[List[Callable[..., Any]]] = None, ): """Put route decorator with PUT `method` @@ -584,7 +977,20 @@ def lambda_handler(event, context): return app.resolve(event, context) ``` """ - return self.route(rule, "PUT", cors, compress, cache_control, middlewares) + return self.route( + rule, + "PUT", + cors, + compress, + cache_control, + summary, + description, + responses, + response_description, + tags, + operation_id, + middlewares, + ) def delete( self, @@ -592,6 +998,12 @@ def delete( cors: Optional[bool] = None, compress: bool = False, cache_control: Optional[str] = None, + summary: Optional[str] = None, + description: Optional[str] = None, + responses: Optional[Dict[int, Dict[str, Any]]] = None, + response_description: str = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, + tags: Optional[List["Tag"]] = None, + operation_id: Optional[str] = None, middlewares: Optional[List[Callable[..., Any]]] = None, ): """Delete route decorator with DELETE `method` @@ -616,7 +1028,20 @@ def lambda_handler(event, context): return app.resolve(event, context) ``` """ - return self.route(rule, "DELETE", cors, compress, cache_control, middlewares) + return self.route( + rule, + "DELETE", + cors, + compress, + cache_control, + summary, + description, + responses, + response_description, + tags, + operation_id, + middlewares, + ) def patch( self, @@ -624,6 +1049,12 @@ def patch( cors: Optional[bool] = None, compress: bool = False, cache_control: Optional[str] = None, + summary: Optional[str] = None, + description: Optional[str] = None, + responses: Optional[Dict[int, Dict[str, Any]]] = None, + response_description: str = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, + tags: Optional[List["Tag"]] = None, + operation_id: Optional[str] = None, middlewares: Optional[List[Callable]] = None, ): """Patch route decorator with PATCH `method` @@ -651,7 +1082,20 @@ def lambda_handler(event, context): return app.resolve(event, context) ``` """ - return self.route(rule, "PATCH", cors, compress, cache_control, middlewares) + return self.route( + rule, + "PATCH", + cors, + compress, + cache_control, + summary, + description, + responses, + response_description, + tags, + operation_id, + middlewares, + ) def _push_processed_stack_frame(self, frame: str): """ @@ -676,7 +1120,7 @@ def clear_context(self): class MiddlewareFrame: """ - creates a Middle Stack Wrapper instance to be used as a "Frame" in the overall stack of + Creates a Middle Stack Wrapper instance to be used as a "Frame" in the overall stack of middleware functions. Each instance contains the current middleware and the next middleware function to be called in the stack. @@ -813,6 +1257,7 @@ def __init__( debug: Optional[bool] = None, serializer: Optional[Callable[[Dict], str]] = None, strip_prefixes: Optional[List[Union[str, Pattern]]] = None, + enable_validation: bool = False, ): """ Parameters @@ -824,12 +1269,14 @@ def __init__( debug: Optional[bool] Enables debug mode, by default False. Can be also be enabled by "POWERTOOLS_DEV" environment variable - serializer : Callable, optional + serializer: Callable, optional function to serialize `obj` to a JSON formatted `str`, by default json.dumps strip_prefixes: List[Union[str, Pattern]], optional optional list of prefixes to be removed from the request path before doing the routing. This is often used with api gateways with multiple custom mappings. Each prefix can be a static string or a compiled regex pattern + enable_validation: Optional[bool] + Enables validation of the request body against the route schema, by default False. """ self._proxy_type = proxy_type self._dynamic_routes: List[Route] = [] @@ -840,6 +1287,7 @@ def __init__( self._cors_enabled: bool = cors is not None self._cors_methods: Set[str] = {"OPTIONS"} self._debug = self._has_debug(debug) + self._enable_validation = enable_validation self._strip_prefixes = strip_prefixes self.context: Dict = {} # early init as customers might add context before event resolution self.processed_stack_frames = [] @@ -847,6 +1295,203 @@ def __init__( # Allow for a custom serializer or a concise json serialization self._serializer = serializer or partial(json.dumps, separators=(",", ":"), cls=Encoder) + if self._enable_validation: + from aws_lambda_powertools.event_handler.middlewares.openapi_validation import OpenAPIValidationMiddleware + + self.use([OpenAPIValidationMiddleware()]) + + # When using validation, we need to skip the serializer, as the middleware is doing it automatically. + # However, if the user is using a custom serializer, we need to abort. + if serializer: + raise ValueError("Cannot use a custom serializer when using validation") + + # Install a dummy serializer + self._serializer = lambda args: args # type: ignore + + def get_openapi_schema( + self, + *, + title: str = "Powertools API", + version: str = "1.0.0", + openapi_version: str = "3.1.0", + summary: Optional[str] = None, + description: Optional[str] = None, + tags: Optional[List["Tag"]] = None, + servers: Optional[List["Server"]] = None, + terms_of_service: Optional[str] = None, + contact: Optional["Contact"] = None, + license_info: Optional["License"] = None, + ) -> "OpenAPI": + """ + Returns the OpenAPI schema as a pydantic model. + + Parameters + ---------- + title: str + The title of the application. + version: str + The version of the OpenAPI document (which is distinct from the OpenAPI Specification version or the API + openapi_version: str, default = "3.1.0" + The version of the OpenAPI Specification (which the document uses). + summary: str, optional + A short summary of what the application does. + description: str, optional + A verbose explanation of the application behavior. + tags: List[Tag], optional + A list of tags used by the specification with additional metadata. + servers: List[Server], optional + An array of Server Objects, which provide connectivity information to a target server. + terms_of_service: str, optional + A URL to the Terms of Service for the API. MUST be in the format of a URL. + contact: Contact, optional + The contact information for the exposed API. + license_info: + The license information for the exposed API. + + Returns + ------- + OpenAPI: pydantic model + The OpenAPI schema as a pydantic model. + """ + + from aws_lambda_powertools.event_handler.openapi.compat import ( + GenerateJsonSchema, + get_compat_model_name_map, + get_definitions, + ) + from aws_lambda_powertools.event_handler.openapi.models import OpenAPI, PathItem, Server + from aws_lambda_powertools.event_handler.openapi.types import ( + COMPONENT_REF_TEMPLATE, + ) + + # Start with the bare minimum required for a valid OpenAPI schema + info: Dict[str, Any] = {"title": title, "version": version} + + optional_fields = { + "summary": summary, + "description": description, + "termsOfService": terms_of_service, + "contact": contact, + "license": license_info, + } + + info.update({field: value for field, value in optional_fields.items() if value}) + + output: Dict[str, Any] = {"openapi": openapi_version, "info": info} + if servers: + output["servers"] = servers + else: + # If the servers property is not provided, or is an empty array, the default value would be a Server Object + # with an url value of /. + output["servers"] = [Server(url="/")] + + components: Dict[str, Dict[str, Any]] = {} + paths: Dict[str, Dict[str, Any]] = {} + operation_ids: Set[str] = set() + + all_routes = self._dynamic_routes + self._static_routes + all_fields = self._get_fields_from_routes(all_routes) + model_name_map = get_compat_model_name_map(all_fields) + + # Collect all models and definitions + schema_generator = GenerateJsonSchema(ref_template=COMPONENT_REF_TEMPLATE) + field_mapping, definitions = get_definitions( + fields=all_fields, + schema_generator=schema_generator, + model_name_map=model_name_map, + ) + + # Add routes to the OpenAPI schema + for route in all_routes: + result = route._get_openapi_path( + dependant=route.dependant, + operation_ids=operation_ids, + model_name_map=model_name_map, + field_mapping=field_mapping, + ) + if result: + path, path_definitions = result + if path: + paths.setdefault(route.path, {}).update(path) + if path_definitions: + definitions.update(path_definitions) + + if definitions: + components["schemas"] = {k: definitions[k] for k in sorted(definitions)} + if components: + output["components"] = components + if tags: + output["tags"] = tags + + output["paths"] = {k: PathItem(**v) for k, v in paths.items()} + + return OpenAPI(**output) + + def get_openapi_json_schema( + self, + *, + title: str = "Powertools API", + version: str = "1.0.0", + openapi_version: str = "3.1.0", + summary: Optional[str] = None, + description: Optional[str] = None, + tags: Optional[List["Tag"]] = None, + servers: Optional[List["Server"]] = None, + terms_of_service: Optional[str] = None, + contact: Optional["Contact"] = None, + license_info: Optional["License"] = None, + ) -> str: + """ + Returns the OpenAPI schema as a JSON serializable dict + + Parameters + ---------- + title: str + The title of the application. + version: str + The version of the OpenAPI document (which is distinct from the OpenAPI Specification version or the API + openapi_version: str, default = "3.1.0" + The version of the OpenAPI Specification (which the document uses). + summary: str, optional + A short summary of what the application does. + description: str, optional + A verbose explanation of the application behavior. + tags: List[Tag], optional + A list of tags used by the specification with additional metadata. + servers: List[Server], optional + An array of Server Objects, which provide connectivity information to a target server. + terms_of_service: str, optional + A URL to the Terms of Service for the API. MUST be in the format of a URL. + contact: Contact, optional + The contact information for the exposed API. + license_info: + The license information for the exposed API. + + Returns + ------- + str + The OpenAPI schema as a JSON serializable dict. + """ + from aws_lambda_powertools.event_handler.openapi.compat import model_json + + return model_json( + self.get_openapi_schema( + title=title, + version=version, + openapi_version=openapi_version, + summary=summary, + description=description, + tags=tags, + servers=servers, + terms_of_service=terms_of_service, + contact=contact, + license_info=license_info, + ), + by_alias=True, + exclude_none=True, + indent=2, + ) + def route( self, rule: str, @@ -854,6 +1499,12 @@ def route( cors: Optional[bool] = None, compress: bool = False, cache_control: Optional[str] = None, + summary: Optional[str] = None, + description: Optional[str] = None, + responses: Optional[Dict[int, Dict[str, Any]]] = None, + response_description: str = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, + tags: Optional[List["Tag"]] = None, + operation_id: Optional[str] = None, middlewares: Optional[List[Callable[..., Any]]] = None, ): """Route decorator includes parameter `method`""" @@ -861,19 +1512,24 @@ def route( def register_resolver(func: Callable): methods = (method,) if isinstance(method, str) else method logger.debug(f"Adding route using rule {rule} and methods: {','.join((m.upper() for m in methods))}") - if cors is None: - cors_enabled = self._cors_enabled - else: - cors_enabled = cors + + cors_enabled = self._cors_enabled if cors is None else cors for item in methods: _route = Route( item, + rule, self._compile_regex(rule), func, cors_enabled, compress, cache_control, + summary, + description, + responses, + response_description, + tags, + operation_id, middlewares, ) @@ -886,13 +1542,8 @@ def register_resolver(func: Callable): else: self._static_routes.append(_route) - route_key = item + rule - if route_key in self._route_keys: - warnings.warn( - f"A route like this was already registered. method: '{item}' rule: '{rule}'", - stacklevel=2, - ) - self._route_keys.append(route_key) + self._create_route_key(item, rule) + if cors_enabled: logger.debug(f"Registering method {item.upper()} to Allow Methods in CORS") self._cors_methods.add(item.upper()) @@ -946,6 +1597,15 @@ def resolve(self, event, context) -> Dict[str, Any]: def __call__(self, event, context) -> Any: return self.resolve(event, context) + def _create_route_key(self, item: str, rule: str): + route_key = item + rule + if route_key in self._route_keys: + warnings.warn( + f"A route like this was already registered. method: '{item}' rule: '{rule}'", + stacklevel=2, + ) + self._route_keys.append(route_key) + @staticmethod def _has_debug(debug: Optional[bool] = None) -> bool: # It might have been explicitly switched off (debug=False) @@ -1104,7 +1764,7 @@ def _call_route(self, route: Route, route_arguments: Dict[str, str]) -> Response logger.exception(exc) if self._debug: # If the user has turned on debug mode, - # we'll let the original exception propagate so + # we'll let the original exception propagate, so # they get more information about what went wrong. return ResponseBuilder( Response( @@ -1229,6 +1889,36 @@ def include_router(self, router: "Router", prefix: Optional[str] = None) -> None # Still need to ignore for mypy checks or will cause failures (false-positive) self.route(*new_route, middlewares=middlewares)(func) # type: ignore + @staticmethod + def _get_fields_from_routes(routes: Sequence[Route]) -> List["ModelField"]: + """ + Returns a list of fields from the routes + """ + + from aws_lambda_powertools.event_handler.openapi.compat import ModelField + from aws_lambda_powertools.event_handler.openapi.dependant import ( + get_flat_params, + ) + + body_fields_from_routes: List["ModelField"] = [] + responses_from_routes: List["ModelField"] = [] + request_fields_from_routes: List["ModelField"] = [] + + for route in routes: + if route.body_field: + if not isinstance(route.body_field, ModelField): + raise AssertionError("A request body myst be a Pydantic Field") + body_fields_from_routes.append(route.body_field) + + params = get_flat_params(route.dependant) + request_fields_from_routes.extend(params) + + if route.dependant.return_param: + responses_from_routes.append(route.dependant.return_param) + + flat_models = list(responses_from_routes + request_fields_from_routes + body_fields_from_routes) + return flat_models + class Router(BaseRouter): """Router helper class to allow splitting ApiGatewayResolver into multiple files""" @@ -1246,13 +1936,31 @@ def route( cors: Optional[bool] = None, compress: bool = False, cache_control: Optional[str] = None, + summary: Optional[str] = None, + description: Optional[str] = None, + responses: Optional[Dict[int, Dict[str, Any]]] = None, + response_description: Optional[str] = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, + tags: Optional[List["Tag"]] = None, + operation_id: Optional[str] = None, middlewares: Optional[List[Callable[..., Any]]] = None, ): def register_route(func: Callable): # Convert methods to tuple. It needs to be hashable as its part of the self._routes dict key methods = (method,) if isinstance(method, str) else tuple(method) - route_key = (rule, methods, cors, compress, cache_control) + route_key = ( + rule, + methods, + cors, + compress, + cache_control, + summary, + description, + responses, + response_description, + tags, + operation_id, + ) # Collate Middleware for routes if middlewares is not None: @@ -1280,9 +1988,17 @@ def __init__( debug: Optional[bool] = None, serializer: Optional[Callable[[Dict], str]] = None, strip_prefixes: Optional[List[Union[str, Pattern]]] = None, + enable_validation: bool = False, ): """Amazon API Gateway REST and HTTP API v1 payload resolver""" - super().__init__(ProxyEventType.APIGatewayProxyEvent, cors, debug, serializer, strip_prefixes) + super().__init__( + ProxyEventType.APIGatewayProxyEvent, + cors, + debug, + serializer, + strip_prefixes, + enable_validation, + ) # override route to ignore trailing "/" in routes for REST API def route( @@ -1292,10 +2008,29 @@ def route( cors: Optional[bool] = None, compress: bool = False, cache_control: Optional[str] = None, + summary: Optional[str] = None, + description: Optional[str] = None, + responses: Optional[Dict[int, Dict[str, Any]]] = None, + response_description: str = _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION, + tags: Optional[List["Tag"]] = None, + operation_id: Optional[str] = None, middlewares: Optional[List[Callable[..., Any]]] = None, ): # NOTE: see #1552 for more context. - return super().route(rule.rstrip("/"), method, cors, compress, cache_control, middlewares) + return super().route( + rule.rstrip("/"), + method, + cors, + compress, + cache_control, + summary, + description, + responses, + response_description, + tags, + operation_id, + middlewares, + ) # Override _compile_regex to exclude trailing slashes for route resolution @staticmethod @@ -1312,9 +2047,17 @@ def __init__( debug: Optional[bool] = None, serializer: Optional[Callable[[Dict], str]] = None, strip_prefixes: Optional[List[Union[str, Pattern]]] = None, + enable_validation: bool = False, ): """Amazon API Gateway HTTP API v2 payload resolver""" - super().__init__(ProxyEventType.APIGatewayProxyEventV2, cors, debug, serializer, strip_prefixes) + super().__init__( + ProxyEventType.APIGatewayProxyEventV2, + cors, + debug, + serializer, + strip_prefixes, + enable_validation, + ) class ALBResolver(ApiGatewayResolver): @@ -1326,6 +2069,7 @@ def __init__( debug: Optional[bool] = None, serializer: Optional[Callable[[Dict], str]] = None, strip_prefixes: Optional[List[Union[str, Pattern]]] = None, + enable_validation: bool = False, ): """Amazon Application Load Balancer (ALB) resolver""" - super().__init__(ProxyEventType.ALBEvent, cors, debug, serializer, strip_prefixes) + super().__init__(ProxyEventType.ALBEvent, cors, debug, serializer, strip_prefixes, enable_validation) diff --git a/aws_lambda_powertools/event_handler/lambda_function_url.py b/aws_lambda_powertools/event_handler/lambda_function_url.py index 433a013ab0b..bacdc8549c7 100644 --- a/aws_lambda_powertools/event_handler/lambda_function_url.py +++ b/aws_lambda_powertools/event_handler/lambda_function_url.py @@ -52,5 +52,13 @@ def __init__( debug: Optional[bool] = None, serializer: Optional[Callable[[Dict], str]] = None, strip_prefixes: Optional[List[Union[str, Pattern]]] = None, + enable_validation: bool = False, ): - super().__init__(ProxyEventType.LambdaFunctionUrlEvent, cors, debug, serializer, strip_prefixes) + super().__init__( + ProxyEventType.LambdaFunctionUrlEvent, + cors, + debug, + serializer, + strip_prefixes, + enable_validation, + ) diff --git a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py new file mode 100644 index 00000000000..ea7b303bfa5 --- /dev/null +++ b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py @@ -0,0 +1,342 @@ +import dataclasses +import json +import logging +from copy import deepcopy +from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple + +from pydantic import BaseModel + +from aws_lambda_powertools.event_handler import Response +from aws_lambda_powertools.event_handler.api_gateway import Route +from aws_lambda_powertools.event_handler.middlewares import BaseMiddlewareHandler, NextMiddleware +from aws_lambda_powertools.event_handler.openapi.compat import ( + ModelField, + _model_dump, + _normalize_errors, + _regenerate_error_with_loc, + get_missing_field_error, +) +from aws_lambda_powertools.event_handler.openapi.encoders import jsonable_encoder +from aws_lambda_powertools.event_handler.openapi.exceptions import RequestValidationError +from aws_lambda_powertools.event_handler.openapi.params import Param +from aws_lambda_powertools.event_handler.openapi.types import IncEx +from aws_lambda_powertools.event_handler.types import EventHandlerInstance + +logger = logging.getLogger(__name__) + + +class OpenAPIValidationMiddleware(BaseMiddlewareHandler): + """ + OpenAPIValidationMiddleware is a middleware that validates the request against the OpenAPI schema defined by the + Lambda handler. It also validates the response against the OpenAPI schema defined by the Lambda handler. It + should not be used directly, but rather through the `enable_validation` parameter of the `ApiGatewayResolver`. + + Examples + -------- + + ```python + from typing import List + + from pydantic import BaseModel + + from aws_lambda_powertools.event_handler.api_gateway import ( + APIGatewayRestResolver, + ) + + class Todo(BaseModel): + name: str + + app = APIGatewayRestResolver(enable_validation=True) + + @app.get("/todos") + def get_todos(): List[Todo]: + return [Todo(name="hello world")] + ``` + """ + + def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) -> Response: + logger.debug("OpenAPIValidationMiddleware handler") + + route: Route = app.context["_route"] + + values: Dict[str, Any] = {} + errors: List[Any] = [] + + try: + # Process path values, which can be found on the route_args + path_values, path_errors = _request_params_to_args( + route.dependant.path_params, + app.context["_route_args"], + ) + + # Process query values + query_values, query_errors = _request_params_to_args( + route.dependant.query_params, + app.current_event.query_string_parameters or {}, + ) + + values.update(path_values) + values.update(query_values) + errors += path_errors + query_errors + + # Process the request body, if it exists + if route.dependant.body_params: + (body_values, body_errors) = _request_body_to_args( + required_params=route.dependant.body_params, + received_body=self._get_body(app), + ) + values.update(body_values) + errors.extend(body_errors) + + if errors: + # Raise the validation errors + raise RequestValidationError(_normalize_errors(errors)) + else: + # Re-write the route_args with the validated values, and call the next middleware + app.context["_route_args"] = values + response = next_middleware(app) + + # Process the response body if it exists + raw_response = jsonable_encoder(response.body) + + # Validate and serialize the response + return self._serialize_response(field=route.dependant.return_param, response_content=raw_response) + except RequestValidationError as e: + return Response( + status_code=422, + content_type="application/json", + body=json.dumps({"detail": e.errors()}), + ) + + def _serialize_response( + self, + *, + field: Optional[ModelField] = None, + response_content: Any, + include: Optional[IncEx] = None, + exclude: Optional[IncEx] = None, + by_alias: bool = True, + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = False, + ) -> Any: + """ + Serialize the response content according to the field type. + """ + if field: + errors: List[Dict[str, Any]] = [] + # MAINTENANCE: remove this when we drop pydantic v1 + if not hasattr(field, "serializable"): + response_content = self._prepare_response_content( + response_content, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + ) + + value = _validate_field(field=field, value=response_content, loc=("response",), existing_errors=errors) + if errors: + raise RequestValidationError(errors=_normalize_errors(errors), body=response_content) + + if hasattr(field, "serialize"): + return field.serialize( + value, + include=include, + exclude=exclude, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + ) + + return jsonable_encoder( + value, + include=include, + exclude=exclude, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + ) + else: + # Just serialize the response content returned from the handler + return jsonable_encoder(response_content) + + def _prepare_response_content( + self, + res: Any, + *, + exclude_unset: bool, + exclude_defaults: bool = False, + exclude_none: bool = False, + ) -> Any: + """ + Prepares the response content for serialization. + """ + if isinstance(res, BaseModel): + return _model_dump( + res, + by_alias=True, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + ) + elif isinstance(res, list): + return [ + self._prepare_response_content(item, exclude_unset=exclude_unset, exclude_defaults=exclude_defaults) + for item in res + ] + elif isinstance(res, dict): + return { + k: self._prepare_response_content(v, exclude_unset=exclude_unset, exclude_defaults=exclude_defaults) + for k, v in res.items() + } + elif dataclasses.is_dataclass(res): + return dataclasses.asdict(res) + return res + + def _get_body(self, app: EventHandlerInstance) -> Dict[str, Any]: + """ + Get the request body from the event, and parse it as JSON. + """ + + content_type_value = app.current_event.get_header_value("content-type") + if not content_type_value or content_type_value.startswith("application/json"): + try: + return app.current_event.json_body + except json.JSONDecodeError as e: + raise RequestValidationError( + [ + { + "type": "json_invalid", + "loc": ("body", e.pos), + "msg": "JSON decode error", + "input": {}, + "ctx": {"error": e.msg}, + }, + ], + body=e.doc, + ) from e + else: + raise NotImplementedError("Only JSON body is supported") + + +def _request_params_to_args( + required_params: Sequence[ModelField], + received_params: Mapping[str, Any], +) -> Tuple[Dict[str, Any], List[Any]]: + """ + Convert the request params to a dictionary of values using validation, and returns a list of errors. + """ + values = {} + errors = [] + + for field in required_params: + value = received_params.get(field.alias) + + field_info = field.field_info + if not isinstance(field_info, Param): + raise AssertionError(f"Expected Param field_info, got {field_info}") + + loc = (field_info.in_.value, field.alias) + + # If we don't have a value, see if it's required or has a default + if value is None: + if field.required: + errors.append(get_missing_field_error(loc=loc)) + else: + values[field.name] = deepcopy(field.default) + continue + + # Finally, validate the value + values[field.name] = _validate_field(field=field, value=value, loc=loc, existing_errors=errors) + + return values, errors + + +def _request_body_to_args( + required_params: List[ModelField], + received_body: Optional[Dict[str, Any]], +) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]: + """ + Convert the request body to a dictionary of values using validation, and returns a list of errors. + """ + values: Dict[str, Any] = {} + errors: List[Dict[str, Any]] = [] + + received_body, field_alias_omitted = _get_embed_body( + field=required_params[0], + required_params=required_params, + received_body=received_body, + ) + + for field in required_params: + # This sets the location to: + # { "user": { object } } if field.alias == user + # { { object } if field_alias is omitted + loc: Tuple[str, ...] = ("body", field.alias) + if field_alias_omitted: + loc = ("body",) + + value: Optional[Any] = None + + # Now that we know what to look for, try to get the value from the received body + if received_body is not None: + try: + value = received_body.get(field.alias) + except AttributeError: + errors.append(get_missing_field_error(loc)) + continue + + # Determine if the field is required + if value is None: + if field.required: + errors.append(get_missing_field_error(loc)) + else: + values[field.name] = deepcopy(field.default) + continue + + # MAINTENANCE: Handle byte and file fields + + # Finally, validate the value + values[field.name] = _validate_field(field=field, value=value, loc=loc, existing_errors=errors) + + return values, errors + + +def _validate_field( + *, + field: ModelField, + value: Any, + loc: Tuple[str, ...], + existing_errors: List[Dict[str, Any]], +): + """ + Validate a field, and append any errors to the existing_errors list. + """ + validated_value, errors = field.validate(value, value, loc=loc) + + if isinstance(errors, list): + processed_errors = _regenerate_error_with_loc(errors=errors, loc_prefix=()) + existing_errors.extend(processed_errors) + elif errors: + existing_errors.append(errors) + + return validated_value + + +def _get_embed_body( + *, + field: ModelField, + required_params: List[ModelField], + received_body: Optional[Dict[str, Any]], +) -> Tuple[Optional[Dict[str, Any]], bool]: + field_info = field.field_info + embed = getattr(field_info, "embed", None) + + # If the field is an embed, and the field alias is omitted, we need to wrap the received body in the field alias. + field_alias_omitted = len(required_params) == 1 and not embed + if field_alias_omitted: + received_body = {field.alias: received_body} + + return received_body, field_alias_omitted diff --git a/aws_lambda_powertools/event_handler/openapi/__init__.py b/aws_lambda_powertools/event_handler/openapi/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/aws_lambda_powertools/event_handler/openapi/compat.py b/aws_lambda_powertools/event_handler/openapi/compat.py new file mode 100644 index 00000000000..54b78f7e5f6 --- /dev/null +++ b/aws_lambda_powertools/event_handler/openapi/compat.py @@ -0,0 +1,497 @@ +# mypy: ignore-errors +# flake8: noqa +from collections import deque +from copy import copy + +# MAINTENANCE: remove when deprecating Pydantic v1. Mypy doesn't handle two different code paths that import different +# versions of a module, so we need to ignore errors here. + +from dataclasses import dataclass, is_dataclass +from enum import Enum +from typing import Any, Dict, List, Set, Tuple, Type, Union, FrozenSet, Deque, Sequence, Mapping + +from typing_extensions import Annotated, Literal, get_origin, get_args + +from pydantic import BaseModel, create_model +from pydantic.fields import FieldInfo + +from aws_lambda_powertools.event_handler.openapi.types import ( + COMPONENT_REF_PREFIX, + PYDANTIC_V2, + ModelNameMap, + UnionType, +) + +sequence_annotation_to_type = { + Sequence: list, + List: list, + list: list, + Tuple: tuple, + tuple: tuple, + Set: set, + set: set, + FrozenSet: frozenset, + frozenset: frozenset, + Deque: deque, + deque: deque, +} + +sequence_types = tuple(sequence_annotation_to_type.keys()) + +RequestErrorModel: Type[BaseModel] = create_model("Request") + +if PYDANTIC_V2: + from pydantic import TypeAdapter, ValidationError + from pydantic._internal._typing_extra import eval_type_lenient + from pydantic.fields import FieldInfo + from pydantic._internal._utils import lenient_issubclass + from pydantic.json_schema import GenerateJsonSchema, JsonSchemaValue + from pydantic_core import PydanticUndefined, PydanticUndefinedType + + from aws_lambda_powertools.event_handler.openapi.types import IncEx + + Undefined = PydanticUndefined + Required = PydanticUndefined + UndefinedType = PydanticUndefinedType + + evaluate_forwardref = eval_type_lenient + + class ErrorWrapper(Exception): + pass + + @dataclass + class ModelField: + field_info: FieldInfo + name: str + mode: Literal["validation", "serialization"] = "validation" + + @property + def alias(self) -> str: + value = self.field_info.alias + return value if value is not None else self.name + + @property + def required(self) -> bool: + return self.field_info.is_required() + + @property + def default(self) -> Any: + return self.get_default() + + @property + def type_(self) -> Any: + return self.field_info.annotation + + def __post_init__(self) -> None: + self._type_adapter: TypeAdapter[Any] = TypeAdapter( + Annotated[self.field_info.annotation, self.field_info], + ) + + def get_default(self) -> Any: + if self.field_info.is_required(): + return Undefined + return self.field_info.get_default(call_default_factory=True) + + def serialize( + self, + value: Any, + *, + mode: Literal["json", "python"] = "json", + include: Union[IncEx, None] = None, + exclude: Union[IncEx, None] = None, + by_alias: bool = True, + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = False, + ) -> Any: + return self._type_adapter.dump_python( + value, + mode=mode, + include=include, + exclude=exclude, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + ) + + def validate( + self, value: Any, values: Dict[str, Any] = {}, *, loc: Tuple[Union[int, str], ...] = () + ) -> Tuple[Any, Union[List[Dict[str, Any]], None]]: + try: + return (self._type_adapter.validate_python(value, from_attributes=True), None) + except ValidationError as exc: + return None, _regenerate_error_with_loc(errors=exc.errors(), loc_prefix=loc) + + def __hash__(self) -> int: + # Each ModelField is unique for our purposes + return id(self) + + def get_schema_from_model_field( + *, + field: ModelField, + model_name_map: ModelNameMap, + field_mapping: Dict[ + Tuple[ModelField, Literal["validation", "serialization"]], + JsonSchemaValue, + ], + ) -> Dict[str, Any]: + json_schema = field_mapping[(field, field.mode)] + if "$ref" not in json_schema: + # MAINTENANCE: remove when deprecating Pydantic v1 + # Ref: https://github.com/pydantic/pydantic/blob/d61792cc42c80b13b23e3ffa74bc37ec7c77f7d1/pydantic/schema.py#L207 + json_schema["title"] = field.field_info.title or field.alias.title().replace("_", " ") + return json_schema + + def get_definitions( + *, + fields: List[ModelField], + schema_generator: GenerateJsonSchema, + model_name_map: ModelNameMap, + ) -> Tuple[ + Dict[ + Tuple[ModelField, Literal["validation", "serialization"]], + Dict[str, Any], + ], + Dict[str, Dict[str, Any]], + ]: + inputs = [(field, field.mode, field._type_adapter.core_schema) for field in fields] + field_mapping, definitions = schema_generator.generate_definitions(inputs=inputs) + + return field_mapping, definitions + + def get_compat_model_name_map(fields: List[ModelField]) -> ModelNameMap: + return {} + + def get_annotation_from_field_info(annotation: Any, field_info: FieldInfo, field_name: str) -> Any: + return annotation + + def model_rebuild(model: Type[BaseModel]) -> None: + model.model_rebuild() + + def copy_field_info(*, field_info: FieldInfo, annotation: Any) -> FieldInfo: + return type(field_info).from_annotation(annotation) + + def get_missing_field_error(loc: Tuple[str, ...]) -> Dict[str, Any]: + error = ValidationError.from_exception_data( + "Field required", [{"type": "missing", "loc": loc, "input": {}}] + ).errors()[0] + error["input"] = None + return error + + def is_scalar_field(field: ModelField) -> bool: + from aws_lambda_powertools.event_handler.openapi.params import Body + + return field_annotation_is_scalar(field.field_info.annotation) and not isinstance(field.field_info, Body) + + def is_scalar_sequence_field(field: ModelField) -> bool: + return field_annotation_is_scalar_sequence(field.field_info.annotation) + + def is_sequence_field(field: ModelField) -> bool: + return field_annotation_is_sequence(field.field_info.annotation) + + def is_bytes_field(field: ModelField) -> bool: + return is_bytes_or_nonable_bytes_annotation(field.type_) + + def is_bytes_sequence_field(field: ModelField) -> bool: + return is_bytes_sequence_annotation(field.type_) + + def serialize_sequence_value(*, field: ModelField, value: Any) -> Sequence[Any]: + origin_type = get_origin(field.field_info.annotation) or field.field_info.annotation + if not issubclass(origin_type, sequence_types): # type: ignore[arg-type] + raise AssertionError(f"Expected sequence type, got {origin_type}") + return sequence_annotation_to_type[origin_type](value) # type: ignore[no-any-return] + + def _normalize_errors(errors: Sequence[Any]) -> List[Dict[str, Any]]: + return errors # type: ignore[return-value] + + def create_body_model(*, fields: Sequence[ModelField], model_name: str) -> Type[BaseModel]: + field_params = {f.name: (f.field_info.annotation, f.field_info) for f in fields} + model: Type[BaseModel] = create_model(model_name, **field_params) + return model + + def _model_dump(model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any) -> Any: + return model.model_dump(mode=mode, **kwargs) + + def model_json(model: BaseModel, **kwargs: Any) -> Any: + return model.model_dump_json(**kwargs) + +else: + from pydantic import BaseModel, ValidationError + from pydantic.fields import ( + ModelField, + Required, + Undefined, + UndefinedType, + SHAPE_LIST, + SHAPE_SET, + SHAPE_FROZENSET, + SHAPE_TUPLE, + SHAPE_SEQUENCE, + SHAPE_TUPLE_ELLIPSIS, + SHAPE_SINGLETON, + ) + from pydantic.schema import ( + field_schema, + get_annotation_from_field_info, + get_flat_models_from_fields, + get_model_name_map, + model_process_schema, + ) + from pydantic.errors import MissingError + from pydantic.error_wrappers import ErrorWrapper + from pydantic.utils import lenient_issubclass + from pydantic.typing import evaluate_forwardref + + JsonSchemaValue = Dict[str, Any] + + sequence_shapes = [ + SHAPE_LIST, + SHAPE_SET, + SHAPE_FROZENSET, + SHAPE_TUPLE, + SHAPE_SEQUENCE, + SHAPE_TUPLE_ELLIPSIS, + ] + sequence_shape_to_type = { + SHAPE_LIST: list, + SHAPE_SET: set, + SHAPE_TUPLE: tuple, + SHAPE_SEQUENCE: list, + SHAPE_TUPLE_ELLIPSIS: list, + } + + @dataclass + class GenerateJsonSchema: + ref_template: str + + def get_schema_from_model_field( + *, + field: ModelField, + model_name_map: ModelNameMap, + field_mapping: Dict[ + Tuple[ModelField, Literal["validation", "serialization"]], + JsonSchemaValue, + ], + ) -> Dict[str, Any]: + return field_schema( + field, + model_name_map=model_name_map, + ref_prefix=COMPONENT_REF_PREFIX, + )[0] + + def get_definitions( + *, + fields: List[ModelField], + schema_generator: GenerateJsonSchema, + model_name_map: ModelNameMap, + ) -> Tuple[ + Dict[Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue], + Dict[str, Dict[str, Any]], + ]: + models = get_flat_models_from_fields(fields, known_models=set()) + return {}, get_model_definitions(flat_models=models, model_name_map=model_name_map) + + def get_model_definitions( + *, + flat_models: Set[Union[Type[BaseModel], Type[Enum]]], + model_name_map: ModelNameMap, + ) -> Dict[str, Any]: + definitions: Dict[str, Dict[str, Any]] = {} + for model in flat_models: + m_schema, m_definitions, _ = model_process_schema( + model, + model_name_map=model_name_map, + ref_prefix=COMPONENT_REF_PREFIX, + ) + definitions.update(m_definitions) + model_name = model_name_map[model] + if "description" in m_schema: + m_schema["description"] = m_schema["description"].split("\f")[0] + definitions[model_name] = m_schema + return definitions + + def get_compat_model_name_map(fields: List[ModelField]) -> ModelNameMap: + models = get_flat_models_from_fields(fields, known_models=set()) + return get_model_name_map(models) + + def model_rebuild(model: Type[BaseModel]) -> None: + model.update_forward_refs() + + def copy_field_info(*, field_info: FieldInfo, annotation: Any) -> FieldInfo: + return copy(field_info) + + def is_pv1_scalar_field(field: ModelField) -> bool: + from aws_lambda_powertools.event_handler.openapi.params import Body + + if not ( + field.shape == SHAPE_SINGLETON + and not lenient_issubclass(field.type_, BaseModel) + and not lenient_issubclass(field.type_, dict) + and not field_annotation_is_sequence(field.type_) + and not is_dataclass(field.type_) + and not isinstance(field.field_info, Body) + ): + return False + + if field.sub_fields: + if not all(is_pv1_scalar_sequence_field(f) for f in field.sub_fields): + return False + + return True + + def is_pv1_scalar_sequence_field(field: ModelField) -> bool: + if (field.shape in sequence_shapes) and not lenient_issubclass(field.type_, BaseModel): + if field.sub_fields is not None: + for sub_field in field.sub_fields: + if not is_pv1_scalar_field(sub_field): + return False + return True + if _annotation_is_sequence(field.type_): + return True + return False + + def is_scalar_field(field: ModelField) -> bool: + return is_pv1_scalar_field(field) + + def is_scalar_sequence_field(field: ModelField) -> bool: + return is_pv1_scalar_sequence_field(field) + + def is_sequence_field(field: ModelField) -> bool: + return field.shape in sequence_shapes or _annotation_is_sequence(field.type_) + + def is_bytes_field(field: ModelField) -> bool: + return lenient_issubclass(field.type_, bytes) + + def is_bytes_sequence_field(field: ModelField) -> bool: + return field.shape in sequence_shapes and lenient_issubclass(field.type_, bytes) # type: ignore[attr-defined] + + def _annotation_is_sequence(annotation: Union[Type[Any], None]) -> bool: + if lenient_issubclass(annotation, (str, bytes)): + return False + return lenient_issubclass(annotation, sequence_types) + + def get_missing_field_error(loc: Tuple[str, ...]) -> Dict[str, Any]: + missing_field_error = ErrorWrapper(MissingError(), loc=loc) + new_error = ValidationError([missing_field_error], RequestErrorModel) + return new_error.errors()[0] + + def _normalize_errors(errors: Sequence[Any]) -> List[Dict[str, Any]]: + use_errors: List[Any] = [] + for error in errors: + if isinstance(error, ErrorWrapper): + new_errors = ValidationError(errors=[error], model=RequestErrorModel).errors() # type: ignore[call-arg] + use_errors.extend(new_errors) + elif isinstance(error, list): + use_errors.extend(_normalize_errors(error)) + else: + use_errors.append(error) + return use_errors + + def create_body_model(*, fields: Sequence[ModelField], model_name: str) -> Type[BaseModel]: + body_model = create_model(model_name) + for f in fields: + body_model.__fields__[f.name] = f # type: ignore[index] + return body_model + + def serialize_sequence_value(*, field: ModelField, value: Any) -> Sequence[Any]: + return sequence_shape_to_type[field.shape](value) + + def _model_dump(model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any) -> Any: + return model.dict(**kwargs) + + def model_json(model: BaseModel, **kwargs: Any) -> Any: + return model.json(**kwargs) + + +# Common code for both versions + + +def field_annotation_is_complex(annotation: Union[Type[Any], None]) -> bool: + origin = get_origin(annotation) + if origin is Union or origin is UnionType: + return any(field_annotation_is_complex(arg) for arg in get_args(annotation)) + + return ( + _annotation_is_complex(annotation) + or _annotation_is_complex(origin) + or hasattr(origin, "__pydantic_core_schema__") + or hasattr(origin, "__get_pydantic_core_schema__") + ) + + +def field_annotation_is_scalar(annotation: Any) -> bool: + return annotation is Ellipsis or not field_annotation_is_complex(annotation) + + +def field_annotation_is_sequence(annotation: Union[Type[Any], None]) -> bool: + return _annotation_is_sequence(annotation) or _annotation_is_sequence(get_origin(annotation)) + + +def field_annotation_is_scalar_sequence(annotation: Union[Type[Any], None]) -> bool: + origin = get_origin(annotation) + if origin is Union or origin is UnionType: + at_least_one_scalar_sequence = False + for arg in get_args(annotation): + if field_annotation_is_scalar_sequence(arg): + at_least_one_scalar_sequence = True + continue + elif not field_annotation_is_scalar(arg): + return False + return at_least_one_scalar_sequence + return field_annotation_is_sequence(annotation) and all( + field_annotation_is_scalar(sub_annotation) for sub_annotation in get_args(annotation) + ) + + +def is_bytes_or_nonable_bytes_annotation(annotation: Any) -> bool: + if lenient_issubclass(annotation, bytes): + return True + origin = get_origin(annotation) + if origin is Union or origin is UnionType: + for arg in get_args(annotation): + if lenient_issubclass(arg, bytes): + return True + return False + + +def is_bytes_sequence_annotation(annotation: Any) -> bool: + origin = get_origin(annotation) + if origin is Union or origin is UnionType: + at_least_one = False + for arg in get_args(annotation): + if is_bytes_sequence_annotation(arg): + at_least_one = True + break + return at_least_one + return field_annotation_is_sequence(annotation) and all( + is_bytes_or_nonable_bytes_annotation(sub_annotation) for sub_annotation in get_args(annotation) + ) + + +def value_is_sequence(value: Any) -> bool: + return isinstance(value, sequence_types) and not isinstance(value, (str, bytes)) # type: ignore[arg-type] + + +def _annotation_is_complex(annotation: Union[Type[Any], None]) -> bool: + return ( + lenient_issubclass(annotation, (BaseModel, Mapping)) # TODO: UploadFile + or _annotation_is_sequence(annotation) + or is_dataclass(annotation) + ) + + +def _annotation_is_sequence(annotation: Union[Type[Any], None]) -> bool: + if lenient_issubclass(annotation, (str, bytes)): + return False + return lenient_issubclass(annotation, sequence_types) + + +def _regenerate_error_with_loc( + *, errors: Sequence[Any], loc_prefix: Tuple[Union[str, int], ...] +) -> List[Dict[str, Any]]: + updated_loc_errors: List[Any] = [ + {**err, "loc": loc_prefix + err.get("loc", ())} for err in _normalize_errors(errors) + ] + + return updated_loc_errors diff --git a/aws_lambda_powertools/event_handler/openapi/dependant.py b/aws_lambda_powertools/event_handler/openapi/dependant.py new file mode 100644 index 00000000000..8cbb8b942ed --- /dev/null +++ b/aws_lambda_powertools/event_handler/openapi/dependant.py @@ -0,0 +1,342 @@ +import inspect +import re +from typing import Any, Callable, Dict, ForwardRef, List, Optional, Set, Tuple, Type, cast + +from pydantic import BaseModel + +from aws_lambda_powertools.event_handler.openapi.compat import ( + ModelField, + create_body_model, + evaluate_forwardref, + is_scalar_field, + is_scalar_sequence_field, +) +from aws_lambda_powertools.event_handler.openapi.params import ( + Body, + Dependant, + File, + Form, + Header, + Param, + ParamTypes, + Query, + analyze_param, + create_response_field, + get_flat_dependant, +) + +""" +This turns the opaque function signature into typed, validated models. + +It relies on Pydantic's typing and validation to achieve this in a declarative way. +This enables traits like autocompletion, validation, and declarative structure vs imperative parsing. + +This code parses an OpenAPI operation handler function signature into Pydantic models. It uses inspect to get the +signature and regex to parse path parameters. Each parameter is analyzed to extract its type annotation and generate +a corresponding Pydantic field, which are added to a Dependant model. Return values are handled similarly. + +This modeling allows for type checking, automatic parameter name/location/type extraction, and input validation - +turning the opaque signature into validated models. It relies on Pydantic's typing and validation for a declarative +approach over imperative parsing, enabling autocompletion, validation and structure. +""" + + +def add_param_to_fields( + *, + field: ModelField, + dependant: Dependant, +) -> None: + """ + Adds a parameter to the list of parameters in the dependant model. + + Parameters + ---------- + field: ModelField + The field to add + dependant: Dependant + The dependant model to add the field to + + """ + field_info = cast(Param, field.field_info) + if field_info.in_ == ParamTypes.path: + dependant.path_params.append(field) + elif field_info.in_ == ParamTypes.query: + dependant.query_params.append(field) + elif field_info.in_ == ParamTypes.header: + dependant.header_params.append(field) + else: + if field_info.in_ != ParamTypes.cookie: + raise AssertionError(f"Unsupported param type: {field_info.in_}") + dependant.cookie_params.append(field) + + +def get_typed_annotation(annotation: Any, globalns: Dict[str, Any]) -> Any: + """ + Evaluates a type annotation, which can be a string or a ForwardRef. + """ + if isinstance(annotation, str): + annotation = ForwardRef(annotation) + annotation = evaluate_forwardref(annotation, globalns, globalns) + return annotation + + +def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature: + """ + Returns a typed signature for a callable, resolving forward references. + + Parameters + ---------- + call: Callable[..., Any] + The callable to get the signature for + + Returns + ------- + inspect.Signature + The typed signature + """ + signature = inspect.signature(call) + + # Gets the global namespace for the call. This is used to resolve forward references. + globalns = getattr(call, "__global__", {}) + + typed_params = [ + inspect.Parameter( + name=param.name, + kind=param.kind, + default=param.default, + annotation=get_typed_annotation(param.annotation, globalns), + ) + for param in signature.parameters.values() + ] + + # If the return annotation is not empty, add it to the signature. + if signature.return_annotation is not inspect.Signature.empty: + return_param = inspect.Parameter( + name="Return", + kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, + default=None, + annotation=get_typed_annotation(signature.return_annotation, globalns), + ) + return inspect.Signature(typed_params, return_annotation=return_param.annotation) + else: + return inspect.Signature(typed_params) + + +def get_path_param_names(path: str) -> Set[str]: + """ + Returns the path parameter names from a path template. Those are the strings between < and >. + + Parameters + ---------- + path: str + The path template + + Returns + ------- + Set[str] + The path parameter names + + """ + return set(re.findall("<(.*?)>", path)) + + +def get_dependant( + *, + path: str, + call: Callable[..., Any], + name: Optional[str] = None, +) -> Dependant: + """ + Returns a dependant model for a handler function. A dependant model is a model that contains + the parameters and return value of a handler function. + + Parameters + ---------- + path: str + The path template + call: Callable[..., Any] + The handler function + name: str, optional + The name of the handler function + + Returns + ------- + Dependant + The dependant model for the handler function + """ + path_param_names = get_path_param_names(path) + endpoint_signature = get_typed_signature(call) + signature_params = endpoint_signature.parameters + + dependant = Dependant( + call=call, + name=name, + path=path, + ) + + # Add each parameter to the dependant model + for param_name, param in signature_params.items(): + # If the parameter is a path parameter, we need to set the in_ field to "path". + is_path_param = param_name in path_param_names + + # Analyze the parameter to get the Pydantic field. + param_field = analyze_param( + param_name=param_name, + annotation=param.annotation, + value=param.default, + is_path_param=is_path_param, + is_response_param=False, + ) + if param_field is None: + raise AssertionError(f"Parameter field is None for param: {param_name}") + + if is_body_param(param_field=param_field, is_path_param=is_path_param): + dependant.body_params.append(param_field) + else: + add_param_to_fields(field=param_field, dependant=dependant) + + # If the return annotation is not empty, add it to the dependant model. + return_annotation = endpoint_signature.return_annotation + if return_annotation is not inspect.Signature.empty: + param_field = analyze_param( + param_name="return", + annotation=return_annotation, + value=None, + is_path_param=False, + is_response_param=True, + ) + if param_field is None: + raise AssertionError("Param field is None for return annotation") + + dependant.return_param = param_field + + return dependant + + +def is_body_param(*, param_field: ModelField, is_path_param: bool) -> bool: + """ + Returns whether a parameter is a request body parameter, by checking if it is a scalar field or a body field. + + Parameters + ---------- + param_field: ModelField + The parameter field + is_path_param: bool + Whether the parameter is a path parameter + + Returns + ------- + bool + Whether the parameter is a request body parameter + """ + if is_path_param: + if not is_scalar_field(field=param_field): + raise AssertionError("Path params must be of one of the supported types") + return False + elif is_scalar_field(field=param_field): + return False + elif isinstance(param_field.field_info, (Query, Header)) and is_scalar_sequence_field(param_field): + return False + else: + if not isinstance(param_field.field_info, Body): + raise AssertionError(f"Param: {param_field.name} can only be a request body, use Body()") + return True + + +def get_flat_params(dependant: Dependant) -> List[ModelField]: + """ + Get a list of all the parameters from a Dependant object. + + Parameters + ---------- + dependant : Dependant + The Dependant object containing the parameters. + + Returns + ------- + List[ModelField] + A list of ModelField objects containing the flat parameters from the Dependant object. + + """ + flat_dependant = get_flat_dependant(dependant) + return ( + flat_dependant.path_params + + flat_dependant.query_params + + flat_dependant.header_params + + flat_dependant.cookie_params + ) + + +def get_body_field(*, dependant: Dependant, name: str) -> Optional[ModelField]: + """ + Get the Body field for a given Dependant object. + """ + + flat_dependant = get_flat_dependant(dependant) + if not flat_dependant.body_params: + return None + + first_param = flat_dependant.body_params[0] + field_info = first_param.field_info + + # Handle the case where there is only one body parameter and it is embedded + embed = getattr(field_info, "embed", None) + body_param_names_set = {param.name for param in flat_dependant.body_params} + if len(body_param_names_set) == 1 and not embed: + return first_param + + # If one field requires to embed, all have to be embedded + for param in flat_dependant.body_params: + setattr(param.field_info, "embed", True) # noqa: B010 + + # Generate a custom body model for this endpoint + model_name = "Body_" + name + body_model = create_body_model(fields=flat_dependant.body_params, model_name=model_name) + + required = any(True for f in flat_dependant.body_params if f.required) + + body_field_info, body_field_info_kwargs = get_body_field_info( + body_model=body_model, + flat_dependant=flat_dependant, + required=required, + ) + + final_field = create_response_field( + name="body", + type_=body_model, + required=required, + alias="body", + field_info=body_field_info(**body_field_info_kwargs), + ) + return final_field + + +def get_body_field_info( + *, + body_model: Type[BaseModel], + flat_dependant: Dependant, + required: bool, +) -> Tuple[Type[Body], Dict[str, Any]]: + """ + Get the Body field info and kwargs for a given body model. + """ + + body_field_info_kwargs: Dict[str, Any] = {"annotation": body_model, "alias": "body"} + + if not required: + body_field_info_kwargs["default"] = None + + if any(isinstance(f.field_info, File) for f in flat_dependant.body_params): + body_field_info: Type[Body] = File + elif any(isinstance(f.field_info, Form) for f in flat_dependant.body_params): + body_field_info = Form + else: + body_field_info = Body + + body_param_media_types = [ + f.field_info.media_type for f in flat_dependant.body_params if isinstance(f.field_info, Body) + ] + if len(set(body_param_media_types)) == 1: + body_field_info_kwargs["media_type"] = body_param_media_types[0] + + return body_field_info, body_field_info_kwargs diff --git a/aws_lambda_powertools/event_handler/openapi/encoders.py b/aws_lambda_powertools/event_handler/openapi/encoders.py new file mode 100644 index 00000000000..94c1cb5d659 --- /dev/null +++ b/aws_lambda_powertools/event_handler/openapi/encoders.py @@ -0,0 +1,344 @@ +import dataclasses +import datetime +from collections import defaultdict, deque +from decimal import Decimal +from enum import Enum +from pathlib import Path, PurePath +from re import Pattern +from types import GeneratorType +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union +from uuid import UUID + +from pydantic import BaseModel +from pydantic.color import Color +from pydantic.types import SecretBytes, SecretStr + +from aws_lambda_powertools.event_handler.openapi.compat import _model_dump +from aws_lambda_powertools.event_handler.openapi.types import IncEx + +""" +This module contains the encoders used by jsonable_encoder to convert Python objects to JSON serializable data types. +""" + + +def jsonable_encoder( # noqa: PLR0911 + obj: Any, + include: Optional[IncEx] = None, + exclude: Optional[IncEx] = None, + by_alias: bool = True, + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = False, +) -> Any: + """ + JSON encodes an arbitrary Python object into JSON serializable data types. + + This is a modified version of fastapi.encoders.jsonable_encoder that supports + encoding of pydantic.BaseModel objects. + + Parameters + ---------- + obj : Any + The object to encode + include : Optional[IncEx], optional + A set or dictionary of strings that specifies which properties should be included, by default None, + meaning everything is included + exclude : Optional[IncEx], optional + A set or dictionary of strings that specifies which properties should be excluded, by default None, + meaning nothing is excluded + by_alias : bool, optional + Whether field aliases should be respected, by default True + exclude_unset : bool, optional + Whether fields that are not set should be excluded, by default False + exclude_defaults : bool, optional + Whether fields that are equal to their default value (as specified in the model) should be excluded, + by default False + exclude_none : bool, optional + Whether fields that are equal to None should be excluded, by default False + + Returns + ------- + Any + The JSON serializable data types + """ + if include is not None and not isinstance(include, (set, dict)): + include = set(include) + if exclude is not None and not isinstance(exclude, (set, dict)): + exclude = set(exclude) + + # Pydantic models + if isinstance(obj, BaseModel): + return _dump_base_model( + obj=obj, + include=include, + exclude=exclude, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_none=exclude_none, + exclude_defaults=exclude_defaults, + ) + + # Dataclasses + if dataclasses.is_dataclass(obj): + obj_dict = dataclasses.asdict(obj) + return jsonable_encoder( + obj_dict, + include=include, + exclude=exclude, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + ) + + # Enums + if isinstance(obj, Enum): + return obj.value + + # Paths + if isinstance(obj, PurePath): + return str(obj) + + # Scalars + if isinstance(obj, (str, int, float, type(None))): + return obj + + # Dictionaries + if isinstance(obj, dict): + return _dump_dict( + obj=obj, + include=include, + exclude=exclude, + by_alias=by_alias, + exclude_none=exclude_none, + exclude_unset=exclude_unset, + ) + + # Sequences + if isinstance(obj, (list, set, frozenset, GeneratorType, tuple, deque)): + return _dump_sequence( + obj=obj, + include=include, + exclude=exclude, + by_alias=by_alias, + exclude_none=exclude_none, + exclude_defaults=exclude_defaults, + exclude_unset=exclude_unset, + ) + + # Other types + if type(obj) in ENCODERS_BY_TYPE: + return ENCODERS_BY_TYPE[type(obj)](obj) + + for encoder, classes_tuple in encoders_by_class_tuples.items(): + if isinstance(obj, classes_tuple): + return encoder(obj) + + # Default + return _dump_other( + obj=obj, + include=include, + exclude=exclude, + by_alias=by_alias, + exclude_none=exclude_none, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + ) + + +def _dump_base_model( + *, + obj: Any, + include: Optional[IncEx] = None, + exclude: Optional[IncEx] = None, + by_alias: bool = True, + exclude_unset: bool = False, + exclude_none: bool = False, + exclude_defaults: bool = False, +): + """ + Dump a BaseModel object to a dict, using the same parameters as jsonable_encoder + """ + obj_dict = _model_dump( + obj, + mode="json", + include=include, + exclude=exclude, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_none=exclude_none, + exclude_defaults=exclude_defaults, + ) + if "__root__" in obj_dict: + obj_dict = obj_dict["__root__"] + + return jsonable_encoder( + obj_dict, + exclude_none=exclude_none, + exclude_defaults=exclude_defaults, + ) + + +def _dump_dict( + *, + obj: Any, + include: Optional[IncEx] = None, + exclude: Optional[IncEx] = None, + by_alias: bool = True, + exclude_unset: bool = False, + exclude_none: bool = False, +) -> Dict[str, Any]: + """ + Dump a dict to a dict, using the same parameters as jsonable_encoder + """ + encoded_dict = {} + allowed_keys = set(obj.keys()) + if include is not None: + allowed_keys &= set(include) + if exclude is not None: + allowed_keys -= set(exclude) + for key, value in obj.items(): + if ( + (not isinstance(key, str) or not key.startswith("_sa")) + and (value is not None or not exclude_none) + and key in allowed_keys + ): + encoded_key = jsonable_encoder( + key, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_none=exclude_none, + ) + encoded_value = jsonable_encoder( + value, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_none=exclude_none, + ) + encoded_dict[encoded_key] = encoded_value + return encoded_dict + + +def _dump_sequence( + *, + obj: Any, + include: Optional[IncEx] = None, + exclude: Optional[IncEx] = None, + by_alias: bool = True, + exclude_unset: bool = False, + exclude_none: bool = False, + exclude_defaults: bool = False, +) -> List[Any]: + """ + Dump a sequence to a list, using the same parameters as jsonable_encoder + """ + encoded_list = [] + for item in obj: + encoded_list.append( + jsonable_encoder( + item, + include=include, + exclude=exclude, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + ), + ) + return encoded_list + + +def _dump_other( + *, + obj: Any, + include: Optional[IncEx] = None, + exclude: Optional[IncEx] = None, + by_alias: bool = True, + exclude_unset: bool = False, + exclude_none: bool = False, + exclude_defaults: bool = False, +) -> Any: + """ + Dump an object to ah hashable object, using the same parameters as jsonable_encoder + """ + try: + data = dict(obj) + except Exception as e: + errors: List[Exception] = [e] + try: + data = vars(obj) + except Exception as e: + errors.append(e) + raise ValueError(errors) from e + return jsonable_encoder( + data, + include=include, + exclude=exclude, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + ) + + +def iso_format(o: Union[datetime.date, datetime.time]) -> str: + """ + ISO format for date and time + """ + return o.isoformat() + + +def decimal_encoder(dec_value: Decimal) -> Union[int, float]: + """ + Encodes a Decimal as int of there's no exponent, otherwise float + + This is useful when we use ConstrainedDecimal to represent Numeric(x,0) + where an integer (but not int typed) is used. Encoding this as a float + results in failed round-tripping between encode and parse. + + >>> decimal_encoder(Decimal("1.0")) + 1.0 + + >>> decimal_encoder(Decimal("1")) + 1 + """ + if dec_value.as_tuple().exponent >= 0: # type: ignore[operator] + return int(dec_value) + else: + return float(dec_value) + + +# Encoders for types that are not JSON serializable +ENCODERS_BY_TYPE: Dict[Type[Any], Callable[[Any], Any]] = { + bytes: lambda o: o.decode(), + Color: str, + datetime.date: iso_format, + datetime.datetime: iso_format, + datetime.time: iso_format, + datetime.timedelta: lambda td: td.total_seconds(), + Decimal: decimal_encoder, + Enum: lambda o: o.value, + frozenset: list, + deque: list, + GeneratorType: list, + Path: str, + Pattern: lambda o: o.pattern, + SecretBytes: str, + SecretStr: str, + set: list, + UUID: str, +} + + +# Generates a mapping of encoders to a tuple of classes that they can encode +def generate_encoders_by_class_tuples( + type_encoder_map: Dict[Any, Callable[[Any], Any]], +) -> Dict[Callable[[Any], Any], Tuple[Any, ...]]: + encoders: Dict[Callable[[Any], Any], Tuple[Any, ...]] = defaultdict(tuple) + for type_, encoder in type_encoder_map.items(): + encoders[encoder] += (type_,) + return encoders + + +# Mapping of encoders to a tuple of classes that they can encode +encoders_by_class_tuples = generate_encoders_by_class_tuples(ENCODERS_BY_TYPE) diff --git a/aws_lambda_powertools/event_handler/openapi/exceptions.py b/aws_lambda_powertools/event_handler/openapi/exceptions.py new file mode 100644 index 00000000000..fdd829ba9b1 --- /dev/null +++ b/aws_lambda_powertools/event_handler/openapi/exceptions.py @@ -0,0 +1,23 @@ +from typing import Any, Sequence + + +class ValidationException(Exception): + """ + Base exception for all validation errors + """ + + def __init__(self, errors: Sequence[Any]) -> None: + self._errors = errors + + def errors(self) -> Sequence[Any]: + return self._errors + + +class RequestValidationError(ValidationException): + """ + Raised when the request body does not match the OpenAPI schema + """ + + def __init__(self, errors: Sequence[Any], *, body: Any = None) -> None: + super().__init__(errors) + self.body = body diff --git a/aws_lambda_powertools/event_handler/openapi/models.py b/aws_lambda_powertools/event_handler/openapi/models.py new file mode 100644 index 00000000000..80818315f18 --- /dev/null +++ b/aws_lambda_powertools/event_handler/openapi/models.py @@ -0,0 +1,583 @@ +from enum import Enum +from typing import Any, Dict, List, Optional, Set, Union + +from pydantic import AnyUrl, BaseModel, Field + +from aws_lambda_powertools.event_handler.openapi.compat import model_rebuild +from aws_lambda_powertools.event_handler.openapi.types import PYDANTIC_V2 +from aws_lambda_powertools.shared.types import Annotated, Literal + +""" +The code defines Pydantic models for the various OpenAPI objects like OpenAPI, PathItem, Operation, Parameter etc. +These models can be used to parse OpenAPI JSON/YAML files into Python objects, or generate OpenAPI from Python data. +""" + + +# https://swagger.io/specification/#contact-object +class Contact(BaseModel): + name: Optional[str] = None + url: Optional[AnyUrl] = None + email: Optional[str] = None + + if PYDANTIC_V2: + model_config = {"extra": "allow"} + else: + + class Config: + extra = "allow" + + +# https://swagger.io/specification/#license-object +class License(BaseModel): + name: str + identifier: Optional[str] = None + url: Optional[AnyUrl] = None + + if PYDANTIC_V2: + model_config = {"extra": "allow"} + + else: + + class Config: + extra = "allow" + + +# https://swagger.io/specification/#info-object +class Info(BaseModel): + title: str + summary: Optional[str] = None + description: Optional[str] = None + termsOfService: Optional[str] = None + contact: Optional[Contact] = None + license: Optional[License] = None # noqa: A003 + version: str + + if PYDANTIC_V2: + model_config = {"extra": "allow"} + + else: + + class Config: + extra = "allow" + + +# https://swagger.io/specification/#server-variable-object +class ServerVariable(BaseModel): + enum: Annotated[Optional[List[str]], Field(min_length=1)] = None + default: str + description: Optional[str] = None + + if PYDANTIC_V2: + model_config = {"extra": "allow"} + + else: + + class Config: + extra = "allow" + + +# https://swagger.io/specification/#server-object +class Server(BaseModel): + url: Union[AnyUrl, str] + description: Optional[str] = None + variables: Optional[Dict[str, ServerVariable]] = None + + if PYDANTIC_V2: + model_config = {"extra": "allow"} + + else: + + class Config: + extra = "allow" + + +# https://swagger.io/specification/#reference-object +class Reference(BaseModel): + ref: str = Field(alias="$ref") + + +# https://swagger.io/specification/#discriminator-object +class Discriminator(BaseModel): + propertyName: str + mapping: Optional[Dict[str, str]] = None + + +# https://swagger.io/specification/#xml-object +class XML(BaseModel): + name: Optional[str] = None + namespace: Optional[str] = None + prefix: Optional[str] = None + attribute: Optional[bool] = None + wrapped: Optional[bool] = None + + if PYDANTIC_V2: + model_config = {"extra": "allow"} + + else: + + class Config: + extra = "allow" + + +# https://swagger.io/specification/#external-documentation-object +class ExternalDocumentation(BaseModel): + description: Optional[str] = None + url: AnyUrl + + if PYDANTIC_V2: + model_config = {"extra": "allow"} + + else: + + class Config: + extra = "allow" + + +# https://swagger.io/specification/#schema-object +class Schema(BaseModel): + # Ref: JSON Schema 2020-12: https://json-schema.org/draft/2020-12/json-schema-core.html#name-the-json-schema-core-vocabu + # Core Vocabulary + schema_: Optional[str] = Field(default=None, alias="$schema") + vocabulary: Optional[str] = Field(default=None, alias="$vocabulary") + id: Optional[str] = Field(default=None, alias="$id") # noqa: A003 + anchor: Optional[str] = Field(default=None, alias="$anchor") + dynamicAnchor: Optional[str] = Field(default=None, alias="$dynamicAnchor") + ref: Optional[str] = Field(default=None, alias="$ref") + dynamicRef: Optional[str] = Field(default=None, alias="$dynamicRef") + defs: Optional[Dict[str, "SchemaOrBool"]] = Field(default=None, alias="$defs") + comment: Optional[str] = Field(default=None, alias="$comment") + # Ref: JSON Schema 2020-12: https://json-schema.org/draft/2020-12/json-schema-core.html#name-a-vocabulary-for-applying-s + # A Vocabulary for Applying Subschemas + allOf: Optional[List["SchemaOrBool"]] = None + anyOf: Optional[List["SchemaOrBool"]] = None + oneOf: Optional[List["SchemaOrBool"]] = None + not_: Optional["SchemaOrBool"] = Field(default=None, alias="not") + if_: Optional["SchemaOrBool"] = Field(default=None, alias="if") + then: Optional["SchemaOrBool"] = None + else_: Optional["SchemaOrBool"] = Field(default=None, alias="else") + dependentSchemas: Optional[Dict[str, "SchemaOrBool"]] = None + prefixItems: Optional[List["SchemaOrBool"]] = None + # MAINTENANCE: uncomment and remove below when deprecating Pydantic v1 + # MAINTENANCE: It generates a list of schemas for tuples, before prefixItems was available + # MAINTENANCE: items: Optional["SchemaOrBool"] = None + items: Optional[Union["SchemaOrBool", List["SchemaOrBool"]]] = None + contains: Optional["SchemaOrBool"] = None + properties: Optional[Dict[str, "SchemaOrBool"]] = None + patternProperties: Optional[Dict[str, "SchemaOrBool"]] = None + additionalProperties: Optional["SchemaOrBool"] = None + propertyNames: Optional["SchemaOrBool"] = None + unevaluatedItems: Optional["SchemaOrBool"] = None + unevaluatedProperties: Optional["SchemaOrBool"] = None + # Ref: JSON Schema Validation 2020-12: https://json-schema.org/draft/2020-12/json-schema-validation.html#name-a-vocabulary-for-structural + # A Vocabulary for Structural Validation + type: Optional[str] = None # noqa: A003 + enum: Optional[List[Any]] = None + const: Optional[Any] = None + multipleOf: Optional[float] = Field(default=None, gt=0) + maximum: Optional[float] = None + exclusiveMaximum: Optional[float] = None + minimum: Optional[float] = None + exclusiveMinimum: Optional[float] = None + maxLength: Optional[int] = Field(default=None, ge=0) + minLength: Optional[int] = Field(default=None, ge=0) + pattern: Optional[str] = None + maxItems: Optional[int] = Field(default=None, ge=0) + minItems: Optional[int] = Field(default=None, ge=0) + uniqueItems: Optional[bool] = None + maxContains: Optional[int] = Field(default=None, ge=0) + minContains: Optional[int] = Field(default=None, ge=0) + maxProperties: Optional[int] = Field(default=None, ge=0) + minProperties: Optional[int] = Field(default=None, ge=0) + required: Optional[List[str]] = None + dependentRequired: Optional[Dict[str, Set[str]]] = None + # Ref: JSON Schema Validation 2020-12: https://json-schema.org/draft/2020-12/json-schema-validation.html#name-vocabularies-for-semantic-c + # Vocabularies for Semantic Content With "format" + format: Optional[str] = None # noqa: A003 + # Ref: JSON Schema Validation 2020-12: https://json-schema.org/draft/2020-12/json-schema-validation.html#name-a-vocabulary-for-the-conten + # A Vocabulary for the Contents of String-Encoded Data + contentEncoding: Optional[str] = None + contentMediaType: Optional[str] = None + contentSchema: Optional["SchemaOrBool"] = None + # Ref: JSON Schema Validation 2020-12: https://json-schema.org/draft/2020-12/json-schema-validation.html#name-a-vocabulary-for-basic-meta + # A Vocabulary for Basic Meta-Data Annotations + title: Optional[str] = None + description: Optional[str] = None + default: Optional[Any] = None + deprecated: Optional[bool] = None + readOnly: Optional[bool] = None + writeOnly: Optional[bool] = None + examples: Optional[List["Example"]] = None + # Ref: OpenAPI 3.1.0: https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.1.0.md#schema-object + # Schema Object + discriminator: Optional[Discriminator] = None + xml: Optional[XML] = None + externalDocs: Optional[ExternalDocumentation] = None + + if PYDANTIC_V2: + model_config = {"extra": "allow"} + + else: + + class Config: + extra = "allow" + + +# Ref: https://json-schema.org/draft/2020-12/json-schema-core.html#name-json-schema-documents +# A JSON Schema MUST be an object or a boolean. +SchemaOrBool = Union[Schema, bool] + + +# https://swagger.io/specification/#example-object +class Example(BaseModel): + summary: Optional[str] = None + description: Optional[str] = None + value: Optional[Any] = None + externalValue: Optional[AnyUrl] = None + + if PYDANTIC_V2: + model_config = {"extra": "allow"} + + else: + + class Config: + extra = "allow" + + +class ParameterInType(Enum): + query = "query" + header = "header" + path = "path" + cookie = "cookie" + + +# https://swagger.io/specification/#encoding-object +class Encoding(BaseModel): + contentType: Optional[str] = None + headers: Optional[Dict[str, Union["Header", Reference]]] = None + style: Optional[str] = None + explode: Optional[bool] = None + allowReserved: Optional[bool] = None + + if PYDANTIC_V2: + model_config = {"extra": "allow"} + + else: + + class Config: + extra = "allow" + + +# https://swagger.io/specification/#media-type-object +class MediaType(BaseModel): + schema_: Optional[Union[Schema, Reference]] = Field(default=None, alias="schema") + examples: Optional[Dict[str, Union[Example, Reference]]] = None + encoding: Optional[Dict[str, Encoding]] = None + + if PYDANTIC_V2: + model_config = {"extra": "allow"} + + else: + + class Config: + extra = "allow" + + +# https://swagger.io/specification/#parameter-object +class ParameterBase(BaseModel): + description: Optional[str] = None + required: Optional[bool] = None + deprecated: Optional[bool] = None + # Serialization rules for simple scenarios + style: Optional[str] = None + explode: Optional[bool] = None + allowReserved: Optional[bool] = None + schema_: Optional[Union[Schema, Reference]] = Field(default=None, alias="schema") + examples: Optional[Dict[str, Union[Example, Reference]]] = None + # Serialization rules for more complex scenarios + content: Optional[Dict[str, MediaType]] = None + + if PYDANTIC_V2: + model_config = {"extra": "allow"} + + else: + + class Config: + extra = "allow" + + +class Parameter(ParameterBase): + name: str + in_: ParameterInType = Field(alias="in") + + +class Header(ParameterBase): + pass + + +# https://swagger.io/specification/#request-body-object +class RequestBody(BaseModel): + description: Optional[str] = None + content: Dict[str, MediaType] + required: Optional[bool] = None + + if PYDANTIC_V2: + model_config = {"extra": "allow"} + + else: + + class Config: + extra = "allow" + + +# https://swagger.io/specification/#link-object +class Link(BaseModel): + operationRef: Optional[str] = None + operationId: Optional[str] = None + parameters: Optional[Dict[str, Union[Any, str]]] = None + requestBody: Optional[Union[Any, str]] = None + description: Optional[str] = None + server: Optional[Server] = None + + if PYDANTIC_V2: + model_config = {"extra": "allow"} + + else: + + class Config: + extra = "allow" + + +# https://swagger.io/specification/#response-object +class Response(BaseModel): + description: str + headers: Optional[Dict[str, Union[Header, Reference]]] = None + content: Optional[Dict[str, MediaType]] = None + links: Optional[Dict[str, Union[Link, Reference]]] = None + + if PYDANTIC_V2: + model_config = {"extra": "allow"} + + else: + + class Config: + extra = "allow" + + +# https://swagger.io/specification/#operation-object +class Operation(BaseModel): + tags: Optional[List[str]] = None + summary: Optional[str] = None + description: Optional[str] = None + externalDocs: Optional[ExternalDocumentation] = None + operationId: Optional[str] = None + parameters: Optional[List[Union[Parameter, Reference]]] = None + requestBody: Optional[Union[RequestBody, Reference]] = None + # Using Any for Specification Extensions + responses: Optional[Dict[int, Union[Response, Any]]] = None + callbacks: Optional[Dict[str, Union[Dict[str, "PathItem"], Reference]]] = None + deprecated: Optional[bool] = None + security: Optional[List[Dict[str, List[str]]]] = None + servers: Optional[List[Server]] = None + + if PYDANTIC_V2: + model_config = {"extra": "allow"} + + else: + + class Config: + extra = "allow" + + +# https://swagger.io/specification/#path-item-object +class PathItem(BaseModel): + ref: Optional[str] = Field(default=None, alias="$ref") + summary: Optional[str] = None + description: Optional[str] = None + get: Optional[Operation] = None + put: Optional[Operation] = None + post: Optional[Operation] = None + delete: Optional[Operation] = None + options: Optional[Operation] = None + head: Optional[Operation] = None + patch: Optional[Operation] = None + trace: Optional[Operation] = None + servers: Optional[List[Server]] = None + parameters: Optional[List[Union[Parameter, Reference]]] = None + + if PYDANTIC_V2: + model_config = {"extra": "allow"} + + else: + + class Config: + extra = "allow" + + +# https://swagger.io/specification/#security-scheme-object +class SecuritySchemeType(Enum): + apiKey = "apiKey" + http = "http" + oauth2 = "oauth2" + openIdConnect = "openIdConnect" + + +class SecurityBase(BaseModel): + type_: SecuritySchemeType = Field(alias="type") + description: Optional[str] = None + + if PYDANTIC_V2: + model_config = {"extra": "allow"} + + else: + + class Config: + extra = "allow" + + +class APIKeyIn(Enum): + query = "query" + header = "header" + cookie = "cookie" + + +class APIKey(SecurityBase): + type_: SecuritySchemeType = Field(default=SecuritySchemeType.apiKey, alias="type") + in_: APIKeyIn = Field(alias="in") + name: str + + +class HTTPBase(SecurityBase): + type_: SecuritySchemeType = Field(default=SecuritySchemeType.http, alias="type") + scheme: str + + +class HTTPBearer(HTTPBase): + scheme: Literal["bearer"] = "bearer" + bearerFormat: Optional[str] = None + + +class OAuthFlow(BaseModel): + refreshUrl: Optional[str] = None + scopes: Dict[str, str] = {} + + if PYDANTIC_V2: + model_config = {"extra": "allow"} + + else: + + class Config: + extra = "allow" + + +class OAuthFlowImplicit(OAuthFlow): + authorizationUrl: str + + +class OAuthFlowPassword(OAuthFlow): + tokenUrl: str + + +class OAuthFlowClientCredentials(OAuthFlow): + tokenUrl: str + + +class OAuthFlowAuthorizationCode(OAuthFlow): + authorizationUrl: str + tokenUrl: str + + +class OAuthFlows(BaseModel): + implicit: Optional[OAuthFlowImplicit] = None + password: Optional[OAuthFlowPassword] = None + clientCredentials: Optional[OAuthFlowClientCredentials] = None + authorizationCode: Optional[OAuthFlowAuthorizationCode] = None + + if PYDANTIC_V2: + model_config = {"extra": "allow"} + + else: + + class Config: + extra = "allow" + + +class OAuth2(SecurityBase): + type_: SecuritySchemeType = Field(default=SecuritySchemeType.oauth2, alias="type") + flows: OAuthFlows + + +class OpenIdConnect(SecurityBase): + type_: SecuritySchemeType = Field( + default=SecuritySchemeType.openIdConnect, + alias="type", + ) + openIdConnectUrl: str + + +SecurityScheme = Union[APIKey, HTTPBase, OAuth2, OpenIdConnect, HTTPBearer] + + +# https://swagger.io/specification/#components-object +class Components(BaseModel): + schemas: Optional[Dict[str, Union[Schema, Reference]]] = None + responses: Optional[Dict[str, Union[Response, Reference]]] = None + parameters: Optional[Dict[str, Union[Parameter, Reference]]] = None + examples: Optional[Dict[str, Union[Example, Reference]]] = None + requestBodies: Optional[Dict[str, Union[RequestBody, Reference]]] = None + headers: Optional[Dict[str, Union[Header, Reference]]] = None + securitySchemes: Optional[Dict[str, Union[SecurityScheme, Reference]]] = None + links: Optional[Dict[str, Union[Link, Reference]]] = None + # Using Any for Specification Extensions + callbacks: Optional[Dict[str, Union[Dict[str, PathItem], Reference, Any]]] = None + pathItems: Optional[Dict[str, Union[PathItem, Reference]]] = None + + if PYDANTIC_V2: + model_config = {"extra": "allow"} + + else: + + class Config: + extra = "allow" + + +# https://swagger.io/specification/#tag-object +class Tag(BaseModel): + name: str + description: Optional[str] = None + externalDocs: Optional[ExternalDocumentation] = None + + if PYDANTIC_V2: + model_config = {"extra": "allow"} + + else: + + class Config: + extra = "allow" + + +# https://swagger.io/specification/#openapi-object +class OpenAPI(BaseModel): + openapi: str + info: Info + jsonSchemaDialect: Optional[str] = None + servers: Optional[List[Server]] = None + # Using Any for Specification Extensions + paths: Optional[Dict[str, Union[PathItem, Any]]] = None + webhooks: Optional[Dict[str, Union[PathItem, Reference]]] = None + components: Optional[Components] = None + security: Optional[List[Dict[str, List[str]]]] = None + tags: Optional[List[Tag]] = None + externalDocs: Optional[ExternalDocumentation] = None + + if PYDANTIC_V2: + model_config = {"extra": "allow"} + + else: + + class Config: + extra = "allow" + + +model_rebuild(Schema) +model_rebuild(Operation) +model_rebuild(Encoding) diff --git a/aws_lambda_powertools/event_handler/openapi/params.py b/aws_lambda_powertools/event_handler/openapi/params.py new file mode 100644 index 00000000000..797b44f6232 --- /dev/null +++ b/aws_lambda_powertools/event_handler/openapi/params.py @@ -0,0 +1,841 @@ +import inspect +from enum import Enum +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union + +from pydantic import BaseConfig +from pydantic.fields import FieldInfo + +from aws_lambda_powertools.event_handler.openapi.compat import ( + ModelField, + Required, + Undefined, + UndefinedType, + copy_field_info, + field_annotation_is_scalar, + get_annotation_from_field_info, +) +from aws_lambda_powertools.event_handler.openapi.types import PYDANTIC_V2, CacheKey +from aws_lambda_powertools.shared.types import Annotated, Literal, get_args, get_origin + +""" +This turns the low-level function signature into typed, validated Pydantic models for consumption. +""" + + +class ParamTypes(Enum): + query = "query" + header = "header" + path = "path" + cookie = "cookie" + + +# MAINTENANCE: update when deprecating Pydantic v1, remove this alias +_Unset: Any = Undefined + + +class Dependant: + """ + A class used internally to represent a dependency between path operation decorators and the path operation function. + """ + + def __init__( + self, + *, + path_params: Optional[List[ModelField]] = None, + query_params: Optional[List[ModelField]] = None, + header_params: Optional[List[ModelField]] = None, + cookie_params: Optional[List[ModelField]] = None, + body_params: Optional[List[ModelField]] = None, + return_param: Optional[ModelField] = None, + name: Optional[str] = None, + call: Optional[Callable[..., Any]] = None, + request_param_name: Optional[str] = None, + websocket_param_name: Optional[str] = None, + http_connection_param_name: Optional[str] = None, + response_param_name: Optional[str] = None, + background_tasks_param_name: Optional[str] = None, + path: Optional[str] = None, + ) -> None: + self.path_params = path_params or [] + self.query_params = query_params or [] + self.header_params = header_params or [] + self.cookie_params = cookie_params or [] + self.body_params = body_params or [] + self.return_param = return_param or None + self.request_param_name = request_param_name + self.websocket_param_name = websocket_param_name + self.http_connection_param_name = http_connection_param_name + self.response_param_name = response_param_name + self.background_tasks_param_name = background_tasks_param_name + self.name = name + self.call = call + # Store the path to be able to re-generate a dependable from it in overrides + self.path = path + # Save the cache key at creation to optimize performance + self.cache_key: CacheKey = self.call + + +class Param(FieldInfo): + """ + A class used internally to represent a parameter in a path operation. + """ + + in_: ParamTypes + + def __init__( + self, + default: Any = Undefined, + *, + default_factory: Union[Callable[[], Any], None] = _Unset, + annotation: Optional[Any] = None, + alias: Optional[str] = None, + alias_priority: Union[int, None] = _Unset, + # MAINTENANCE: update when deprecating Pydantic v1, import these types + # MAINTENANCE: validation_alias: str | AliasPath | AliasChoices | None + validation_alias: Union[str, None] = None, + serialization_alias: Union[str, None] = None, + title: Optional[str] = None, + description: Optional[str] = None, + gt: Optional[float] = None, + ge: Optional[float] = None, + lt: Optional[float] = None, + le: Optional[float] = None, + min_length: Optional[int] = None, + max_length: Optional[int] = None, + pattern: Optional[str] = None, + discriminator: Union[str, None] = None, + strict: Union[bool, None] = _Unset, + multiple_of: Union[float, None] = _Unset, + allow_inf_nan: Union[bool, None] = _Unset, + max_digits: Union[int, None] = _Unset, + decimal_places: Union[int, None] = _Unset, + examples: Optional[List[Any]] = None, + deprecated: Optional[bool] = None, + include_in_schema: bool = True, + json_schema_extra: Union[Dict[str, Any], None] = None, + **extra: Any, + ): + self.deprecated = deprecated + self.include_in_schema = include_in_schema + + kwargs = dict( + default=default, + default_factory=default_factory, + alias=alias, + title=title, + description=description, + gt=gt, + ge=ge, + lt=lt, + le=le, + min_length=min_length, + max_length=max_length, + discriminator=discriminator, + multiple_of=multiple_of, + allow_nan=allow_inf_nan, + max_digits=max_digits, + decimal_places=decimal_places, + **extra, + ) + if examples is not None: + kwargs["examples"] = examples + + current_json_schema_extra = json_schema_extra or extra + if PYDANTIC_V2: + kwargs.update( + { + "annotation": annotation, + "alias_priority": alias_priority, + "validation_alias": validation_alias, + "serialization_alias": serialization_alias, + "strict": strict, + "json_schema_extra": current_json_schema_extra, + "pattern": pattern, + }, + ) + else: + kwargs["regex"] = pattern + kwargs.update(**current_json_schema_extra) + + use_kwargs = {k: v for k, v in kwargs.items() if v is not _Unset} + + super().__init__(**use_kwargs) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.default})" + + +class Path(Param): + """ + A class used internally to represent a path parameter in a path operation. + """ + + in_ = ParamTypes.path + + def __init__( + self, + default: Any = ..., + *, + default_factory: Union[Callable[[], Any], None] = _Unset, + annotation: Optional[Any] = None, + alias: Optional[str] = None, + alias_priority: Union[int, None] = _Unset, + # MAINTENANCE: update when deprecating Pydantic v1, import these types + # MAINTENANCE: validation_alias: str | AliasPath | AliasChoices | None + validation_alias: Union[str, None] = None, + serialization_alias: Union[str, None] = None, + title: Optional[str] = None, + description: Optional[str] = None, + gt: Optional[float] = None, + ge: Optional[float] = None, + lt: Optional[float] = None, + le: Optional[float] = None, + min_length: Optional[int] = None, + max_length: Optional[int] = None, + pattern: Optional[str] = None, + discriminator: Union[str, None] = None, + strict: Union[bool, None] = _Unset, + multiple_of: Union[float, None] = _Unset, + allow_inf_nan: Union[bool, None] = _Unset, + max_digits: Union[int, None] = _Unset, + decimal_places: Union[int, None] = _Unset, + examples: Optional[List[Any]] = None, + deprecated: Optional[bool] = None, + include_in_schema: bool = True, + json_schema_extra: Union[Dict[str, Any], None] = None, + **extra: Any, + ): + if default is not ...: + raise AssertionError("Path parameters cannot have a default value") + + super(Path, self).__init__( + default=default, + default_factory=default_factory, + annotation=annotation, + alias=alias, + alias_priority=alias_priority, + validation_alias=validation_alias, + serialization_alias=serialization_alias, + title=title, + description=description, + gt=gt, + ge=ge, + lt=lt, + le=le, + min_length=min_length, + max_length=max_length, + pattern=pattern, + discriminator=discriminator, + strict=strict, + multiple_of=multiple_of, + allow_inf_nan=allow_inf_nan, + max_digits=max_digits, + decimal_places=decimal_places, + deprecated=deprecated, + examples=examples, + include_in_schema=include_in_schema, + json_schema_extra=json_schema_extra, + **extra, + ) + + +class Query(Param): + """ + A class used internally to represent a query parameter in a path operation. + """ + + in_ = ParamTypes.query + + def __init__( + self, + default: Any = _Unset, + *, + default_factory: Union[Callable[[], Any], None] = _Unset, + annotation: Optional[Any] = None, + alias: Optional[str] = None, + alias_priority: Union[int, None] = _Unset, + validation_alias: Union[str, None] = None, + serialization_alias: Union[str, None] = None, + title: Optional[str] = None, + description: Optional[str] = None, + gt: Optional[float] = None, + ge: Optional[float] = None, + lt: Optional[float] = None, + le: Optional[float] = None, + min_length: Optional[int] = None, + max_length: Optional[int] = None, + pattern: Optional[str] = None, + discriminator: Union[str, None] = None, + strict: Union[bool, None] = _Unset, + multiple_of: Union[float, None] = _Unset, + allow_inf_nan: Union[bool, None] = _Unset, + max_digits: Union[int, None] = _Unset, + decimal_places: Union[int, None] = _Unset, + examples: Optional[List[Any]] = None, + deprecated: Optional[bool] = None, + include_in_schema: bool = True, + json_schema_extra: Union[Dict[str, Any], None] = None, + **extra: Any, + ): + super().__init__( + default=default, + default_factory=default_factory, + annotation=annotation, + alias=alias, + alias_priority=alias_priority, + validation_alias=validation_alias, + serialization_alias=serialization_alias, + title=title, + description=description, + gt=gt, + ge=ge, + lt=lt, + le=le, + min_length=min_length, + max_length=max_length, + pattern=pattern, + discriminator=discriminator, + strict=strict, + multiple_of=multiple_of, + allow_inf_nan=allow_inf_nan, + max_digits=max_digits, + decimal_places=decimal_places, + deprecated=deprecated, + examples=examples, + include_in_schema=include_in_schema, + json_schema_extra=json_schema_extra, + **extra, + ) + + +class Header(Param): + """ + A class used internally to represent a header parameter in a path operation. + """ + + in_ = ParamTypes.header + + def __init__( + self, + default: Any = Undefined, + *, + default_factory: Union[Callable[[], Any], None] = _Unset, + annotation: Optional[Any] = None, + alias: Optional[str] = None, + alias_priority: Union[int, None] = _Unset, + # MAINTENANCE: update when deprecating Pydantic v1, import these types + # str | AliasPath | AliasChoices | None + validation_alias: Union[str, None] = None, + serialization_alias: Union[str, None] = None, + convert_underscores: bool = True, + title: Optional[str] = None, + description: Optional[str] = None, + gt: Optional[float] = None, + ge: Optional[float] = None, + lt: Optional[float] = None, + le: Optional[float] = None, + min_length: Optional[int] = None, + max_length: Optional[int] = None, + pattern: Optional[str] = None, + discriminator: Union[str, None] = None, + strict: Union[bool, None] = _Unset, + multiple_of: Union[float, None] = _Unset, + allow_inf_nan: Union[bool, None] = _Unset, + max_digits: Union[int, None] = _Unset, + decimal_places: Union[int, None] = _Unset, + examples: Optional[List[Any]] = None, + deprecated: Optional[bool] = None, + include_in_schema: bool = True, + json_schema_extra: Union[Dict[str, Any], None] = None, + **extra: Any, + ): + self.convert_underscores = convert_underscores + super().__init__( + default=default, + default_factory=default_factory, + annotation=annotation, + alias=alias, + alias_priority=alias_priority, + validation_alias=validation_alias, + serialization_alias=serialization_alias, + title=title, + description=description, + gt=gt, + ge=ge, + lt=lt, + le=le, + min_length=min_length, + max_length=max_length, + pattern=pattern, + discriminator=discriminator, + strict=strict, + multiple_of=multiple_of, + allow_inf_nan=allow_inf_nan, + max_digits=max_digits, + decimal_places=decimal_places, + deprecated=deprecated, + examples=examples, + include_in_schema=include_in_schema, + json_schema_extra=json_schema_extra, + **extra, + ) + + +class Body(FieldInfo): + """ + A class used internally to represent a body parameter in a path operation. + """ + + def __init__( + self, + default: Any = Undefined, + *, + default_factory: Union[Callable[[], Any], None] = _Unset, + annotation: Optional[Any] = None, + embed: bool = False, + media_type: str = "application/json", + alias: Optional[str] = None, + alias_priority: Union[int, None] = _Unset, + # MAINTENANCE: update when deprecating Pydantic v1, import these types + # str | AliasPath | AliasChoices | None + validation_alias: Union[str, None] = None, + serialization_alias: Union[str, None] = None, + title: Optional[str] = None, + description: Optional[str] = None, + gt: Optional[float] = None, + ge: Optional[float] = None, + lt: Optional[float] = None, + le: Optional[float] = None, + min_length: Optional[int] = None, + max_length: Optional[int] = None, + pattern: Optional[str] = None, + discriminator: Union[str, None] = None, + strict: Union[bool, None] = _Unset, + multiple_of: Union[float, None] = _Unset, + allow_inf_nan: Union[bool, None] = _Unset, + max_digits: Union[int, None] = _Unset, + decimal_places: Union[int, None] = _Unset, + examples: Optional[List[Any]] = None, + deprecated: Optional[bool] = None, + include_in_schema: bool = True, + json_schema_extra: Union[Dict[str, Any], None] = None, + **extra: Any, + ): + self.embed = embed + self.media_type = media_type + self.deprecated = deprecated + self.include_in_schema = include_in_schema + kwargs = dict( + default=default, + default_factory=default_factory, + alias=alias, + title=title, + description=description, + gt=gt, + ge=ge, + lt=lt, + le=le, + min_length=min_length, + max_length=max_length, + discriminator=discriminator, + multiple_of=multiple_of, + allow_nan=allow_inf_nan, + max_digits=max_digits, + decimal_places=decimal_places, + **extra, + ) + if examples is not None: + kwargs["examples"] = examples + current_json_schema_extra = json_schema_extra or extra + if PYDANTIC_V2: + kwargs.update( + { + "annotation": annotation, + "alias_priority": alias_priority, + "validation_alias": validation_alias, + "serialization_alias": serialization_alias, + "strict": strict, + "json_schema_extra": current_json_schema_extra, + "pattern": pattern, + }, + ) + else: + kwargs["regex"] = pattern + kwargs.update(**current_json_schema_extra) + + use_kwargs = {k: v for k, v in kwargs.items() if v is not _Unset} + + super().__init__(**use_kwargs) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.default})" + + +class Form(Body): + """ + A class used internally to represent a form parameter in a path operation. + """ + + def __init__( + self, + default: Any = Undefined, + *, + default_factory: Union[Callable[[], Any], None] = _Unset, + annotation: Optional[Any] = None, + media_type: str = "application/x-www-form-urlencoded", + alias: Optional[str] = None, + alias_priority: Union[int, None] = _Unset, + # MAINTENANCE: update when deprecating Pydantic v1, import these types + # str | AliasPath | AliasChoices | None + validation_alias: Union[str, None] = None, + serialization_alias: Union[str, None] = None, + title: Optional[str] = None, + description: Optional[str] = None, + gt: Optional[float] = None, + ge: Optional[float] = None, + lt: Optional[float] = None, + le: Optional[float] = None, + min_length: Optional[int] = None, + max_length: Optional[int] = None, + pattern: Optional[str] = None, + discriminator: Union[str, None] = None, + strict: Union[bool, None] = _Unset, + multiple_of: Union[float, None] = _Unset, + allow_inf_nan: Union[bool, None] = _Unset, + max_digits: Union[int, None] = _Unset, + decimal_places: Union[int, None] = _Unset, + examples: Optional[List[Any]] = None, + deprecated: Optional[bool] = None, + include_in_schema: bool = True, + json_schema_extra: Union[Dict[str, Any], None] = None, + **extra: Any, + ): + super().__init__( + default=default, + default_factory=default_factory, + annotation=annotation, + embed=True, + media_type=media_type, + alias=alias, + alias_priority=alias_priority, + validation_alias=validation_alias, + serialization_alias=serialization_alias, + title=title, + description=description, + gt=gt, + ge=ge, + lt=lt, + le=le, + min_length=min_length, + max_length=max_length, + pattern=pattern, + discriminator=discriminator, + strict=strict, + multiple_of=multiple_of, + allow_inf_nan=allow_inf_nan, + max_digits=max_digits, + decimal_places=decimal_places, + deprecated=deprecated, + examples=examples, + include_in_schema=include_in_schema, + json_schema_extra=json_schema_extra, + **extra, + ) + + +class File(Form): + """ + A class used internally to represent a file parameter in a path operation. + """ + + def __init__( + self, + default: Any = Undefined, + *, + default_factory: Union[Callable[[], Any], None] = _Unset, + annotation: Optional[Any] = None, + media_type: str = "multipart/form-data", + alias: Optional[str] = None, + alias_priority: Union[int, None] = _Unset, + # MAINTENANCE: update when deprecating Pydantic v1, import these types + # str | AliasPath | AliasChoices | None + validation_alias: Union[str, None] = None, + serialization_alias: Union[str, None] = None, + title: Optional[str] = None, + description: Optional[str] = None, + gt: Optional[float] = None, + ge: Optional[float] = None, + lt: Optional[float] = None, + le: Optional[float] = None, + min_length: Optional[int] = None, + max_length: Optional[int] = None, + pattern: Optional[str] = None, + discriminator: Union[str, None] = None, + strict: Union[bool, None] = _Unset, + multiple_of: Union[float, None] = _Unset, + allow_inf_nan: Union[bool, None] = _Unset, + max_digits: Union[int, None] = _Unset, + decimal_places: Union[int, None] = _Unset, + examples: Optional[List[Any]] = None, + deprecated: Optional[bool] = None, + include_in_schema: bool = True, + json_schema_extra: Union[Dict[str, Any], None] = None, + **extra: Any, + ): + super().__init__( + default=default, + default_factory=default_factory, + annotation=annotation, + media_type=media_type, + alias=alias, + alias_priority=alias_priority, + validation_alias=validation_alias, + serialization_alias=serialization_alias, + title=title, + description=description, + gt=gt, + ge=ge, + lt=lt, + le=le, + min_length=min_length, + max_length=max_length, + pattern=pattern, + discriminator=discriminator, + strict=strict, + multiple_of=multiple_of, + allow_inf_nan=allow_inf_nan, + max_digits=max_digits, + decimal_places=decimal_places, + deprecated=deprecated, + examples=examples, + include_in_schema=include_in_schema, + json_schema_extra=json_schema_extra, + **extra, + ) + + +def get_flat_dependant( + dependant: Dependant, + visited: Optional[List[CacheKey]] = None, +) -> Dependant: + """ + Flatten a recursive Dependant model structure. + + This function recursively concatenates the parameter fields of a Dependant model and its dependencies into a flat + Dependant structure. This is useful for scenarios like parameter validation where the nested structure is not + relevant. + + Parameters + ---------- + dependant: Dependant + The dependant model to flatten + skip_repeats: bool + If True, child Dependents already visited will be skipped to avoid duplicates + visited: List[CacheKey], optional + Keeps track of visited Dependents to avoid infinite recursion. Defaults to empty list. + + Returns + ------- + Dependant + The flattened Dependant model + """ + if visited is None: + visited = [] + visited.append(dependant.cache_key) + + return Dependant( + path_params=dependant.path_params.copy(), + query_params=dependant.query_params.copy(), + header_params=dependant.header_params.copy(), + cookie_params=dependant.cookie_params.copy(), + body_params=dependant.body_params.copy(), + path=dependant.path, + ) + + +def analyze_param( + *, + param_name: str, + annotation: Any, + value: Any, + is_path_param: bool, + is_response_param: bool, +) -> Optional[ModelField]: + """ + Analyze a parameter annotation and value to determine the type and default value of the parameter. + + Parameters + ---------- + param_name: str + The name of the parameter + annotation + The annotation of the parameter + value + The value of the parameter + is_path_param + Whether the parameter is a path parameter + is_response_param + Whether the parameter is the return annotation + + Returns + ------- + Optional[ModelField] + The type annotation and the Pydantic field representing the parameter + """ + field_info, type_annotation = get_field_info_and_type_annotation(annotation, value, is_path_param) + + # If the value is a FieldInfo, we use it as the FieldInfo for the parameter + if isinstance(value, FieldInfo): + if field_info is not None: + raise AssertionError("Cannot use a FieldInfo as a parameter annotation and pass a FieldInfo as a value") + field_info = value + + if PYDANTIC_V2: + field_info.annotation = type_annotation # type: ignore[attr-defined,unused-ignore] + + # If we didn't determine the FieldInfo yet, we create a default one + if field_info is None: + default_value = value if value is not inspect.Signature.empty else Required + + # Check if the parameter is part of the path. Otherwise, defaults to query. + if is_path_param: + field_info = Path(annotation=type_annotation) + elif not field_annotation_is_scalar(annotation=type_annotation): + field_info = Body(annotation=type_annotation, default=default_value) + else: + field_info = Query(annotation=type_annotation, default=default_value) + + # When we have a response field, we need to set the default value to Required + if is_response_param: + field_info.default = Required + + field = _create_model_field(field_info, type_annotation, param_name, is_path_param) + return field + + +def get_field_info_and_type_annotation(annotation, value, is_path_param: bool) -> Tuple[Optional[FieldInfo], Any]: + """ + Get the FieldInfo and type annotation from an annotation and value. + """ + field_info: Optional[FieldInfo] = None + type_annotation: Any = Any + + if annotation is not inspect.Signature.empty: + # If the annotation is an Annotated type, we need to extract the type annotation and the FieldInfo + if get_origin(annotation) is Annotated: + field_info, type_annotation = get_field_info_annotated_type(annotation, value, is_path_param) + # If the annotation is not an Annotated type, we use it as the type annotation + else: + type_annotation = annotation + + return field_info, type_annotation + + +def get_field_info_annotated_type(annotation, value, is_path_param: bool) -> Tuple[Optional[FieldInfo], Any]: + """ + Get the FieldInfo and type annotation from an Annotated type. + """ + field_info: Optional[FieldInfo] = None + annotated_args = get_args(annotation) + type_annotation = annotated_args[0] + powertools_annotations = [arg for arg in annotated_args[1:] if isinstance(arg, FieldInfo)] + + if len(powertools_annotations) > 1: + raise AssertionError("Only one FieldInfo can be used per parameter") + + powertools_annotation = next(iter(powertools_annotations), None) + + if isinstance(powertools_annotation, FieldInfo): + # Copy `field_info` because we mutate `field_info.default` later + field_info = copy_field_info( + field_info=powertools_annotation, + annotation=annotation, + ) + if field_info.default not in [Undefined, Required]: + raise AssertionError("FieldInfo needs to have a default value of Undefined or Required") + + if value is not inspect.Signature.empty: + if is_path_param: + raise AssertionError("Cannot use a FieldInfo as a path parameter and pass a value") + field_info.default = value + else: + field_info.default = Required + + return field_info, type_annotation + + +def create_response_field( + name: str, + type_: Type[Any], + default: Optional[Any] = Undefined, + required: Union[bool, UndefinedType] = Undefined, + model_config: Type[BaseConfig] = BaseConfig, + field_info: Optional[FieldInfo] = None, + alias: Optional[str] = None, + mode: Literal["validation", "serialization"] = "validation", +) -> ModelField: + """ + Create a new response field. Raises if type_ is invalid. + """ + if PYDANTIC_V2: + field_info = field_info or FieldInfo( + annotation=type_, + default=default, + alias=alias, + ) + else: + field_info = field_info or FieldInfo() + kwargs = {"name": name, "field_info": field_info} + + if PYDANTIC_V2: + kwargs.update({"mode": mode}) + else: + kwargs.update( + { + "type_": type_, + "class_validators": {}, + "default": default, + "required": required, + "model_config": model_config, + "alias": alias, + }, + ) + return ModelField(**kwargs) # type: ignore[arg-type] + + +def _create_model_field( + field_info: Optional[FieldInfo], + type_annotation: Any, + param_name: str, + is_path_param: bool, +) -> Optional[ModelField]: + """ + Create a new ModelField from a FieldInfo and type annotation. + """ + if field_info is None: + return None + + if is_path_param: + if not isinstance(field_info, Path): + raise AssertionError("Path parameters must be of type Path") + elif isinstance(field_info, Param) and getattr(field_info, "in_", None) is None: + field_info.in_ = ParamTypes.query + + # If the field_info is a Param, we use the `in_` attribute to determine the type annotation + use_annotation = get_annotation_from_field_info(type_annotation, field_info, param_name) + + # If the field doesn't have a defined alias, we use the param name + if not field_info.alias and getattr(field_info, "convert_underscores", None): + alias = param_name.replace("_", "-") + else: + alias = field_info.alias or param_name + field_info.alias = alias + + return create_response_field( + name=param_name, + type_=use_annotation, + default=field_info.default, + alias=alias, + required=field_info.default in (Required, Undefined), + field_info=field_info, + ) diff --git a/aws_lambda_powertools/event_handler/openapi/types.py b/aws_lambda_powertools/event_handler/openapi/types.py new file mode 100644 index 00000000000..9161d8dc170 --- /dev/null +++ b/aws_lambda_powertools/event_handler/openapi/types.py @@ -0,0 +1,52 @@ +import types +from enum import Enum +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Set, Type, Union + +if TYPE_CHECKING: + from pydantic import BaseModel # noqa: F401 + +CacheKey = Optional[Callable[..., Any]] +IncEx = Union[Set[int], Set[str], Dict[int, Any], Dict[str, Any]] +ModelNameMap = Dict[Union[Type["BaseModel"], Type[Enum]], str] +TypeModelOrEnum = Union[Type["BaseModel"], Type[Enum]] +UnionType = getattr(types, "UnionType", Union) + + +COMPONENT_REF_PREFIX = "#/components/schemas/" +COMPONENT_REF_TEMPLATE = "#/components/schemas/{model}" +METHODS_WITH_BODY = {"GET", "HEAD", "POST", "PUT", "DELETE", "PATCH"} + +try: + from pydantic.version import VERSION as PYDANTIC_VERSION + + PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.") +except ImportError: + PYDANTIC_V2 = False + + +validation_error_definition = { + "title": "ValidationError", + "type": "object", + "properties": { + "loc": { + "title": "Location", + "type": "array", + "items": {"anyOf": [{"type": "string"}, {"type": "integer"}]}, + }, + "msg": {"title": "Message", "type": "string"}, + "type": {"title": "Error Type", "type": "string"}, + }, + "required": ["loc", "msg", "type"], +} + +validation_error_response_definition = { + "title": "HTTPValidationError", + "type": "object", + "properties": { + "detail": { + "title": "Detail", + "type": "array", + "items": {"$ref": COMPONENT_REF_PREFIX + "ValidationError"}, + }, + }, +} diff --git a/aws_lambda_powertools/event_handler/vpc_lattice.py b/aws_lambda_powertools/event_handler/vpc_lattice.py index bcee046e382..4fa8d061afb 100644 --- a/aws_lambda_powertools/event_handler/vpc_lattice.py +++ b/aws_lambda_powertools/event_handler/vpc_lattice.py @@ -48,9 +48,10 @@ def __init__( debug: Optional[bool] = None, serializer: Optional[Callable[[Dict], str]] = None, strip_prefixes: Optional[List[Union[str, Pattern]]] = None, + enable_validation: bool = False, ): """Amazon VPC Lattice resolver""" - super().__init__(ProxyEventType.VPCLatticeEvent, cors, debug, serializer, strip_prefixes) + super().__init__(ProxyEventType.VPCLatticeEvent, cors, debug, serializer, strip_prefixes, enable_validation) class VPCLatticeV2Resolver(ApiGatewayResolver): @@ -93,6 +94,7 @@ def __init__( debug: Optional[bool] = None, serializer: Optional[Callable[[Dict], str]] = None, strip_prefixes: Optional[List[Union[str, Pattern]]] = None, + enable_validation: bool = False, ): """Amazon VPC Lattice resolver""" - super().__init__(ProxyEventType.VPCLatticeEventV2, cors, debug, serializer, strip_prefixes) + super().__init__(ProxyEventType.VPCLatticeEventV2, cors, debug, serializer, strip_prefixes, enable_validation) diff --git a/aws_lambda_powertools/shared/types.py b/aws_lambda_powertools/shared/types.py index 633db46c587..100005159e4 100644 --- a/aws_lambda_powertools/shared/types.py +++ b/aws_lambda_powertools/shared/types.py @@ -6,6 +6,10 @@ else: from typing_extensions import Literal, Protocol, TypedDict +if sys.version_info >= (3, 9): + from typing import Annotated +else: + from typing_extensions import Annotated if sys.version_info >= (3, 11): from typing import NotRequired @@ -13,13 +17,15 @@ from typing_extensions import NotRequired +# Even though `get_args` and `get_origin` were added in Python 3.8, they only handle Annotated correctly on 3.10. +# So for python < 3.10 we use the backport from typing_extensions. if sys.version_info >= (3, 10): - from typing import TypeAlias + from typing import TypeAlias, get_args, get_origin else: - from typing_extensions import TypeAlias + from typing_extensions import TypeAlias, get_args, get_origin AnyCallableT = TypeVar("AnyCallableT", bound=Callable[..., Any]) # noqa: VNE001 # JSON primitives only, mypy doesn't support recursive tho JSONType = Union[str, int, float, bool, None, Dict[str, Any], List[Any]] -__all__ = ["Protocol", "TypedDict", "Literal", "NotRequired", "TypeAlias"] +__all__ = ["get_args", "get_origin", "Annotated", "Protocol", "TypedDict", "Literal", "NotRequired", "TypeAlias"] diff --git a/codecov.yml b/codecov.yml new file mode 100644 index 00000000000..caff9bcbfab --- /dev/null +++ b/codecov.yml @@ -0,0 +1,2 @@ +ignore: + - "aws_lambda_powertools/event_handler/openapi/compat.py" diff --git a/pyproject.toml b/pyproject.toml index d3883b6167c..b1702929515 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -169,6 +169,7 @@ exclude = ''' | buck-out | build | dist + | aws_lambda_powertools/event_handler/openapi/compat.py )/ | example ) diff --git a/ruff.toml b/ruff.toml index a0f8e4fe74f..553a8c47b3d 100644 --- a/ruff.toml +++ b/ruff.toml @@ -87,5 +87,6 @@ split-on-trailing-comma = true "tests/e2e/utils/data_fetcher/__init__.py" = ["F401"] "aws_lambda_powertools/utilities/data_classes/s3_event.py" = ["A003"] "aws_lambda_powertools/utilities/parser/models/__init__.py" = ["E402"] +"aws_lambda_powertools/event_handler/openapi/compat.py" = ["F401"] # Maintenance: we're keeping EphemeralMetrics code in case of Hyrum's law so we can quickly revert it "aws_lambda_powertools/metrics/metrics.py" = ["ERA001"] diff --git a/tests/functional/event_handler/test_openapi_encoders.py b/tests/functional/event_handler/test_openapi_encoders.py new file mode 100644 index 00000000000..4062384b16e --- /dev/null +++ b/tests/functional/event_handler/test_openapi_encoders.py @@ -0,0 +1,195 @@ +import math +from dataclasses import dataclass +from typing import List + +import pytest +from pydantic import BaseModel +from pydantic.color import Color + +from aws_lambda_powertools.event_handler.openapi.encoders import jsonable_encoder + + +@pytest.fixture +def pydanticv1_only(): + from pydantic import __version__ + + version = __version__.split(".") + if version[0] != "1": + pytest.skip("pydanticv1 test only") + + +def test_openapi_encode_include(): + class User(BaseModel): + name: str + age: int + + result = jsonable_encoder(User(name="John", age=20), include=["name"]) + assert result == {"name": "John"} + + +def test_openapi_encode_exclude(): + class User(BaseModel): + name: str + age: int + + result = jsonable_encoder(User(name="John", age=20), exclude=["age"]) + assert result == {"name": "John"} + + +def test_openapi_encode_pydantic(): + class Order(BaseModel): + quantity: int + + class User(BaseModel): + name: str + order: Order + + result = jsonable_encoder(User(name="John", order=Order(quantity=2))) + assert result == {"name": "John", "order": {"quantity": 2}} + + +@pytest.mark.usefixtures("pydanticv1_only") +def test_openapi_encode_pydantic_root_types(): + class User(BaseModel): + __root__: List[str] + + result = jsonable_encoder(User(__root__=["John", "Jane"])) + assert result == ["John", "Jane"] + + +def test_openapi_encode_dataclass(): + @dataclass + class Order: + quantity: int + + @dataclass + class User: + name: str + order: Order + + result = jsonable_encoder(User(name="John", order=Order(quantity=2))) + assert result == {"name": "John", "order": {"quantity": 2}} + + +def test_openapi_encode_enum(): + from enum import Enum + + class Color(Enum): + RED = "red" + GREEN = "green" + BLUE = "blue" + + result = jsonable_encoder(Color.RED) + assert result == "red" + + +def test_openapi_encode_purepath(): + from pathlib import PurePath + + result = jsonable_encoder(PurePath("/foo/bar")) + assert result == "/foo/bar" + + +def test_openapi_encode_scalars(): + result = jsonable_encoder("foo") + assert result == "foo" + + result = jsonable_encoder(1) + assert result == 1 + + result = jsonable_encoder(1.0) + assert math.isclose(result, 1.0) + + result = jsonable_encoder(True) + assert result is True + + result = jsonable_encoder(None) + assert result is None + + +def test_openapi_encode_dict(): + result = jsonable_encoder({"foo": "bar"}) + assert result == {"foo": "bar"} + + +def test_openapi_encode_dict_with_include(): + result = jsonable_encoder({"foo": "bar", "bar": "foo"}, include=["foo"]) + assert result == {"foo": "bar"} + + +def test_openapi_encode_dict_with_exclude(): + result = jsonable_encoder({"foo": "bar", "bar": "foo"}, exclude=["bar"]) + assert result == {"foo": "bar"} + + +def test_openapi_encode_sequences(): + result = jsonable_encoder(["foo", "bar"]) + assert result == ["foo", "bar"] + + result = jsonable_encoder(("foo", "bar")) + assert result == ["foo", "bar"] + + result = jsonable_encoder({"foo", "bar"}) + assert set(result) == {"foo", "bar"} + + result = jsonable_encoder(frozenset(("foo", "bar"))) + assert set(result) == {"foo", "bar"} + + +def test_openapi_encode_bytes(): + result = jsonable_encoder(b"foo") + assert result == "foo" + + +def test_openapi_encode_timedelta(): + from datetime import timedelta + + result = jsonable_encoder(timedelta(seconds=1)) + assert result == 1 + + +def test_openapi_encode_decimal(): + from decimal import Decimal + + result = jsonable_encoder(Decimal("1.0")) + assert math.isclose(result, 1.0) + + result = jsonable_encoder(Decimal("1")) + assert result == 1 + + +def test_openapi_encode_uuid(): + from uuid import UUID + + result = jsonable_encoder(UUID("123e4567-e89b-12d3-a456-426614174000")) + assert result == "123e4567-e89b-12d3-a456-426614174000" + + +def test_openapi_encode_encodable(): + from datetime import date, datetime, time + + result = jsonable_encoder(date(2021, 1, 1)) + assert result == "2021-01-01" + + result = jsonable_encoder(datetime(2021, 1, 1, 0, 0, 0)) + assert result == "2021-01-01T00:00:00" + + result = jsonable_encoder(time(0, 0, 0)) + assert result == "00:00:00" + + +def test_openapi_encode_subclasses(): + class MyColor(Color): + pass + + result = jsonable_encoder(MyColor("red")) + assert result == "red" + + +def test_openapi_encode_other(): + class User: + def __init__(self, name: str): + self.name = name + + result = jsonable_encoder(User(name="John")) + assert result == {"name": "John"} diff --git a/tests/functional/event_handler/test_openapi_params.py b/tests/functional/event_handler/test_openapi_params.py new file mode 100644 index 00000000000..b5f4afa9fbe --- /dev/null +++ b/tests/functional/event_handler/test_openapi_params.py @@ -0,0 +1,315 @@ +from dataclasses import dataclass +from datetime import datetime +from typing import List + +from pydantic import BaseModel + +from aws_lambda_powertools.event_handler.api_gateway import APIGatewayRestResolver +from aws_lambda_powertools.event_handler.openapi.models import ( + Example, + Parameter, + ParameterInType, + Schema, +) +from aws_lambda_powertools.event_handler.openapi.params import ( + Body, + Header, + Param, + ParamTypes, + Query, + _create_model_field, +) +from aws_lambda_powertools.shared.types import Annotated + +JSON_CONTENT_TYPE = "application/json" + + +def test_openapi_no_params(): + app = APIGatewayRestResolver() + + @app.get("/") + def handler(): + raise NotImplementedError() + + schema = app.get_openapi_schema() + assert schema.info.title == "Powertools API" + assert schema.info.version == "1.0.0" + + assert len(schema.paths.keys()) == 1 + assert "/" in schema.paths + + path = schema.paths["/"] + assert path.get + + get = path.get + assert get.summary == "GET /" + assert get.operationId == "handler__get" + + assert get.responses is not None + assert 200 in get.responses.keys() + response = get.responses[200] + assert response.description == "Successful Response" + + assert JSON_CONTENT_TYPE in response.content + json_response = response.content[JSON_CONTENT_TYPE] + assert json_response.schema_ == Schema() + assert not json_response.examples + assert not json_response.encoding + + +def test_openapi_with_scalar_params(): + app = APIGatewayRestResolver() + + @app.get("/users/") + def handler(user_id: str, include_extra: bool = False): + raise NotImplementedError() + + schema = app.get_openapi_schema(title="My API", version="0.2.2") + assert schema.info.title == "My API" + assert schema.info.version == "0.2.2" + + assert len(schema.paths.keys()) == 1 + assert "/users/" in schema.paths + + path = schema.paths["/users/"] + assert path.get + + get = path.get + assert get.summary == "GET /users/" + assert get.operationId == "handler_users__user_id__get" + assert len(get.parameters) == 2 + + parameter = get.parameters[0] + assert isinstance(parameter, Parameter) + assert parameter.in_ == ParameterInType.path + assert parameter.name == "user_id" + assert parameter.required is True + assert parameter.schema_.default is None + assert parameter.schema_.type == "string" + assert parameter.schema_.title == "User Id" + + parameter = get.parameters[1] + assert isinstance(parameter, Parameter) + assert parameter.in_ == ParameterInType.query + assert parameter.name == "include_extra" + assert parameter.required is False + assert parameter.schema_.default is False + assert parameter.schema_.type == "boolean" + assert parameter.schema_.title == "Include Extra" + + +def test_openapi_with_custom_params(): + app = APIGatewayRestResolver() + + @app.get("/users", summary="Get Users", operation_id="GetUsers", description="Get paginated users", tags=["Users"]) + def handler( + count: Annotated[ + int, + Query(gt=0, lt=100, examples=[Example(summary="Example 1", value=10)]), + ] = 1, + ): + print(count) + raise NotImplementedError() + + schema = app.get_openapi_schema() + + get = schema.paths["/users"].get + assert len(get.parameters) == 1 + assert get.summary == "Get Users" + assert get.operationId == "GetUsers" + assert get.description == "Get paginated users" + assert get.tags == ["Users"] + + parameter = get.parameters[0] + assert parameter.required is False + assert parameter.name == "count" + assert parameter.in_ == ParameterInType.query + assert parameter.schema_.type == "integer" + assert parameter.schema_.default == 1 + assert parameter.schema_.title == "Count" + assert parameter.schema_.exclusiveMinimum == 0 + assert parameter.schema_.exclusiveMaximum == 100 + assert len(parameter.schema_.examples) == 1 + assert parameter.schema_.examples[0].summary == "Example 1" + assert parameter.schema_.examples[0].value == 10 + + +def test_openapi_with_scalar_returns(): + app = APIGatewayRestResolver() + + @app.get("/") + def handler() -> str: + return "Hello, world" + + schema = app.get_openapi_schema() + assert len(schema.paths.keys()) == 1 + + get = schema.paths["/"].get + assert get.parameters is None + + response = get.responses[200].content[JSON_CONTENT_TYPE] + assert response.schema_.title == "Return" + assert response.schema_.type == "string" + + +def test_openapi_with_pydantic_returns(): + app = APIGatewayRestResolver() + + class User(BaseModel): + name: str + + @app.get("/") + def handler() -> User: + return User(name="Ruben Fonseca") + + schema = app.get_openapi_schema() + assert len(schema.paths.keys()) == 1 + + get = schema.paths["/"].get + assert get.parameters is None + + response = get.responses[200].content[JSON_CONTENT_TYPE] + reference = response.schema_ + assert reference.ref == "#/components/schemas/User" + + assert "User" in schema.components.schemas + user_schema = schema.components.schemas["User"] + assert isinstance(user_schema, Schema) + assert user_schema.title == "User" + assert "name" in user_schema.properties + + +def test_openapi_with_pydantic_nested_returns(): + app = APIGatewayRestResolver() + + class Order(BaseModel): + date: datetime + + class User(BaseModel): + name: str + orders: List[Order] + + @app.get("/") + def handler() -> User: + return User(name="Ruben Fonseca", orders=[Order(date=datetime.now())]) + + schema = app.get_openapi_schema() + assert len(schema.paths.keys()) == 1 + + assert "User" in schema.components.schemas + assert "Order" in schema.components.schemas + + user_schema = schema.components.schemas["User"] + assert "orders" in user_schema.properties + assert user_schema.properties["orders"].type == "array" + + +def test_openapi_with_dataclass_return(): + app = APIGatewayRestResolver() + + @dataclass + class User: + surname: str + + @app.get("/") + def handler() -> User: + return User(surname="Fonseca") + + schema = app.get_openapi_schema() + assert len(schema.paths.keys()) == 1 + + get = schema.paths["/"].get + assert get.parameters is None + + response = get.responses[200].content[JSON_CONTENT_TYPE] + reference = response.schema_ + assert reference.ref == "#/components/schemas/User" + + assert "User" in schema.components.schemas + user_schema = schema.components.schemas["User"] + assert isinstance(user_schema, Schema) + assert user_schema.title == "User" + assert "surname" in user_schema.properties + + +def test_openapi_with_body_param(): + app = APIGatewayRestResolver() + + class User(BaseModel): + name: str + + @app.post("/users") + def handler(user: User): + print(user) + + schema = app.get_openapi_schema() + assert len(schema.paths.keys()) == 1 + + post = schema.paths["/users"].post + assert post.parameters is None + assert post.requestBody is not None + + request_body = post.requestBody + assert request_body.required is True + assert request_body.content[JSON_CONTENT_TYPE].schema_.ref == "#/components/schemas/User" + + +def test_openapi_with_embed_body_param(): + app = APIGatewayRestResolver() + + class User(BaseModel): + name: str + + @app.post("/users") + def handler(user: Annotated[User, Body(embed=True)]): + print(user) + + schema = app.get_openapi_schema() + assert len(schema.paths.keys()) == 1 + + post = schema.paths["/users"].post + assert post.parameters is None + assert post.requestBody is not None + + request_body = post.requestBody + assert request_body.required is True + # Notice here we craft a specific schema for the embedded user + assert request_body.content[JSON_CONTENT_TYPE].schema_.ref == "#/components/schemas/Body_handler_users_post" + + # Ensure that the custom body schema actually points to the real user class + components = schema.components + assert "Body_handler_users_post" in components.schemas + body_post_handler_schema = components.schemas["Body_handler_users_post"] + assert body_post_handler_schema.properties["user"].ref == "#/components/schemas/User" + + +def test_create_header(): + header = Header(convert_underscores=True) + assert header.convert_underscores is True + + +def test_create_body(): + body = Body(embed=True, examples=[Example(summary="Example 1", value=10)]) + assert body.embed is True + + +# Tests that when we try to create a model without a field type, we return None +def test_create_empty_model_field(): + result = _create_model_field(None, int, "name", False) + assert result is None + + +# Tests that when we try to crate a param model without a source, we default to "query" +def test_create_model_field_with_empty_in(): + field_info = Param() + + result = _create_model_field(field_info, int, "name", False) + assert result.field_info.in_ == ParamTypes.query + + +# Tests that when we try to create a model field with convert_underscore, we convert the field name +def test_create_model_field_convert_underscore(): + field_info = Header(alias=None, convert_underscores=True) + + result = _create_model_field(field_info, int, "user_id", False) + assert result.alias == "user-id" diff --git a/tests/functional/event_handler/test_openapi_responses.py b/tests/functional/event_handler/test_openapi_responses.py new file mode 100644 index 00000000000..bd470867428 --- /dev/null +++ b/tests/functional/event_handler/test_openapi_responses.py @@ -0,0 +1,49 @@ +from aws_lambda_powertools.event_handler import APIGatewayRestResolver + + +def test_openapi_default_response(): + app = APIGatewayRestResolver(enable_validation=True) + + @app.get("/") + def handler(): + pass + + schema = app.get_openapi_schema() + responses = schema.paths["/"].get.responses + assert 200 in responses.keys() + assert responses[200].description == "Successful Response" + + assert 422 in responses.keys() + assert responses[422].description == "Validation Error" + + +def test_openapi_200_response_with_description(): + app = APIGatewayRestResolver(enable_validation=True) + + @app.get("/", response_description="Custom response") + def handler(): + return {"message": "hello world"} + + schema = app.get_openapi_schema() + responses = schema.paths["/"].get.responses + assert 200 in responses.keys() + assert responses[200].description == "Custom response" + + assert 422 in responses.keys() + assert responses[422].description == "Validation Error" + + +def test_openapi_200_custom_response(): + app = APIGatewayRestResolver(enable_validation=True) + + @app.get("/", responses={202: {"description": "Custom response"}}) + def handler(): + return {"message": "hello world"} + + schema = app.get_openapi_schema() + responses = schema.paths["/"].get.responses + assert 202 in responses.keys() + assert responses[202].description == "Custom response" + + assert 200 not in responses.keys() + assert 422 not in responses.keys() diff --git a/tests/functional/event_handler/test_openapi_serialization.py b/tests/functional/event_handler/test_openapi_serialization.py new file mode 100644 index 00000000000..63f1c0e4f9d --- /dev/null +++ b/tests/functional/event_handler/test_openapi_serialization.py @@ -0,0 +1,39 @@ +import json +from typing import Dict + +import pytest + +from aws_lambda_powertools.event_handler import APIGatewayRestResolver + + +def test_openapi_duplicated_serialization(): + # GIVEN APIGatewayRestResolver is initialized with enable_validation=True + app = APIGatewayRestResolver(enable_validation=True) + + # WHEN we have duplicated operations + @app.get("/") + def handler(): + pass + + @app.get("/") + def handler(): # noqa: F811 + pass + + # THEN we should get a warning + with pytest.warns(UserWarning, match="Duplicate Operation*"): + app.get_openapi_schema() + + +def test_openapi_serialize_json(): + # GIVEN APIGatewayRestResolver is initialized with enable_validation=True + app = APIGatewayRestResolver(enable_validation=True) + + @app.get("/") + def handler(): + pass + + # WHEN we serialize as json_schema + schema = json.loads(app.get_openapi_json_schema()) + + # THEN we should get a dictionary + assert isinstance(schema, Dict) diff --git a/tests/functional/event_handler/test_openapi_servers.py b/tests/functional/event_handler/test_openapi_servers.py new file mode 100644 index 00000000000..a1ae70a1237 --- /dev/null +++ b/tests/functional/event_handler/test_openapi_servers.py @@ -0,0 +1,26 @@ +from aws_lambda_powertools.event_handler.api_gateway import APIGatewayRestResolver +from aws_lambda_powertools.event_handler.openapi.models import Server + + +def test_openapi_schema_default_server(): + app = APIGatewayRestResolver() + + schema = app.get_openapi_schema(title="Hello API", version="1.0.0") + assert schema.servers + assert len(schema.servers) == 1 + assert schema.servers[0].url == "/" + + +def test_openapi_schema_custom_server(): + app = APIGatewayRestResolver() + + schema = app.get_openapi_schema( + title="Hello API", + version="1.0.0", + servers=[Server(url="https://example.org/", description="Example website")], + ) + + assert schema.servers + assert len(schema.servers) == 1 + assert str(schema.servers[0].url) == "https://example.org/" + assert schema.servers[0].description == "Example website" diff --git a/tests/functional/event_handler/test_openapi_validation_middleware.py b/tests/functional/event_handler/test_openapi_validation_middleware.py new file mode 100644 index 00000000000..6b4b94405d8 --- /dev/null +++ b/tests/functional/event_handler/test_openapi_validation_middleware.py @@ -0,0 +1,290 @@ +import json +from dataclasses import dataclass +from enum import Enum +from pathlib import PurePath +from typing import Tuple + +from pydantic import BaseModel + +from aws_lambda_powertools.event_handler import APIGatewayRestResolver +from aws_lambda_powertools.event_handler.openapi.params import Body +from aws_lambda_powertools.shared.types import Annotated +from tests.functional.utils import load_event + +LOAD_GW_EVENT = load_event("apiGatewayProxyEvent.json") + + +def test_validate_scalars(): + # GIVEN an APIGatewayRestResolver with validation enabled + app = APIGatewayRestResolver(enable_validation=True) + + # WHEN a handler is defined with a scalar parameter + @app.get("/users/") + def handler(user_id: int): + print(user_id) + + # sending a number + LOAD_GW_EVENT["path"] = "/users/123" + + # THEN the handler should be invoked and return 200 + result = app(LOAD_GW_EVENT, {}) + assert result["statusCode"] == 200 + + # sending a string + LOAD_GW_EVENT["path"] = "/users/abc" + + # THEN the handler should be invoked and return 422 + result = app(LOAD_GW_EVENT, {}) + assert result["statusCode"] == 422 + assert any(text in result["body"] for text in ["type_error.integer", "int_parsing"]) + + +def test_validate_scalars_with_default(): + # GIVEN an APIGatewayRestResolver with validation enabled + app = APIGatewayRestResolver(enable_validation=True) + + # WHEN a handler is defined with a default scalar parameter + @app.get("/users/") + def handler(user_id: int = 123): + print(user_id) + + # sending a number + LOAD_GW_EVENT["path"] = "/users/123" + + # THEN the handler should be invoked and return 200 + result = app(LOAD_GW_EVENT, {}) + assert result["statusCode"] == 200 + + # sending a string + LOAD_GW_EVENT["path"] = "/users/abc" + + # THEN the handler should be invoked and return 422 + result = app(LOAD_GW_EVENT, {}) + assert result["statusCode"] == 422 + assert any(text in result["body"] for text in ["type_error.integer", "int_parsing"]) + + +def test_validate_scalars_with_default_and_optional(): + # GIVEN an APIGatewayRestResolver with validation enabled + app = APIGatewayRestResolver(enable_validation=True) + + # WHEN a handler is defined with a default scalar parameter + @app.get("/users/") + def handler(user_id: int = 123, include_extra: bool = False): + print(user_id) + + # sending a number + LOAD_GW_EVENT["path"] = "/users/123" + + # THEN the handler should be invoked and return 200 + result = app(LOAD_GW_EVENT, {}) + assert result["statusCode"] == 200 + + # sending a string + LOAD_GW_EVENT["path"] = "/users/abc" + + # THEN the handler should be invoked and return 422 + result = app(LOAD_GW_EVENT, {}) + assert result["statusCode"] == 422 + assert any(text in result["body"] for text in ["type_error.integer", "int_parsing"]) + + +def test_validate_return_type(): + # GIVEN an APIGatewayRestResolver with validation enabled + app = APIGatewayRestResolver(enable_validation=True) + + # WHEN a handler is defined with a return type + @app.get("/") + def handler() -> int: + return 123 + + LOAD_GW_EVENT["path"] = "/" + + # THEN the handler should be invoked and return 200 + # THEN the body must be 123 + result = app(LOAD_GW_EVENT, {}) + assert result["statusCode"] == 200 + assert result["body"] == 123 + + +def test_validate_return_tuple(): + # GIVEN an APIGatewayRestResolver with validation enabled + app = APIGatewayRestResolver(enable_validation=True) + + sample_tuple = (1, 2, 3) + + # WHEN a handler is defined with a return type as Tuple + @app.get("/") + def handler() -> Tuple: + return sample_tuple + + LOAD_GW_EVENT["path"] = "/" + + # THEN the handler should be invoked and return 200 + # THEN the body must be a tuple + result = app(LOAD_GW_EVENT, {}) + assert result["statusCode"] == 200 + assert result["body"] == list(sample_tuple) + + +def test_validate_return_purepath(): + # GIVEN an APIGatewayRestResolver with validation enabled + app = APIGatewayRestResolver(enable_validation=True) + + sample_path = PurePath(__file__) + + # WHEN a handler is defined with a return type as string + # WHEN return value is a PurePath + @app.get("/") + def handler() -> str: + return sample_path + + LOAD_GW_EVENT["path"] = "/" + + # THEN the handler should be invoked and return 200 + # THEN the body must be a string + result = app(LOAD_GW_EVENT, {}) + assert result["statusCode"] == 200 + assert result["body"] == sample_path.as_posix() + + +def test_validate_return_enum(): + # GIVEN an APIGatewayRestResolver with validation enabled + app = APIGatewayRestResolver(enable_validation=True) + + class Model(Enum): + name = "powertools" + + # WHEN a handler is defined with a return type as Enum + @app.get("/") + def handler() -> Model: + return Model.name.value + + LOAD_GW_EVENT["path"] = "/" + + # THEN the handler should be invoked and return 200 + # THEN the body must be a string + result = app(LOAD_GW_EVENT, {}) + assert result["statusCode"] == 200 + assert result["body"] == "powertools" + + +def test_validate_return_dataclass(): + # GIVEN an APIGatewayRestResolver with validation enabled + app = APIGatewayRestResolver(enable_validation=True) + + @dataclass + class Model: + name: str + age: int + + # WHEN a handler is defined with a return type as dataclass + @app.get("/") + def handler() -> Model: + return Model(name="John", age=30) + + LOAD_GW_EVENT["path"] = "/" + + # THEN the handler should be invoked and return 200 + # THEN the body must be a dict + result = app(LOAD_GW_EVENT, {}) + assert result["statusCode"] == 200 + assert result["body"] == {"name": "John", "age": 30} + + +def test_validate_return_model(): + # GIVEN an APIGatewayRestResolver with validation enabled + app = APIGatewayRestResolver(enable_validation=True) + + class Model(BaseModel): + name: str + age: int + + # WHEN a handler is defined with a return type as Pydantic model + @app.get("/") + def handler() -> Model: + return Model(name="John", age=30) + + LOAD_GW_EVENT["path"] = "/" + + # THEN the handler should be invoked and return 200 + # THEN the body must be a dict + result = app(LOAD_GW_EVENT, {}) + assert result["statusCode"] == 200 + assert result["body"] == {"name": "John", "age": 30} + + +def test_validate_invalid_return_model(): + # GIVEN an APIGatewayRestResolver with validation enabled + app = APIGatewayRestResolver(enable_validation=True) + + class Model(BaseModel): + name: str + age: int + + # WHEN a handler is defined with a return type as Pydantic model + @app.get("/") + def handler() -> Model: + return {"name": "John"} # type: ignore + + LOAD_GW_EVENT["path"] = "/" + + # THEN the handler should be invoked and return 422 + # THEN the body must be a dict + result = app(LOAD_GW_EVENT, {}) + assert result["statusCode"] == 422 + assert "missing" in result["body"] + + +def test_validate_body_param(): + # GIVEN an APIGatewayRestResolver with validation enabled + app = APIGatewayRestResolver(enable_validation=True) + + class Model(BaseModel): + name: str + age: int + + # WHEN a handler is defined with a body parameter + @app.post("/") + def handler(user: Model) -> Model: + return user + + LOAD_GW_EVENT["httpMethod"] = "POST" + LOAD_GW_EVENT["path"] = "/" + LOAD_GW_EVENT["body"] = json.dumps({"name": "John", "age": 30}) + + # THEN the handler should be invoked and return 200 + # THEN the body must be a dict + result = app(LOAD_GW_EVENT, {}) + assert result["statusCode"] == 200 + assert result["body"] == {"name": "John", "age": 30} + + +def test_validate_embed_body_param(): + # GIVEN an APIGatewayRestResolver with validation enabled + app = APIGatewayRestResolver(enable_validation=True) + + class Model(BaseModel): + name: str + age: int + + # WHEN a handler is defined with a body parameter + @app.post("/") + def handler(user: Annotated[Model, Body(embed=True)]) -> Model: + return user + + LOAD_GW_EVENT["httpMethod"] = "POST" + LOAD_GW_EVENT["path"] = "/" + LOAD_GW_EVENT["body"] = json.dumps({"name": "John", "age": 30}) + + # THEN the handler should be invoked and return 422 + # THEN the body must be a dict + result = app(LOAD_GW_EVENT, {}) + assert result["statusCode"] == 422 + assert "missing" in result["body"] + + # THEN the handler should be invoked and return 200 + # THEN the body must be a dict + LOAD_GW_EVENT["body"] = json.dumps({"user": {"name": "John", "age": 30}}) + result = app(LOAD_GW_EVENT, {}) + assert result["statusCode"] == 200