From edbb14b4eb7f9db92eb5dcb95a3894935df3069e Mon Sep 17 00:00:00 2001 From: Leandro Damascena Date: Wed, 11 Sep 2024 14:30:53 +0100 Subject: [PATCH] Make sure Bedrock Agent works with Pydantic v2 --- .../event_handler/bedrock_agent.py | 107 ++++++++++++++++++ .../event_handler/test_bedrock_agent.py | 21 +++- 2 files changed, 127 insertions(+), 1 deletion(-) diff --git a/aws_lambda_powertools/event_handler/bedrock_agent.py b/aws_lambda_powertools/event_handler/bedrock_agent.py index faf551d1646..1c305cd4197 100644 --- a/aws_lambda_powertools/event_handler/bedrock_agent.py +++ b/aws_lambda_powertools/event_handler/bedrock_agent.py @@ -1,5 +1,6 @@ from __future__ import annotations +import json from typing import TYPE_CHECKING, Any, Callable from typing_extensions import override @@ -10,10 +11,12 @@ ProxyEventType, ResponseBuilder, ) +from aws_lambda_powertools.event_handler.openapi.constants import DEFAULT_API_VERSION, DEFAULT_OPENAPI_VERSION if TYPE_CHECKING: from re import Match + from aws_lambda_powertools.event_handler.openapi.models import Contact, License, SecurityScheme, Server, Tag from aws_lambda_powertools.event_handler.openapi.types import OpenAPIResponse from aws_lambda_powertools.utilities.data_classes import BedrockAgentEvent @@ -273,3 +276,107 @@ def _convert_matches_into_route_keys(self, match: Match) -> dict[str, str]: if match.groupdict() and self.current_event.parameters: parameters = {parameter["name"]: parameter["value"] for parameter in self.current_event.parameters} return parameters + + @override + def get_openapi_json_schema( + self, + *, + title: str = "Powertools API", + version: str = DEFAULT_API_VERSION, + openapi_version: str = DEFAULT_OPENAPI_VERSION, + summary: str | None = None, + description: str | None = None, + tags: list[Tag | str] | None = None, + servers: list[Server] | None = None, + terms_of_service: str | None = None, + contact: Contact | None = None, + license_info: License | None = None, + security_schemes: dict[str, SecurityScheme] | None = None, + security: list[dict[str, list[str]]] | None = None, + ) -> str: + """ + Returns the OpenAPI schema as a JSON serializable dict. + Since Bedrock Agents only support OpenAPI 3.0.0, we convert OpenAPI 3.1.0 schemas + and enforce 3.0.0 compatibility for seamless integration. + + Parameters + ---------- + title: str + The title of the application. + version: str + The version of the OpenAPI document (which is distinct from the OpenAPI Specification version or the API + openapi_version: str, default = "3.0.0" + The version of the OpenAPI Specification (which the document uses). + summary: str, optional + A short summary of what the application does. + description: str, optional + A verbose explanation of the application behavior. + tags: list[Tag, str], optional + A list of tags used by the specification with additional metadata. + servers: list[Server], optional + An array of Server Objects, which provide connectivity information to a target server. + terms_of_service: str, optional + A URL to the Terms of Service for the API. MUST be in the format of a URL. + contact: Contact, optional + The contact information for the exposed API. + license_info: License, optional + The license information for the exposed API. + security_schemes: dict[str, SecurityScheme]], optional + A declaration of the security schemes available to be used in the specification. + security: list[dict[str, list[str]]], optional + A declaration of which security mechanisms are applied globally across the API. + + Returns + ------- + str + The OpenAPI schema as a JSON serializable dict. + """ + from aws_lambda_powertools.event_handler.openapi.compat import model_json + + schema = super().get_openapi_schema( + title=title, + version=version, + openapi_version=openapi_version, + summary=summary, + description=description, + tags=tags, + servers=servers, + terms_of_service=terms_of_service, + contact=contact, + license_info=license_info, + security_schemes=security_schemes, + security=security, + ) + schema.openapi = "3.0.3" + + # Transform OpenAPI 3.1 into 3.0 + def inner(yaml_dict): + if isinstance(yaml_dict, dict): + if "anyOf" in yaml_dict and isinstance((anyOf := yaml_dict["anyOf"]), list): + for i, item in enumerate(anyOf): + if isinstance(item, dict) and item.get("type") == "null": + anyOf.pop(i) + yaml_dict["nullable"] = True + if "examples" in yaml_dict: + examples = yaml_dict["examples"] + del yaml_dict["examples"] + if isinstance(examples, list) and len(examples): + yaml_dict["example"] = examples[0] + for value in yaml_dict.values(): + inner(value) + elif isinstance(yaml_dict, list): + for item in yaml_dict: + inner(item) + + model = json.loads( + model_json( + schema, + by_alias=True, + exclude_none=True, + indent=2, + ), + ) + + inner(model) + + return json.dumps(model) diff --git a/tests/functional/event_handler/test_bedrock_agent.py b/tests/functional/event_handler/test_bedrock_agent.py index 7d2da8c0486..6dcc55c2da5 100644 --- a/tests/functional/event_handler/test_bedrock_agent.py +++ b/tests/functional/event_handler/test_bedrock_agent.py @@ -1,6 +1,7 @@ import json -from typing import Any, Dict +from typing import Any, Dict, Optional +import pytest from typing_extensions import Annotated from aws_lambda_powertools.event_handler import BedrockAgentResolver, Response, content_types @@ -181,3 +182,21 @@ def send_reminders( # THEN return the correct result body = result["response"]["responseBody"]["application/json"]["body"] assert json.loads(body) is True + + +@pytest.mark.usefixtures("pydanticv2_only") +def test_openapi_schema_for_pydanticv2(openapi30_schema): + # GIVEN BedrockAgentResolver is initialized with enable_validation=True + app = BedrockAgentResolver(enable_validation=True) + + # WHEN we have a simple handler + @app.get("/", description="Testing") + def handler() -> Optional[Dict]: + pass + + # WHEN we get the schema + schema = json.loads(app.get_openapi_json_schema()) + + # THEN the schema must be a valid 3.0.3 version + assert openapi30_schema(schema) + assert schema.get("openapi") == "3.0.3"