diff --git a/google/generativeai/types/content_types.py b/google/generativeai/types/content_types.py index f3db610e1..a5eb1cd6d 100644 --- a/google/generativeai/types/content_types.py +++ b/google/generativeai/types/content_types.py @@ -21,8 +21,9 @@ import mimetypes import pathlib import typing -from typing import Any, Callable, Union -from typing_extensions import TypedDict +from typing import Any, Callable, Union, get_type_hints, get_origin, get_args +from typing_extensions import TypedDict, is_typeddict +import dataclasses import pydantic @@ -334,9 +335,42 @@ def to_contents(contents: ContentsType) -> list[protos.Content]: return contents -def _schema_for_class(cls: TypedDict) -> dict[str, Any]: +def _schema_for_class(cls: type) -> dict[str, Any]: schema = _build_schema("dummy", {"dummy": (cls, pydantic.Field())}) - return schema["properties"]["dummy"] + properties = schema["properties"]["dummy"] + + # Handling TypedDict + if is_typeddict(cls): + required_keys = [] + type_hints = get_type_hints(cls) + for key, type_hint in type_hints.items(): + if key in cls.__required_keys__: + # Check if the type is Optional, i.e., Union[..., NoneType] + if get_origin(type_hint) is Union and type(None) in get_args(type_hint): + continue + required_keys.append(key) + properties["required"] = required_keys + + # Handling dataclasses + elif dataclasses.is_dataclass(cls): + required_keys = [] + for field in dataclasses.fields(cls): + if field.default is dataclasses.MISSING and field.default_factory is dataclasses.MISSING: + required_keys.append(field.name) # Field is required if it has no default value + properties["required"] = required_keys + + # Handling Pydantic models + elif issubclass(cls, pydantic.BaseModel): + required_keys = [name for name, field in cls.__fields__.items() if field.is_required()] + properties["required"] = required_keys + + # Bug that it sets default values in case default exists + # TODO: Should be handled in the schema generation or not be allowed + + for key in properties["properties"]: + if 'default' in properties["properties"][key]: + properties["properties"][key].pop('default') + return properties def _schema_for_function(