diff --git a/pydantic_ai_slim/pydantic_ai/__init__.py b/pydantic_ai_slim/pydantic_ai/__init__.py index 21ef4dec6..43d985dc4 100644 --- a/pydantic_ai_slim/pydantic_ai/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/__init__.py @@ -12,7 +12,7 @@ ) from .format_prompt import format_as_xml from .messages import AudioUrl, BinaryContent, DocumentUrl, ImageUrl, VideoUrl -from .result import ToolOutput +from .result import JsonSchemaOutput, PromptedJsonOutput, ToolOutput from .tools import RunContext, Tool __all__ = ( @@ -43,6 +43,8 @@ 'RunContext', # result 'ToolOutput', + 'JsonSchemaOutput', + 'PromptedJsonOutput', # format_prompt 'format_as_xml', ) diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index 1c83a2852..44602715e 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -7,7 +7,7 @@ from contextlib import asynccontextmanager, contextmanager from contextvars import ContextVar from dataclasses import field -from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union, cast +from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union from opentelemetry.trace import Tracer from typing_extensions import TypeGuard, TypeVar, assert_never @@ -90,7 +90,7 @@ class GraphAgentDeps(Generic[DepsT, OutputDataT]): end_strategy: EndStrategy get_instructions: Callable[[RunContext[DepsT]], Awaitable[str | None]] - output_schema: _output.OutputSchema[OutputDataT] | None + output_schema: _output.OutputSchema[OutputDataT] output_validators: list[_output.OutputValidator[DepsT, OutputDataT]] function_tools: dict[str, Tool[DepsT]] = dataclasses.field(repr=False) @@ -262,8 +262,10 @@ async def add_mcp_server_tools(server: MCPServer) -> None: output_schema = ctx.deps.output_schema return models.ModelRequestParameters( function_tools=function_tool_defs, - allow_text_output=_output.allow_text_output(output_schema), - output_tools=output_schema.tool_defs() if output_schema is not None else [], + output_mode=output_schema.mode, + output_object=output_schema.object_def if isinstance(output_schema, _output.JsonTextOutputSchema) else None, + output_tools=output_schema.tool_defs() if isinstance(output_schema, _output.ToolOutputSchema) else [], + allow_text_output=isinstance(output_schema, _output.TextOutputSchema), ) @@ -448,7 +450,7 @@ async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]: # when the model has already returned text along side tool calls # in this scenario, if text responses are allowed, we return text from the most recent model # response, if any - if _output.allow_text_output(ctx.deps.output_schema): + if isinstance(ctx.deps.output_schema, _output.TextOutputSchema): for message in reversed(ctx.state.message_history): if isinstance(message, _messages.ModelResponse): last_texts = [p.content for p in message.parts if isinstance(p, _messages.TextPart)] @@ -471,10 +473,11 @@ async def _handle_tool_calls( output_schema = ctx.deps.output_schema run_context = build_run_context(ctx) - # first, look for the output tool call final_result: result.FinalResult[NodeRunEndT] | None = None parts: list[_messages.ModelRequestPart] = [] - if output_schema is not None: + + # first, look for the output tool call + if isinstance(output_schema, _output.ToolOutputSchema): for call, output_tool in output_schema.find_tool(tool_calls): try: result_data = await output_tool.process(call, run_context) @@ -532,9 +535,9 @@ async def _handle_text_response( text = '\n\n'.join(texts) try: - if _output.allow_text_output(output_schema): - # The following cast is safe because we know `str` is an allowed result type - result_data = cast(NodeRunEndT, text) + if isinstance(output_schema, _output.TextOutputSchema): + run_context = build_run_context(ctx) + result_data = await output_schema.process(text, run_context) else: m = _messages.RetryPromptPart( content='Plain text responses are not permitted, please include your response in a tool call', @@ -633,7 +636,7 @@ async def process_function_tools( # noqa C901 yield event call_index_to_event_id[len(calls_to_run)] = event.call_id calls_to_run.append((mcp_tool, call)) - elif output_schema is not None and call.tool_name in output_schema.tools: + elif call.tool_name in output_schema.tools: # if tool_name is in output_schema, it means we found a output tool but an error occurred in # validation, we don't add another part here if output_tool_name is not None: @@ -762,7 +765,9 @@ def _unknown_tool( ) -> _messages.RetryPromptPart: ctx.state.increment_retries(ctx.deps.max_result_retries) tool_names = list(ctx.deps.function_tools.keys()) - if output_schema := ctx.deps.output_schema: + + output_schema = ctx.deps.output_schema + if isinstance(output_schema, _output.ToolOutputSchema): tool_names.extend(output_schema.tool_names()) if tool_names: @@ -839,7 +844,7 @@ def get_captured_run_messages() -> _RunMessages: def build_agent_graph( name: str | None, deps_type: type[DepsT], - output_type: _output.OutputType[OutputT], + output_type: _output.OutputSpec[OutputT], ) -> Graph[GraphAgentState, GraphAgentDeps[DepsT, result.FinalResult[OutputT]], result.FinalResult[OutputT]]: """Build the execution [Graph][pydantic_graph.Graph] for a given agent.""" nodes = ( diff --git a/pydantic_ai_slim/pydantic_ai/_output.py b/pydantic_ai_slim/pydantic_ai/_output.py index 916ddb7e8..37787bc9b 100644 --- a/pydantic_ai_slim/pydantic_ai/_output.py +++ b/pydantic_ai_slim/pydantic_ai/_output.py @@ -1,18 +1,20 @@ from __future__ import annotations as _annotations import inspect +import json +from abc import ABC, abstractmethod from collections.abc import Awaitable, Iterable, Iterator, Sequence from dataclasses import dataclass, field -from typing import Any, Callable, Generic, Literal, Union, cast +from typing import Any, Callable, Generic, Literal, Union, cast, overload from pydantic import TypeAdapter, ValidationError from pydantic_core import SchemaValidator -from typing_extensions import TypeAliasType, TypedDict, TypeVar, get_args, get_origin +from typing_extensions import TypeAliasType, TypedDict, TypeVar, assert_never, get_args, get_origin from typing_inspection import typing_objects from typing_inspection.introspection import is_union_origin from . import _function_schema, _utils, messages as _messages -from .exceptions import ModelRetry +from .exceptions import ModelRetry, UserError from .tools import AgentDepsT, GenerateToolJsonSchema, ObjectJsonSchema, RunContext, ToolDefinition T = TypeVar('T') @@ -113,7 +115,7 @@ def __init__(self, tool_retry: _messages.RetryPromptPart): class ToolOutput(Generic[OutputDataT]): """Marker class to use tools for outputs, and customize the tool.""" - output_type: SimpleOutputType[OutputDataT] + output: OutputTypeOrFunction[OutputDataT] name: str | None description: str | None max_retries: int | None @@ -121,120 +123,479 @@ class ToolOutput(Generic[OutputDataT]): def __init__( self, - type_: SimpleOutputType[OutputDataT], + type_: OutputTypeOrFunction[OutputDataT], *, name: str | None = None, description: str | None = None, max_retries: int | None = None, strict: bool | None = None, ): - self.output_type = type_ + self.output = type_ self.name = name self.description = description self.max_retries = max_retries self.strict = strict +@dataclass +class TextOutput(Generic[OutputDataT]): + """Marker class to use text output with an output function.""" + + output_function: TextOutputFunction[OutputDataT] + + +@dataclass(init=False) +class JsonSchemaOutput(Generic[OutputDataT]): + """Marker class to use JSON schema output for outputs.""" + + outputs: Sequence[OutputTypeOrFunction[OutputDataT]] + name: str | None + description: str | None + strict: bool | None + + def __init__( + self, + type_: OutputTypeOrFunction[OutputDataT] | Sequence[OutputTypeOrFunction[OutputDataT]], + *, + name: str | None = None, + description: str | None = None, + strict: bool | None = True, + ): + self.outputs = flatten_output_spec(type_) + self.name = name + self.description = description + self.strict = strict + + +class PromptedJsonOutput(Generic[OutputDataT]): + """Marker class to use prompted JSON mode for outputs.""" + + outputs: Sequence[OutputTypeOrFunction[OutputDataT]] + name: str | None + description: str | None + + def __init__( + self, + type_: OutputTypeOrFunction[OutputDataT] | Sequence[OutputTypeOrFunction[OutputDataT]], + *, + name: str | None = None, + description: str | None = None, + ): + self.outputs = flatten_output_spec(type_) + self.name = name + self.description = description + + T_co = TypeVar('T_co', covariant=True) -# output_type=Type or output_type=function or output_type=object.method -SimpleOutputType = TypeAliasType( - 'SimpleOutputType', Union[type[T_co], Callable[..., Union[Awaitable[T_co], T_co]]], type_params=(T_co,) + +OutputTypeOrFunction = TypeAliasType( + 'OutputTypeOrFunction', Union[type[T_co], Callable[..., Union[Awaitable[T_co], T_co]]], type_params=(T_co,) ) -# output_type=ToolOutput() or -SimpleOutputTypeOrMarker = TypeAliasType( - 'SimpleOutputTypeOrMarker', Union[SimpleOutputType[T_co], ToolOutput[T_co]], type_params=(T_co,) +OutputSpec = TypeAliasType( + 'OutputSpec', + Union[ + OutputTypeOrFunction[T_co], + ToolOutput[T_co], + TextOutput[T_co], + Sequence[Union[OutputTypeOrFunction[T_co], ToolOutput[T_co], TextOutput[T_co]]], + JsonSchemaOutput[T_co], + PromptedJsonOutput[T_co], + ], + type_params=(T_co,), ) -# output_type= or [, ...] -OutputType = TypeAliasType( - 'OutputType', Union[SimpleOutputTypeOrMarker[T_co], Sequence[SimpleOutputTypeOrMarker[T_co]]], type_params=(T_co,) + +TextOutputFunction = TypeAliasType( + 'TextOutputFunction', + Union[ + Callable[[RunContext, str], Union[Awaitable[T_co], T_co]], + Callable[[str], Union[Awaitable[T_co], T_co]], + ], + type_params=(T_co,), ) -@dataclass -class OutputSchema(Generic[OutputDataT]): - """Model the final output from an agent run. +OutputMode = Literal['text', 'tool', 'json_schema', 'prompted_json', 'tool_or_text'] +"""All output modes.""" +SupportableOutputMode = Literal['tool', 'json_schema'] +"""Output modes that require specific support by a model (class). Used by ModelProfile.output_modes""" +StructuredOutputMode = Literal['tool', 'json_schema', 'prompted_json'] +"""Output modes that can be used for any structured output. Used by ModelProfile.default_output_mode""" + + +class BaseOutputSchema(ABC, Generic[OutputDataT]): + @property + @abstractmethod + def mode(self) -> OutputMode | None: + raise NotImplementedError() + + @abstractmethod + def with_default_mode(self, mode: StructuredOutputMode) -> OutputSchema[OutputDataT]: + raise NotImplementedError() + + @abstractmethod + def is_supported(self, supported_modes: set[SupportableOutputMode]) -> bool: + """Whether the mode is supported by the model.""" + raise NotImplementedError() + + @property + def tools(self) -> dict[str, OutputTool[OutputDataT]]: + """Get the tools for this output schema.""" + return {} + - Similar to `Tool` but for the final output of running an agent. - """ +@dataclass(init=False) +class OutputSchema(BaseOutputSchema[OutputDataT], ABC): + """Model the final output from an agent run.""" - tools: dict[str, OutputTool[OutputDataT]] - allow_text_output: bool + @classmethod + @overload + def build( + cls, + output_spec: OutputSpec[OutputDataT], + *, + default_mode: StructuredOutputMode, + name: str | None = None, + description: str | None = None, + strict: bool | None = None, + ) -> OutputSchema[OutputDataT]: ... @classmethod + @overload def build( - cls: type[OutputSchema[OutputDataT]], - output_type: OutputType[OutputDataT], + cls, + output_spec: OutputSpec[OutputDataT], + *, + default_mode: None = None, name: str | None = None, description: str | None = None, strict: bool | None = None, - ) -> OutputSchema[OutputDataT] | None: + ) -> OutputSchemaWithoutMode[OutputDataT]: ... + + @classmethod + def build( + cls, + output_spec: OutputSpec[OutputDataT], + *, + default_mode: StructuredOutputMode | None = None, + name: str | None = None, + description: str | None = None, + strict: bool | None = None, + ) -> BaseOutputSchema[OutputDataT]: """Build an OutputSchema dataclass from an output type.""" - if output_type is str: - return None + if output_spec is str: + return PlainTextOutputSchema() + + if isinstance(output_spec, JsonSchemaOutput): + return JsonSchemaOutputSchema( + cls._build_processor( + output_spec.outputs, + name=output_spec.name, + description=output_spec.description, + strict=output_spec.strict, + ), + ) - output_types: Sequence[SimpleOutputTypeOrMarker[OutputDataT]] - if isinstance(output_type, Sequence): - output_types = output_type - else: - output_types = (output_type,) + if isinstance(output_spec, PromptedJsonOutput): + return PromptedJsonOutputSchema( + cls._build_processor(output_spec.outputs, name=output_spec.name, description=output_spec.description), + ) - output_types_flat: list[SimpleOutputTypeOrMarker[OutputDataT]] = [] - for output_type in output_types: - if union_types := get_union_args(output_type): - output_types_flat.extend(union_types) + text_outputs: Sequence[type[str] | TextOutput[OutputDataT]] = [] + tool_outputs: Sequence[ToolOutput[OutputDataT]] = [] + other_outputs: Sequence[OutputTypeOrFunction[OutputDataT]] = [] + for output in flatten_output_spec(output_spec): + if output is str: + text_outputs.append(cast(type[str], output)) + elif isinstance(output, TextOutput): + text_outputs.append(output) + elif isinstance(output, ToolOutput): + tool_outputs.append(output) else: - output_types_flat.append(output_type) + other_outputs.append(output) - allow_text_output = False - if str in output_types_flat: - allow_text_output = True - output_types_flat = [t for t in output_types_flat if t is not str] + tools = cls._build_tools(tool_outputs + other_outputs, name=name, description=description, strict=strict) - multiple = len(output_types_flat) > 1 + if len(text_outputs) > 0: + if len(text_outputs) > 1: + raise UserError('Only one text output is allowed.') + text_output = text_outputs[0] - default_tool_name = name or DEFAULT_OUTPUT_TOOL_NAME - default_tool_description = description - default_tool_strict = strict + text_output_schema = None + if isinstance(text_output, TextOutput): + text_output_schema = PlainTextOutputProcessor(text_output.output_function) + if len(tools) == 0: + return PlainTextOutputSchema(text_output_schema) + else: + return ToolOrTextOutputSchema(processor=text_output_schema, tools=tools) + + if len(tool_outputs) > 0: + return ToolOutputSchema(tools) + + if len(other_outputs) > 0: + schema = OutputSchemaWithoutMode( + processor=cls._build_processor(other_outputs, name=name, description=description, strict=strict), + tools=tools, + ) + if default_mode: + schema = schema.with_default_mode(default_mode) + return schema + + raise UserError('No output type provided.') # pragma: no cover + + @staticmethod + def _build_tools( + outputs: list[OutputTypeOrFunction[OutputDataT] | ToolOutput[OutputDataT]], + name: str | None = None, + description: str | None = None, + strict: bool | None = None, + ) -> dict[str, OutputTool[OutputDataT]]: tools: dict[str, OutputTool[OutputDataT]] = {} - for output_type in output_types_flat: - tool_name = None - tool_description = None - tool_strict = None - if isinstance(output_type, ToolOutput): - tool_output_type = output_type.output_type + + default_name = name or DEFAULT_OUTPUT_TOOL_NAME + default_description = description + default_strict = strict + + multiple = len(outputs) > 1 + for output in outputs: + name = None + description = None + strict = None + if isinstance(output, ToolOutput): # do we need to error on conflicts here? (DavidM): If this is internal maybe doesn't matter, if public, use overloads - tool_name = output_type.name - tool_description = output_type.description - tool_strict = output_type.strict - else: - tool_output_type = output_type + name = output.name + description = output.description + strict = output.strict + + output = output.output - if tool_name is None: - tool_name = default_tool_name + if name is None: + name = default_name if multiple: - tool_name += f'_{tool_output_type.__name__}' + name += f'_{output.__name__}' i = 1 - original_tool_name = tool_name - while tool_name in tools: + original_name = name + while name in tools: i += 1 - tool_name = f'{original_tool_name}_{i}' + name = f'{original_name}_{i}' + + description = description or default_description + if strict is None: + strict = default_strict + + processor = ObjectOutputProcessor(output=output, description=description, strict=strict) + tools[name] = OutputTool(name=name, processor=processor, multiple=multiple) + + return tools + + @staticmethod + def _build_processor( + outputs: Sequence[OutputTypeOrFunction[OutputDataT]], + name: str | None = None, + description: str | None = None, + strict: bool | None = None, + ) -> ObjectOutputProcessor[OutputDataT] | UnionOutputProcessor[OutputDataT]: + outputs = flatten_output_spec(outputs) + if len(outputs) == 1: + return ObjectOutputProcessor(output=outputs[0], name=name, description=description, strict=strict) + + return UnionOutputProcessor(outputs=outputs, strict=strict, name=name, description=description) + + @property + @abstractmethod + def mode(self) -> OutputMode: + raise NotImplementedError() - tool_description = tool_description or default_tool_description - if tool_strict is None: - tool_strict = default_tool_strict + def with_default_mode(self, mode: StructuredOutputMode) -> OutputSchema[OutputDataT]: + return self - parameters_schema = OutputObjectSchema( - output_type=tool_output_type, description=tool_description, strict=tool_strict + +@dataclass(init=False) +class OutputSchemaWithoutMode(BaseOutputSchema[OutputDataT]): + processor: ObjectOutputProcessor[OutputDataT] | UnionOutputProcessor[OutputDataT] + _tools: dict[str, OutputTool[OutputDataT]] = field(default_factory=dict) + + def __init__( + self, + processor: ObjectOutputProcessor[OutputDataT] | UnionOutputProcessor[OutputDataT], + tools: dict[str, OutputTool[OutputDataT]], + ): + self.processor = processor + self._tools = tools + + @property + def mode(self) -> None: + return None # pragma: no cover + + def with_default_mode(self, mode: StructuredOutputMode) -> OutputSchema[OutputDataT]: + if mode == 'json_schema': + return JsonSchemaOutputSchema( + self.processor, + ) + elif mode == 'prompted_json': + return PromptedJsonOutputSchema( + self.processor, ) - tools[tool_name] = OutputTool(name=tool_name, parameters_schema=parameters_schema, multiple=multiple) + elif mode == 'tool': + return ToolOutputSchema(tools=self.tools) + else: + assert_never(mode) + + def is_supported(self, supported_modes: set[SupportableOutputMode]) -> bool: + """Whether the mode is supported by the model.""" + return False # pragma: no cover + + @property + def tools(self) -> dict[str, OutputTool[OutputDataT]]: + """Get the tools for this output schema.""" + # We return tools here as they're checked in Agent._register_tool. + # At that point we may don't know yet what output mode we're going to use if no model was provided or it was deferred until agent.run time. + return self._tools + + +class TextOutputSchema(OutputSchema[OutputDataT], ABC): + @abstractmethod + async def process( + self, + text: str, + run_context: RunContext[AgentDepsT], + allow_partial: bool = False, + wrap_validation_errors: bool = True, + ) -> OutputDataT: + raise NotImplementedError() + + +@dataclass +class PlainTextOutputSchema(TextOutputSchema[OutputDataT]): + processor: PlainTextOutputProcessor[OutputDataT] | None = None + + @property + def mode(self) -> OutputMode: + return 'text' + + def is_supported(self, supported_modes: set[SupportableOutputMode]) -> bool: + """Whether the mode is supported by the model.""" + return True + + async def process( + self, + text: str, + run_context: RunContext[AgentDepsT], + allow_partial: bool = False, + wrap_validation_errors: bool = True, + ) -> OutputDataT: + """Validate an output message. + + Args: + text: The output text to validate. + run_context: The current run context. + allow_partial: If true, allow partial validation. + wrap_validation_errors: If true, wrap the validation errors in a retry message. + + Returns: + Either the validated output data (left) or a retry message (right). + """ + if self.processor is None: + return cast(OutputDataT, text) + + return await self.processor.process( + text, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors + ) + + +@dataclass +class JsonTextOutputSchema(TextOutputSchema[OutputDataT], ABC): + processor: ObjectOutputProcessor[OutputDataT] | UnionOutputProcessor[OutputDataT] + + @property + def object_def(self) -> OutputObjectDefinition: + return self.processor.object_def + + async def process( + self, + text: str, + run_context: RunContext[AgentDepsT], + allow_partial: bool = False, + wrap_validation_errors: bool = True, + ) -> OutputDataT: + """Validate an output message. + + Args: + text: The output text to validate. + run_context: The current run context. + allow_partial: If true, allow partial validation. + wrap_validation_errors: If true, wrap the validation errors in a retry message. - return cls( - tools=tools, - allow_text_output=allow_text_output, + Returns: + Either the validated output data (left) or a retry message (right). + """ + text = _utils.strip_markdown_fences(text) + + return await self.processor.process( + text, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors ) + +class JsonSchemaOutputSchema(JsonTextOutputSchema[OutputDataT]): + @property + def mode(self) -> OutputMode: + return 'json_schema' + + def is_supported(self, supported_modes: set[SupportableOutputMode]) -> bool: + """Whether the mode is supported by the model.""" + return 'json_schema' in supported_modes + + +class PromptedJsonOutputSchema(JsonTextOutputSchema[OutputDataT]): + @property + def mode(self) -> OutputMode: + return 'prompted_json' + + def is_supported(self, supported_modes: set[SupportableOutputMode]) -> bool: + """Whether the mode is supported by the model.""" + return True + + def instructions(self, template: str) -> str: + """Get instructions for model to output manual JSON matching the schema.""" + object_def = self.object_def + schema = object_def.json_schema.copy() + if object_def.name: + schema['title'] = object_def.name + if object_def.description: + schema['description'] = object_def.description + + return template.format(schema=json.dumps(schema)) + + +@dataclass(init=False) +class ToolOutputSchema(OutputSchema[OutputDataT]): + _tools: dict[str, OutputTool[OutputDataT]] = field(default_factory=dict) + + def __init__(self, tools: dict[str, OutputTool[OutputDataT]]): + self._tools = tools + + @property + def mode(self) -> OutputMode: + return 'tool' + + def is_supported(self, supported_modes: set[SupportableOutputMode]) -> bool: + """Whether the mode is supported by the model.""" + return 'tool' in supported_modes + + @property + def tools(self) -> dict[str, OutputTool[OutputDataT]]: + """Get the tools for this output schema.""" + return self._tools + + def tool_names(self) -> list[str]: + """Return the names of the tools.""" + return list(self.tools.keys()) + + def tool_defs(self) -> list[ToolDefinition]: + """Get tool definitions to register with the model.""" + return [t.tool_def for t in self.tools.values()] + def find_named_tool( self, parts: Iterable[_messages.ModelResponsePart], tool_name: str ) -> tuple[_messages.ToolCallPart, OutputTool[OutputDataT]] | None: @@ -254,61 +615,82 @@ def find_tool( if result := self.tools.get(part.tool_name): yield part, result - def tool_names(self) -> list[str]: - """Return the names of the tools.""" - return list(self.tools.keys()) - def tool_defs(self) -> list[ToolDefinition]: - """Get tool definitions to register with the model.""" - return [t.tool_def for t in self.tools.values()] +@dataclass(init=False) +class ToolOrTextOutputSchema(PlainTextOutputSchema[OutputDataT], ToolOutputSchema[OutputDataT]): + def __init__( + self, + processor: PlainTextOutputProcessor[OutputDataT] | None, + tools: dict[str, OutputTool[OutputDataT]], + ): + self.processor = processor + self._tools = tools + @property + def mode(self) -> OutputMode: + return 'tool_or_text' -def allow_text_output(output_schema: OutputSchema[Any] | None) -> bool: - return output_schema is None or output_schema.allow_text_output + def is_supported(self, supported_modes: set[SupportableOutputMode]) -> bool: + """Whether the mode is supported by the model.""" + return 'tool' in supported_modes @dataclass class OutputObjectDefinition: - name: str json_schema: ObjectJsonSchema + name: str | None = None description: str | None = None strict: bool | None = None @dataclass(init=False) -class OutputObjectSchema(Generic[OutputDataT]): - definition: OutputObjectDefinition - validator: SchemaValidator - function_schema: _function_schema.FunctionSchema | None = None +class BaseOutputProcessor(ABC, Generic[OutputDataT]): + @abstractmethod + async def process( + self, + data: str, + run_context: RunContext[AgentDepsT], + allow_partial: bool = False, + wrap_validation_errors: bool = True, + ) -> OutputDataT: + """Process an output message, performing validation and (if necessary) calling the output function.""" + raise NotImplementedError() + + +@dataclass(init=False) +class ObjectOutputProcessor(BaseOutputProcessor[OutputDataT]): + object_def: OutputObjectDefinition outer_typed_dict_key: str | None = None + _validator: SchemaValidator + _function_schema: _function_schema.FunctionSchema | None = None def __init__( self, + output: OutputTypeOrFunction[OutputDataT], *, - output_type: SimpleOutputType[OutputDataT], name: str | None = None, description: str | None = None, strict: bool | None = None, ): - if inspect.isfunction(output_type) or inspect.ismethod(output_type): - self.function_schema = _function_schema.function_schema(output_type, GenerateToolJsonSchema) - self.validator = self.function_schema.validator - json_schema = self.function_schema.json_schema - json_schema['description'] = self.function_schema.description + if inspect.isfunction(output) or inspect.ismethod(output): + self._function_schema = _function_schema.function_schema(output, GenerateToolJsonSchema) + self._validator = self._function_schema.validator + json_schema = self._function_schema.json_schema + json_schema['description'] = self._function_schema.description else: type_adapter: TypeAdapter[Any] - if _utils.is_model_like(output_type): - type_adapter = TypeAdapter(output_type) + if _utils.is_model_like(output): + type_adapter = TypeAdapter(output) else: self.outer_typed_dict_key = 'response' response_data_typed_dict = TypedDict( # noqa: UP013 'response_data_typed_dict', - {'response': cast(type[OutputDataT], output_type)}, # pyright: ignore[reportInvalidTypeForm] + {'response': cast(type[OutputDataT], output)}, # pyright: ignore[reportInvalidTypeForm] ) type_adapter = TypeAdapter(response_data_typed_dict) # Really a PluggableSchemaValidator, but it's API-compatible - self.validator = cast(SchemaValidator, type_adapter.validator) + self._validator = cast(SchemaValidator, type_adapter.validator) json_schema = _utils.check_object_json_schema( type_adapter.json_schema(schema_generator=GenerateToolJsonSchema) ) @@ -323,8 +705,8 @@ def __init__( else: description = f'{description}. {json_schema_description}' - self.definition = OutputObjectDefinition( - name=name or getattr(output_type, '__name__', DEFAULT_OUTPUT_TOOL_NAME), + self.object_def = OutputObjectDefinition( + name=name or getattr(output, '__name__', None), description=description, json_schema=json_schema, strict=strict, @@ -335,6 +717,7 @@ async def process( data: str | dict[str, Any] | None, run_context: RunContext[AgentDepsT], allow_partial: bool = False, + wrap_validation_errors: bool = True, ) -> OutputDataT: """Process an output message, performing validation and (if necessary) calling the output function. @@ -342,45 +725,235 @@ async def process( data: The output data to validate. run_context: The current run context. allow_partial: If true, allow partial validation. + wrap_validation_errors: If true, wrap the validation errors in a retry message. Returns: Either the validated output data (left) or a retry message (right). """ - pyd_allow_partial: Literal['off', 'trailing-strings'] = 'trailing-strings' if allow_partial else 'off' - if isinstance(data, str): - output = self.validator.validate_json(data or '{}', allow_partial=pyd_allow_partial) - else: - output = self.validator.validate_python(data or {}, allow_partial=pyd_allow_partial) - - if self.function_schema: - output = await self.function_schema.call(output, run_context) + try: + pyd_allow_partial: Literal['off', 'trailing-strings'] = 'trailing-strings' if allow_partial else 'off' + if isinstance(data, str): + output = self._validator.validate_json(data or '{}', allow_partial=pyd_allow_partial) + else: + output = self._validator.validate_python(data or {}, allow_partial=pyd_allow_partial) + except ValidationError as e: + if wrap_validation_errors: + m = _messages.RetryPromptPart( + content=e.errors(include_url=False), + ) + raise ToolRetryError(m) from e + else: + raise # pragma: lax no cover if k := self.outer_typed_dict_key: output = output[k] + + if self._function_schema: + try: + output = await self._function_schema.call(output, run_context) + except ModelRetry as r: + if wrap_validation_errors: + m = _messages.RetryPromptPart( + content=r.message, + ) + raise ToolRetryError(m) from r + else: + raise # pragma: lax no cover + return output +@dataclass +class UnionOutputResult: + kind: str + data: ObjectJsonSchema + + +@dataclass +class UnionOutputModel: + result: UnionOutputResult + + +@dataclass(init=False) +class UnionOutputProcessor(BaseOutputProcessor[OutputDataT]): + object_def: OutputObjectDefinition + _union_processor: ObjectOutputProcessor[UnionOutputModel] + _processors: dict[str, ObjectOutputProcessor[OutputDataT]] + + def __init__( + self, + outputs: Sequence[OutputTypeOrFunction[OutputDataT]], + *, + name: str | None = None, + description: str | None = None, + strict: bool | None = None, + ): + self._union_processor = ObjectOutputProcessor(output=UnionOutputModel) + + json_schemas: list[ObjectJsonSchema] = [] + self._processors = {} + for output in outputs: + processor = ObjectOutputProcessor(output=output, strict=strict) + object_def = processor.object_def + + object_key = object_def.name or output.__name__ + i = 1 + original_key = object_key + while object_key in self._processors: + i += 1 + object_key = f'{original_key}_{i}' + + self._processors[object_key] = processor + + json_schema = object_def.json_schema + if object_def.name: # pragma: no branch + json_schema['title'] = object_def.name + if object_def.description: + json_schema['description'] = object_def.description + + json_schemas.append(json_schema) + + json_schemas, all_defs = _utils.merge_json_schema_defs(json_schemas) + + discriminated_json_schemas: list[ObjectJsonSchema] = [] + for object_key, json_schema in zip(self._processors.keys(), json_schemas): + title = json_schema.pop('title', None) + description = json_schema.pop('description', None) + + discriminated_json_schema = { + 'type': 'object', + 'properties': { + 'kind': { + 'type': 'string', + 'const': object_key, + }, + 'data': json_schema, + }, + 'required': ['kind', 'data'], + 'additionalProperties': False, + } + if title: # pragma: no branch + discriminated_json_schema['title'] = title + if description: + discriminated_json_schema['description'] = description + + discriminated_json_schemas.append(discriminated_json_schema) + + json_schema = { + 'type': 'object', + 'properties': { + 'result': { + 'anyOf': discriminated_json_schemas, + } + }, + 'required': ['result'], + 'additionalProperties': False, + } + if all_defs: + json_schema['$defs'] = all_defs + + self.object_def = OutputObjectDefinition( + json_schema=json_schema, + strict=strict, + name=name, + description=description, + ) + + async def process( + self, + data: str | dict[str, Any] | None, + run_context: RunContext[AgentDepsT], + allow_partial: bool = False, + wrap_validation_errors: bool = True, + ) -> OutputDataT: + union_object = await self._union_processor.process( + data, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors + ) + + result = union_object.result + kind = result.kind + data = result.data + try: + processor = self._processors[kind] + except KeyError as e: # pragma: no cover + if wrap_validation_errors: + m = _messages.RetryPromptPart(content=f'Invalid kind: {kind}') + raise ToolRetryError(m) from e + else: + raise + + return await processor.process( + data, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors + ) + + +@dataclass(init=False) +class PlainTextOutputProcessor(BaseOutputProcessor[OutputDataT]): + _function_schema: _function_schema.FunctionSchema + _str_argument_name: str + + def __init__( + self, + output_function: TextOutputFunction[OutputDataT], + ): + self._function_schema = _function_schema.function_schema(output_function, GenerateToolJsonSchema) + + arguments_schema = self._function_schema.json_schema.get('properties', {}) + argument_name = next(iter(arguments_schema.keys()), None) + if argument_name and arguments_schema.get(argument_name, {}).get('type') == 'string': + self._str_argument_name = argument_name + return + + raise UserError('TextOutput must take a function taking a `str`') + + @property + def object_def(self) -> None: + return None # pragma: no cover + + async def process( + self, + data: str, + run_context: RunContext[AgentDepsT], + allow_partial: bool = False, + wrap_validation_errors: bool = True, + ) -> OutputDataT: + args = {self._str_argument_name: data} + + try: + output = await self._function_schema.call(args, run_context) + except ModelRetry as r: + if wrap_validation_errors: + m = _messages.RetryPromptPart( + content=r.message, + ) + raise ToolRetryError(m) from r + else: + raise # pragma: lax no cover + + return cast(OutputDataT, output) + + @dataclass(init=False) class OutputTool(Generic[OutputDataT]): - parameters_schema: OutputObjectSchema[OutputDataT] + processor: ObjectOutputProcessor[OutputDataT] tool_def: ToolDefinition - def __init__(self, *, name: str, parameters_schema: OutputObjectSchema[OutputDataT], multiple: bool): - self.parameters_schema = parameters_schema - definition = parameters_schema.definition + def __init__(self, *, name: str, processor: ObjectOutputProcessor[OutputDataT], multiple: bool): + self.processor = processor + object_def = processor.object_def - description = definition.description + description = object_def.description if not description: description = DEFAULT_OUTPUT_TOOL_DESCRIPTION if multiple: - description = f'{definition.name}: {description}' + description = f'{object_def.name}: {description}' self.tool_def = ToolDefinition( name=name, description=description, - parameters_json_schema=definition.json_schema, - strict=definition.strict, - outer_typed_dict_key=parameters_schema.outer_typed_dict_key, + parameters_json_schema=object_def.json_schema, + strict=object_def.strict, + outer_typed_dict_key=processor.outer_typed_dict_key, ) async def process( @@ -402,7 +975,9 @@ async def process( Either the validated output data (left) or a retry message (right). """ try: - output = await self.parameters_schema.process(tool_call.args, run_context, allow_partial=allow_partial) + output = await self.processor.process( + tool_call.args, run_context, allow_partial=allow_partial, wrap_validation_errors=False + ) except ValidationError as e: if wrap_validation_errors: m = _messages.RetryPromptPart( @@ -437,3 +1012,19 @@ def get_union_args(tp: Any) -> tuple[Any, ...]: return get_args(tp) else: return () + + +def flatten_output_spec(output_spec: T | Sequence[T]) -> list[T]: + outputs: Sequence[T] + if isinstance(output_spec, Sequence): + outputs = output_spec + else: + outputs = (output_spec,) + + outputs_flat: list[T] = [] + for output in outputs: + if union_types := get_union_args(output): + outputs_flat.extend(union_types) + else: + outputs_flat.append(output) + return outputs_flat diff --git a/pydantic_ai_slim/pydantic_ai/_utils.py b/pydantic_ai_slim/pydantic_ai/_utils.py index 77c34fbac..d4e770601 100644 --- a/pydantic_ai_slim/pydantic_ai/_utils.py +++ b/pydantic_ai_slim/pydantic_ai/_utils.py @@ -1,6 +1,7 @@ from __future__ import annotations as _annotations import asyncio +import re import time import uuid from collections.abc import AsyncIterable, AsyncIterator, Iterator @@ -9,7 +10,7 @@ from datetime import datetime, timezone from functools import partial from types import GenericAlias -from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union +from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union, cast from anyio.to_thread import run_sync from pydantic import BaseModel, TypeAdapter @@ -302,3 +303,94 @@ def dataclasses_no_defaults_repr(self: Any) -> str: def number_to_datetime(x: int | float) -> datetime: return TypeAdapter(datetime).validate_python(x) + + +def _update_mapped_json_schema_refs(s: dict[str, Any], name_mapping: dict[str, str]) -> None: + """Update $refs in a schema to use the new names from name_mapping.""" + if '$ref' in s: + ref = s['$ref'] + if ref.startswith('#/$defs/'): # pragma: no branch + original_name = ref[8:] # Remove '#/$defs/' + new_name = name_mapping.get(original_name, original_name) + s['$ref'] = f'#/$defs/{new_name}' + + # Recursively update refs in properties + if 'properties' in s: + props: dict[str, Any] = s['properties'] + for prop in props.values(): + if isinstance(prop, dict): + prop = cast(dict[str, Any], prop) + _update_mapped_json_schema_refs(prop, name_mapping) + + # Handle arrays + if 'items' in s and isinstance(s['items'], dict): + items: dict[str, Any] = s['items'] + _update_mapped_json_schema_refs(items, name_mapping) + if 'prefixItems' in s: + prefix_items: list[dict[str, Any]] = s['prefixItems'] + for item in prefix_items: + if isinstance(item, dict): + _update_mapped_json_schema_refs(item, name_mapping) + + # Handle unions + for union_type in ['anyOf', 'oneOf']: + if union_type in s: + union_items: list[dict[str, Any]] = s[union_type] + for item in union_items: + if isinstance(item, dict): + _update_mapped_json_schema_refs(item, name_mapping) + + +def merge_json_schema_defs(schemas: list[dict[str, Any]]) -> tuple[list[dict[str, Any]], dict[str, dict[str, Any]]]: + """Merges the `$defs` from different JSON schemas into a single deduplicated `$defs`, handling name collisions of `$defs` that are not the same, and rewrites `$ref`s to point to the new `$defs`. + + Returns a tuple of the rewritten schemas and a dictionary of the new `$defs`. + """ + all_defs: dict[str, dict[str, Any]] = {} + rewritten_schemas: list[dict[str, Any]] = [] + + for schema in schemas: + if '$defs' not in schema: + rewritten_schemas.append(schema) + continue + + schema = schema.copy() + defs = schema.pop('$defs', None) + schema_name_mapping: dict[str, str] = {} + + # Process definitions and build mapping + for name, def_schema in defs.items(): + if name not in all_defs: + all_defs[name] = def_schema + schema_name_mapping[name] = name + elif def_schema != all_defs[name]: + new_name = name + if title := schema.get('title'): + new_name = f'{title}_{name}' + + i = 1 + original_new_name = new_name + new_name = f'{new_name}_{i}' + while new_name in all_defs: + i += 1 + new_name = f'{original_new_name}_{i}' + + all_defs[new_name] = def_schema + schema_name_mapping[name] = new_name + + _update_mapped_json_schema_refs(schema, schema_name_mapping) + rewritten_schemas.append(schema) + + return rewritten_schemas, all_defs + + +def strip_markdown_fences(text: str) -> str: + if text.startswith('{'): + return text + + regex = r'```(?:\w+)?\n(\{.*\})\n```' + match = re.search(regex, text, re.DOTALL) + if match: + return match.group(1) + + return text diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index e8a636d32..b48650adf 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -14,6 +14,7 @@ from pydantic.json_schema import GenerateJsonSchema from typing_extensions import Literal, Never, Self, TypeIs, TypeVar, deprecated +from pydantic_ai.profiles import ModelProfile from pydantic_graph import End, Graph, GraphRun, GraphRunContext from pydantic_graph._utils import get_event_loop @@ -127,7 +128,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]): be merged with this value, with the runtime argument taking priority. """ - output_type: _output.OutputType[OutputDataT] + output_type: _output.OutputSpec[OutputDataT] """ The type of data output by agent runs, used to validate the data returned by the model, defaults to `str`. """ @@ -140,7 +141,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]): _deps_type: type[AgentDepsT] = dataclasses.field(repr=False) _deprecated_result_tool_name: str | None = dataclasses.field(repr=False) _deprecated_result_tool_description: str | None = dataclasses.field(repr=False) - _output_schema: _output.OutputSchema[OutputDataT] | None = dataclasses.field(repr=False) + _output_schema: _output.BaseOutputSchema[OutputDataT] = dataclasses.field(repr=False) _output_validators: list[_output.OutputValidator[AgentDepsT, OutputDataT]] = dataclasses.field(repr=False) _instructions: str | None = dataclasses.field(repr=False) _instructions_functions: list[_system_prompt.SystemPromptRunner[AgentDepsT]] = dataclasses.field(repr=False) @@ -162,7 +163,7 @@ def __init__( self, model: models.Model | models.KnownModelName | str | None = None, *, - output_type: _output.OutputType[OutputDataT] = str, + output_type: _output.OutputSpec[OutputDataT] = str, instructions: str | _system_prompt.SystemPromptFunc[AgentDepsT] | Sequence[str | _system_prompt.SystemPromptFunc[AgentDepsT]] @@ -317,8 +318,12 @@ def __init__( warnings.warn('`result_retries` is deprecated, use `max_result_retries` instead', DeprecationWarning) output_retries = result_retries + default_output_mode = self.model.profile.default_output_mode if isinstance(self.model, models.Model) else None self._output_schema = _output.OutputSchema[OutputDataT].build( - output_type, self._deprecated_result_tool_name, self._deprecated_result_tool_description + output_type, + default_mode=default_output_mode, + name=self._deprecated_result_tool_name, + description=self._deprecated_result_tool_description, ) self._output_validators = [] @@ -374,7 +379,7 @@ async def run( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, - output_type: _output.OutputType[RunOutputDataT], + output_type: _output.OutputSpec[RunOutputDataT], message_history: list[_messages.ModelMessage] | None = None, model: models.Model | models.KnownModelName | str | None = None, deps: AgentDepsT = None, @@ -404,7 +409,7 @@ async def run( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, - output_type: _output.OutputType[RunOutputDataT] | None = None, + output_type: _output.OutputSpec[RunOutputDataT] | None = None, message_history: list[_messages.ModelMessage] | None = None, model: models.Model | models.KnownModelName | str | None = None, deps: AgentDepsT = None, @@ -492,7 +497,7 @@ def iter( self, user_prompt: str | Sequence[_messages.UserContent] | None, *, - output_type: _output.OutputType[RunOutputDataT], + output_type: _output.OutputSpec[RunOutputDataT], message_history: list[_messages.ModelMessage] | None = None, model: models.Model | models.KnownModelName | str | None = None, deps: AgentDepsT = None, @@ -524,7 +529,7 @@ async def iter( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, - output_type: _output.OutputType[RunOutputDataT] | None = None, + output_type: _output.OutputSpec[RunOutputDataT] | None = None, message_history: list[_messages.ModelMessage] | None = None, model: models.Model | models.KnownModelName | str | None = None, deps: AgentDepsT = None, @@ -623,7 +628,7 @@ async def main(): deps = self._get_deps(deps) new_message_index = len(message_history) if message_history else 0 - output_schema = self._prepare_output_schema(output_type) + output_schema = self._prepare_output_schema(output_type, model_used.profile) output_type_ = output_type or self.output_type @@ -666,13 +671,20 @@ async def main(): ) async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None: - if self._instructions is None and not self._instructions_functions: - return None + parts = [ + self._instructions, + *[await func.run(run_context) for func in self._instructions_functions], + ] + + if isinstance(output_schema, _output.PromptedJsonOutputSchema): + template = model_used.profile.prompted_json_output_instructions + instructions = output_schema.instructions(template) + parts.append(instructions) - instructions = self._instructions or '' - for instructions_runner in self._instructions_functions: - instructions += '\n' + await instructions_runner.run(run_context) - return instructions.strip() + parts = [p for p in parts if p] + if not parts: + return None + return '\n\n'.join(parts).strip() # Copy the function tools so that retry state is agent-run-specific # Note that the retry count is reset to 0 when this happens due to the `default=0` and `init=False`. @@ -769,7 +781,7 @@ def run_sync( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, - output_type: _output.OutputType[RunOutputDataT] | None = None, + output_type: _output.OutputSpec[RunOutputDataT] | None = None, message_history: list[_messages.ModelMessage] | None = None, model: models.Model | models.KnownModelName | str | None = None, deps: AgentDepsT = None, @@ -799,7 +811,7 @@ def run_sync( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, - output_type: _output.OutputType[RunOutputDataT] | None = None, + output_type: _output.OutputSpec[RunOutputDataT] | None = None, message_history: list[_messages.ModelMessage] | None = None, model: models.Model | models.KnownModelName | str | None = None, deps: AgentDepsT = None, @@ -882,7 +894,7 @@ def run_stream( self, user_prompt: str | Sequence[_messages.UserContent], *, - output_type: _output.OutputType[RunOutputDataT], + output_type: _output.OutputSpec[RunOutputDataT], message_history: list[_messages.ModelMessage] | None = None, model: models.Model | models.KnownModelName | str | None = None, deps: AgentDepsT = None, @@ -913,7 +925,7 @@ async def run_stream( # noqa C901 self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, - output_type: _output.OutputType[RunOutputDataT] | None = None, + output_type: _output.OutputSpec[RunOutputDataT] | None = None, message_history: list[_messages.ModelMessage] | None = None, model: models.Model | models.KnownModelName | str | None = None, deps: AgentDepsT = None, @@ -992,10 +1004,13 @@ async def stream_to_final( async for maybe_part_event in streamed_response: if isinstance(maybe_part_event, _messages.PartStartEvent): new_part = maybe_part_event.part - if isinstance(new_part, _messages.TextPart): - if _output.allow_text_output(output_schema): - return FinalResult(s, None, None) - elif isinstance(new_part, _messages.ToolCallPart) and output_schema: + if isinstance(new_part, _messages.TextPart) and isinstance( + output_schema, _output.TextOutputSchema + ): + return FinalResult(s, None, None) + elif isinstance(new_part, _messages.ToolCallPart) and isinstance( + output_schema, _output.ToolOutputSchema + ): # pragma: no branch for call, _ in output_schema.find_tool([new_part]): return FinalResult(s, call.tool_name, call.tool_call_id) return None @@ -1551,8 +1566,8 @@ def _register_tool(self, tool: Tool[AgentDepsT]) -> None: if tool.name in self._function_tools: raise exceptions.UserError(f'Tool name conflicts with existing tool: {tool.name!r}') - if self._output_schema and tool.name in self._output_schema.tools: - raise exceptions.UserError(f'Tool name conflicts with result schema name: {tool.name!r}') + if tool.name in self._output_schema.tools: + raise exceptions.UserError(f'Tool name conflicts with output tool name: {tool.name!r}') self._function_tools[tool.name] = tool @@ -1627,18 +1642,25 @@ def last_run_messages(self) -> list[_messages.ModelMessage]: raise AttributeError('The `last_run_messages` attribute has been removed, use `capture_run_messages` instead.') def _prepare_output_schema( - self, output_type: _output.OutputType[RunOutputDataT] | None - ) -> _output.OutputSchema[RunOutputDataT] | None: + self, output_type: _output.OutputSpec[RunOutputDataT] | None, model_profile: ModelProfile + ) -> _output.OutputSchema[RunOutputDataT]: if output_type is not None: if self._output_validators: raise exceptions.UserError('Cannot set a custom run `output_type` when the agent has output validators') - return _output.OutputSchema[RunOutputDataT].build( + schema = _output.OutputSchema[RunOutputDataT].build( output_type, - self._deprecated_result_tool_name, - self._deprecated_result_tool_description, + name=self._deprecated_result_tool_name, + description=self._deprecated_result_tool_description, + default_mode=model_profile.default_output_mode, ) else: - return self._output_schema # pyright: ignore[reportReturnType] + schema = self._output_schema.with_default_mode(model_profile.default_output_mode) + + if not schema.is_supported(model_profile.output_modes): + modes = ', '.join(f"'{m}'" for m in model_profile.output_modes) + raise exceptions.UserError(f"Output mode '{schema.mode}' is not among supported modes: {modes}") + + return schema # pyright: ignore[reportReturnType] @staticmethod def is_model_request_node( diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py index 1ff370133..214c4c4fc 100644 --- a/pydantic_ai_slim/pydantic_ai/models/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py @@ -20,6 +20,7 @@ from pydantic_ai.profiles import DEFAULT_PROFILE, ModelProfile, ModelProfileSpec +from .._output import OutputMode, OutputObjectDefinition from .._parts_manager import ModelResponsePartsManager from ..exceptions import UserError from ..messages import FileUrl, ModelMessage, ModelRequest, ModelResponse, ModelResponseStreamEvent, VideoUrl @@ -299,8 +300,11 @@ class ModelRequestParameters: """Configuration for an agent's request to a model, specifically related to tools and output handling.""" function_tools: list[ToolDefinition] = field(default_factory=list) - allow_text_output: bool = True + + output_mode: OutputMode = 'text' + output_object: OutputObjectDefinition | None = None output_tools: list[ToolDefinition] = field(default_factory=list) + allow_text_output: bool = True class Model(ABC): @@ -345,6 +349,11 @@ def customize_request_parameters(self, model_request_parameters: ModelRequestPar function_tools=[_customize_tool_def(transformer, t) for t in model_request_parameters.function_tools], output_tools=[_customize_tool_def(transformer, t) for t in model_request_parameters.output_tools], ) + if output_object := model_request_parameters.output_object: + model_request_parameters = replace( + model_request_parameters, + output_object=_customize_output_object(transformer, output_object), + ) return model_request_parameters @@ -712,3 +721,11 @@ def _customize_tool_def(transformer: type[JsonSchemaTransformer], t: ToolDefinit if t.strict is None: t = replace(t, strict=schema_transformer.is_strict_compatible) return replace(t, parameters_json_schema=parameters_json_schema) + + +def _customize_output_object(transformer: type[JsonSchemaTransformer], o: OutputObjectDefinition): + schema_transformer = transformer(o.json_schema, strict=o.strict) + son_schema = schema_transformer.walk() + if o.strict is None: + o = replace(o, strict=schema_transformer.is_strict_compatible) + return replace(o, json_schema=son_schema) diff --git a/pydantic_ai_slim/pydantic_ai/models/function.py b/pydantic_ai_slim/pydantic_ai/models/function.py index 22bcddffb..ce97e4196 100644 --- a/pydantic_ai_slim/pydantic_ai/models/function.py +++ b/pydantic_ai_slim/pydantic_ai/models/function.py @@ -11,6 +11,8 @@ from typing_extensions import TypeAlias, assert_never, overload +from pydantic_ai.profiles import ModelProfileSpec + from .. import _utils, usage from .._utils import PeekableAsyncStream from ..messages import ( @@ -48,14 +50,27 @@ class FunctionModel(Model): _system: str = field(default='function', repr=False) @overload - def __init__(self, function: FunctionDef, *, model_name: str | None = None) -> None: ... + def __init__( + self, function: FunctionDef, *, model_name: str | None = None, profile: ModelProfileSpec | None = None + ) -> None: ... @overload - def __init__(self, *, stream_function: StreamFunctionDef, model_name: str | None = None) -> None: ... + def __init__( + self, + *, + stream_function: StreamFunctionDef, + model_name: str | None = None, + profile: ModelProfileSpec | None = None, + ) -> None: ... @overload def __init__( - self, function: FunctionDef, *, stream_function: StreamFunctionDef, model_name: str | None = None + self, + function: FunctionDef, + *, + stream_function: StreamFunctionDef, + model_name: str | None = None, + profile: ModelProfileSpec | None = None, ) -> None: ... def __init__( @@ -64,6 +79,7 @@ def __init__( *, stream_function: StreamFunctionDef | None = None, model_name: str | None = None, + profile: ModelProfileSpec | None = None, ): """Initialize a `FunctionModel`. @@ -73,6 +89,7 @@ def __init__( function: The function to call for non-streamed requests. stream_function: The function to call for streamed requests. model_name: The name of the model. If not provided, a name is generated from the function names. + profile: The model profile to use. """ if function is None and stream_function is None: raise TypeError('Either `function` or `stream_function` must be provided') @@ -82,6 +99,7 @@ def __init__( function_name = self.function.__name__ if self.function is not None else '' stream_function_name = self.stream_function.__name__ if self.stream_function is not None else '' self._model_name = model_name or f'function:{function_name}:{stream_function_name}' + self._profile = profile async def request( self, diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index 76284c8f0..f8ccfc22f 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -16,6 +16,8 @@ from pydantic_ai.providers import Provider, infer_provider from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage +from .._output import OutputObjectDefinition +from ..exceptions import UserError from ..messages import ( BinaryContent, FileUrl, @@ -190,12 +192,10 @@ def _get_tools(self, model_request_parameters: ModelRequestParameters) -> _Gemin def _get_tool_config( self, model_request_parameters: ModelRequestParameters, tools: _GeminiTools | None ) -> _GeminiToolConfig | None: - if model_request_parameters.allow_text_output: - return None - elif tools: + if not model_request_parameters.allow_text_output and tools: return _tool_config([t['name'] for t in tools['function_declarations']]) else: - return _tool_config([]) # pragma: no cover + return None @asynccontextmanager async def _make_request( @@ -218,6 +218,20 @@ async def _make_request( request_data['toolConfig'] = tool_config generation_config = _settings_to_generation_config(model_settings) + + output_mode = model_request_parameters.output_mode + if output_mode == 'json_schema': + generation_config['response_mime_type'] = 'application/json' + + output_object = model_request_parameters.output_object + assert output_object is not None + generation_config['response_schema'] = self._map_response_schema(output_object) + + if tools: + raise UserError('Google does not support JSON schema output and tools at the same time.') + elif output_mode == 'prompted_json' and not tools: + generation_config['response_mime_type'] = 'application/json' + if generation_config: request_data['generationConfig'] = generation_config @@ -363,6 +377,15 @@ async def _map_user_prompt(self, part: UserPromptPart) -> list[_GeminiPartUnion] assert_never(item) return content + def _map_response_schema(self, o: OutputObjectDefinition) -> dict[str, Any]: + response_schema = o.json_schema.copy() + if o.name: + response_schema['title'] = o.name + if o.description: + response_schema['description'] = o.description + + return response_schema + def _settings_to_generation_config(model_settings: GeminiModelSettings) -> _GeminiGenerationConfig: config: _GeminiGenerationConfig = {} @@ -564,6 +587,8 @@ class _GeminiGenerationConfig(TypedDict, total=False): frequency_penalty: float stop_sequences: list[str] thinking_config: ThinkingConfig + response_mime_type: str + response_schema: dict[str, Any] class _GeminiContent(TypedDict): diff --git a/pydantic_ai_slim/pydantic_ai/models/google.py b/pydantic_ai_slim/pydantic_ai/models/google.py index b9b491916..16bf5e5f9 100644 --- a/pydantic_ai_slim/pydantic_ai/models/google.py +++ b/pydantic_ai_slim/pydantic_ai/models/google.py @@ -10,9 +10,10 @@ from typing_extensions import assert_never -from pydantic_ai.providers import Provider +from pydantic_ai._output import OutputObjectDefinition from .. import UnexpectedModelBehavior, _utils, usage +from ..exceptions import UserError from ..messages import ( BinaryContent, FileUrl, @@ -30,6 +31,7 @@ VideoUrl, ) from ..profiles import ModelProfileSpec +from ..providers import Provider from ..settings import ModelSettings from ..tools import ToolDefinition from . import ( @@ -211,9 +213,7 @@ def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[T def _get_tool_config( self, model_request_parameters: ModelRequestParameters, tools: list[ToolDict] | None ) -> ToolConfigDict | None: - if model_request_parameters.allow_text_output: - return None - elif tools: + if not model_request_parameters.allow_text_output and tools: names: list[str] = [] for tool in tools: for function_declaration in tool.get('function_declarations') or []: @@ -221,7 +221,7 @@ def _get_tool_config( names.append(name) return _tool_config(names) else: - return _tool_config([]) # pragma: no cover + return None @overload async def _generate_content( @@ -249,6 +249,22 @@ async def _generate_content( model_request_parameters: ModelRequestParameters, ) -> GenerateContentResponse | Awaitable[AsyncIterator[GenerateContentResponse]]: tools = self._get_tools(model_request_parameters) + + output_mode = model_request_parameters.output_mode + response_mime_type = None + response_schema = None + if output_mode == 'json_schema': + response_mime_type = 'application/json' + + output_object = model_request_parameters.output_object + assert output_object is not None + response_schema = self._map_response_schema(output_object) + + if tools: + raise UserError('Google does not support JSON schema output and tools at the same time/') + elif output_mode == 'prompted_json' and not tools: + response_mime_type = 'application/json' + tool_config = self._get_tool_config(model_request_parameters, tools) system_instruction, contents = await self._map_messages(messages) @@ -266,6 +282,8 @@ async def _generate_content( labels=model_settings.get('google_labels'), tools=cast(ToolListUnionDict, tools), tool_config=tool_config, + response_mime_type=response_mime_type, + response_schema=response_schema, ) func = self.client.aio.models.generate_content_stream if stream else self.client.aio.models.generate_content @@ -383,6 +401,15 @@ async def _map_user_prompt(self, part: UserPromptPart) -> list[PartDict]: assert_never(item) return content + def _map_response_schema(self, o: OutputObjectDefinition) -> dict[str, Any]: + response_schema = o.json_schema.copy() + if o.name: + response_schema['title'] = o.name + if o.description: + response_schema['description'] = o.description + + return response_schema + @dataclass class GeminiStreamedResponse(StreamedResponse): diff --git a/pydantic_ai_slim/pydantic_ai/models/mistral.py b/pydantic_ai_slim/pydantic_ai/models/mistral.py index 34483a928..8c558bca7 100644 --- a/pydantic_ai_slim/pydantic_ai/models/mistral.py +++ b/pydantic_ai_slim/pydantic_ai/models/mistral.py @@ -250,6 +250,7 @@ async def _stream_completions_create( ) elif model_request_parameters.output_tools: + # TODO: Port to native "manual JSON" mode # Json Mode parameters_json_schemas = [tool.parameters_json_schema for tool in model_request_parameters.output_tools] user_output_format_message = self._generate_user_output_format(parameters_json_schemas) @@ -258,7 +259,9 @@ async def _stream_completions_create( response = await self.client.chat.stream_async( model=str(self._model_name), messages=mistral_messages, - response_format={'type': 'json_object'}, + response_format={ + 'type': 'json_object' + }, # TODO: Should be able to use json_schema now: https://docs.mistral.ai/capabilities/structured-output/custom_structured_output/, https://github.com/mistralai/client-python/blob/bc4adf335968c8a272e1ab7da8461c9943d8e701/src/mistralai/extra/utils/response_format.py#L9 stream=True, http_headers={'User-Agent': get_user_agent()}, ) @@ -566,6 +569,7 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # Attempt to produce an output tool call from the received text if self._output_tools: self._delta_content += text + # TODO: Port to native "manual JSON" mode maybe_tool_call_part = self._try_get_output_tool_from_text(self._delta_content, self._output_tools) if maybe_tool_call_part: yield self._parts_manager.handle_tool_call_part( diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index 430661582..75b02e38b 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -14,6 +14,7 @@ from pydantic_ai.providers import Provider, infer_provider from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage +from .._output import DEFAULT_OUTPUT_TOOL_NAME, OutputObjectDefinition from .._utils import guard_tool_call_id as _guard_tool_call_id, number_to_datetime from ..messages import ( AudioUrl, @@ -270,8 +271,6 @@ async def _completions_create( model_request_parameters: ModelRequestParameters, ) -> chat.ChatCompletion | AsyncStream[ChatCompletionChunk]: tools = self._get_tools(model_request_parameters) - - # standalone function to make it easier to override if not tools: tool_choice: Literal['none', 'required', 'auto'] | None = None elif not model_request_parameters.allow_text_output: @@ -281,6 +280,18 @@ async def _completions_create( openai_messages = await self._map_messages(messages) + response_format: chat.completion_create_params.ResponseFormat | None = None + output_mode = model_request_parameters.output_mode + if output_mode == 'json_schema': + output_object = model_request_parameters.output_object + assert output_object is not None + response_format = self._map_json_schema(output_object) + elif ( + output_mode == 'prompted_json' + and OpenAIModelProfile.from_profile(self.profile).openai_supports_json_object_response_format + ): + response_format = {'type': 'json_object'} + sampling_settings = ( model_settings if OpenAIModelProfile.from_profile(self.profile).openai_supports_sampling_settings @@ -301,6 +312,7 @@ async def _completions_create( stop=model_settings.get('stop_sequences', NOT_GIVEN), max_completion_tokens=model_settings.get('max_tokens', NOT_GIVEN), timeout=model_settings.get('timeout', NOT_GIVEN), + response_format=response_format or NOT_GIVEN, seed=model_settings.get('seed', NOT_GIVEN), reasoning_effort=model_settings.get('openai_reasoning_effort', NOT_GIVEN), user=model_settings.get('openai_user', NOT_GIVEN), @@ -420,6 +432,21 @@ def _map_tool_call(t: ToolCallPart) -> chat.ChatCompletionMessageToolCallParam: function={'name': t.tool_name, 'arguments': t.args_as_json_str()}, ) + @staticmethod + def _map_json_schema(o: OutputObjectDefinition) -> chat.completion_create_params.ResponseFormat: + response_format_param: chat.completion_create_params.ResponseFormatJSONSchema = { # pyright: ignore[reportPrivateImportUsage] + 'type': 'json_schema', + 'json_schema': { + 'name': o.name or DEFAULT_OUTPUT_TOOL_NAME, + 'schema': o.json_schema, + }, + } + if o.description: + response_format_param['json_schema']['description'] = o.description + if o.strict: # pragma: no branch + response_format_param['json_schema']['strict'] = o.strict + return response_format_param + def _map_tool_definition(self, f: ToolDefinition) -> chat.ChatCompletionToolParam: tool_param: chat.ChatCompletionToolParam = { 'type': 'function', @@ -659,7 +686,6 @@ async def _responses_create( tools = self._get_tools(model_request_parameters) tools = list(model_settings.get('openai_builtin_tools', [])) + tools - # standalone function to make it easier to override if not tools: tool_choice: Literal['none', 'required', 'auto'] | None = None elif not model_request_parameters.allow_text_output: @@ -670,6 +696,25 @@ async def _responses_create( instructions, openai_messages = await self._map_messages(messages) reasoning = self._get_reasoning(model_settings) + text: responses.ResponseTextConfigParam | None = None + output_mode = model_request_parameters.output_mode + if output_mode == 'json_schema': + output_object = model_request_parameters.output_object + assert output_object is not None + text = {'format': self._map_json_schema(output_object)} + elif ( + output_mode == 'prompted_json' + and OpenAIModelProfile.from_profile(self.profile).openai_supports_json_object_response_format + ): + text = {'format': {'type': 'json_object'}} + + # Without this trick, we'd hit this error: + # > Response input messages must contain the word 'json' in some form to use 'text.format' of type 'json_object'. + # Apparently they're only checking input messages for "JSON", not instructions. + assert isinstance(instructions, str) + openai_messages.insert(0, responses.EasyInputMessageParam(role='system', content=instructions)) + instructions = NOT_GIVEN + sampling_settings = ( model_settings if OpenAIModelProfile.from_profile(self.profile).openai_supports_sampling_settings @@ -694,6 +739,7 @@ async def _responses_create( timeout=model_settings.get('timeout', NOT_GIVEN), reasoning=reasoning, user=model_settings.get('openai_user', NOT_GIVEN), + text=text or NOT_GIVEN, extra_headers=extra_headers, extra_body=model_settings.get('extra_body'), ) @@ -785,6 +831,19 @@ def _map_tool_call(t: ToolCallPart) -> responses.ResponseFunctionToolCallParam: type='function_call', ) + @staticmethod + def _map_json_schema(o: OutputObjectDefinition) -> responses.ResponseFormatTextJSONSchemaConfigParam: + response_format_param: responses.ResponseFormatTextJSONSchemaConfigParam = { + 'type': 'json_schema', + 'name': o.name or DEFAULT_OUTPUT_TOOL_NAME, + 'schema': o.json_schema, + } + if o.description: + response_format_param['description'] = o.description + if o.strict: # pragma: no branch + response_format_param['strict'] = o.strict + return response_format_param + @staticmethod async def _map_user_prompt(part: UserPromptPart) -> responses.EasyInputMessageParam: content: str | list[responses.ResponseInputContentParam] diff --git a/pydantic_ai_slim/pydantic_ai/models/test.py b/pydantic_ai_slim/pydantic_ai/models/test.py index 0daad25bc..8b8453ac6 100644 --- a/pydantic_ai_slim/pydantic_ai/models/test.py +++ b/pydantic_ai_slim/pydantic_ai/models/test.py @@ -130,7 +130,7 @@ def _get_tool_calls(self, model_request_parameters: ModelRequestParameters) -> l def _get_output(self, model_request_parameters: ModelRequestParameters) -> _WrappedTextOutput | _WrappedToolOutput: if self.custom_output_text is not None: - assert model_request_parameters.allow_text_output, ( + assert model_request_parameters.output_mode != 'tool', ( 'Plain response not allowed, but `custom_output_text` is set.' ) assert self.custom_output_args is None, 'Cannot set both `custom_output_text` and `custom_output_args`.' diff --git a/pydantic_ai_slim/pydantic_ai/profiles/__init__.py b/pydantic_ai_slim/pydantic_ai/profiles/__init__.py index 3792c95a6..5f7119bff 100644 --- a/pydantic_ai_slim/pydantic_ai/profiles/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/profiles/__init__.py @@ -1,10 +1,13 @@ from __future__ import annotations as _annotations -from dataclasses import dataclass, fields, replace +from dataclasses import dataclass, field, fields, replace +from textwrap import dedent from typing import Callable, Union from typing_extensions import Self +from pydantic_ai._output import StructuredOutputMode, SupportableOutputMode + from ._json_schema import JsonSchemaTransformer @@ -13,6 +16,22 @@ class ModelProfile: """Describes how requests to a specific model or family of models need to be constructed to get the best results, independent of the model and provider classes used.""" json_schema_transformer: type[JsonSchemaTransformer] | None = None + """The transformer to use to make JSON schemas for tools and structured output compatible with the model.""" + output_modes: set[SupportableOutputMode] = field(default_factory=lambda: {'tool'}) + """The output modes supported by the model. Essentially all models support `tool` mode, but some also support `json_schema` mode, which needs to be specifically implemented on the model class.""" + default_output_mode: StructuredOutputMode = 'tool' + """The default output mode to use for the model.""" + + prompted_json_output_instructions: str = dedent( + """ + Always respond with a JSON object that's compatible with this schema: + + {schema} + + Don't include any text or Markdown fencing before or after. + """ + ) + """The instructions to use for prompted JSON output. The schema placeholder will be replaced with the JSON schema for the output.""" @classmethod def from_profile(cls, profile: ModelProfile | None) -> Self: diff --git a/pydantic_ai_slim/pydantic_ai/profiles/google.py b/pydantic_ai_slim/pydantic_ai/profiles/google.py index c544e185b..a0cdc61fc 100644 --- a/pydantic_ai_slim/pydantic_ai/profiles/google.py +++ b/pydantic_ai_slim/pydantic_ai/profiles/google.py @@ -10,7 +10,7 @@ def google_model_profile(model_name: str) -> ModelProfile | None: """Get the model profile for a Google model.""" - return ModelProfile(json_schema_transformer=GoogleJsonSchemaTransformer) + return ModelProfile(json_schema_transformer=GoogleJsonSchemaTransformer, output_modes={'tool', 'json_schema'}) class GoogleJsonSchemaTransformer(JsonSchemaTransformer): diff --git a/pydantic_ai_slim/pydantic_ai/profiles/openai.py b/pydantic_ai_slim/pydantic_ai/profiles/openai.py index 464f7d20b..83e82f9ba 100644 --- a/pydantic_ai_slim/pydantic_ai/profiles/openai.py +++ b/pydantic_ai_slim/pydantic_ai/profiles/openai.py @@ -21,12 +21,21 @@ class OpenAIModelProfile(ModelProfile): openai_supports_sampling_settings: bool = True """Turn off to don't send sampling settings like `temperature` and `top_p` to models that don't support them, like OpenAI's o-series reasoning models.""" + openai_supports_json_object_response_format: bool = True + """This can be set by a provider or user if the OpenAI-"compatible" API doesn't support the `json_object` `response_format`. + Note that if a model does not support the `json_schema` `response_format`, that value should be removed from `ModelProfile.output_modes`. + """ + def openai_model_profile(model_name: str) -> ModelProfile: """Get the model profile for an OpenAI model.""" is_reasoning_model = model_name.startswith('o') + # `json_schema` output_mode is only supported with the gpt-4o-mini, gpt-4o-mini-2024-07-18, and gpt-4o-2024-08-06 model snapshots and later, + # but we leave it in here for all models because the `default_output_mode` is `'tool'`, so `json_schema` is only used + # when the user specifically uses the JsonSchemaOutput marker, so an error from the API is acceptable. return OpenAIModelProfile( json_schema_transformer=OpenAIJsonSchemaTransformer, + output_modes={'tool', 'json_schema'}, openai_supports_sampling_settings=not is_reasoning_model, ) diff --git a/pydantic_ai_slim/pydantic_ai/result.py b/pydantic_ai_slim/pydantic_ai/result.py index 443e98b32..0f7416062 100644 --- a/pydantic_ai_slim/pydantic_ai/result.py +++ b/pydantic_ai_slim/pydantic_ai/result.py @@ -5,24 +5,39 @@ from copy import copy from dataclasses import dataclass, field from datetime import datetime -from typing import Generic, cast +from typing import Generic -from typing_extensions import TypeVar, assert_type, deprecated, overload +from pydantic import ValidationError +from typing_extensions import TypeVar, deprecated, overload -from . import _output, _utils, exceptions, messages as _messages, models +from . import _utils, exceptions, messages as _messages, models from ._output import ( + JsonSchemaOutput, OutputDataT, OutputDataT_inv, OutputSchema, OutputValidator, OutputValidatorFunc, + PlainTextOutputSchema, + PromptedJsonOutput, + TextOutput, + TextOutputSchema, ToolOutput, + ToolOutputSchema, ) from .messages import AgentStreamEvent, FinalResultEvent from .tools import AgentDepsT, RunContext from .usage import Usage, UsageLimits -__all__ = 'OutputDataT', 'OutputDataT_inv', 'ToolOutput', 'OutputValidatorFunc' +__all__ = ( + 'OutputDataT', + 'OutputDataT_inv', + 'ToolOutput', + 'TextOutput', + 'JsonSchemaOutput', + 'PromptedJsonOutput', + 'OutputValidatorFunc', +) T = TypeVar('T') @@ -32,7 +47,7 @@ @dataclass class AgentStream(Generic[AgentDepsT, OutputDataT]): _raw_stream_response: models.StreamedResponse - _output_schema: OutputSchema[OutputDataT] | None + _output_schema: OutputSchema[OutputDataT] _output_validators: list[OutputValidator[AgentDepsT, OutputDataT]] _run_ctx: RunContext[AgentDepsT] _usage_limits: UsageLimits | None @@ -80,7 +95,7 @@ async def _validate_response( ) -> OutputDataT: """Validate a structured result message.""" call = None - if self._output_schema is not None and output_tool_name is not None: + if isinstance(self._output_schema, ToolOutputSchema) and output_tool_name is not None: match = self._output_schema.find_named_tool(message.parts, output_tool_name) if match is None: raise exceptions.UnexpectedModelBehavior( # pragma: no cover @@ -91,10 +106,16 @@ async def _validate_response( result_data = await output_tool.process( call, self._run_ctx, allow_partial=allow_partial, wrap_validation_errors=False ) - else: + elif isinstance(self._output_schema, TextOutputSchema): text = '\n\n'.join(x.content for x in message.parts if isinstance(x, _messages.TextPart)) - # The following cast is safe because we know `str` is an allowed output type - result_data = cast(OutputDataT, text) + + result_data = await self._output_schema.process( + text, self._run_ctx, allow_partial=allow_partial, wrap_validation_errors=False + ) + else: + raise exceptions.UnexpectedModelBehavior( # pragma: no cover + 'Invalid response, unable to process text output' + ) for validator in self._output_validators: result_data = await validator.validate(result_data, call, self._run_ctx) @@ -117,14 +138,12 @@ def _get_final_result_event(e: _messages.ModelResponseStreamEvent) -> _messages. """Return an appropriate FinalResultEvent if `e` corresponds to a part that will produce a final result.""" if isinstance(e, _messages.PartStartEvent): new_part = e.part - if isinstance(new_part, _messages.ToolCallPart): - if output_schema: - for call, _ in output_schema.find_tool([new_part]): # pragma: no branch - return _messages.FinalResultEvent( - tool_name=call.tool_name, tool_call_id=call.tool_call_id - ) - elif _output.allow_text_output(output_schema): # pragma: no branch - assert_type(e, _messages.PartStartEvent) + if isinstance(new_part, _messages.ToolCallPart) and isinstance(output_schema, ToolOutputSchema): + for call, _ in output_schema.find_tool([new_part]): # pragma: no branch + return _messages.FinalResultEvent(tool_name=call.tool_name, tool_call_id=call.tool_call_id) + elif isinstance(new_part, _messages.TextPart) and isinstance( + output_schema, TextOutputSchema + ): # pragma: no branch return _messages.FinalResultEvent(tool_name=None, tool_call_id=None) usage_checking_stream = _get_usage_checking_stream_response( @@ -155,7 +174,7 @@ class StreamedRunResult(Generic[AgentDepsT, OutputDataT]): _usage_limits: UsageLimits | None _stream_response: models.StreamedResponse - _output_schema: OutputSchema[OutputDataT] | None + _output_schema: OutputSchema[OutputDataT] _run_ctx: RunContext[AgentDepsT] _output_validators: list[OutputValidator[AgentDepsT, OutputDataT]] _output_tool_name: str | None @@ -296,7 +315,11 @@ async def stream(self, *, debounce_by: float | None = 0.1) -> AsyncIterator[Outp An async iterable of the response data. """ async for structured_message, is_last in self.stream_structured(debounce_by=debounce_by): - yield await self.validate_structured_output(structured_message, allow_partial=not is_last) + try: + yield await self.validate_structured_output(structured_message, allow_partial=not is_last) + except ValidationError: + if is_last: + raise # pragma: lax no cover async def stream_text(self, *, delta: bool = False, debounce_by: float | None = 0.1) -> AsyncIterator[str]: """Stream the text result as an async iterable. @@ -311,7 +334,7 @@ async def stream_text(self, *, delta: bool = False, debounce_by: float | None = Debouncing is particularly important for long structured responses to reduce the overhead of performing validation as each token is received. """ - if self._output_schema and not self._output_schema.allow_text_output: + if not isinstance(self._output_schema, PlainTextOutputSchema): raise exceptions.UserError('stream_text() can only be used with text responses') if delta: @@ -390,7 +413,7 @@ async def validate_structured_output( ) -> OutputDataT: """Validate a structured result message.""" call = None - if self._output_schema is not None and self._output_tool_name is not None: + if isinstance(self._output_schema, ToolOutputSchema) and self._output_tool_name is not None: match = self._output_schema.find_named_tool(message.parts, self._output_tool_name) if match is None: raise exceptions.UnexpectedModelBehavior( # pragma: no cover @@ -401,9 +424,16 @@ async def validate_structured_output( result_data = await output_tool.process( call, self._run_ctx, allow_partial=allow_partial, wrap_validation_errors=False ) - else: + elif isinstance(self._output_schema, TextOutputSchema): text = '\n\n'.join(x.content for x in message.parts if isinstance(x, _messages.TextPart)) - result_data = cast(OutputDataT, text) + + result_data = await self._output_schema.process( + text, self._run_ctx, allow_partial=allow_partial, wrap_validation_errors=False + ) + else: + raise exceptions.UnexpectedModelBehavior( # pragma: no cover + 'Invalid response, unable to process text output' + ) for validator in self._output_validators: result_data = await validator.validate(result_data, call, self._run_ctx) # pragma: no cover diff --git a/tests/models/cassettes/test_anthropic/test_anthropic_prompted_json_output.yaml b/tests/models/cassettes/test_anthropic/test_anthropic_prompted_json_output.yaml new file mode 100644 index 000000000..e88afebdf --- /dev/null +++ b/tests/models/cassettes/test_anthropic/test_anthropic_prompted_json_output.yaml @@ -0,0 +1,161 @@ +interactions: +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '740' + content-type: + - application/json + host: + - api.anthropic.com + method: POST + parsed_body: + max_tokens: 1024 + messages: + - content: + - text: What is the largest city in the user country? Use the get_user_country tool and then your own world knowledge. + type: text + role: user + model: claude-3-5-sonnet-latest + stream: false + system: |+ + Always respond with a JSON object that's compatible with this schema: + + {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} + + Don't include any text or Markdown fencing before or after. + + tool_choice: + type: auto + tools: + - description: '' + input_schema: + additionalProperties: false + properties: {} + type: object + name: get_user_country + uri: https://api.anthropic.com/v1/messages?beta=true + response: + headers: + connection: + - keep-alive + content-length: + - '397' + content-type: + - application/json + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + content: + - id: toolu_017UryVwtsKsjonhFV3cgV3X + input: {} + name: get_user_country + type: tool_use + id: msg_014CpBKzioMqUyLWrMihpvsz + model: claude-3-5-sonnet-20241022 + role: assistant + stop_reason: tool_use + stop_sequence: null + type: message + usage: + cache_creation_input_tokens: 0 + cache_read_input_tokens: 0 + input_tokens: 459 + output_tokens: 38 + service_tier: standard + status: + code: 200 + message: OK +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '1002' + content-type: + - application/json + host: + - api.anthropic.com + method: POST + parsed_body: + max_tokens: 1024 + messages: + - content: + - text: What is the largest city in the user country? Use the get_user_country tool and then your own world knowledge. + type: text + role: user + - content: + - id: toolu_017UryVwtsKsjonhFV3cgV3X + input: {} + name: get_user_country + type: tool_use + role: assistant + - content: + - content: Mexico + is_error: false + tool_use_id: toolu_017UryVwtsKsjonhFV3cgV3X + type: tool_result + role: user + model: claude-3-5-sonnet-latest + stream: false + system: |+ + Always respond with a JSON object that's compatible with this schema: + + {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} + + Don't include any text or Markdown fencing before or after. + + tool_choice: + type: auto + tools: + - description: '' + input_schema: + additionalProperties: false + properties: {} + type: object + name: get_user_country + uri: https://api.anthropic.com/v1/messages?beta=true + response: + headers: + connection: + - keep-alive + content-length: + - '380' + content-type: + - application/json + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + content: + - text: '{"city": "Mexico City", "country": "Mexico"}' + type: text + id: msg_014JeWCouH6DpdqzMTaBdkpJ + model: claude-3-5-sonnet-20241022 + role: assistant + stop_reason: end_turn + stop_sequence: null + type: message + usage: + cache_creation_input_tokens: 0 + cache_read_input_tokens: 0 + input_tokens: 510 + output_tokens: 17 + service_tier: standard + status: + code: 200 + message: OK +version: 1 +... diff --git a/tests/models/cassettes/test_anthropic/test_anthropic_prompted_json_output_multiple.yaml b/tests/models/cassettes/test_anthropic/test_anthropic_prompted_json_output_multiple.yaml new file mode 100644 index 000000000..183daa406 --- /dev/null +++ b/tests/models/cassettes/test_anthropic/test_anthropic_prompted_json_output_multiple.yaml @@ -0,0 +1,66 @@ +interactions: +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '1268' + content-type: + - application/json + host: + - api.anthropic.com + method: POST + parsed_body: + max_tokens: 1024 + messages: + - content: + - text: What is the largest city in Mexico? + type: text + role: user + model: claude-3-5-sonnet-latest + stream: false + system: |+ + Always respond with a JSON object that's compatible with this schema: + + {"type": "object", "properties": {"result": {"anyOf": [{"type": "object", "properties": {"kind": {"type": "string", "const": "CityLocation"}, "data": {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"}}, "description": "CityLocation", "required": ["kind", "data"], "additionalProperties": false}, {"type": "object", "properties": {"kind": {"type": "string", "const": "CountryLanguage"}, "data": {"properties": {"country": {"type": "string"}, "language": {"type": "string"}}, "required": ["country", "language"], "title": "CountryLanguage", "type": "object"}}, "description": "CountryLanguage", "required": ["kind", "data"], "additionalProperties": false}]}}, "required": ["result"], "additionalProperties": false} + + Don't include any text or Markdown fencing before or after. + + uri: https://api.anthropic.com/v1/messages?beta=true + response: + headers: + connection: + - keep-alive + content-length: + - '434' + content-type: + - application/json + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + content: + - text: '{"result": {"kind": "CityLocation", "data": {"city": "Mexico City", "country": "Mexico"}}}' + type: text + id: msg_013ttUi3HCcKt7PkJpoWs5FT + model: claude-3-5-sonnet-20241022 + role: assistant + stop_reason: end_turn + stop_sequence: null + type: message + usage: + cache_creation_input_tokens: 0 + cache_read_input_tokens: 0 + input_tokens: 281 + output_tokens: 31 + service_tier: standard + status: + code: 200 + message: OK +version: 1 +... diff --git a/tests/models/cassettes/test_anthropic/test_anthropic_text_output_function.yaml b/tests/models/cassettes/test_anthropic/test_anthropic_text_output_function.yaml new file mode 100644 index 000000000..ad365d4ac --- /dev/null +++ b/tests/models/cassettes/test_anthropic/test_anthropic_text_output_function.yaml @@ -0,0 +1,156 @@ +interactions: +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '409' + content-type: + - application/json + host: + - api.anthropic.com + method: POST + parsed_body: + max_tokens: 1024 + messages: + - content: + - text: What is the largest city in the user country? Use the get_user_country tool and then your own world knowledge. + type: text + role: user + model: claude-3-5-sonnet-latest + stream: false + tool_choice: + type: auto + tools: + - description: '' + input_schema: + additionalProperties: false + properties: {} + type: object + name: get_user_country + uri: https://api.anthropic.com/v1/messages?beta=true + response: + headers: + connection: + - keep-alive + content-length: + - '540' + content-type: + - application/json + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + content: + - text: I'll help you find the largest city in your country. Let me first check your country using the get_user_country + tool. + type: text + - id: toolu_01EZuxfc6MsPsPgrAKQohw3e + input: {} + name: get_user_country + type: tool_use + id: msg_014NE4yfV1Yz2vLAJzapxxef + model: claude-3-5-sonnet-20241022 + role: assistant + stop_reason: tool_use + stop_sequence: null + type: message + usage: + cache_creation_input_tokens: 0 + cache_read_input_tokens: 0 + input_tokens: 383 + output_tokens: 66 + service_tier: standard + status: + code: 200 + message: OK +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '814' + content-type: + - application/json + host: + - api.anthropic.com + method: POST + parsed_body: + max_tokens: 1024 + messages: + - content: + - text: What is the largest city in the user country? Use the get_user_country tool and then your own world knowledge. + type: text + role: user + - content: + - text: I'll help you find the largest city in your country. Let me first check your country using the get_user_country + tool. + type: text + - id: toolu_01EZuxfc6MsPsPgrAKQohw3e + input: {} + name: get_user_country + type: tool_use + role: assistant + - content: + - content: Mexico + is_error: false + tool_use_id: toolu_01EZuxfc6MsPsPgrAKQohw3e + type: tool_result + role: user + model: claude-3-5-sonnet-latest + stream: false + tool_choice: + type: auto + tools: + - description: '' + input_schema: + additionalProperties: false + properties: {} + type: object + name: get_user_country + uri: https://api.anthropic.com/v1/messages?beta=true + response: + headers: + connection: + - keep-alive + content-length: + - '801' + content-type: + - application/json + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + content: + - text: Based on the result, you are located in Mexico. The largest city in Mexico is Mexico City (Ciudad de México), + which is also the nation's capital. Mexico City has a population of approximately 9.2 million people in the city + proper, and over 21 million people in its metropolitan area, making it one of the largest urban agglomerations in + the world. It is both the political and economic center of Mexico, located in the Valley of Mexico in the central + part of the country. + type: text + id: msg_0193srwo7TCx49h97wDwc7K7 + model: claude-3-5-sonnet-20241022 + role: assistant + stop_reason: end_turn + stop_sequence: null + type: message + usage: + cache_creation_input_tokens: 0 + cache_read_input_tokens: 0 + input_tokens: 461 + output_tokens: 107 + service_tier: standard + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/cassettes/test_anthropic/test_anthropic_tool_output.yaml b/tests/models/cassettes/test_anthropic/test_anthropic_tool_output.yaml new file mode 100644 index 000000000..560e1f34c --- /dev/null +++ b/tests/models/cassettes/test_anthropic/test_anthropic_tool_output.yaml @@ -0,0 +1,176 @@ +interactions: +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '585' + content-type: + - application/json + host: + - api.anthropic.com + method: POST + parsed_body: + max_tokens: 1024 + messages: + - content: + - text: What is the largest city in the user country? + type: text + role: user + model: claude-3-5-sonnet-latest + stream: false + tool_choice: + type: any + tools: + - description: '' + input_schema: + additionalProperties: false + properties: {} + type: object + name: get_user_country + - description: The final response which ends this conversation + input_schema: + properties: + city: + type: string + country: + type: string + required: + - city + - country + title: CityLocation + type: object + name: final_result + uri: https://api.anthropic.com/v1/messages?beta=true + response: + headers: + connection: + - keep-alive + content-length: + - '397' + content-type: + - application/json + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + content: + - id: toolu_019pMboNVRg5jkw4PKkofQ6Y + input: {} + name: get_user_country + type: tool_use + id: msg_01EnfsDTixCmHjqvk9QarBj4 + model: claude-3-5-sonnet-20241022 + role: assistant + stop_reason: tool_use + stop_sequence: null + type: message + usage: + cache_creation_input_tokens: 0 + cache_read_input_tokens: 0 + input_tokens: 445 + output_tokens: 23 + service_tier: standard + status: + code: 200 + message: OK +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '847' + content-type: + - application/json + host: + - api.anthropic.com + method: POST + parsed_body: + max_tokens: 1024 + messages: + - content: + - text: What is the largest city in the user country? + type: text + role: user + - content: + - id: toolu_019pMboNVRg5jkw4PKkofQ6Y + input: {} + name: get_user_country + type: tool_use + role: assistant + - content: + - content: Mexico + is_error: false + tool_use_id: toolu_019pMboNVRg5jkw4PKkofQ6Y + type: tool_result + role: user + model: claude-3-5-sonnet-latest + stream: false + tool_choice: + type: any + tools: + - description: '' + input_schema: + additionalProperties: false + properties: {} + type: object + name: get_user_country + - description: The final response which ends this conversation + input_schema: + properties: + city: + type: string + country: + type: string + required: + - city + - country + title: CityLocation + type: object + name: final_result + uri: https://api.anthropic.com/v1/messages?beta=true + response: + headers: + connection: + - keep-alive + content-length: + - '432' + content-type: + - application/json + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + content: + - id: toolu_01V4d2H4EWp5LDM2aXaeyR6W + input: + city: Mexico City + country: Mexico + name: final_result + type: tool_use + id: msg_01Hbm5BtKzfVtWs8Eb7rCNNx + model: claude-3-5-sonnet-20241022 + role: assistant + stop_reason: tool_use + stop_sequence: null + type: message + usage: + cache_creation_input_tokens: 0 + cache_read_input_tokens: 0 + input_tokens: 497 + output_tokens: 56 + service_tier: standard + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/cassettes/test_gemini/test_gemini_json_schema_output.yaml b/tests/models/cassettes/test_gemini/test_gemini_json_schema_output.yaml new file mode 100644 index 000000000..d7f14c9ca --- /dev/null +++ b/tests/models/cassettes/test_gemini/test_gemini_json_schema_output.yaml @@ -0,0 +1,79 @@ +interactions: +- request: + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '305' + content-type: + - application/json + host: + - generativelanguage.googleapis.com + method: POST + parsed_body: + contents: + - parts: + - text: What is the largest city in Mexico? + role: user + generationConfig: + response_mime_type: application/json + response_schema: + properties: + city: + type: string + country: + type: string + required: + - city + - country + title: CityLocation + type: object + uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent + response: + headers: + alt-svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + content-length: + - '710' + content-type: + - application/json; charset=UTF-8 + server-timing: + - gfet4t7; dur=819 + transfer-encoding: + - chunked + vary: + - Origin + - X-Origin + - Referer + parsed_body: + candidates: + - avgLogprobs: -0.00018302639946341515 + content: + parts: + - text: |- + { + "city": "Mexico City", + "country": "Mexico" + } + role: model + finishReason: STOP + modelVersion: gemini-2.0-flash + responseId: SEVIaJvJHICK7dcP3OzRiQQ + usageMetadata: + candidatesTokenCount: 20 + candidatesTokensDetails: + - modality: TEXT + tokenCount: 20 + promptTokenCount: 17 + promptTokensDetails: + - modality: TEXT + tokenCount: 17 + totalTokenCount: 37 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/cassettes/test_gemini/test_gemini_json_schema_output_multiple.yaml b/tests/models/cassettes/test_gemini/test_gemini_json_schema_output_multiple.yaml new file mode 100644 index 000000000..3b306d133 --- /dev/null +++ b/tests/models/cassettes/test_gemini/test_gemini_json_schema_output_multiple.yaml @@ -0,0 +1,120 @@ +interactions: +- request: + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '791' + content-type: + - application/json + host: + - generativelanguage.googleapis.com + method: POST + parsed_body: + contents: + - parts: + - text: What is the primarily language spoken in Mexico? + role: user + generationConfig: + response_mime_type: application/json + response_schema: + properties: + result: + anyOf: + - description: CityLocation + properties: + data: + properties: + city: + type: string + country: + type: string + required: + - city + - country + type: object + kind: + enum: + - CityLocation + type: string + required: + - kind + - data + type: object + - description: CountryLanguage + properties: + data: + properties: + country: + type: string + language: + type: string + required: + - country + - language + type: object + kind: + enum: + - CountryLanguage + type: string + required: + - kind + - data + type: object + required: + - result + type: object + uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent + response: + headers: + alt-svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + content-length: + - '800' + content-type: + - application/json; charset=UTF-8 + server-timing: + - gfet4t7; dur=963 + transfer-encoding: + - chunked + vary: + - Origin + - X-Origin + - Referer + parsed_body: + candidates: + - avgLogprobs: -3.3667640072172103e-06 + content: + parts: + - text: |- + { + "result": { + "data": { + "country": "Mexico", + "language": "Spanish" + }, + "kind": "CountryLanguage" + } + } + role: model + finishReason: STOP + modelVersion: gemini-2.0-flash + responseId: 2jxIaPucEYCK7dcP3OzRiQQ + usageMetadata: + candidatesTokenCount: 46 + candidatesTokensDetails: + - modality: TEXT + tokenCount: 46 + promptTokenCount: 46 + promptTokensDetails: + - modality: TEXT + tokenCount: 46 + totalTokenCount: 92 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/cassettes/test_gemini/test_gemini_prompted_json_output.yaml b/tests/models/cassettes/test_gemini/test_gemini_prompted_json_output.yaml new file mode 100644 index 000000000..2268e7f84 --- /dev/null +++ b/tests/models/cassettes/test_gemini/test_gemini_prompted_json_output.yaml @@ -0,0 +1,74 @@ +interactions: +- request: + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '521' + content-type: + - application/json + host: + - generativelanguage.googleapis.com + method: POST + parsed_body: + contents: + - parts: + - text: What is the largest city in Mexico? + role: user + generationConfig: + response_mime_type: application/json + systemInstruction: + parts: + - text: |- + Always respond with a JSON object that's compatible with this schema: + + {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} + + Don't include any text or Markdown fencing before or after. + role: user + uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent + response: + headers: + alt-svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + content-length: + - '880' + content-type: + - application/json; charset=UTF-8 + server-timing: + - gfet4t7; dur=841 + transfer-encoding: + - chunked + vary: + - Origin + - X-Origin + - Referer + parsed_body: + candidates: + - avgLogprobs: -0.007913463882037572 + content: + parts: + - text: '{"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], + "title": "CityLocation", "type": "object", "city": "Mexico City", "country": "Mexico"}' + role: model + finishReason: STOP + modelVersion: gemini-2.0-flash + responseId: 2zxIaIiLE4CK7dcP3OzRiQQ + usageMetadata: + candidatesTokenCount: 56 + candidatesTokensDetails: + - modality: TEXT + tokenCount: 56 + promptTokenCount: 80 + promptTokensDetails: + - modality: TEXT + tokenCount: 80 + totalTokenCount: 136 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/cassettes/test_gemini/test_gemini_prompted_json_output_multiple.yaml b/tests/models/cassettes/test_gemini/test_gemini_prompted_json_output_multiple.yaml new file mode 100644 index 000000000..e96fc20d7 --- /dev/null +++ b/tests/models/cassettes/test_gemini/test_gemini_prompted_json_output_multiple.yaml @@ -0,0 +1,73 @@ +interactions: +- request: + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '1287' + content-type: + - application/json + host: + - generativelanguage.googleapis.com + method: POST + parsed_body: + contents: + - parts: + - text: What is the largest city in Mexico? + role: user + generationConfig: + response_mime_type: application/json + systemInstruction: + parts: + - text: |- + Always respond with a JSON object that's compatible with this schema: + + {"type": "object", "properties": {"result": {"anyOf": [{"type": "object", "properties": {"kind": {"type": "string", "const": "CityLocation"}, "data": {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"}}, "description": "CityLocation", "required": ["kind", "data"], "additionalProperties": false}, {"type": "object", "properties": {"kind": {"type": "string", "const": "CountryLanguage"}, "data": {"properties": {"country": {"type": "string"}, "language": {"type": "string"}}, "required": ["country", "language"], "title": "CountryLanguage", "type": "object"}}, "description": "CountryLanguage", "required": ["kind", "data"], "additionalProperties": false}]}}, "required": ["result"], "additionalProperties": false} + + Don't include any text or Markdown fencing before or after. + role: user + uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent + response: + headers: + alt-svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + content-length: + - '757' + content-type: + - application/json; charset=UTF-8 + server-timing: + - gfet4t7; dur=823 + transfer-encoding: + - chunked + vary: + - Origin + - X-Origin + - Referer + parsed_body: + candidates: + - avgLogprobs: -0.0030997690779191477 + content: + parts: + - text: '{"result": {"kind": "CityLocation", "data": {"city": "Mexico City", "country": "Mexico"}}}' + role: model + finishReason: STOP + modelVersion: gemini-2.0-flash + responseId: Wz1IaOH5OdGU7dcPjoS34QI + usageMetadata: + candidatesTokenCount: 27 + candidatesTokensDetails: + - modality: TEXT + tokenCount: 27 + promptTokenCount: 253 + promptTokensDetails: + - modality: TEXT + tokenCount: 253 + totalTokenCount: 280 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/cassettes/test_gemini/test_gemini_prompted_json_output_with_tools.yaml b/tests/models/cassettes/test_gemini/test_gemini_prompted_json_output_with_tools.yaml new file mode 100644 index 000000000..f10da3ad7 --- /dev/null +++ b/tests/models/cassettes/test_gemini/test_gemini_prompted_json_output_with_tools.yaml @@ -0,0 +1,157 @@ +interactions: +- request: + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '615' + content-type: + - application/json + host: + - generativelanguage.googleapis.com + method: POST + parsed_body: + contents: + - parts: + - text: What is the largest city in the user country? Use the get_user_country tool and then your own world knowledge. + role: user + systemInstruction: + parts: + - text: |- + Always respond with a JSON object that's compatible with this schema: + + {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} + + Don't include any text or Markdown fencing before or after. + role: user + tools: + functionDeclarations: + - description: '' + name: get_user_country + uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-pro-preview-05-06:generateContent + response: + headers: + alt-svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + content-length: + - '653' + content-type: + - application/json; charset=UTF-8 + server-timing: + - gfet4t7; dur=4501 + transfer-encoding: + - chunked + vary: + - Origin + - X-Origin + - Referer + parsed_body: + candidates: + - content: + parts: + - functionCall: + args: {} + name: get_user_country + role: model + finishReason: STOP + index: 0 + modelVersion: models/gemini-2.5-pro-preview-05-06 + responseId: rj9IaPTzNdCBqtsPg-GD6QU + usageMetadata: + candidatesTokenCount: 12 + promptTokenCount: 123 + promptTokensDetails: + - modality: TEXT + tokenCount: 123 + thoughtsTokenCount: 318 + totalTokenCount: 453 + status: + code: 200 + message: OK +- request: + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '809' + content-type: + - application/json + host: + - generativelanguage.googleapis.com + method: POST + parsed_body: + contents: + - parts: + - text: What is the largest city in the user country? Use the get_user_country tool and then your own world knowledge. + role: user + - parts: + - functionCall: + args: {} + name: get_user_country + role: model + - parts: + - functionResponse: + name: get_user_country + response: + return_value: Mexico + role: user + systemInstruction: + parts: + - text: |- + Always respond with a JSON object that's compatible with this schema: + + {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} + + Don't include any text or Markdown fencing before or after. + role: user + tools: + functionDeclarations: + - description: '' + name: get_user_country + uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-pro-preview-05-06:generateContent + response: + headers: + alt-svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + content-length: + - '616' + content-type: + - application/json; charset=UTF-8 + server-timing: + - gfet4t7; dur=1823 + transfer-encoding: + - chunked + vary: + - Origin + - X-Origin + - Referer + parsed_body: + candidates: + - content: + parts: + - text: '{"city": "Mexico City", "country": "Mexico"}' + role: model + finishReason: STOP + index: 0 + modelVersion: models/gemini-2.5-pro-preview-05-06 + responseId: sD9IaOCyLPqumtkP6p_T0AE + usageMetadata: + candidatesTokenCount: 13 + promptTokenCount: 154 + promptTokensDetails: + - modality: TEXT + tokenCount: 154 + thoughtsTokenCount: 94 + totalTokenCount: 261 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/cassettes/test_gemini/test_gemini_text_output_function.yaml b/tests/models/cassettes/test_gemini/test_gemini_text_output_function.yaml new file mode 100644 index 000000000..7d54ce938 --- /dev/null +++ b/tests/models/cassettes/test_gemini/test_gemini_text_output_function.yaml @@ -0,0 +1,63 @@ +interactions: +- request: + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '87' + content-type: + - application/json + host: + - generativelanguage.googleapis.com + method: POST + parsed_body: + contents: + - parts: + - text: What is the largest city in Mexico? + role: user + uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-pro-preview-05-06:generateContent + response: + headers: + alt-svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + content-length: + - '753' + content-type: + - application/json; charset=UTF-8 + server-timing: + - gfet4t7; dur=6856 + transfer-encoding: + - chunked + vary: + - Origin + - X-Origin + - Referer + parsed_body: + candidates: + - content: + parts: + - text: |- + The largest city in Mexico is **Mexico City (Ciudad de México, CDMX)**. + + It's the capital of Mexico and one of the largest metropolitan areas in the world, both by population and land area. + role: model + finishReason: STOP + index: 0 + modelVersion: models/gemini-2.5-pro-preview-05-06 + responseId: TT9IaNfGN_DmqtsPzKnE4AE + usageMetadata: + candidatesTokenCount: 44 + promptTokenCount: 9 + promptTokensDetails: + - modality: TEXT + tokenCount: 9 + thoughtsTokenCount: 545 + totalTokenCount: 598 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/cassettes/test_gemini/test_gemini_tool_output.yaml b/tests/models/cassettes/test_gemini/test_gemini_tool_output.yaml new file mode 100644 index 000000000..f0c7adc68 --- /dev/null +++ b/tests/models/cassettes/test_gemini/test_gemini_tool_output.yaml @@ -0,0 +1,183 @@ +interactions: +- request: + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '511' + content-type: + - application/json + host: + - generativelanguage.googleapis.com + method: POST + parsed_body: + contents: + - parts: + - text: What is the largest city in the user country? + role: user + toolConfig: + function_calling_config: + allowed_function_names: + - get_user_country + - final_result + mode: ANY + tools: + functionDeclarations: + - description: '' + name: get_user_country + - description: The final response which ends this conversation + name: final_result + parameters: + properties: + city: + type: string + country: + type: string + required: + - city + - country + type: object + uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent + response: + headers: + alt-svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + content-length: + - '733' + content-type: + - application/json; charset=UTF-8 + server-timing: + - gfet4t7; dur=591 + transfer-encoding: + - chunked + vary: + - Origin + - X-Origin + - Referer + parsed_body: + candidates: + - avgLogprobs: 5.670217797160149e-06 + content: + parts: + - functionCall: + args: {} + name: get_user_country + role: model + finishReason: STOP + modelVersion: gemini-2.0-flash + responseId: SDxIaMqaGOS9nvgPzP-Y2QE + usageMetadata: + candidatesTokenCount: 5 + candidatesTokensDetails: + - modality: TEXT + tokenCount: 5 + promptTokenCount: 32 + promptTokensDetails: + - modality: TEXT + tokenCount: 32 + totalTokenCount: 37 + status: + code: 200 + message: OK +- request: + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '705' + content-type: + - application/json + host: + - generativelanguage.googleapis.com + method: POST + parsed_body: + contents: + - parts: + - text: What is the largest city in the user country? + role: user + - parts: + - functionCall: + args: {} + name: get_user_country + role: model + - parts: + - functionResponse: + name: get_user_country + response: + return_value: Mexico + role: user + toolConfig: + function_calling_config: + allowed_function_names: + - get_user_country + - final_result + mode: ANY + tools: + functionDeclarations: + - description: '' + name: get_user_country + - description: The final response which ends this conversation + name: final_result + parameters: + properties: + city: + type: string + country: + type: string + required: + - city + - country + type: object + uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent + response: + headers: + alt-svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + content-length: + - '821' + content-type: + - application/json; charset=UTF-8 + server-timing: + - gfet4t7; dur=613 + transfer-encoding: + - chunked + vary: + - Origin + - X-Origin + - Referer + parsed_body: + candidates: + - avgLogprobs: -3.3069271012209356e-05 + content: + parts: + - functionCall: + args: + city: Mexico City + country: Mexico + name: final_result + role: model + finishReason: STOP + modelVersion: gemini-2.0-flash + responseId: SDxIaNHrNcy3nvgPm5DhwQo + usageMetadata: + candidatesTokenCount: 8 + candidatesTokensDetails: + - modality: TEXT + tokenCount: 8 + promptTokenCount: 46 + promptTokensDetails: + - modality: TEXT + tokenCount: 46 + totalTokenCount: 54 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/cassettes/test_google/test_google_json_schema_output.yaml b/tests/models/cassettes/test_google/test_google_json_schema_output.yaml new file mode 100644 index 000000000..1d9ae0339 --- /dev/null +++ b/tests/models/cassettes/test_google/test_google_json_schema_output.yaml @@ -0,0 +1,86 @@ +interactions: +- request: + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '453' + content-type: + - application/json + host: + - generativelanguage.googleapis.com + method: POST + parsed_body: + contents: + - parts: + - text: What is the largest city in Mexico? + role: user + generationConfig: + responseMimeType: application/json + responseSchema: + properties: + city: + type: STRING + country: + type: STRING + property_ordering: + - city + - country + required: + - city + - country + title: CityLocation + type: OBJECT + toolConfig: + functionCallingConfig: + allowedFunctionNames: [] + mode: ANY + uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent + response: + headers: + alt-svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + content-length: + - '710' + content-type: + - application/json; charset=UTF-8 + server-timing: + - gfet4t7; dur=780 + transfer-encoding: + - chunked + vary: + - Origin + - X-Origin + - Referer + parsed_body: + candidates: + - avgLogprobs: -0.0002309985226020217 + content: + parts: + - text: |- + { + "city": "Mexico City", + "country": "Mexico" + } + role: model + finishReason: STOP + modelVersion: gemini-2.0-flash + responseId: Gm9HaNr3KteI_NUPmYvnoA8 + usageMetadata: + candidatesTokenCount: 20 + candidatesTokensDetails: + - modality: TEXT + tokenCount: 20 + promptTokenCount: 19 + promptTokensDetails: + - modality: TEXT + tokenCount: 19 + totalTokenCount: 39 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/cassettes/test_google/test_google_json_schema_output_multiple.yaml b/tests/models/cassettes/test_google/test_google_json_schema_output_multiple.yaml new file mode 100644 index 000000000..74dd03c89 --- /dev/null +++ b/tests/models/cassettes/test_google/test_google_json_schema_output_multiple.yaml @@ -0,0 +1,138 @@ +interactions: +- request: + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '1200' + content-type: + - application/json + host: + - generativelanguage.googleapis.com + method: POST + parsed_body: + contents: + - parts: + - text: What is the primarily language spoken in Mexico? + role: user + generationConfig: + responseMimeType: application/json + responseSchema: + description: The final response which ends this conversation + properties: + result: + any_of: + - description: CityLocation + properties: + data: + properties: + city: + type: STRING + country: + type: STRING + property_ordering: + - city + - country + required: + - city + - country + type: OBJECT + kind: + enum: + - CityLocation + type: STRING + property_ordering: + - kind + - data + required: + - kind + - data + type: OBJECT + - description: CountryLanguage + properties: + data: + properties: + country: + type: STRING + language: + type: STRING + property_ordering: + - country + - language + required: + - country + - language + type: OBJECT + kind: + enum: + - CountryLanguage + type: STRING + property_ordering: + - kind + - data + required: + - kind + - data + type: OBJECT + required: + - result + title: final_result + type: OBJECT + toolConfig: + functionCallingConfig: + allowedFunctionNames: [] + mode: ANY + uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent + response: + headers: + alt-svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + content-length: + - '800' + content-type: + - application/json; charset=UTF-8 + server-timing: + - gfet4t7; dur=884 + transfer-encoding: + - chunked + vary: + - Origin + - X-Origin + - Referer + parsed_body: + candidates: + - avgLogprobs: -0.0002536005138055138 + content: + parts: + - text: |- + { + "result": { + "kind": "CountryLanguage", + "data": { + "country": "Mexico", + "language": "Spanish" + } + } + } + role: model + finishReason: STOP + modelVersion: gemini-2.0-flash + responseId: W29HaJzGMNGU7dcPjoS34QI + usageMetadata: + candidatesTokenCount: 46 + candidatesTokensDetails: + - modality: TEXT + tokenCount: 46 + promptTokenCount: 64 + promptTokensDetails: + - modality: TEXT + tokenCount: 64 + totalTokenCount: 110 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/cassettes/test_google/test_google_prompted_json_output.yaml b/tests/models/cassettes/test_google/test_google_prompted_json_output.yaml new file mode 100644 index 000000000..3b241acae --- /dev/null +++ b/tests/models/cassettes/test_google/test_google_prompted_json_output.yaml @@ -0,0 +1,78 @@ +interactions: +- request: + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '619' + content-type: + - application/json + host: + - generativelanguage.googleapis.com + method: POST + parsed_body: + contents: + - parts: + - text: What is the largest city in Mexico? + role: user + generationConfig: + responseMimeType: application/json + systemInstruction: + parts: + - text: |- + Always respond with a JSON object that's compatible with this schema: + + {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} + + Don't include any text or Markdown fencing before or after. + role: user + toolConfig: + functionCallingConfig: + allowedFunctionNames: [] + mode: ANY + uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent + response: + headers: + alt-svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + content-length: + - '879' + content-type: + - application/json; charset=UTF-8 + server-timing: + - gfet4t7; dur=829 + transfer-encoding: + - chunked + vary: + - Origin + - X-Origin + - Referer + parsed_body: + candidates: + - avgLogprobs: -0.010130892906870161 + content: + parts: + - text: '{"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], + "title": "CityLocation", "type": "object", "city": "Mexico City", "country": "Mexico"}' + role: model + finishReason: STOP + modelVersion: gemini-2.0-flash + responseId: 4HlHaK75MdGU7dcPjoS34QI + usageMetadata: + candidatesTokenCount: 56 + candidatesTokensDetails: + - modality: TEXT + tokenCount: 56 + promptTokenCount: 80 + promptTokensDetails: + - modality: TEXT + tokenCount: 80 + totalTokenCount: 136 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/cassettes/test_google/test_google_prompted_json_output_multiple.yaml b/tests/models/cassettes/test_google/test_google_prompted_json_output_multiple.yaml new file mode 100644 index 000000000..33383473f --- /dev/null +++ b/tests/models/cassettes/test_google/test_google_prompted_json_output_multiple.yaml @@ -0,0 +1,77 @@ +interactions: +- request: + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '1341' + content-type: + - application/json + host: + - generativelanguage.googleapis.com + method: POST + parsed_body: + contents: + - parts: + - text: What is the largest city in Mexico? + role: user + generationConfig: + responseMimeType: application/json + systemInstruction: + parts: + - text: |- + Always respond with a JSON object that's compatible with this schema: + + {"type": "object", "properties": {"result": {"anyOf": [{"type": "object", "properties": {"kind": {"const": "CityLocation"}, "data": {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"}}, "description": "CityLocation", "required": ["kind", "data"], "additionalProperties": false}, {"type": "object", "properties": {"kind": {"const": "CountryLanguage"}, "data": {"properties": {"country": {"type": "string"}, "language": {"type": "string"}}, "required": ["country", "language"], "title": "CountryLanguage", "type": "object"}}, "description": "CountryLanguage", "required": ["kind", "data"], "additionalProperties": false}]}}, "required": ["result"], "additionalProperties": false} + + Don't include any text or Markdown fencing before or after. + role: user + toolConfig: + functionCallingConfig: + allowedFunctionNames: [] + mode: ANY + uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent + response: + headers: + alt-svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + content-length: + - '758' + content-type: + - application/json; charset=UTF-8 + server-timing: + - gfet4t7; dur=734 + transfer-encoding: + - chunked + vary: + - Origin + - X-Origin + - Referer + parsed_body: + candidates: + - avgLogprobs: -0.0008548707873732956 + content: + parts: + - text: '{"result": {"kind": "CityLocation", "data": {"city": "Mexico City", "country": "Mexico"}}}' + role: model + finishReason: STOP + modelVersion: gemini-2.0-flash + responseId: 6nlHaO_5GdeI_NUPmYvnoA8 + usageMetadata: + candidatesTokenCount: 27 + candidatesTokensDetails: + - modality: TEXT + tokenCount: 27 + promptTokenCount: 241 + promptTokensDetails: + - modality: TEXT + tokenCount: 241 + totalTokenCount: 268 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/cassettes/test_google/test_google_prompted_json_output_with_tools.yaml b/tests/models/cassettes/test_google/test_google_prompted_json_output_with_tools.yaml new file mode 100644 index 000000000..976533c66 --- /dev/null +++ b/tests/models/cassettes/test_google/test_google_prompted_json_output_with_tools.yaml @@ -0,0 +1,164 @@ +interactions: +- request: + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '658' + content-type: + - application/json + host: + - generativelanguage.googleapis.com + method: POST + parsed_body: + contents: + - parts: + - text: What is the largest city in the user country? Use the get_user_country tool and then your own world knowledge. + role: user + generationConfig: {} + systemInstruction: + parts: + - text: |- + Always respond with a JSON object that's compatible with this schema: + + {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} + + Don't include any text or Markdown fencing before or after. + role: user + tools: + - functionDeclarations: + - description: '' + name: get_user_country + uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-pro-preview-05-06:generateContent + response: + headers: + alt-svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + content-length: + - '653' + content-type: + - application/json; charset=UTF-8 + server-timing: + - gfet4t7; dur=3776 + transfer-encoding: + - chunked + vary: + - Origin + - X-Origin + - Referer + parsed_body: + candidates: + - content: + parts: + - functionCall: + args: {} + name: get_user_country + role: model + finishReason: STOP + index: 0 + modelVersion: models/gemini-2.5-pro-preview-05-06 + responseId: FnpHaOqcKrzQz7IPkuLo8QE + usageMetadata: + candidatesTokenCount: 12 + promptTokenCount: 123 + promptTokensDetails: + - modality: TEXT + tokenCount: 123 + thoughtsTokenCount: 266 + totalTokenCount: 401 + status: + code: 200 + message: OK +- request: + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '967' + content-type: + - application/json + host: + - generativelanguage.googleapis.com + method: POST + parsed_body: + contents: + - parts: + - text: What is the largest city in the user country? Use the get_user_country tool and then your own world knowledge. + role: user + - parts: + - functionCall: + args: {} + id: pyd_ai_479a74a75212414fb3c7bd2242e9b669 + name: get_user_country + role: model + - parts: + - functionResponse: + id: pyd_ai_479a74a75212414fb3c7bd2242e9b669 + name: get_user_country + response: + return_value: Mexico + role: user + generationConfig: {} + systemInstruction: + parts: + - text: |- + Always respond with a JSON object that's compatible with this schema: + + {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} + + Don't include any text or Markdown fencing before or after. + role: user + tools: + - functionDeclarations: + - description: '' + name: get_user_country + uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-pro-preview-05-06:generateContent + response: + headers: + alt-svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + content-length: + - '630' + content-type: + - application/json; charset=UTF-8 + server-timing: + - gfet4t7; dur=1888 + transfer-encoding: + - chunked + vary: + - Origin + - X-Origin + - Referer + parsed_body: + candidates: + - content: + parts: + - text: |- + ```json + {"city": "Mexico City", "country": "Mexico"} + ``` + role: model + finishReason: STOP + index: 0 + modelVersion: models/gemini-2.5-pro-preview-05-06 + responseId: GHpHaOPkI43Qz7IPxt6T2Ac + usageMetadata: + candidatesTokenCount: 18 + promptTokenCount: 154 + promptTokensDetails: + - modality: TEXT + tokenCount: 154 + thoughtsTokenCount: 94 + totalTokenCount: 266 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/cassettes/test_google/test_google_text_output_function.yaml b/tests/models/cassettes/test_google/test_google_text_output_function.yaml new file mode 100644 index 000000000..ebfbfc86f --- /dev/null +++ b/tests/models/cassettes/test_google/test_google_text_output_function.yaml @@ -0,0 +1,147 @@ +interactions: +- request: + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '279' + content-type: + - application/json + host: + - generativelanguage.googleapis.com + method: POST + parsed_body: + contents: + - parts: + - text: What is the largest city in the user country? Use the get_user_country tool and then your own world knowledge. + role: user + generationConfig: {} + tools: + - functionDeclarations: + - description: '' + name: get_user_country + uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-pro-preview-05-06:generateContent + response: + headers: + alt-svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + content-length: + - '769' + content-type: + - application/json; charset=UTF-8 + server-timing: + - gfet4t7; dur=2956 + transfer-encoding: + - chunked + vary: + - Origin + - X-Origin + - Referer + parsed_body: + candidates: + - content: + parts: + - text: | + Okay, I can help with that. First, I need to determine your country. + - functionCall: + args: {} + name: get_user_country + role: model + finishReason: STOP + index: 0 + modelVersion: models/gemini-2.5-pro-preview-05-06 + responseId: J25HaLv8GvDQz7IPp_zUiQo + usageMetadata: + candidatesTokenCount: 30 + promptTokenCount: 49 + promptTokensDetails: + - modality: TEXT + tokenCount: 49 + thoughtsTokenCount: 159 + totalTokenCount: 238 + status: + code: 200 + message: OK +- request: + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '672' + content-type: + - application/json + host: + - generativelanguage.googleapis.com + method: POST + parsed_body: + contents: + - parts: + - text: What is the largest city in the user country? Use the get_user_country tool and then your own world knowledge. + role: user + - parts: + - text: | + Okay, I can help with that. First, I need to determine your country. + - functionCall: + args: {} + id: pyd_ai_82dd46d016b24cf999ce5d812b383f1a + name: get_user_country + role: model + - parts: + - functionResponse: + id: pyd_ai_82dd46d016b24cf999ce5d812b383f1a + name: get_user_country + response: + return_value: Mexico + role: user + generationConfig: {} + tools: + - functionDeclarations: + - description: '' + name: get_user_country + uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-pro-preview-05-06:generateContent + response: + headers: + alt-svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + content-length: + - '637' + content-type: + - application/json; charset=UTF-8 + server-timing: + - gfet4t7; dur=1426 + transfer-encoding: + - chunked + vary: + - Origin + - X-Origin + - Referer + parsed_body: + candidates: + - content: + parts: + - text: Based on the information I have, the largest city in Mexico is Mexico City. + role: model + finishReason: STOP + index: 0 + modelVersion: models/gemini-2.5-pro-preview-05-06 + responseId: KG5HaKT3Nc2fz7IPy9KsuQU + usageMetadata: + candidatesTokenCount: 16 + promptTokenCount: 98 + promptTokensDetails: + - modality: TEXT + tokenCount: 98 + thoughtsTokenCount: 45 + totalTokenCount: 159 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/cassettes/test_google/test_google_tool_output.yaml b/tests/models/cassettes/test_google/test_google_tool_output.yaml new file mode 100644 index 000000000..bebefdf6a --- /dev/null +++ b/tests/models/cassettes/test_google/test_google_tool_output.yaml @@ -0,0 +1,187 @@ +interactions: +- request: + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '568' + content-type: + - application/json + host: + - generativelanguage.googleapis.com + method: POST + parsed_body: + contents: + - parts: + - text: What is the largest city in the user country? + role: user + generationConfig: {} + toolConfig: + functionCallingConfig: + allowedFunctionNames: + - get_user_country + - final_result + mode: ANY + tools: + - functionDeclarations: + - description: '' + name: get_user_country + - description: The final response which ends this conversation + name: final_result + parameters: + properties: + city: + type: STRING + country: + type: STRING + required: + - city + - country + type: OBJECT + uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent + response: + headers: + alt-svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + content-length: + - '733' + content-type: + - application/json; charset=UTF-8 + server-timing: + - gfet4t7; dur=644 + transfer-encoding: + - chunked + vary: + - Origin + - X-Origin + - Referer + parsed_body: + candidates: + - avgLogprobs: 5.670217797160149e-06 + content: + parts: + - functionCall: + args: {} + name: get_user_country + role: model + finishReason: STOP + modelVersion: gemini-2.0-flash + responseId: F21HaLmGI5m2nvgP-__7yAg + usageMetadata: + candidatesTokenCount: 5 + candidatesTokensDetails: + - modality: TEXT + tokenCount: 5 + promptTokenCount: 32 + promptTokensDetails: + - modality: TEXT + tokenCount: 32 + totalTokenCount: 37 + status: + code: 200 + message: OK +- request: + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '877' + content-type: + - application/json + host: + - generativelanguage.googleapis.com + method: POST + parsed_body: + contents: + - parts: + - text: What is the largest city in the user country? + role: user + - parts: + - functionCall: + args: {} + id: pyd_ai_9bbd9b896939438e8ff5aba64fed8674 + name: get_user_country + role: model + - parts: + - functionResponse: + id: pyd_ai_9bbd9b896939438e8ff5aba64fed8674 + name: get_user_country + response: + return_value: Mexico + role: user + generationConfig: {} + toolConfig: + functionCallingConfig: + allowedFunctionNames: + - get_user_country + - final_result + mode: ANY + tools: + - functionDeclarations: + - description: '' + name: get_user_country + - description: The final response which ends this conversation + name: final_result + parameters: + properties: + city: + type: STRING + country: + type: STRING + required: + - city + - country + type: OBJECT + uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent + response: + headers: + alt-svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + content-length: + - '821' + content-type: + - application/json; charset=UTF-8 + server-timing: + - gfet4t7; dur=531 + transfer-encoding: + - chunked + vary: + - Origin + - X-Origin + - Referer + parsed_body: + candidates: + - avgLogprobs: -3.289666346972808e-05 + content: + parts: + - functionCall: + args: + city: Mexico City + country: Mexico + name: final_result + role: model + finishReason: STOP + modelVersion: gemini-2.0-flash + responseId: GG1HaMXtBoW8nvgPkaDy0Ag + usageMetadata: + candidatesTokenCount: 8 + candidatesTokensDetails: + - modality: TEXT + tokenCount: 8 + promptTokenCount: 46 + promptTokensDetails: + - modality: TEXT + tokenCount: 46 + totalTokenCount: 54 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/cassettes/test_openai/test_openai_json_schema_output.yaml b/tests/models/cassettes/test_openai/test_openai_json_schema_output.yaml new file mode 100644 index 000000000..ff4477f3d --- /dev/null +++ b/tests/models/cassettes/test_openai/test_openai_json_schema_output.yaml @@ -0,0 +1,223 @@ +interactions: +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '522' + content-type: + - application/json + host: + - api.openai.com + method: POST + parsed_body: + messages: + - content: What is the largest city in the user country? + role: user + model: gpt-4o + n: 1 + response_format: + json_schema: + name: result + schema: + properties: + city: + type: string + country: + type: string + required: + - city + - country + type: object + strict: false + type: json_schema + stream: false + tool_choice: auto + tools: + - function: + description: '' + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + type: function + uri: https://api.openai.com/v1/chat/completions + response: + headers: + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 + connection: + - keep-alive + content-length: + - '1066' + content-type: + - application/json + openai-organization: + - pydantic-28gund + openai-processing-ms: + - '341' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + choices: + - finish_reason: tool_calls + index: 0 + logprobs: null + message: + annotations: [] + content: null + refusal: null + role: assistant + tool_calls: + - function: + arguments: '{}' + name: get_user_country + id: call_PkRGedQNRFUzJp2R7dO7avWR + type: function + created: 1746142582 + id: chatcmpl-BSXjyBwGuZrtuuSzNCeaWMpGv2MZ3 + model: gpt-4o-2024-08-06 + object: chat.completion + service_tier: default + system_fingerprint: fp_f5bdcc3276 + usage: + completion_tokens: 12 + completion_tokens_details: + accepted_prediction_tokens: 0 + audio_tokens: 0 + reasoning_tokens: 0 + rejected_prediction_tokens: 0 + prompt_tokens: 71 + prompt_tokens_details: + audio_tokens: 0 + cached_tokens: 0 + total_tokens: 83 + status: + code: 200 + message: OK +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '753' + content-type: + - application/json + cookie: + - __cf_bm=dOa3_E1SoWV.vbgZ8L7tx8o9S.XyNTE.YS0K0I3JHq4-1746142583-1.0.1.1-0TuvhdYsoD.J1522DBXH0yrAP_M9MlzvlcpyfwQQNZy.KO5gri6ejQ.gFuwLV5hGhuY0W2uI1dN7ZF1lirVHKeEnEz5s_89aJjrMWjyBd8M; + _cfuvid=xQIJVHkOP28w5fPnAvDHPiCRlU7kkNj6iFV87W4u8Ds-1746142583128-0.0.1.1-604800000 + host: + - api.openai.com + method: POST + parsed_body: + messages: + - content: What is the largest city in the user country? + role: user + - role: assistant + tool_calls: + - function: + arguments: '{}' + name: get_user_country + id: call_PkRGedQNRFUzJp2R7dO7avWR + type: function + - content: Mexico + role: tool + tool_call_id: call_PkRGedQNRFUzJp2R7dO7avWR + model: gpt-4o + n: 1 + response_format: + json_schema: + name: result + schema: + properties: + city: + type: string + country: + type: string + required: + - city + - country + type: object + strict: false + type: json_schema + stream: false + tool_choice: auto + tools: + - function: + description: '' + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + type: function + uri: https://api.openai.com/v1/chat/completions + response: + headers: + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 + connection: + - keep-alive + content-length: + - '852' + content-type: + - application/json + openai-organization: + - pydantic-28gund + openai-processing-ms: + - '553' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + choices: + - finish_reason: stop + index: 0 + logprobs: null + message: + annotations: [] + content: '{"city":"Mexico City","country":"Mexico"}' + refusal: null + role: assistant + created: 1746142583 + id: chatcmpl-BSXjzYGu67dhTy5r8KmjJvQ4HhDVO + model: gpt-4o-2024-08-06 + object: chat.completion + service_tier: default + system_fingerprint: fp_f5bdcc3276 + usage: + completion_tokens: 15 + completion_tokens_details: + accepted_prediction_tokens: 0 + audio_tokens: 0 + reasoning_tokens: 0 + rejected_prediction_tokens: 0 + prompt_tokens: 92 + prompt_tokens_details: + audio_tokens: 0 + cached_tokens: 0 + total_tokens: 107 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/cassettes/test_openai/test_openai_json_schema_output_multiple.yaml b/tests/models/cassettes/test_openai/test_openai_json_schema_output_multiple.yaml new file mode 100644 index 000000000..d01e28ab0 --- /dev/null +++ b/tests/models/cassettes/test_openai/test_openai_json_schema_output_multiple.yaml @@ -0,0 +1,293 @@ +interactions: +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '1120' + content-type: + - application/json + host: + - api.openai.com + method: POST + parsed_body: + messages: + - content: What is the largest city in the user country? + role: user + model: gpt-4o + response_format: + json_schema: + description: The final response which ends this conversation + name: final_result + schema: + additionalProperties: false + properties: + result: + anyOf: + - additionalProperties: false + description: CityLocation + properties: + data: + properties: + city: + type: string + country: + type: string + required: + - city + - country + type: object + kind: + const: CityLocation + required: + - kind + - data + type: object + - additionalProperties: false + description: CountryLanguage + properties: + data: + properties: + country: + type: string + language: + type: string + required: + - country + - language + type: object + kind: + const: CountryLanguage + required: + - kind + - data + type: object + required: + - result + type: object + type: json_schema + stream: false + tool_choice: auto + tools: + - function: + description: '' + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + type: function + uri: https://api.openai.com/v1/chat/completions + response: + headers: + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 + connection: + - keep-alive + content-length: + - '1068' + content-type: + - application/json + openai-organization: + - pydantic-28gund + openai-processing-ms: + - '868' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + choices: + - finish_reason: tool_calls + index: 0 + logprobs: null + message: + annotations: [] + content: null + refusal: null + role: assistant + tool_calls: + - function: + arguments: '{}' + name: get_user_country + id: call_SIttSeiOistt33Htj4oiHOOX + type: function + created: 1749511286 + id: chatcmpl-Bgg5utuCSXMQ38j0n2qgfdQKcR9VD + model: gpt-4o-2024-08-06 + object: chat.completion + service_tier: default + system_fingerprint: fp_9bddfca6e2 + usage: + completion_tokens: 11 + completion_tokens_details: + accepted_prediction_tokens: 0 + audio_tokens: 0 + reasoning_tokens: 0 + rejected_prediction_tokens: 0 + prompt_tokens: 160 + prompt_tokens_details: + audio_tokens: 0 + cached_tokens: 0 + total_tokens: 171 + status: + code: 200 + message: OK +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '1351' + content-type: + - application/json + cookie: + - __cf_bm=OFzdr.HrmtC0DNdnfrTQYsK8_PwAVR9GUqjYSCgwtVM-1749511286-1.0.1.1-9_dbth7ET4rzl01UDRTw3fY1nJ20FnMCC0BBmd57gzKF8n5DnNQaI4K1mT.23nn9IUsMyHAZUNn6t1EML3d7GfGJyiLZOxrTWaqacALgzlM; + _cfuvid=f32dQYPsRd6Jc7kg.3hHa1QYAyG8f_aMMXUF.bC6gmY-1749511286914-0.0.1.1-604800000 + host: + - api.openai.com + method: POST + parsed_body: + messages: + - content: What is the largest city in the user country? + role: user + - role: assistant + tool_calls: + - function: + arguments: '{}' + name: get_user_country + id: call_SIttSeiOistt33Htj4oiHOOX + type: function + - content: Mexico + role: tool + tool_call_id: call_SIttSeiOistt33Htj4oiHOOX + model: gpt-4o + response_format: + json_schema: + description: The final response which ends this conversation + name: final_result + schema: + additionalProperties: false + properties: + result: + anyOf: + - additionalProperties: false + description: CityLocation + properties: + data: + properties: + city: + type: string + country: + type: string + required: + - city + - country + type: object + kind: + const: CityLocation + required: + - kind + - data + type: object + - additionalProperties: false + description: CountryLanguage + properties: + data: + properties: + country: + type: string + language: + type: string + required: + - country + - language + type: object + kind: + const: CountryLanguage + required: + - kind + - data + type: object + required: + - result + type: object + type: json_schema + stream: false + tool_choice: auto + tools: + - function: + description: '' + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + type: function + uri: https://api.openai.com/v1/chat/completions + response: + headers: + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 + connection: + - keep-alive + content-length: + - '903' + content-type: + - application/json + openai-organization: + - pydantic-28gund + openai-processing-ms: + - '920' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + choices: + - finish_reason: stop + index: 0 + logprobs: null + message: + annotations: [] + content: '{"result":{"kind":"CityLocation","data":{"city":"Mexico City","country":"Mexico"}}}' + refusal: null + role: assistant + created: 1749511287 + id: chatcmpl-Bgg5vrxUtCDlvgMreoxYxPaKxANmd + model: gpt-4o-2024-08-06 + object: chat.completion + service_tier: default + system_fingerprint: fp_9bddfca6e2 + usage: + completion_tokens: 25 + completion_tokens_details: + accepted_prediction_tokens: 0 + audio_tokens: 0 + reasoning_tokens: 0 + rejected_prediction_tokens: 0 + prompt_tokens: 181 + prompt_tokens_details: + audio_tokens: 0 + cached_tokens: 0 + total_tokens: 206 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/cassettes/test_openai/test_openai_prompted_json_output.yaml b/tests/models/cassettes/test_openai/test_openai_prompted_json_output.yaml new file mode 100644 index 000000000..4eed79085 --- /dev/null +++ b/tests/models/cassettes/test_openai/test_openai_prompted_json_output.yaml @@ -0,0 +1,209 @@ +interactions: +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '690' + content-type: + - application/json + host: + - api.openai.com + method: POST + parsed_body: + messages: + - content: |- + Always respond with a JSON object that's compatible with this schema: + + {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} + + Don't include any text or Markdown fencing before or after. + role: system + - content: What is the largest city in the user country? + role: user + model: gpt-4o + response_format: + type: json_object + stream: false + tool_choice: auto + tools: + - function: + description: '' + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + type: function + uri: https://api.openai.com/v1/chat/completions + response: + headers: + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 + connection: + - keep-alive + content-length: + - '1068' + content-type: + - application/json + openai-organization: + - pydantic-28gund + openai-processing-ms: + - '569' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + choices: + - finish_reason: tool_calls + index: 0 + logprobs: null + message: + annotations: [] + content: null + refusal: null + role: assistant + tool_calls: + - function: + arguments: '{}' + name: get_user_country + id: call_s7oT9jaLAsEqTgvxZTmFh0wB + type: function + created: 1749514895 + id: chatcmpl-Bgh27PeOaFW6qmF04qC5uI2H9mviw + model: gpt-4o-2024-08-06 + object: chat.completion + service_tier: default + system_fingerprint: fp_07871e2ad8 + usage: + completion_tokens: 11 + completion_tokens_details: + accepted_prediction_tokens: 0 + audio_tokens: 0 + reasoning_tokens: 0 + rejected_prediction_tokens: 0 + prompt_tokens: 109 + prompt_tokens_details: + audio_tokens: 0 + cached_tokens: 0 + total_tokens: 120 + status: + code: 200 + message: OK +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '921' + content-type: + - application/json + cookie: + - __cf_bm=jcec.FXQ2vs1UTNFhcDbuMrvzdFu7d7L1To24_vRFiQ-1749514896-1.0.1.1-PEeul2ZYkvLFmEXXk4Xlgvun2HcuGEJ0UUliLVWKx17kMCjZ8WiZbB2Yavq3RRGlxsJZsAWIVMQQ10Vb_2aqGVtQ2aiYTlnDMX3Ktkuciyk; + _cfuvid=zanrNpp5OAiS0wLKfkW9LCs3qTO2FvIaiBZptR_D2P0-1749514896187-0.0.1.1-604800000 + host: + - api.openai.com + method: POST + parsed_body: + messages: + - content: |- + Always respond with a JSON object that's compatible with this schema: + + {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} + + Don't include any text or Markdown fencing before or after. + role: system + - content: What is the largest city in the user country? + role: user + - role: assistant + tool_calls: + - function: + arguments: '{}' + name: get_user_country + id: call_s7oT9jaLAsEqTgvxZTmFh0wB + type: function + - content: Mexico + role: tool + tool_call_id: call_s7oT9jaLAsEqTgvxZTmFh0wB + model: gpt-4o + response_format: + type: json_object + stream: false + tool_choice: auto + tools: + - function: + description: '' + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + type: function + uri: https://api.openai.com/v1/chat/completions + response: + headers: + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 + connection: + - keep-alive + content-length: + - '853' + content-type: + - application/json + openai-organization: + - pydantic-28gund + openai-processing-ms: + - '718' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + choices: + - finish_reason: stop + index: 0 + logprobs: null + message: + annotations: [] + content: '{"city":"Mexico City","country":"Mexico"}' + refusal: null + role: assistant + created: 1749514896 + id: chatcmpl-Bgh28advCSFhGHPnzUevVS6g6Uwg0 + model: gpt-4o-2024-08-06 + object: chat.completion + service_tier: default + system_fingerprint: fp_07871e2ad8 + usage: + completion_tokens: 11 + completion_tokens_details: + accepted_prediction_tokens: 0 + audio_tokens: 0 + reasoning_tokens: 0 + rejected_prediction_tokens: 0 + prompt_tokens: 130 + prompt_tokens_details: + audio_tokens: 0 + cached_tokens: 0 + total_tokens: 141 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/cassettes/test_openai/test_openai_prompted_json_output_multiple.yaml b/tests/models/cassettes/test_openai/test_openai_prompted_json_output_multiple.yaml new file mode 100644 index 000000000..3d3ba886a --- /dev/null +++ b/tests/models/cassettes/test_openai/test_openai_prompted_json_output_multiple.yaml @@ -0,0 +1,209 @@ +interactions: +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '1412' + content-type: + - application/json + host: + - api.openai.com + method: POST + parsed_body: + messages: + - content: |- + Always respond with a JSON object that's compatible with this schema: + + {"type": "object", "properties": {"result": {"anyOf": [{"type": "object", "properties": {"kind": {"const": "CityLocation"}, "data": {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"}}, "description": "CityLocation", "required": ["kind", "data"], "additionalProperties": false}, {"type": "object", "properties": {"kind": {"const": "CountryLanguage"}, "data": {"properties": {"country": {"type": "string"}, "language": {"type": "string"}}, "required": ["country", "language"], "title": "CountryLanguage", "type": "object"}}, "description": "CountryLanguage", "required": ["kind", "data"], "additionalProperties": false}]}}, "required": ["result"], "additionalProperties": false} + + Don't include any text or Markdown fencing before or after. + role: system + - content: What is the largest city in the user country? + role: user + model: gpt-4o + response_format: + type: json_object + stream: false + tool_choice: auto + tools: + - function: + description: '' + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + type: function + uri: https://api.openai.com/v1/chat/completions + response: + headers: + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 + connection: + - keep-alive + content-length: + - '1068' + content-type: + - application/json + openai-organization: + - pydantic-28gund + openai-processing-ms: + - '428' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + choices: + - finish_reason: tool_calls + index: 0 + logprobs: null + message: + annotations: [] + content: null + refusal: null + role: assistant + tool_calls: + - function: + arguments: '{}' + name: get_user_country + id: call_wJD14IyJ4KKVtjCrGyNCHO09 + type: function + created: 1749514898 + id: chatcmpl-Bgh2AW2NXGgMc7iS639MJXNRgtatR + model: gpt-4o-2024-08-06 + object: chat.completion + service_tier: default + system_fingerprint: fp_9bddfca6e2 + usage: + completion_tokens: 11 + completion_tokens_details: + accepted_prediction_tokens: 0 + audio_tokens: 0 + reasoning_tokens: 0 + rejected_prediction_tokens: 0 + prompt_tokens: 273 + prompt_tokens_details: + audio_tokens: 0 + cached_tokens: 0 + total_tokens: 284 + status: + code: 200 + message: OK +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '1643' + content-type: + - application/json + cookie: + - __cf_bm=gqjIEMZSez95CPkkPVuU_AoDutHrobFMbFPjq43G66M-1749514899-1.0.1.1-5TGB9WajW5pzCRtVtWeQfiwyQUZs1JwWy9qC8VGlgq7s5pQWKerukQtYB7GqNDrdb.1pbtFyt2HZ9xV3YiSbK4H1bZS_hS1CCeoup_3IQW0; + _cfuvid=ZN6eoNau4b.bJ8kvRn2z9R0HgTUd9nOsupKUtLXQowU-1749514899280-0.0.1.1-604800000 + host: + - api.openai.com + method: POST + parsed_body: + messages: + - content: |- + Always respond with a JSON object that's compatible with this schema: + + {"type": "object", "properties": {"result": {"anyOf": [{"type": "object", "properties": {"kind": {"const": "CityLocation"}, "data": {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"}}, "description": "CityLocation", "required": ["kind", "data"], "additionalProperties": false}, {"type": "object", "properties": {"kind": {"const": "CountryLanguage"}, "data": {"properties": {"country": {"type": "string"}, "language": {"type": "string"}}, "required": ["country", "language"], "title": "CountryLanguage", "type": "object"}}, "description": "CountryLanguage", "required": ["kind", "data"], "additionalProperties": false}]}}, "required": ["result"], "additionalProperties": false} + + Don't include any text or Markdown fencing before or after. + role: system + - content: What is the largest city in the user country? + role: user + - role: assistant + tool_calls: + - function: + arguments: '{}' + name: get_user_country + id: call_wJD14IyJ4KKVtjCrGyNCHO09 + type: function + - content: Mexico + role: tool + tool_call_id: call_wJD14IyJ4KKVtjCrGyNCHO09 + model: gpt-4o + response_format: + type: json_object + stream: false + tool_choice: auto + tools: + - function: + description: '' + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + type: function + uri: https://api.openai.com/v1/chat/completions + response: + headers: + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 + connection: + - keep-alive + content-length: + - '903' + content-type: + - application/json + openai-organization: + - pydantic-28gund + openai-processing-ms: + - '763' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + choices: + - finish_reason: stop + index: 0 + logprobs: null + message: + annotations: [] + content: '{"result":{"kind":"CityLocation","data":{"city":"Mexico City","country":"Mexico"}}}' + refusal: null + role: assistant + created: 1749514899 + id: chatcmpl-Bgh2BthuopRnSqCuUgMbBnOqgkDHC + model: gpt-4o-2024-08-06 + object: chat.completion + service_tier: default + system_fingerprint: fp_9bddfca6e2 + usage: + completion_tokens: 21 + completion_tokens_details: + accepted_prediction_tokens: 0 + audio_tokens: 0 + reasoning_tokens: 0 + rejected_prediction_tokens: 0 + prompt_tokens: 294 + prompt_tokens_details: + audio_tokens: 0 + cached_tokens: 0 + total_tokens: 315 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/cassettes/test_openai/test_openai_text_output_function.yaml b/tests/models/cassettes/test_openai/test_openai_text_output_function.yaml new file mode 100644 index 000000000..9a2f3c06f --- /dev/null +++ b/tests/models/cassettes/test_openai/test_openai_text_output_function.yaml @@ -0,0 +1,191 @@ +interactions: +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '303' + content-type: + - application/json + host: + - api.openai.com + method: POST + parsed_body: + messages: + - content: What is the largest city in the user country? + role: user + model: gpt-4o + stream: false + tool_choice: auto + tools: + - function: + description: '' + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + type: function + uri: https://api.openai.com/v1/chat/completions + response: + headers: + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 + connection: + - keep-alive + content-length: + - '1066' + content-type: + - application/json + openai-organization: + - pydantic-28gund + openai-processing-ms: + - '432' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + choices: + - finish_reason: tool_calls + index: 0 + logprobs: null + message: + annotations: [] + content: null + refusal: null + role: assistant + tool_calls: + - function: + arguments: '{}' + name: get_user_country + id: call_J1YabdC7G7kzEZNbbZopwenH + type: function + created: 1749504053 + id: chatcmpl-BgeDFS85bfHosRFEEAvq8reaCPCZ8 + model: gpt-4o-2024-08-06 + object: chat.completion + service_tier: default + system_fingerprint: fp_9bddfca6e2 + usage: + completion_tokens: 11 + completion_tokens_details: + accepted_prediction_tokens: 0 + audio_tokens: 0 + reasoning_tokens: 0 + rejected_prediction_tokens: 0 + prompt_tokens: 42 + prompt_tokens_details: + audio_tokens: 0 + cached_tokens: 0 + total_tokens: 53 + status: + code: 200 + message: OK +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '534' + content-type: + - application/json + cookie: + - __cf_bm=YTub3t5GuFdFQZLwCTHT2eGO.fT0zx3Sk2kEY.wvtik-1749504053-1.0.1.1-BMg98yRknUs3LAtnRn_3w1W2X4aoKkKWHIwaBFv.1bdfOF._ZCV0pIGVcI1saCXHR9BMUfQzhTdEPeLlXocUxVzzYQCNTOAxf21UZXcs.ks; + _cfuvid=u8gIns9XYwRGSqmjviw_hUFmKp.LpNqiNvoFMcyyK40-1749504053813-0.0.1.1-604800000 + host: + - api.openai.com + method: POST + parsed_body: + messages: + - content: What is the largest city in the user country? + role: user + - role: assistant + tool_calls: + - function: + arguments: '{}' + name: get_user_country + id: call_J1YabdC7G7kzEZNbbZopwenH + type: function + - content: Mexico + role: tool + tool_call_id: call_J1YabdC7G7kzEZNbbZopwenH + model: gpt-4o + stream: false + tool_choice: auto + tools: + - function: + description: '' + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + type: function + uri: https://api.openai.com/v1/chat/completions + response: + headers: + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 + connection: + - keep-alive + content-length: + - '844' + content-type: + - application/json + openai-organization: + - pydantic-28gund + openai-processing-ms: + - '449' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + choices: + - finish_reason: stop + index: 0 + logprobs: null + message: + annotations: [] + content: The largest city in Mexico is Mexico City. + refusal: null + role: assistant + created: 1749504054 + id: chatcmpl-BgeDGX9eDyVrEI56aP2vtIHahBzFH + model: gpt-4o-2024-08-06 + object: chat.completion + service_tier: default + system_fingerprint: fp_9bddfca6e2 + usage: + completion_tokens: 10 + completion_tokens_details: + accepted_prediction_tokens: 0 + audio_tokens: 0 + reasoning_tokens: 0 + rejected_prediction_tokens: 0 + prompt_tokens: 63 + prompt_tokens_details: + audio_tokens: 0 + cached_tokens: 0 + total_tokens: 73 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/cassettes/test_openai/test_openai_tool_output.yaml b/tests/models/cassettes/test_openai/test_openai_tool_output.yaml new file mode 100644 index 000000000..56f7441f1 --- /dev/null +++ b/tests/models/cassettes/test_openai/test_openai_tool_output.yaml @@ -0,0 +1,227 @@ +interactions: +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '561' + content-type: + - application/json + host: + - api.openai.com + method: POST + parsed_body: + messages: + - content: What is the largest city in the user country? + role: user + model: gpt-4o + n: 1 + stream: false + tool_choice: required + tools: + - function: + description: '' + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + type: function + - function: + description: The final response which ends this conversation + name: final_result + parameters: + properties: + city: + type: string + country: + type: string + required: + - city + - country + type: object + type: function + uri: https://api.openai.com/v1/chat/completions + response: + headers: + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 + connection: + - keep-alive + content-length: + - '1066' + content-type: + - application/json + openai-organization: + - pydantic-28gund + openai-processing-ms: + - '348' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + choices: + - finish_reason: tool_calls + index: 0 + logprobs: null + message: + annotations: [] + content: null + refusal: null + role: assistant + tool_calls: + - function: + arguments: '{}' + name: get_user_country + id: call_iXFttys57ap0o16JSlC8yhYo + type: function + created: 1746142584 + id: chatcmpl-BSXk0dWkG4hfPt0lph4oFO35iT73I + model: gpt-4o-2024-08-06 + object: chat.completion + service_tier: default + system_fingerprint: fp_f5bdcc3276 + usage: + completion_tokens: 12 + completion_tokens_details: + accepted_prediction_tokens: 0 + audio_tokens: 0 + reasoning_tokens: 0 + rejected_prediction_tokens: 0 + prompt_tokens: 68 + prompt_tokens_details: + audio_tokens: 0 + cached_tokens: 0 + total_tokens: 80 + status: + code: 200 + message: OK +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '792' + content-type: + - application/json + cookie: + - __cf_bm=yM.C6I_kAzJk3Dm7H52actN1zAEW8fj.Gd2yeJ7tKN0-1746142584-1.0.1.1-xk91aElDtLLC8aROrOKHlp5vck_h.R.zQkS6OrsiBOwuFA8rE1kGswpactMEtYxV9WgWDN2B4S2B4zs8heyxmcfiNjmOf075n.OPqYpVla4; + _cfuvid=JCllInpf6fg1JdOS7xSj3bZOXYf9PYJ8uoamRTx7ku4-1746142584855-0.0.1.1-604800000 + host: + - api.openai.com + method: POST + parsed_body: + messages: + - content: What is the largest city in the user country? + role: user + - role: assistant + tool_calls: + - function: + arguments: '{}' + name: get_user_country + id: call_iXFttys57ap0o16JSlC8yhYo + type: function + - content: Mexico + role: tool + tool_call_id: call_iXFttys57ap0o16JSlC8yhYo + model: gpt-4o + n: 1 + stream: false + tool_choice: required + tools: + - function: + description: '' + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + type: function + - function: + description: The final response which ends this conversation + name: final_result + parameters: + properties: + city: + type: string + country: + type: string + required: + - city + - country + type: object + type: function + uri: https://api.openai.com/v1/chat/completions + response: + headers: + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 + connection: + - keep-alive + content-length: + - '1113' + content-type: + - application/json + openai-organization: + - pydantic-28gund + openai-processing-ms: + - '1919' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + choices: + - finish_reason: tool_calls + index: 0 + logprobs: null + message: + annotations: [] + content: null + refusal: null + role: assistant + tool_calls: + - function: + arguments: '{"city": "Mexico City", "country": "Mexico"}' + name: final_result + id: call_gmD2oUZUzSoCkmNmp3JPUF7R + type: function + created: 1746142585 + id: chatcmpl-BSXk1xGHYzbhXgUkSutK08bdoNv5s + model: gpt-4o-2024-08-06 + object: chat.completion + service_tier: default + system_fingerprint: fp_f5bdcc3276 + usage: + completion_tokens: 36 + completion_tokens_details: + accepted_prediction_tokens: 0 + audio_tokens: 0 + reasoning_tokens: 0 + rejected_prediction_tokens: 0 + prompt_tokens: 89 + prompt_tokens_details: + audio_tokens: 0 + cached_tokens: 0 + total_tokens: 125 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/cassettes/test_openai_responses/test_json_schema_output.yaml b/tests/models/cassettes/test_openai_responses/test_json_schema_output.yaml new file mode 100644 index 000000000..9fd1b6989 --- /dev/null +++ b/tests/models/cassettes/test_openai_responses/test_json_schema_output.yaml @@ -0,0 +1,288 @@ +interactions: +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '533' + content-type: + - application/json + host: + - api.openai.com + method: POST + parsed_body: + input: + - content: What is the largest city in the user country? + role: user + model: gpt-4o + stream: false + text: + format: + name: CityLocation + schema: + additionalProperties: false + properties: + city: + type: string + country: + type: string + required: + - city + - country + type: object + strict: true + type: json_schema + tool_choice: auto + tools: + - description: '' + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + strict: false + type: function + uri: https://api.openai.com/v1/responses + response: + headers: + alt-svc: + - h3=":443"; ma=86400 + connection: + - keep-alive + content-length: + - '1808' + content-type: + - application/json + openai-organization: + - pydantic-28gund + openai-processing-ms: + - '636' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + background: false + created_at: 1749516047 + error: null + id: resp_68477f0f220081a1a621d6bcdc7f31a50b8591d9001d2329 + incomplete_details: null + instructions: null + max_output_tokens: null + metadata: {} + model: gpt-4o-2024-08-06 + object: response + output: + - arguments: '{}' + call_id: call_tTAThu8l2S9hNky2krdwijGP + id: fc_68477f0fa7c081a19a525f7c6f180f310b8591d9001d2329 + name: get_user_country + status: completed + type: function_call + parallel_tool_calls: true + previous_response_id: null + reasoning: + effort: null + summary: null + service_tier: default + status: completed + store: true + temperature: 1.0 + text: + format: + description: null + name: CityLocation + schema: + additionalProperties: false + properties: + city: + type: string + country: + type: string + required: + - city + - country + type: object + strict: true + type: json_schema + tool_choice: auto + tools: + - description: null + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + strict: false + type: function + top_p: 1.0 + truncation: disabled + usage: + input_tokens: 66 + input_tokens_details: + cached_tokens: 0 + output_tokens: 12 + output_tokens_details: + reasoning_tokens: 0 + total_tokens: 78 + user: null + status: + code: 200 + message: OK +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '769' + content-type: + - application/json + cookie: + - __cf_bm=My3TWVEPFsaYcjJ.iWxTB6P67jFSuxSF.n13qHpH9BA-1749516047-1.0.1.1-2bg2ltV1yu2uhfqewI9eEG1ulzfU_gq8pLx9YwHte33BTk2PgxBwaRdyegdEs_dVkAbaCoAPsQRIQmW21QPf_U2Fd1vdibnoExA_.rvTYv8; + _cfuvid=_7XoQBGwU.UsQgiPHVWMTXLLbADtbSwhrO9PY7I_3Dw-1749516047790-0.0.1.1-604800000 + host: + - api.openai.com + method: POST + parsed_body: + input: + - content: What is the largest city in the user country? + role: user + - content: '' + role: assistant + - arguments: '{}' + call_id: call_tTAThu8l2S9hNky2krdwijGP + name: get_user_country + type: function_call + - call_id: call_tTAThu8l2S9hNky2krdwijGP + output: Mexico + type: function_call_output + model: gpt-4o + stream: false + text: + format: + name: CityLocation + schema: + additionalProperties: false + properties: + city: + type: string + country: + type: string + required: + - city + - country + type: object + strict: true + type: json_schema + tool_choice: auto + tools: + - description: '' + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + strict: false + type: function + uri: https://api.openai.com/v1/responses + response: + headers: + alt-svc: + - h3=":443"; ma=86400 + connection: + - keep-alive + content-length: + - '1902' + content-type: + - application/json + openai-organization: + - pydantic-28gund + openai-processing-ms: + - '883' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + background: false + created_at: 1749516047 + error: null + id: resp_68477f0fde708192989000a62809c6e5020197534e39cc1f + incomplete_details: null + instructions: null + max_output_tokens: null + metadata: {} + model: gpt-4o-2024-08-06 + object: response + output: + - content: + - annotations: [] + text: '{"city":"Mexico City","country":"Mexico"}' + type: output_text + id: msg_68477f10846c81929f1e833b0785e6f3020197534e39cc1f + role: assistant + status: completed + type: message + parallel_tool_calls: true + previous_response_id: null + reasoning: + effort: null + summary: null + service_tier: default + status: completed + store: true + temperature: 1.0 + text: + format: + description: null + name: CityLocation + schema: + additionalProperties: false + properties: + city: + type: string + country: + type: string + required: + - city + - country + type: object + strict: true + type: json_schema + tool_choice: auto + tools: + - description: null + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + strict: false + type: function + top_p: 1.0 + truncation: disabled + usage: + input_tokens: 89 + input_tokens_details: + cached_tokens: 0 + output_tokens: 16 + output_tokens_details: + reasoning_tokens: 0 + total_tokens: 105 + user: null + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/cassettes/test_openai_responses/test_json_schema_output_multiple.yaml b/tests/models/cassettes/test_openai_responses/test_json_schema_output_multiple.yaml new file mode 100644 index 000000000..9c411f3c7 --- /dev/null +++ b/tests/models/cassettes/test_openai_responses/test_json_schema_output_multiple.yaml @@ -0,0 +1,444 @@ +interactions: +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '1143' + content-type: + - application/json + host: + - api.openai.com + method: POST + parsed_body: + input: + - content: What is the largest city in the user country? + role: user + model: gpt-4o + stream: false + text: + format: + name: final_result + schema: + additionalProperties: false + properties: + result: + anyOf: + - additionalProperties: false + description: CityLocation + properties: + data: + additionalProperties: false + properties: + city: + type: string + country: + type: string + required: + - city + - country + type: object + kind: + const: CityLocation + type: string + required: + - kind + - data + type: object + - additionalProperties: false + description: CountryLanguage + properties: + data: + additionalProperties: false + properties: + country: + type: string + language: + type: string + required: + - country + - language + type: object + kind: + const: CountryLanguage + type: string + required: + - kind + - data + type: object + required: + - result + type: object + strict: true + type: json_schema + tool_choice: auto + tools: + - description: '' + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + strict: false + type: function + uri: https://api.openai.com/v1/responses + response: + headers: + alt-svc: + - h3=":443"; ma=86400 + connection: + - keep-alive + content-length: + - '3657' + content-type: + - application/json + openai-organization: + - pydantic-28gund + openai-processing-ms: + - '562' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + background: false + created_at: 1749516048 + error: null + id: resp_68477f10f2d081a39b3438f413b3bafc0dd57d732903c563 + incomplete_details: null + instructions: null + max_output_tokens: null + metadata: {} + model: gpt-4o-2024-08-06 + object: response + output: + - arguments: '{}' + call_id: call_UaLahjOtaM2tTyYZLxTCbOaP + id: fc_68477f1168a081a3981e847cd94275080dd57d732903c563 + name: get_user_country + status: completed + type: function_call + parallel_tool_calls: true + previous_response_id: null + reasoning: + effort: null + summary: null + service_tier: default + status: completed + store: true + temperature: 1.0 + text: + format: + description: null + name: final_result + schema: + additionalProperties: false + properties: + result: + anyOf: + - additionalProperties: false + description: CityLocation + properties: + data: + additionalProperties: false + properties: + city: + type: string + country: + type: string + required: + - city + - country + type: object + kind: + const: CityLocation + type: string + required: + - kind + - data + type: object + - additionalProperties: false + description: CountryLanguage + properties: + data: + additionalProperties: false + properties: + country: + type: string + language: + type: string + required: + - country + - language + type: object + kind: + const: CountryLanguage + type: string + required: + - kind + - data + type: object + required: + - result + type: object + strict: true + type: json_schema + tool_choice: auto + tools: + - description: null + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + strict: false + type: function + top_p: 1.0 + truncation: disabled + usage: + input_tokens: 153 + input_tokens_details: + cached_tokens: 0 + output_tokens: 12 + output_tokens_details: + reasoning_tokens: 0 + total_tokens: 165 + user: null + status: + code: 200 + message: OK +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '1379' + content-type: + - application/json + cookie: + - __cf_bm=3Nl1ERbtfVAI.dGjzCYYN1u71YD5eEoLU0iCrvPPPL0-1749516049-1.0.1.1-LnI7tJwKr.C_wA15Shsl8pcGd32zrRqqv_9u4S84nXtNCopx1iBIKYDsyMg3u1Z3lJ_1Cd1YVM8uKAMjiKmgoqS8GFQ3Z_vV_Mahvqbi4KA; + _cfuvid=oc_k9l86fnMo2ml.0aop6a3eVJEvjxB0lnxWK0_kJq8-1749516049524-0.0.1.1-604800000 + host: + - api.openai.com + method: POST + parsed_body: + input: + - content: What is the largest city in the user country? + role: user + - content: '' + role: assistant + - arguments: '{}' + call_id: call_UaLahjOtaM2tTyYZLxTCbOaP + name: get_user_country + type: function_call + - call_id: call_UaLahjOtaM2tTyYZLxTCbOaP + output: Mexico + type: function_call_output + model: gpt-4o + stream: false + text: + format: + name: final_result + schema: + additionalProperties: false + properties: + result: + anyOf: + - additionalProperties: false + description: CityLocation + properties: + data: + additionalProperties: false + properties: + city: + type: string + country: + type: string + required: + - city + - country + type: object + kind: + const: CityLocation + type: string + required: + - kind + - data + type: object + - additionalProperties: false + description: CountryLanguage + properties: + data: + additionalProperties: false + properties: + country: + type: string + language: + type: string + required: + - country + - language + type: object + kind: + const: CountryLanguage + type: string + required: + - kind + - data + type: object + required: + - result + type: object + strict: true + type: json_schema + tool_choice: auto + tools: + - description: '' + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + strict: false + type: function + uri: https://api.openai.com/v1/responses + response: + headers: + alt-svc: + - h3=":443"; ma=86400 + connection: + - keep-alive + content-length: + - '3800' + content-type: + - application/json + openai-organization: + - pydantic-28gund + openai-processing-ms: + - '1042' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + background: false + created_at: 1749516049 + error: null + id: resp_68477f119830819da162aa6e10552035061ad97e2eef7871 + incomplete_details: null + instructions: null + max_output_tokens: null + metadata: {} + model: gpt-4o-2024-08-06 + object: response + output: + - content: + - annotations: [] + text: '{"result":{"kind":"CityLocation","data":{"city":"Mexico City","country":"Mexico"}}}' + type: output_text + id: msg_68477f1235b8819d898adc64709c7ebf061ad97e2eef7871 + role: assistant + status: completed + type: message + parallel_tool_calls: true + previous_response_id: null + reasoning: + effort: null + summary: null + service_tier: default + status: completed + store: true + temperature: 1.0 + text: + format: + description: null + name: final_result + schema: + additionalProperties: false + properties: + result: + anyOf: + - additionalProperties: false + description: CityLocation + properties: + data: + additionalProperties: false + properties: + city: + type: string + country: + type: string + required: + - city + - country + type: object + kind: + const: CityLocation + type: string + required: + - kind + - data + type: object + - additionalProperties: false + description: CountryLanguage + properties: + data: + additionalProperties: false + properties: + country: + type: string + language: + type: string + required: + - country + - language + type: object + kind: + const: CountryLanguage + type: string + required: + - kind + - data + type: object + required: + - result + type: object + strict: true + type: json_schema + tool_choice: auto + tools: + - description: null + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + strict: false + type: function + top_p: 1.0 + truncation: disabled + usage: + input_tokens: 176 + input_tokens_details: + cached_tokens: 0 + output_tokens: 26 + output_tokens_details: + reasoning_tokens: 0 + total_tokens: 202 + user: null + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/cassettes/test_openai_responses/test_prompted_json_output.yaml b/tests/models/cassettes/test_openai_responses/test_prompted_json_output.yaml new file mode 100644 index 000000000..35783c516 --- /dev/null +++ b/tests/models/cassettes/test_openai_responses/test_prompted_json_output.yaml @@ -0,0 +1,248 @@ +interactions: +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '689' + content-type: + - application/json + host: + - api.openai.com + method: POST + parsed_body: + input: + - content: |- + Always respond with a JSON object that's compatible with this schema: + + {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} + + Don't include any text or Markdown fencing before or after. + role: system + - content: What is the largest city in the user country? + role: user + model: gpt-4o + stream: false + text: + format: + type: json_object + tool_choice: auto + tools: + - description: '' + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + strict: false + type: function + uri: https://api.openai.com/v1/responses + response: + headers: + alt-svc: + - h3=":443"; ma=86400 + connection: + - keep-alive + content-length: + - '1408' + content-type: + - application/json + openai-organization: + - pydantic-28gund + openai-processing-ms: + - '8314' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + background: false + created_at: 1749561106 + error: null + id: resp_68482f12d63881a1830201ed101ecfbf02f8ef7f2fb42b50 + incomplete_details: null + instructions: null + max_output_tokens: null + metadata: {} + model: gpt-4o-2024-08-06 + object: response + output: + - arguments: '{}' + call_id: call_FrlL4M0CbAy8Dhv4VqF1Shom + id: fc_68482f1b0ff081a1b37b9170ee740d1e02f8ef7f2fb42b50 + name: get_user_country + status: completed + type: function_call + parallel_tool_calls: true + previous_response_id: null + reasoning: + effort: null + summary: null + service_tier: default + status: completed + store: true + temperature: 1.0 + text: + format: + type: json_object + tool_choice: auto + tools: + - description: null + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + strict: false + type: function + top_p: 1.0 + truncation: disabled + usage: + input_tokens: 107 + input_tokens_details: + cached_tokens: 0 + output_tokens: 12 + output_tokens_details: + reasoning_tokens: 0 + total_tokens: 119 + user: null + status: + code: 200 + message: OK +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '925' + content-type: + - application/json + cookie: + - __cf_bm=8a8rNQQYozQt3YjcA61k6KGe.AlrMMrtcIvKv.D1s1E-1749561115-1.0.1.1-OFcqg8xD2_HdbeO74bU2.mLTqDuiK.ploHeu3_ITPvDlGwrVkwk8erMkHagxk4UDxACCCAygnUs1HL.F4AGjQCaZm1m2eYiMVbLqp0iQh7g; + _cfuvid=wKTRRc2dbdYNYnYwA2vRxVjUvqqkQovvKDwULW0Xwns-1749561115173-0.0.1.1-604800000 + host: + - api.openai.com + method: POST + parsed_body: + input: + - content: |- + Always respond with a JSON object that's compatible with this schema: + + {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} + + Don't include any text or Markdown fencing before or after. + role: system + - content: What is the largest city in the user country? + role: user + - content: '' + role: assistant + - arguments: '{}' + call_id: call_FrlL4M0CbAy8Dhv4VqF1Shom + name: get_user_country + type: function_call + - call_id: call_FrlL4M0CbAy8Dhv4VqF1Shom + output: Mexico + type: function_call_output + model: gpt-4o + stream: false + text: + format: + type: json_object + tool_choice: auto + tools: + - description: '' + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + strict: false + type: function + uri: https://api.openai.com/v1/responses + response: + headers: + alt-svc: + - h3=":443"; ma=86400 + connection: + - keep-alive + content-length: + - '1501' + content-type: + - application/json + openai-organization: + - pydantic-28gund + openai-processing-ms: + - '1098' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + background: false + created_at: 1749561115 + error: null + id: resp_68482f1b556081918d64c9088a470bf0044fdb7d019d4115 + incomplete_details: null + instructions: null + max_output_tokens: null + metadata: {} + model: gpt-4o-2024-08-06 + object: response + output: + - content: + - annotations: [] + text: '{"city":"Mexico City","country":"Mexico"}' + type: output_text + id: msg_68482f1c159081918a2405f458009a6a044fdb7d019d4115 + role: assistant + status: completed + type: message + parallel_tool_calls: true + previous_response_id: null + reasoning: + effort: null + summary: null + service_tier: default + status: completed + store: true + temperature: 1.0 + text: + format: + type: json_object + tool_choice: auto + tools: + - description: null + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + strict: false + type: function + top_p: 1.0 + truncation: disabled + usage: + input_tokens: 130 + input_tokens_details: + cached_tokens: 0 + output_tokens: 12 + output_tokens_details: + reasoning_tokens: 0 + total_tokens: 142 + user: null + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/cassettes/test_openai_responses/test_prompted_json_output_multiple.yaml b/tests/models/cassettes/test_openai_responses/test_prompted_json_output_multiple.yaml new file mode 100644 index 000000000..1a3b4dc00 --- /dev/null +++ b/tests/models/cassettes/test_openai_responses/test_prompted_json_output_multiple.yaml @@ -0,0 +1,248 @@ +interactions: +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '1455' + content-type: + - application/json + host: + - api.openai.com + method: POST + parsed_body: + input: + - content: |- + Always respond with a JSON object that's compatible with this schema: + + {"type": "object", "properties": {"result": {"anyOf": [{"type": "object", "properties": {"kind": {"type": "string", "const": "CityLocation"}, "data": {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"}}, "description": "CityLocation", "required": ["kind", "data"], "additionalProperties": false}, {"type": "object", "properties": {"kind": {"type": "string", "const": "CountryLanguage"}, "data": {"properties": {"country": {"type": "string"}, "language": {"type": "string"}}, "required": ["country", "language"], "title": "CountryLanguage", "type": "object"}}, "description": "CountryLanguage", "required": ["kind", "data"], "additionalProperties": false}]}}, "required": ["result"], "additionalProperties": false} + + Don't include any text or Markdown fencing before or after. + role: system + - content: What is the largest city in the user country? + role: user + model: gpt-4o + stream: false + text: + format: + type: json_object + tool_choice: auto + tools: + - description: '' + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + strict: false + type: function + uri: https://api.openai.com/v1/responses + response: + headers: + alt-svc: + - h3=":443"; ma=86400 + connection: + - keep-alive + content-length: + - '1408' + content-type: + - application/json + openai-organization: + - pydantic-28gund + openai-processing-ms: + - '11445' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + background: false + created_at: 1749561117 + error: null + id: resp_68482f1d38e081a1ac828acda978aa6b08e79646fe74d5ee + incomplete_details: null + instructions: null + max_output_tokens: null + metadata: {} + model: gpt-4o-2024-08-06 + object: response + output: + - arguments: '{}' + call_id: call_my4OyoVXRT0m7bLWmsxcaCQI + id: fc_68482f2889d481a199caa61de7ccb62c08e79646fe74d5ee + name: get_user_country + status: completed + type: function_call + parallel_tool_calls: true + previous_response_id: null + reasoning: + effort: null + summary: null + service_tier: default + status: completed + store: true + temperature: 1.0 + text: + format: + type: json_object + tool_choice: auto + tools: + - description: null + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + strict: false + type: function + top_p: 1.0 + truncation: disabled + usage: + input_tokens: 283 + input_tokens_details: + cached_tokens: 0 + output_tokens: 12 + output_tokens_details: + reasoning_tokens: 0 + total_tokens: 295 + user: null + status: + code: 200 + message: OK +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '1691' + content-type: + - application/json + cookie: + - __cf_bm=l95LdgPzGHw0UAhBwse9ADphgmMDWrhYqgiO4gdmSy4-1749561128-1.0.1.1-9zPIs3d5_ipszLpQ7yBaCZEStp8qoRIGFshR93V6n7Z_7AznH0MfuczwuoiaW8e6cEVeVHLhskjXScolO9gP5TmpsaFo37GRuHsHZTRgEeI; + _cfuvid=5L5qtbtbFCFzMmoVufSY.ksn06ay8AFs.UXFEv07pkY-1749561128680-0.0.1.1-604800000 + host: + - api.openai.com + method: POST + parsed_body: + input: + - content: |- + Always respond with a JSON object that's compatible with this schema: + + {"type": "object", "properties": {"result": {"anyOf": [{"type": "object", "properties": {"kind": {"type": "string", "const": "CityLocation"}, "data": {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"}}, "description": "CityLocation", "required": ["kind", "data"], "additionalProperties": false}, {"type": "object", "properties": {"kind": {"type": "string", "const": "CountryLanguage"}, "data": {"properties": {"country": {"type": "string"}, "language": {"type": "string"}}, "required": ["country", "language"], "title": "CountryLanguage", "type": "object"}}, "description": "CountryLanguage", "required": ["kind", "data"], "additionalProperties": false}]}}, "required": ["result"], "additionalProperties": false} + + Don't include any text or Markdown fencing before or after. + role: system + - content: What is the largest city in the user country? + role: user + - content: '' + role: assistant + - arguments: '{}' + call_id: call_my4OyoVXRT0m7bLWmsxcaCQI + name: get_user_country + type: function_call + - call_id: call_my4OyoVXRT0m7bLWmsxcaCQI + output: Mexico + type: function_call_output + model: gpt-4o + stream: false + text: + format: + type: json_object + tool_choice: auto + tools: + - description: '' + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + strict: false + type: function + uri: https://api.openai.com/v1/responses + response: + headers: + alt-svc: + - h3=":443"; ma=86400 + connection: + - keep-alive + content-length: + - '1551' + content-type: + - application/json + openai-organization: + - pydantic-28gund + openai-processing-ms: + - '2545' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + background: false + created_at: 1749561128 + error: null + id: resp_68482f28c1b081a1ae73cbbee012ee4906b4ab2d00d03024 + incomplete_details: null + instructions: null + max_output_tokens: null + metadata: {} + model: gpt-4o-2024-08-06 + object: response + output: + - content: + - annotations: [] + text: '{"result":{"kind":"CityLocation","data":{"city":"Mexico City","country":"Mexico"}}}' + type: output_text + id: msg_68482f296bfc81a18665547d4008ab2c06b4ab2d00d03024 + role: assistant + status: completed + type: message + parallel_tool_calls: true + previous_response_id: null + reasoning: + effort: null + summary: null + service_tier: default + status: completed + store: true + temperature: 1.0 + text: + format: + type: json_object + tool_choice: auto + tools: + - description: null + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + strict: false + type: function + top_p: 1.0 + truncation: disabled + usage: + input_tokens: 306 + input_tokens_details: + cached_tokens: 0 + output_tokens: 22 + output_tokens_details: + reasoning_tokens: 0 + total_tokens: 328 + user: null + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/cassettes/test_openai_responses/test_text_output_function.yaml b/tests/models/cassettes/test_openai_responses/test_text_output_function.yaml new file mode 100644 index 000000000..ff4ff9acf --- /dev/null +++ b/tests/models/cassettes/test_openai_responses/test_text_output_function.yaml @@ -0,0 +1,228 @@ +interactions: +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '302' + content-type: + - application/json + host: + - api.openai.com + method: POST + parsed_body: + input: + - content: What is the largest city in the user country? + role: user + model: gpt-4o + stream: false + tool_choice: auto + tools: + - description: '' + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + strict: false + type: function + uri: https://api.openai.com/v1/responses + response: + headers: + alt-svc: + - h3=":443"; ma=86400 + connection: + - keep-alive + content-length: + - '1399' + content-type: + - application/json + openai-organization: + - pydantic-28gund + openai-processing-ms: + - '490' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + background: false + created_at: 1749516045 + error: null + id: resp_68477f0d9494819ea4f123bba707c9ee0356a60c98816d6a + incomplete_details: null + instructions: null + max_output_tokens: null + metadata: {} + model: gpt-4o-2024-08-06 + object: response + output: + - arguments: '{}' + call_id: call_aTJhYjzmixZaVGqwl5gn2Ncr + id: fc_68477f0dff5c819ea17a1ffbaea621e00356a60c98816d6a + name: get_user_country + status: completed + type: function_call + parallel_tool_calls: true + previous_response_id: null + reasoning: + effort: null + summary: null + service_tier: default + status: completed + store: true + temperature: 1.0 + text: + format: + type: text + tool_choice: auto + tools: + - description: null + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + strict: false + type: function + top_p: 1.0 + truncation: disabled + usage: + input_tokens: 36 + input_tokens_details: + cached_tokens: 0 + output_tokens: 12 + output_tokens_details: + reasoning_tokens: 0 + total_tokens: 48 + user: null + status: + code: 200 + message: OK +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '538' + content-type: + - application/json + cookie: + - __cf_bm=JZXeUMfyA2MKPG61ecku4K0wMqhJgj2ih66RpjtdqZk-1749516046-1.0.1.1-ZF5eievVR.Y5iPpLK_dVCJNl_ANFmiDhY4iZDFbopdvvhXnZvwLMCQVFWg.S.nQ0TvOw0it63SRuHbjo3jcjuD0lnI5oRQBJUOLiQElZ_j4; + _cfuvid=K7T3n3fgO8pCRHtSCoIpwTW2UEh0En8ro1rV5aPciGo-1749516046095-0.0.1.1-604800000 + host: + - api.openai.com + method: POST + parsed_body: + input: + - content: What is the largest city in the user country? + role: user + - content: '' + role: assistant + - arguments: '{}' + call_id: call_aTJhYjzmixZaVGqwl5gn2Ncr + name: get_user_country + type: function_call + - call_id: call_aTJhYjzmixZaVGqwl5gn2Ncr + output: Mexico + type: function_call_output + model: gpt-4o + stream: false + tool_choice: auto + tools: + - description: '' + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + strict: false + type: function + uri: https://api.openai.com/v1/responses + response: + headers: + alt-svc: + - h3=":443"; ma=86400 + connection: + - keep-alive + content-length: + - '1485' + content-type: + - application/json + openai-organization: + - pydantic-28gund + openai-processing-ms: + - '825' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + background: false + created_at: 1749516046 + error: null + id: resp_68477f0e2b28819d9c828ef4ee526d6a03434b607c02582d + incomplete_details: null + instructions: null + max_output_tokens: null + metadata: {} + model: gpt-4o-2024-08-06 + object: response + output: + - content: + - annotations: [] + text: The largest city in Mexico is Mexico City. + type: output_text + id: msg_68477f0ebf54819d88a44fa87aadaff503434b607c02582d + role: assistant + status: completed + type: message + parallel_tool_calls: true + previous_response_id: null + reasoning: + effort: null + summary: null + service_tier: default + status: completed + store: true + temperature: 1.0 + text: + format: + type: text + tool_choice: auto + tools: + - description: null + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + strict: false + type: function + top_p: 1.0 + truncation: disabled + usage: + input_tokens: 59 + input_tokens_details: + cached_tokens: 0 + output_tokens: 11 + output_tokens_details: + reasoning_tokens: 0 + total_tokens: 70 + user: null + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/cassettes/test_openai_responses/test_tool_output.yaml b/tests/models/cassettes/test_openai_responses/test_tool_output.yaml new file mode 100644 index 000000000..bc201f7c1 --- /dev/null +++ b/tests/models/cassettes/test_openai_responses/test_tool_output.yaml @@ -0,0 +1,282 @@ +interactions: +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '556' + content-type: + - application/json + host: + - api.openai.com + method: POST + parsed_body: + input: + - content: What is the largest city in the user country? + role: user + model: gpt-4o + stream: false + tool_choice: required + tools: + - description: '' + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + strict: false + type: function + - description: The final response which ends this conversation + name: final_result + parameters: + properties: + city: + type: string + country: + type: string + required: + - city + - country + type: object + strict: false + type: function + uri: https://api.openai.com/v1/responses + response: + headers: + alt-svc: + - h3=":443"; ma=86400 + connection: + - keep-alive + content-length: + - '1854' + content-type: + - application/json + openai-organization: + - pydantic-28gund + openai-processing-ms: + - '568' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + background: false + created_at: 1749516043 + error: null + id: resp_68477f0b40a8819cb8d55594bc2c232a001fd29e2d5573f7 + incomplete_details: null + instructions: null + max_output_tokens: null + metadata: {} + model: gpt-4o-2024-08-06 + object: response + output: + - arguments: '{}' + call_id: call_ZWkVhdUjupo528U9dqgFeRkH + id: fc_68477f0bb8e4819cba6d781e174d77f8001fd29e2d5573f7 + name: get_user_country + status: completed + type: function_call + parallel_tool_calls: true + previous_response_id: null + reasoning: + effort: null + summary: null + service_tier: default + status: completed + store: true + temperature: 1.0 + text: + format: + type: text + tool_choice: required + tools: + - description: null + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + strict: false + type: function + - description: The final response which ends this conversation + name: final_result + parameters: + properties: + city: + type: string + country: + type: string + required: + - city + - country + type: object + strict: false + type: function + top_p: 1.0 + truncation: disabled + usage: + input_tokens: 62 + input_tokens_details: + cached_tokens: 0 + output_tokens: 12 + output_tokens_details: + reasoning_tokens: 0 + total_tokens: 74 + user: null + status: + code: 200 + message: OK +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '792' + content-type: + - application/json + cookie: + - __cf_bm=78_bxRDp8.6VLECkU4_YSNYd7PlmVGdN1E4j5KBkoOA-1749516043-1.0.1.1-Z9ZwaEzQZcS64A536kPafni6AZEqjCr1xDJ1h2WXjDrs0G_LuZPuq7Z27rs6w0.2DAk_UEY0.H.YMVFpWwe0QTOI28mlvDMbZvVsP6LT4Ug; + _cfuvid=Qym79CFc.nJ8O7pqDQfy1eFUEqIDIX3VuqfAl93F07o-1749516043838-0.0.1.1-604800000 + host: + - api.openai.com + method: POST + parsed_body: + input: + - content: What is the largest city in the user country? + role: user + - content: '' + role: assistant + - arguments: '{}' + call_id: call_ZWkVhdUjupo528U9dqgFeRkH + name: get_user_country + type: function_call + - call_id: call_ZWkVhdUjupo528U9dqgFeRkH + output: Mexico + type: function_call_output + model: gpt-4o + stream: false + tool_choice: required + tools: + - description: '' + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + strict: false + type: function + - description: The final response which ends this conversation + name: final_result + parameters: + properties: + city: + type: string + country: + type: string + required: + - city + - country + type: object + strict: false + type: function + uri: https://api.openai.com/v1/responses + response: + headers: + alt-svc: + - h3=":443"; ma=86400 + connection: + - keep-alive + content-length: + - '1898' + content-type: + - application/json + openai-organization: + - pydantic-28gund + openai-processing-ms: + - '840' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + background: false + created_at: 1749516044 + error: null + id: resp_68477f0bfda8819ea65458cd7cc389b801dc81d4bc91f560 + incomplete_details: null + instructions: null + max_output_tokens: null + metadata: {} + model: gpt-4o-2024-08-06 + object: response + output: + - arguments: '{"city":"Mexico City","country":"Mexico"}' + call_id: call_iFBd0zULhSZRR908DfH73VwN + id: fc_68477f0c91cc819e8024e7e633f0f09401dc81d4bc91f560 + name: final_result + status: completed + type: function_call + parallel_tool_calls: true + previous_response_id: null + reasoning: + effort: null + summary: null + service_tier: default + status: completed + store: true + temperature: 1.0 + text: + format: + type: text + tool_choice: required + tools: + - description: null + name: get_user_country + parameters: + additionalProperties: false + properties: {} + type: object + strict: false + type: function + - description: The final response which ends this conversation + name: final_result + parameters: + properties: + city: + type: string + country: + type: string + required: + - city + - country + type: object + strict: false + type: function + top_p: 1.0 + truncation: disabled + usage: + input_tokens: 85 + input_tokens_details: + cached_tokens: 0 + output_tokens: 20 + output_tokens_details: + reasoning_tokens: 0 + total_tokens: 105 + user: null + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/test_anthropic.py b/tests/models/test_anthropic.py index 8490d6204..c5d769e79 100644 --- a/tests/models/test_anthropic.py +++ b/tests/models/test_anthropic.py @@ -11,6 +11,7 @@ import httpx import pytest from inline_snapshot import snapshot +from pydantic import BaseModel from pydantic_ai import Agent, ModelHTTPError, ModelRetry from pydantic_ai.messages import ( @@ -26,7 +27,7 @@ ToolReturnPart, UserPromptPart, ) -from pydantic_ai.result import Usage +from pydantic_ai.result import PromptedJsonOutput, TextOutput, ToolOutput, Usage from pydantic_ai.settings import ModelSettings from ..conftest import IsDatetime, IsNow, IsStr, TestEnv, raise_if_exception, try_import @@ -1063,3 +1064,342 @@ async def test_anthropic_model_empty_message_on_history(allow_model_requests: No What specifically would you like to know about potatoes?\ """) + + +@pytest.mark.vcr() +async def test_anthropic_tool_output(allow_model_requests: None, anthropic_api_key: str): + m = AnthropicModel('claude-3-5-sonnet-latest', provider=AnthropicProvider(api_key=anthropic_api_key)) + + class CityLocation(BaseModel): + city: str + country: str + + agent = Agent(m, output_type=ToolOutput(CityLocation)) + + @agent.tool_plain + async def get_user_country() -> str: + return 'Mexico' + + result = await agent.run('What is the largest city in the user country?') + assert result.output == snapshot(CityLocation(city='Mexico City', country='Mexico')) + + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='What is the largest city in the user country?', + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[ + ToolCallPart(tool_name='get_user_country', args={}, tool_call_id='toolu_019pMboNVRg5jkw4PKkofQ6Y') + ], + usage=Usage( + requests=1, + request_tokens=445, + response_tokens=23, + total_tokens=468, + details={ + 'cache_creation_input_tokens': 0, + 'cache_read_input_tokens': 0, + 'input_tokens': 445, + 'output_tokens': 23, + }, + ), + model_name='claude-3-5-sonnet-20241022', + timestamp=IsDatetime(), + vendor_id='msg_01EnfsDTixCmHjqvk9QarBj4', + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='get_user_country', + content='Mexico', + tool_call_id='toolu_019pMboNVRg5jkw4PKkofQ6Y', + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[ + ToolCallPart( + tool_name='final_result', + args={'city': 'Mexico City', 'country': 'Mexico'}, + tool_call_id='toolu_01V4d2H4EWp5LDM2aXaeyR6W', + ) + ], + usage=Usage( + requests=1, + request_tokens=497, + response_tokens=56, + total_tokens=553, + details={ + 'cache_creation_input_tokens': 0, + 'cache_read_input_tokens': 0, + 'input_tokens': 497, + 'output_tokens': 56, + }, + ), + model_name='claude-3-5-sonnet-20241022', + timestamp=IsDatetime(), + vendor_id='msg_01Hbm5BtKzfVtWs8Eb7rCNNx', + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='final_result', + content='Final result processed.', + tool_call_id='toolu_01V4d2H4EWp5LDM2aXaeyR6W', + timestamp=IsDatetime(), + ) + ] + ), + ] + ) + + +@pytest.mark.vcr() +async def test_anthropic_text_output_function(allow_model_requests: None, anthropic_api_key: str): + m = AnthropicModel('claude-3-5-sonnet-latest', provider=AnthropicProvider(api_key=anthropic_api_key)) + + def upcase(text: str) -> str: + return text.upper() + + agent = Agent(m, output_type=TextOutput(upcase)) + + @agent.tool_plain + async def get_user_country() -> str: + return 'Mexico' + + result = await agent.run( + 'What is the largest city in the user country? Use the get_user_country tool and then your own world knowledge.' + ) + assert result.output == snapshot( + "BASED ON THE RESULT, YOU ARE LOCATED IN MEXICO. THE LARGEST CITY IN MEXICO IS MEXICO CITY (CIUDAD DE MÉXICO), WHICH IS ALSO THE NATION'S CAPITAL. MEXICO CITY HAS A POPULATION OF APPROXIMATELY 9.2 MILLION PEOPLE IN THE CITY PROPER, AND OVER 21 MILLION PEOPLE IN ITS METROPOLITAN AREA, MAKING IT ONE OF THE LARGEST URBAN AGGLOMERATIONS IN THE WORLD. IT IS BOTH THE POLITICAL AND ECONOMIC CENTER OF MEXICO, LOCATED IN THE VALLEY OF MEXICO IN THE CENTRAL PART OF THE COUNTRY." + ) + + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='What is the largest city in the user country? Use the get_user_country tool and then your own world knowledge.', + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[ + TextPart( + content="I'll help you find the largest city in your country. Let me first check your country using the get_user_country tool." + ), + ToolCallPart(tool_name='get_user_country', args={}, tool_call_id='toolu_01EZuxfc6MsPsPgrAKQohw3e'), + ], + usage=Usage( + requests=1, + request_tokens=383, + response_tokens=66, + total_tokens=449, + details={ + 'cache_creation_input_tokens': 0, + 'cache_read_input_tokens': 0, + 'input_tokens': 383, + 'output_tokens': 66, + }, + ), + model_name='claude-3-5-sonnet-20241022', + timestamp=IsDatetime(), + vendor_id='msg_014NE4yfV1Yz2vLAJzapxxef', + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='get_user_country', + content='Mexico', + tool_call_id='toolu_01EZuxfc6MsPsPgrAKQohw3e', + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[ + TextPart( + content="Based on the result, you are located in Mexico. The largest city in Mexico is Mexico City (Ciudad de México), which is also the nation's capital. Mexico City has a population of approximately 9.2 million people in the city proper, and over 21 million people in its metropolitan area, making it one of the largest urban agglomerations in the world. It is both the political and economic center of Mexico, located in the Valley of Mexico in the central part of the country." + ) + ], + usage=Usage( + requests=1, + request_tokens=461, + response_tokens=107, + total_tokens=568, + details={ + 'cache_creation_input_tokens': 0, + 'cache_read_input_tokens': 0, + 'input_tokens': 461, + 'output_tokens': 107, + }, + ), + model_name='claude-3-5-sonnet-20241022', + timestamp=IsDatetime(), + vendor_id='msg_0193srwo7TCx49h97wDwc7K7', + ), + ] + ) + + +@pytest.mark.vcr() +async def test_anthropic_prompted_json_output(allow_model_requests: None, anthropic_api_key: str): + m = AnthropicModel('claude-3-5-sonnet-latest', provider=AnthropicProvider(api_key=anthropic_api_key)) + + class CityLocation(BaseModel): + city: str + country: str + + agent = Agent(m, output_type=PromptedJsonOutput(CityLocation)) + + @agent.tool_plain + async def get_user_country() -> str: + return 'Mexico' + + result = await agent.run( + 'What is the largest city in the user country? Use the get_user_country tool and then your own world knowledge.' + ) + assert result.output == snapshot(CityLocation(city='Mexico City', country='Mexico')) + + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='What is the largest city in the user country? Use the get_user_country tool and then your own world knowledge.', + timestamp=IsDatetime(), + ) + ], + instructions="""\ +Always respond with a JSON object that's compatible with this schema: + +{"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} + +Don't include any text or Markdown fencing before or after.\ +""", + ), + ModelResponse( + parts=[ + ToolCallPart(tool_name='get_user_country', args={}, tool_call_id='toolu_017UryVwtsKsjonhFV3cgV3X') + ], + usage=Usage( + requests=1, + request_tokens=459, + response_tokens=38, + total_tokens=497, + details={ + 'cache_creation_input_tokens': 0, + 'cache_read_input_tokens': 0, + 'input_tokens': 459, + 'output_tokens': 38, + }, + ), + model_name='claude-3-5-sonnet-20241022', + timestamp=IsDatetime(), + vendor_id='msg_014CpBKzioMqUyLWrMihpvsz', + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='get_user_country', + content='Mexico', + tool_call_id='toolu_017UryVwtsKsjonhFV3cgV3X', + timestamp=IsDatetime(), + ) + ], + instructions="""\ +Always respond with a JSON object that's compatible with this schema: + +{"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} + +Don't include any text or Markdown fencing before or after.\ +""", + ), + ModelResponse( + parts=[TextPart(content='{"city": "Mexico City", "country": "Mexico"}')], + usage=Usage( + requests=1, + request_tokens=510, + response_tokens=17, + total_tokens=527, + details={ + 'cache_creation_input_tokens': 0, + 'cache_read_input_tokens': 0, + 'input_tokens': 510, + 'output_tokens': 17, + }, + ), + model_name='claude-3-5-sonnet-20241022', + timestamp=IsDatetime(), + vendor_id='msg_014JeWCouH6DpdqzMTaBdkpJ', + ), + ] + ) + + +@pytest.mark.vcr() +async def test_anthropic_prompted_json_output_multiple(allow_model_requests: None, anthropic_api_key: str): + m = AnthropicModel('claude-3-5-sonnet-latest', provider=AnthropicProvider(api_key=anthropic_api_key)) + + class CityLocation(BaseModel): + city: str + country: str + + class CountryLanguage(BaseModel): + country: str + language: str + + agent = Agent(m, output_type=PromptedJsonOutput([CityLocation, CountryLanguage])) + + result = await agent.run('What is the largest city in Mexico?') + assert result.output == snapshot(CityLocation(city='Mexico City', country='Mexico')) + + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='What is the largest city in Mexico?', + timestamp=IsDatetime(), + ) + ], + instructions="""\ +Always respond with a JSON object that's compatible with this schema: + +{"type": "object", "properties": {"result": {"anyOf": [{"type": "object", "properties": {"kind": {"type": "string", "const": "CityLocation"}, "data": {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "type": "object"}}, "required": ["kind", "data"], "additionalProperties": false, "title": "CityLocation"}, {"type": "object", "properties": {"kind": {"type": "string", "const": "CountryLanguage"}, "data": {"properties": {"country": {"type": "string"}, "language": {"type": "string"}}, "required": ["country", "language"], "type": "object"}}, "required": ["kind", "data"], "additionalProperties": false, "title": "CountryLanguage"}]}}, "required": ["result"], "additionalProperties": false} + +Don't include any text or Markdown fencing before or after.\ +""", + ), + ModelResponse( + parts=[ + TextPart( + content='{"result": {"kind": "CityLocation", "data": {"city": "Mexico City", "country": "Mexico"}}}' + ) + ], + usage=Usage( + requests=1, + request_tokens=281, + response_tokens=31, + total_tokens=312, + details={ + 'cache_creation_input_tokens': 0, + 'cache_read_input_tokens': 0, + 'input_tokens': 281, + 'output_tokens': 31, + }, + ), + model_name='claude-3-5-sonnet-20241022', + timestamp=IsDatetime(), + vendor_id='msg_013ttUi3HCcKt7PkJpoWs5FT', + ), + ] + ) diff --git a/tests/models/test_bedrock.py b/tests/models/test_bedrock.py index 0a0dec8f8..78f8f4e85 100644 --- a/tests/models/test_bedrock.py +++ b/tests/models/test_bedrock.py @@ -684,7 +684,7 @@ async def test_bedrock_anthropic_no_tool_choice(bedrock_provider: BedrockProvide 'This is my tool', {'type': 'object', 'title': 'Result', 'properties': {'spam': {'type': 'number'}}}, ) - mrp = ModelRequestParameters(function_tools=[my_tool], allow_text_output=False, output_tools=[]) + mrp = ModelRequestParameters(output_mode='tool', function_tools=[my_tool], allow_text_output=False, output_tools=[]) # Models other than Anthropic support tool_choice model = BedrockConverseModel('us.amazon.nova-micro-v1:0', provider=bedrock_provider) diff --git a/tests/models/test_fallback.py b/tests/models/test_fallback.py index db6277527..c88876832 100644 --- a/tests/models/test_fallback.py +++ b/tests/models/test_fallback.py @@ -127,7 +127,7 @@ def test_first_failed_instrumented(capfire: CaptureLogfire) -> None: 'end_time': 3000000000, 'attributes': { 'gen_ai.operation.name': 'chat', - 'model_request_parameters': '{"function_tools": [], "allow_text_output": true, "output_tools": []}', + 'model_request_parameters': '{"function_tools": [], "output_mode": "text", "output_object": null, "output_tools": [], "allow_text_output": true}', 'logfire.span_type': 'span', 'logfire.msg': 'chat fallback:function:failure_response:,function:success_response:', 'gen_ai.system': 'function', @@ -200,7 +200,7 @@ async def test_first_failed_instrumented_stream(capfire: CaptureLogfire) -> None 'end_time': 3000000000, 'attributes': { 'gen_ai.operation.name': 'chat', - 'model_request_parameters': '{"function_tools": [], "allow_text_output": true, "output_tools": []}', + 'model_request_parameters': '{"function_tools": [], "output_mode": "text", "output_object": null, "output_tools": [], "allow_text_output": true}', 'logfire.span_type': 'span', 'logfire.msg': 'chat fallback:function::failure_response_stream,function::success_response_stream', 'gen_ai.system': 'function', @@ -272,7 +272,7 @@ def test_all_failed_instrumented(capfire: CaptureLogfire) -> None: 'gen_ai.operation.name': 'chat', 'gen_ai.system': 'fallback:function,function', 'gen_ai.request.model': 'fallback:function:failure_response:,function:failure_response:', - 'model_request_parameters': '{"function_tools": [], "allow_text_output": true, "output_tools": []}', + 'model_request_parameters': '{"function_tools": [], "output_mode": "text", "output_object": null, "output_tools": [], "allow_text_output": true}', 'logfire.json_schema': '{"type": "object", "properties": {"model_request_parameters": {"type": "object"}}}', 'logfire.span_type': 'span', 'logfire.msg': 'chat fallback:function:failure_response:,function:failure_response:', diff --git a/tests/models/test_gemini.py b/tests/models/test_gemini.py index bdf2dfc8a..a3810fe0c 100644 --- a/tests/models/test_gemini.py +++ b/tests/models/test_gemini.py @@ -53,7 +53,7 @@ _GeminiUsageMetaData, ) from pydantic_ai.providers.google_gla import GoogleGLAProvider -from pydantic_ai.result import Usage +from pydantic_ai.result import JsonSchemaOutput, PromptedJsonOutput, TextOutput, ToolOutput, Usage from pydantic_ai.tools import ToolDefinition from ..conftest import ClientWithHandler, IsDatetime, IsNow, IsStr, TestEnv @@ -67,7 +67,9 @@ async def test_model_simple(allow_model_requests: None): assert m.model_name == 'gemini-1.5-flash' assert 'x-goog-api-key' in m.client.headers - mrp = ModelRequestParameters(function_tools=[], allow_text_output=True, output_tools=[]) + mrp = ModelRequestParameters( + function_tools=[], allow_text_output=True, output_tools=[], output_mode='text', output_object=None + ) mrp = m.customize_request_parameters(mrp) tools = m._get_tools(mrp) tool_config = m._get_tool_config(mrp, tools) @@ -100,7 +102,13 @@ async def test_model_tools(allow_model_requests: None): {'type': 'object', 'title': 'Result', 'properties': {'spam': {'type': 'number'}}, 'required': ['spam']}, ) - mrp = ModelRequestParameters(function_tools=tools, allow_text_output=True, output_tools=[output_tool]) + mrp = ModelRequestParameters( + function_tools=tools, + allow_text_output=True, + output_tools=[output_tool], + output_mode='text', + output_object=None, + ) mrp = m.customize_request_parameters(mrp) tools = m._get_tools(mrp) tool_config = m._get_tool_config(mrp, tools) @@ -142,7 +150,13 @@ async def test_require_response_tool(allow_model_requests: None): 'This is the tool for the final Result', {'type': 'object', 'title': 'Result', 'properties': {'spam': {'type': 'number'}}}, ) - mrp = ModelRequestParameters(function_tools=[], allow_text_output=False, output_tools=[output_tool]) + mrp = ModelRequestParameters( + function_tools=[], + allow_text_output=False, + output_tools=[output_tool], + output_mode='tool', + output_object=None, + ) mrp = m.customize_request_parameters(mrp) tools = m._get_tools(mrp) tool_config = m._get_tool_config(mrp, tools) @@ -223,7 +237,13 @@ class Locations(BaseModel): 'This is the tool for the final Result', json_schema, ) - mrp = ModelRequestParameters(function_tools=[], allow_text_output=True, output_tools=[output_tool]) + mrp = ModelRequestParameters( + function_tools=[], + allow_text_output=True, + output_tools=[output_tool], + output_mode='text', + output_object=None, + ) mrp = m.customize_request_parameters(mrp) assert m._get_tools(mrp) == snapshot( { @@ -302,7 +322,13 @@ class QueryDetails(BaseModel): 'This is the tool for the final Result', json_schema, ) - mrp = ModelRequestParameters(function_tools=[], allow_text_output=True, output_tools=[output_tool]) + mrp = ModelRequestParameters( + function_tools=[], + output_mode='text', + allow_text_output=True, + output_tools=[output_tool], + output_object=None, + ) mrp = m.customize_request_parameters(mrp) # This tests that the enum values are properly converted to strings for Gemini @@ -344,7 +370,13 @@ class Locations(BaseModel): 'This is the tool for the final Result', json_schema, ) - mrp = ModelRequestParameters(function_tools=[], allow_text_output=True, output_tools=[output_tool]) + mrp = ModelRequestParameters( + function_tools=[], + allow_text_output=True, + output_tools=[output_tool], + output_mode='text', + output_object=None, + ) mrp = m.customize_request_parameters(mrp) assert m._get_tools(mrp) == snapshot( _GeminiTools( @@ -408,7 +440,13 @@ class Location(BaseModel): json_schema, ) with pytest.raises(UserError, match=r'Recursive `\$ref`s in JSON Schema are not supported by Gemini'): - mrp = ModelRequestParameters(function_tools=[], allow_text_output=True, output_tools=[output_tool]) + mrp = ModelRequestParameters( + function_tools=[], + allow_text_output=True, + output_tools=[output_tool], + output_mode='text', + output_object=None, + ) mrp = m.customize_request_parameters(mrp) @@ -440,7 +478,13 @@ class FormattedStringFields(BaseModel): 'This is the tool for the final Result', json_schema, ) - mrp = ModelRequestParameters(function_tools=[], allow_text_output=True, output_tools=[output_tool]) + mrp = ModelRequestParameters( + function_tools=[], + allow_text_output=True, + output_tools=[output_tool], + output_mode='text', + output_object=None, + ) mrp = m.customize_request_parameters(mrp) assert m._get_tools(mrp) == snapshot( _GeminiTools( @@ -1398,3 +1442,462 @@ async def test_response_with_thought_part(get_gemini_client: GetGeminiClient): assert result.output == 'Hello from thought test' assert result.usage() == snapshot(Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3)) + + +@pytest.mark.vcr() +async def test_gemini_tool_output(allow_model_requests: None, gemini_api_key: str): + m = GeminiModel('gemini-2.0-flash', provider=GoogleGLAProvider(api_key=gemini_api_key)) + + class CityLocation(BaseModel): + city: str + country: str + + agent = Agent(m, output_type=ToolOutput(CityLocation)) + + @agent.tool_plain + async def get_user_country() -> str: + return 'Mexico' + + result = await agent.run('What is the largest city in the user country?') + assert result.output == snapshot(CityLocation(city='Mexico City', country='Mexico')) + + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='What is the largest city in the user country?', + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[ToolCallPart(tool_name='get_user_country', args={}, tool_call_id=IsStr())], + usage=Usage( + requests=1, + request_tokens=32, + response_tokens=5, + total_tokens=37, + details={'text_prompt_tokens': 32, 'text_candidates_tokens': 5}, + ), + model_name='gemini-2.0-flash', + timestamp=IsDatetime(), + vendor_details={'finish_reason': 'STOP'}, + vendor_id=IsStr(), + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='get_user_country', + content='Mexico', + tool_call_id=IsStr(), + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[ + ToolCallPart( + tool_name='final_result', + args={'country': 'Mexico', 'city': 'Mexico City'}, + tool_call_id=IsStr(), + ) + ], + usage=Usage( + requests=1, + request_tokens=46, + response_tokens=8, + total_tokens=54, + details={'text_prompt_tokens': 46, 'text_candidates_tokens': 8}, + ), + model_name='gemini-2.0-flash', + timestamp=IsDatetime(), + vendor_details={'finish_reason': 'STOP'}, + vendor_id=IsStr(), + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='final_result', + content='Final result processed.', + tool_call_id=IsStr(), + timestamp=IsDatetime(), + ) + ] + ), + ] + ) + + +@pytest.mark.vcr() +async def test_gemini_text_output_function(allow_model_requests: None, gemini_api_key: str): + m = GeminiModel('gemini-2.5-pro-preview-05-06', provider=GoogleGLAProvider(api_key=gemini_api_key)) + + def upcase(text: str) -> str: + return text.upper() + + agent = Agent(m, output_type=TextOutput(upcase)) + + result = await agent.run('What is the largest city in Mexico?') + assert result.output == snapshot("""\ +THE LARGEST CITY IN MEXICO IS **MEXICO CITY (CIUDAD DE MÉXICO, CDMX)**. + +IT'S THE CAPITAL OF MEXICO AND ONE OF THE LARGEST METROPOLITAN AREAS IN THE WORLD, BOTH BY POPULATION AND LAND AREA.\ +""") + + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='What is the largest city in Mexico?', + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[ + TextPart( + content="""\ +The largest city in Mexico is **Mexico City (Ciudad de México, CDMX)**. + +It's the capital of Mexico and one of the largest metropolitan areas in the world, both by population and land area.\ +""" + ) + ], + usage=Usage( + requests=1, + request_tokens=9, + response_tokens=44, + total_tokens=598, + details={'thoughts_tokens': 545, 'text_prompt_tokens': 9}, + ), + model_name='models/gemini-2.5-pro-preview-05-06', + timestamp=IsDatetime(), + vendor_details={'finish_reason': 'STOP'}, + vendor_id='TT9IaNfGN_DmqtsPzKnE4AE', + ), + ] + ) + + +@pytest.mark.vcr() +async def test_gemini_json_schema_output_with_tools(allow_model_requests: None, gemini_api_key: str): + m = GeminiModel('gemini-2.0-flash', provider=GoogleGLAProvider(api_key=gemini_api_key)) + + class CityLocation(BaseModel): + city: str + country: str + + agent = Agent(m, output_type=JsonSchemaOutput(CityLocation)) + + @agent.tool_plain + async def get_user_country() -> str: + return 'Mexico' # pragma: no cover + + with pytest.raises(UserError, match='Google does not support JSON schema output and tools at the same time.'): + await agent.run('What is the largest city in the user country?') + + +@pytest.mark.vcr() +async def test_gemini_json_schema_output(allow_model_requests: None, gemini_api_key: str): + m = GeminiModel('gemini-2.0-flash', provider=GoogleGLAProvider(api_key=gemini_api_key)) + + class CityLocation(BaseModel): + """A city and its country.""" + + city: str + country: str + + agent = Agent(m, output_type=JsonSchemaOutput(CityLocation)) + + result = await agent.run('What is the largest city in Mexico?') + assert result.output == snapshot(CityLocation(city='Mexico City', country='Mexico')) + + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='What is the largest city in Mexico?', + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[ + TextPart( + content="""\ +{ + "city": "Mexico City", + "country": "Mexico" +}\ +""" + ) + ], + usage=Usage( + requests=1, + request_tokens=17, + response_tokens=20, + total_tokens=37, + details={'text_prompt_tokens': 17, 'text_candidates_tokens': 20}, + ), + model_name='gemini-2.0-flash', + timestamp=IsDatetime(), + vendor_details={'finish_reason': 'STOP'}, + vendor_id=IsStr(), + ), + ] + ) + + +@pytest.mark.vcr() +async def test_gemini_json_schema_output_multiple(allow_model_requests: None, gemini_api_key: str): + m = GeminiModel('gemini-2.0-flash', provider=GoogleGLAProvider(api_key=gemini_api_key)) + + class CityLocation(BaseModel): + city: str + country: str + + class CountryLanguage(BaseModel): + country: str + language: str + + agent = Agent(m, output_type=JsonSchemaOutput([CityLocation, CountryLanguage])) + + result = await agent.run('What is the primarily language spoken in Mexico?') + assert result.output == snapshot(CountryLanguage(country='Mexico', language='Spanish')) + + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='What is the primarily language spoken in Mexico?', + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[ + TextPart( + content="""\ +{ + "result": { + "data": { + "country": "Mexico", + "language": "Spanish" + }, + "kind": "CountryLanguage" + } +}\ +""" + ) + ], + usage=Usage( + requests=1, + request_tokens=46, + response_tokens=46, + total_tokens=92, + details={'text_prompt_tokens': 46, 'text_candidates_tokens': 46}, + ), + model_name='gemini-2.0-flash', + timestamp=IsDatetime(), + vendor_details={'finish_reason': 'STOP'}, + vendor_id=IsStr(), + ), + ] + ) + + +@pytest.mark.vcr() +async def test_gemini_prompted_json_output(allow_model_requests: None, gemini_api_key: str): + m = GeminiModel('gemini-2.0-flash', provider=GoogleGLAProvider(api_key=gemini_api_key)) + + class CityLocation(BaseModel): + city: str + country: str + + agent = Agent(m, output_type=PromptedJsonOutput(CityLocation)) + + result = await agent.run('What is the largest city in Mexico?') + assert result.output == snapshot(CityLocation(city='Mexico City', country='Mexico')) + + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='What is the largest city in Mexico?', + timestamp=IsDatetime(), + ) + ], + instructions="""\ +Always respond with a JSON object that's compatible with this schema: + +{"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} + +Don't include any text or Markdown fencing before or after.\ +""", + ), + ModelResponse( + parts=[ + TextPart( + content='{"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object", "city": "Mexico City", "country": "Mexico"}' + ) + ], + usage=Usage( + requests=1, + request_tokens=80, + response_tokens=56, + total_tokens=136, + details={'text_prompt_tokens': 80, 'text_candidates_tokens': 56}, + ), + model_name='gemini-2.0-flash', + timestamp=IsDatetime(), + vendor_details={'finish_reason': 'STOP'}, + vendor_id=IsStr(), + ), + ] + ) + + +@pytest.mark.vcr() +async def test_gemini_prompted_json_output_with_tools(allow_model_requests: None, gemini_api_key: str): + m = GeminiModel('gemini-2.5-pro-preview-05-06', provider=GoogleGLAProvider(api_key=gemini_api_key)) + + class CityLocation(BaseModel): + city: str + country: str + + agent = Agent(m, output_type=PromptedJsonOutput(CityLocation)) + + @agent.tool_plain + async def get_user_country() -> str: + return 'Mexico' + + result = await agent.run( + 'What is the largest city in the user country? Use the get_user_country tool and then your own world knowledge.' + ) + assert result.output == snapshot(CityLocation(city='Mexico City', country='Mexico')) + + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='What is the largest city in the user country? Use the get_user_country tool and then your own world knowledge.', + timestamp=IsDatetime(), + ) + ], + instructions="""\ +Always respond with a JSON object that's compatible with this schema: + +{"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} + +Don't include any text or Markdown fencing before or after.\ +""", + ), + ModelResponse( + parts=[ToolCallPart(tool_name='get_user_country', args={}, tool_call_id=IsStr())], + usage=Usage( + requests=1, + request_tokens=123, + response_tokens=12, + total_tokens=453, + details={'thoughts_tokens': 318, 'text_prompt_tokens': 123}, + ), + model_name='models/gemini-2.5-pro-preview-05-06', + timestamp=IsDatetime(), + vendor_details={'finish_reason': 'STOP'}, + vendor_id=IsStr(), + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='get_user_country', + content='Mexico', + tool_call_id=IsStr(), + timestamp=IsDatetime(), + ) + ], + instructions="""\ +Always respond with a JSON object that's compatible with this schema: + +{"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} + +Don't include any text or Markdown fencing before or after.\ +""", + ), + ModelResponse( + parts=[TextPart(content='{"city": "Mexico City", "country": "Mexico"}')], + usage=Usage( + requests=1, + request_tokens=154, + response_tokens=13, + total_tokens=261, + details={'thoughts_tokens': 94, 'text_prompt_tokens': 154}, + ), + model_name='models/gemini-2.5-pro-preview-05-06', + timestamp=IsDatetime(), + vendor_details={'finish_reason': 'STOP'}, + vendor_id=IsStr(), + ), + ] + ) + + +@pytest.mark.vcr() +async def test_gemini_prompted_json_output_multiple(allow_model_requests: None, gemini_api_key: str): + m = GeminiModel('gemini-2.0-flash', provider=GoogleGLAProvider(api_key=gemini_api_key)) + + class CityLocation(BaseModel): + city: str + country: str + + class CountryLanguage(BaseModel): + country: str + language: str + + agent = Agent(m, output_type=PromptedJsonOutput([CityLocation, CountryLanguage])) + + result = await agent.run('What is the largest city in Mexico?') + assert result.output == snapshot(CityLocation(city='Mexico City', country='Mexico')) + + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='What is the largest city in Mexico?', + timestamp=IsDatetime(), + ) + ], + instructions="""\ +Always respond with a JSON object that's compatible with this schema: + +{"type": "object", "properties": {"result": {"anyOf": [{"type": "object", "properties": {"kind": {"type": "string", "const": "CityLocation"}, "data": {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "type": "object"}}, "required": ["kind", "data"], "additionalProperties": false, "title": "CityLocation"}, {"type": "object", "properties": {"kind": {"type": "string", "const": "CountryLanguage"}, "data": {"properties": {"country": {"type": "string"}, "language": {"type": "string"}}, "required": ["country", "language"], "type": "object"}}, "required": ["kind", "data"], "additionalProperties": false, "title": "CountryLanguage"}]}}, "required": ["result"], "additionalProperties": false} + +Don't include any text or Markdown fencing before or after.\ +""", + ), + ModelResponse( + parts=[ + TextPart( + content='{"result": {"kind": "CityLocation", "data": {"city": "Mexico City", "country": "Mexico"}}}' + ) + ], + usage=Usage( + requests=1, + request_tokens=253, + response_tokens=27, + total_tokens=280, + details={'text_prompt_tokens': 253, 'text_candidates_tokens': 27}, + ), + model_name='gemini-2.0-flash', + timestamp=IsDatetime(), + vendor_details={'finish_reason': 'STOP'}, + vendor_id=IsStr(), + ), + ] + ) diff --git a/tests/models/test_google.py b/tests/models/test_google.py index 6ef80cca5..bdceef479 100644 --- a/tests/models/test_google.py +++ b/tests/models/test_google.py @@ -8,6 +8,7 @@ import pytest from httpx import Request from inline_snapshot import Is, snapshot +from pydantic import BaseModel from pytest_mock import MockerFixture from typing_extensions import TypedDict @@ -34,7 +35,7 @@ UserPromptPart, VideoUrl, ) -from pydantic_ai.usage import Usage +from pydantic_ai.result import JsonSchemaOutput, PromptedJsonOutput, TextOutput, ToolOutput, Usage from ..conftest import IsDatetime, IsInstance, IsStr, try_import @@ -726,3 +727,473 @@ async def test_google_gs_url_force_download_raises_user_error(allow_model_reques url = ImageUrl(url='gs://pydantic-ai-dev/wikipedia_screenshot.png', force_download=True) with pytest.raises(UserError, match='Downloading from protocol "gs://" is not supported.'): _ = await agent.run(['What is the main content of this URL?', url]) + + +async def test_google_tool_output(allow_model_requests: None, google_provider: GoogleProvider): + m = GoogleModel('gemini-2.0-flash', provider=google_provider) + + class CityLocation(BaseModel): + city: str + country: str + + agent = Agent(m, output_type=ToolOutput(CityLocation)) + + @agent.tool_plain + async def get_user_country() -> str: + return 'Mexico' + + result = await agent.run('What is the largest city in the user country?') + assert result.output == snapshot(CityLocation(city='Mexico City', country='Mexico')) + + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='What is the largest city in the user country?', + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[ToolCallPart(tool_name='get_user_country', args={}, tool_call_id=IsStr())], + usage=Usage( + requests=1, + request_tokens=32, + response_tokens=5, + total_tokens=37, + details={'text_candidates_tokens': 5, 'text_prompt_tokens': 32}, + ), + model_name='gemini-2.0-flash', + timestamp=IsDatetime(), + vendor_details={'finish_reason': 'STOP'}, + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='get_user_country', + content='Mexico', + tool_call_id=IsStr(), + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[ + ToolCallPart( + tool_name='final_result', + args={'city': 'Mexico City', 'country': 'Mexico'}, + tool_call_id=IsStr(), + ) + ], + usage=Usage( + requests=1, + request_tokens=46, + response_tokens=8, + total_tokens=54, + details={'text_candidates_tokens': 8, 'text_prompt_tokens': 46}, + ), + model_name='gemini-2.0-flash', + timestamp=IsDatetime(), + vendor_details={'finish_reason': 'STOP'}, + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='final_result', + content='Final result processed.', + tool_call_id=IsStr(), + timestamp=IsDatetime(), + ) + ] + ), + ] + ) + + +async def test_google_text_output_function(allow_model_requests: None, google_provider: GoogleProvider): + m = GoogleModel('gemini-2.5-pro-preview-05-06', provider=google_provider) + + def upcase(text: str) -> str: + return text.upper() + + agent = Agent(m, output_type=TextOutput(upcase)) + + @agent.tool_plain + async def get_user_country() -> str: + return 'Mexico' + + result = await agent.run( + 'What is the largest city in the user country? Use the get_user_country tool and then your own world knowledge.' + ) + assert result.output == snapshot('BASED ON THE INFORMATION I HAVE, THE LARGEST CITY IN MEXICO IS MEXICO CITY.') + + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='What is the largest city in the user country? Use the get_user_country tool and then your own world knowledge.', + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[ + TextPart(content='Okay, I can help with that. First, I need to determine your country.\n'), + ToolCallPart(tool_name='get_user_country', args={}, tool_call_id=IsStr()), + ], + usage=Usage( + requests=1, + request_tokens=49, + response_tokens=30, + total_tokens=238, + details={'thoughts_tokens': 159, 'text_prompt_tokens': 49}, + ), + model_name='models/gemini-2.5-pro-preview-05-06', + timestamp=IsDatetime(), + vendor_details={'finish_reason': 'STOP'}, + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='get_user_country', + content='Mexico', + tool_call_id=IsStr(), + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[TextPart(content='Based on the information I have, the largest city in Mexico is Mexico City.')], + usage=Usage( + requests=1, + request_tokens=98, + response_tokens=16, + total_tokens=159, + details={'thoughts_tokens': 45, 'text_prompt_tokens': 98}, + ), + model_name='models/gemini-2.5-pro-preview-05-06', + timestamp=IsDatetime(), + vendor_details={'finish_reason': 'STOP'}, + ), + ] + ) + + +async def test_google_json_schema_output_with_tools(allow_model_requests: None, google_provider: GoogleProvider): + m = GoogleModel('gemini-2.0-flash', provider=google_provider) + + class CityLocation(BaseModel): + city: str + country: str + + agent = Agent(m, output_type=JsonSchemaOutput(CityLocation)) + + @agent.tool_plain + async def get_user_country() -> str: + return 'Mexico' # pragma: no cover + + with pytest.raises(UserError, match='Google does not support JSON schema output and tools at the same time.'): + await agent.run('What is the largest city in the user country?') + + +async def test_google_json_schema_output(allow_model_requests: None, google_provider: GoogleProvider): + m = GoogleModel('gemini-2.0-flash', provider=google_provider) + + class CityLocation(BaseModel): + """A city and its country.""" + + city: str + country: str + + agent = Agent(m, output_type=JsonSchemaOutput(CityLocation)) + + result = await agent.run('What is the largest city in Mexico?') + assert result.output == snapshot(CityLocation(city='Mexico City', country='Mexico')) + + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='What is the largest city in Mexico?', + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[ + TextPart( + content="""\ +{ + "city": "Mexico City", + "country": "Mexico" +}\ +""" + ) + ], + usage=Usage( + requests=1, + request_tokens=19, + response_tokens=20, + total_tokens=39, + details={'text_candidates_tokens': 20, 'text_prompt_tokens': 19}, + ), + model_name='gemini-2.0-flash', + timestamp=IsDatetime(), + vendor_details={'finish_reason': 'STOP'}, + ), + ] + ) + + +async def test_google_json_schema_output_multiple(allow_model_requests: None, google_provider: GoogleProvider): + m = GoogleModel('gemini-2.0-flash', provider=google_provider) + + class CityLocation(BaseModel): + city: str + country: str + + class CountryLanguage(BaseModel): + country: str + language: str + + agent = Agent(m, output_type=JsonSchemaOutput([CityLocation, CountryLanguage])) + + result = await agent.run('What is the primarily language spoken in Mexico?') + assert result.output == snapshot(CountryLanguage(country='Mexico', language='Spanish')) + + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='What is the primarily language spoken in Mexico?', + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[ + TextPart( + content="""\ +{ + "result": { + "kind": "CountryLanguage", + "data": { + "country": "Mexico", + "language": "Spanish" + } + } +}\ +""" + ) + ], + usage=Usage( + requests=1, + request_tokens=64, + response_tokens=46, + total_tokens=110, + details={'text_candidates_tokens': 46, 'text_prompt_tokens': 64}, + ), + model_name='gemini-2.0-flash', + timestamp=IsDatetime(), + vendor_details={'finish_reason': 'STOP'}, + ), + ] + ) + + +async def test_google_prompted_json_output(allow_model_requests: None, google_provider: GoogleProvider): + m = GoogleModel('gemini-2.0-flash', provider=google_provider) + + class CityLocation(BaseModel): + city: str + country: str + + agent = Agent(m, output_type=PromptedJsonOutput(CityLocation)) + + result = await agent.run('What is the largest city in Mexico?') + assert result.output == snapshot(CityLocation(city='Mexico City', country='Mexico')) + + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='What is the largest city in Mexico?', + timestamp=IsDatetime(), + ) + ], + instructions="""\ +Always respond with a JSON object that's compatible with this schema: + +{"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} + +Don't include any text or Markdown fencing before or after.\ +""", + ), + ModelResponse( + parts=[ + TextPart( + content='{"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object", "city": "Mexico City", "country": "Mexico"}' + ) + ], + usage=Usage( + requests=1, + request_tokens=80, + response_tokens=56, + total_tokens=136, + details={'text_candidates_tokens': 56, 'text_prompt_tokens': 80}, + ), + model_name='gemini-2.0-flash', + timestamp=IsDatetime(), + vendor_details={'finish_reason': 'STOP'}, + ), + ] + ) + + +async def test_google_prompted_json_output_with_tools(allow_model_requests: None, google_provider: GoogleProvider): + m = GoogleModel('gemini-2.5-pro-preview-05-06', provider=google_provider) + + class CityLocation(BaseModel): + city: str + country: str + + agent = Agent(m, output_type=PromptedJsonOutput(CityLocation)) + + @agent.tool_plain + async def get_user_country() -> str: + return 'Mexico' + + result = await agent.run( + 'What is the largest city in the user country? Use the get_user_country tool and then your own world knowledge.' + ) + assert result.output == snapshot(CityLocation(city='Mexico City', country='Mexico')) + + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='What is the largest city in the user country? Use the get_user_country tool and then your own world knowledge.', + timestamp=IsDatetime(), + ) + ], + instructions="""\ +Always respond with a JSON object that's compatible with this schema: + +{"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} + +Don't include any text or Markdown fencing before or after.\ +""", + ), + ModelResponse( + parts=[ToolCallPart(tool_name='get_user_country', args={}, tool_call_id=IsStr())], + usage=Usage( + requests=1, + request_tokens=123, + response_tokens=12, + total_tokens=401, + details={'thoughts_tokens': 266, 'text_prompt_tokens': 123}, + ), + model_name='models/gemini-2.5-pro-preview-05-06', + timestamp=IsDatetime(), + vendor_details={'finish_reason': 'STOP'}, + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='get_user_country', + content='Mexico', + tool_call_id=IsStr(), + timestamp=IsDatetime(), + ) + ], + instructions="""\ +Always respond with a JSON object that's compatible with this schema: + +{"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} + +Don't include any text or Markdown fencing before or after.\ +""", + ), + ModelResponse( + parts=[ + TextPart( + content="""\ +```json +{"city": "Mexico City", "country": "Mexico"} +```\ +""" + ) + ], + usage=Usage( + requests=1, + request_tokens=154, + response_tokens=18, + total_tokens=266, + details={'thoughts_tokens': 94, 'text_prompt_tokens': 154}, + ), + model_name='models/gemini-2.5-pro-preview-05-06', + timestamp=IsDatetime(), + vendor_details={'finish_reason': 'STOP'}, + ), + ] + ) + + +async def test_google_prompted_json_output_multiple(allow_model_requests: None, google_provider: GoogleProvider): + m = GoogleModel('gemini-2.0-flash', provider=google_provider) + + class CityLocation(BaseModel): + city: str + country: str + + class CountryLanguage(BaseModel): + country: str + language: str + + agent = Agent(m, output_type=PromptedJsonOutput([CityLocation, CountryLanguage])) + + result = await agent.run('What is the largest city in Mexico?') + assert result.output == snapshot(CityLocation(city='Mexico City', country='Mexico')) + + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='What is the largest city in Mexico?', + timestamp=IsDatetime(), + ) + ], + instructions="""\ +Always respond with a JSON object that's compatible with this schema: + +{"type": "object", "properties": {"result": {"anyOf": [{"type": "object", "properties": {"kind": {"type": "string", "const": "CityLocation"}, "data": {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "type": "object"}}, "required": ["kind", "data"], "additionalProperties": false, "title": "CityLocation"}, {"type": "object", "properties": {"kind": {"type": "string", "const": "CountryLanguage"}, "data": {"properties": {"country": {"type": "string"}, "language": {"type": "string"}}, "required": ["country", "language"], "type": "object"}}, "required": ["kind", "data"], "additionalProperties": false, "title": "CountryLanguage"}]}}, "required": ["result"], "additionalProperties": false} + +Don't include any text or Markdown fencing before or after.\ +""", + ), + ModelResponse( + parts=[ + TextPart( + content='{"result": {"kind": "CityLocation", "data": {"city": "Mexico City", "country": "Mexico"}}}' + ) + ], + usage=Usage( + requests=1, + request_tokens=241, + response_tokens=27, + total_tokens=268, + details={'text_candidates_tokens': 27, 'text_prompt_tokens': 241}, + ), + model_name='gemini-2.0-flash', + timestamp=IsDatetime(), + vendor_details={'finish_reason': 'STOP'}, + ), + ] + ) diff --git a/tests/models/test_instrumented.py b/tests/models/test_instrumented.py index f7caad399..ac5383663 100644 --- a/tests/models/test_instrumented.py +++ b/tests/models/test_instrumented.py @@ -134,6 +134,8 @@ async def test_instrumented_model(capfire: CaptureLogfire): function_tools=[], allow_text_output=True, output_tools=[], + output_mode='text', + output_object=None, ), ) @@ -151,7 +153,7 @@ async def test_instrumented_model(capfire: CaptureLogfire): 'gen_ai.request.model': 'my_model', 'server.address': 'example.com', 'server.port': 8000, - 'model_request_parameters': '{"function_tools": [], "allow_text_output": true, "output_tools": []}', + 'model_request_parameters': '{"function_tools": [], "output_mode": "text", "output_object": null, "output_tools": [], "allow_text_output": true}', 'logfire.json_schema': '{"type": "object", "properties": {"model_request_parameters": {"type": "object"}}}', 'gen_ai.request.temperature': 1, 'logfire.msg': 'chat my_model', @@ -330,6 +332,8 @@ async def test_instrumented_model_not_recording(): function_tools=[], allow_text_output=True, output_tools=[], + output_mode='text', + output_object=None, ), ) @@ -352,6 +356,8 @@ async def test_instrumented_model_stream(capfire: CaptureLogfire): function_tools=[], allow_text_output=True, output_tools=[], + output_mode='text', + output_object=None, ), ) as response_stream: assert [event async for event in response_stream] == snapshot( @@ -375,7 +381,7 @@ async def test_instrumented_model_stream(capfire: CaptureLogfire): 'gen_ai.request.model': 'my_model', 'server.address': 'example.com', 'server.port': 8000, - 'model_request_parameters': '{"function_tools": [], "allow_text_output": true, "output_tools": []}', + 'model_request_parameters': '{"function_tools": [], "output_mode": "text", "output_object": null, "output_tools": [], "allow_text_output": true}', 'logfire.json_schema': '{"type": "object", "properties": {"model_request_parameters": {"type": "object"}}}', 'gen_ai.request.temperature': 1, 'logfire.msg': 'chat my_model', @@ -440,6 +446,8 @@ async def test_instrumented_model_stream_break(capfire: CaptureLogfire): function_tools=[], allow_text_output=True, output_tools=[], + output_mode='text', + output_object=None, ), ) as response_stream: async for event in response_stream: # pragma: no branch @@ -460,7 +468,7 @@ async def test_instrumented_model_stream_break(capfire: CaptureLogfire): 'gen_ai.request.model': 'my_model', 'server.address': 'example.com', 'server.port': 8000, - 'model_request_parameters': '{"function_tools": [], "allow_text_output": true, "output_tools": []}', + 'model_request_parameters': '{"function_tools": [], "output_mode": "text", "output_object": null, "output_tools": [], "allow_text_output": true}', 'logfire.json_schema': '{"type": "object", "properties": {"model_request_parameters": {"type": "object"}}}', 'gen_ai.request.temperature': 1, 'logfire.msg': 'chat my_model', @@ -543,6 +551,8 @@ async def test_instrumented_model_attributes_mode(capfire: CaptureLogfire): function_tools=[], allow_text_output=True, output_tools=[], + output_mode='text', + output_object=None, ), ) @@ -560,7 +570,7 @@ async def test_instrumented_model_attributes_mode(capfire: CaptureLogfire): 'gen_ai.request.model': 'my_model', 'server.address': 'example.com', 'server.port': 8000, - 'model_request_parameters': '{"function_tools": [], "allow_text_output": true, "output_tools": []}', + 'model_request_parameters': '{"function_tools": [], "output_mode": "text", "output_object": null, "output_tools": [], "allow_text_output": true}', 'gen_ai.request.temperature': 1, 'logfire.msg': 'chat my_model', 'logfire.span_type': 'span', diff --git a/tests/models/test_model_request_parameters.py b/tests/models/test_model_request_parameters.py index 03910db11..98a6d1ccc 100644 --- a/tests/models/test_model_request_parameters.py +++ b/tests/models/test_model_request_parameters.py @@ -4,9 +4,13 @@ def test_model_request_parameters_are_serializable(): - params = ModelRequestParameters(function_tools=[], allow_text_output=False, output_tools=[]) + params = ModelRequestParameters( + function_tools=[], output_mode='text', allow_text_output=True, output_tools=[], output_object=None + ) assert TypeAdapter(ModelRequestParameters).dump_python(params) == { 'function_tools': [], - 'allow_text_output': False, + 'output_mode': 'text', + 'allow_text_output': True, 'output_tools': [], + 'output_object': None, } diff --git a/tests/models/test_openai.py b/tests/models/test_openai.py index 177b0fd80..99238690c 100644 --- a/tests/models/test_openai.py +++ b/tests/models/test_openai.py @@ -34,7 +34,7 @@ from pydantic_ai.profiles._json_schema import InlineDefsJsonSchemaTransformer from pydantic_ai.profiles.openai import OpenAIModelProfile, openai_model_profile from pydantic_ai.providers.google_gla import GoogleGLAProvider -from pydantic_ai.result import Usage +from pydantic_ai.result import JsonSchemaOutput, PromptedJsonOutput, TextOutput, ToolOutput, Usage from pydantic_ai.settings import ModelSettings from pydantic_ai.tools import ToolDefinition @@ -538,6 +538,31 @@ async def test_stream_structured_finish_reason(allow_model_requests: None): assert result.is_complete +async def test_stream_structured_json_schema_output(allow_model_requests: None): + stream = [ + chunk([]), + text_chunk('{"first": "One'), + text_chunk('", "second": "Two"'), + text_chunk('}'), + chunk([]), + ] + mock_client = MockOpenAI.create_mock_stream(stream) + m = OpenAIModel('gpt-4o', provider=OpenAIProvider(openai_client=mock_client)) + agent = Agent(m, output_type=JsonSchemaOutput(MyTypedDict)) + + async with agent.run_stream('') as result: + assert not result.is_complete + assert [dict(c) async for c in result.stream(debounce_by=None)] == snapshot( + [ + {'first': 'One'}, + {'first': 'One', 'second': 'Two'}, + {'first': 'One', 'second': 'Two'}, + {'first': 'One', 'second': 'Two'}, + ] + ) + assert result.is_complete + + async def test_no_content(allow_model_requests: None): stream = [chunk([ChoiceDelta()]), chunk([ChoiceDelta()])] mock_client = MockOpenAI.create_mock_stream(stream) @@ -1759,3 +1784,548 @@ def test_openai_response_timestamp_milliseconds(allow_model_requests: None): result = agent.run_sync('Hello') response = cast(ModelResponse, result.all_messages()[-1]) assert response.timestamp == snapshot(datetime(2025, 6, 1, 3, 7, 48, tzinfo=timezone.utc)) + + +@pytest.mark.vcr() +async def test_openai_tool_output(allow_model_requests: None, openai_api_key: str): + m = OpenAIModel('gpt-4o', provider=OpenAIProvider(api_key=openai_api_key)) + + class CityLocation(BaseModel): + city: str + country: str + + agent = Agent(m, output_type=ToolOutput(CityLocation)) + + @agent.tool_plain + async def get_user_country() -> str: + return 'Mexico' + + result = await agent.run('What is the largest city in the user country?') + assert result.output == snapshot(CityLocation(city='Mexico City', country='Mexico')) + + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='What is the largest city in the user country?', + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[ToolCallPart(tool_name='get_user_country', args='{}', tool_call_id=IsStr())], + usage=Usage( + requests=1, + request_tokens=68, + response_tokens=12, + total_tokens=80, + details={ + 'accepted_prediction_tokens': 0, + 'audio_tokens': 0, + 'reasoning_tokens': 0, + 'rejected_prediction_tokens': 0, + 'cached_tokens': 0, + }, + ), + model_name='gpt-4o-2024-08-06', + timestamp=IsDatetime(), + vendor_id='chatcmpl-BSXk0dWkG4hfPt0lph4oFO35iT73I', + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='get_user_country', + content='Mexico', + tool_call_id=IsStr(), + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[ + ToolCallPart( + tool_name='final_result', + args='{"city": "Mexico City", "country": "Mexico"}', + tool_call_id=IsStr(), + ) + ], + usage=Usage( + requests=1, + request_tokens=89, + response_tokens=36, + total_tokens=125, + details={ + 'accepted_prediction_tokens': 0, + 'audio_tokens': 0, + 'reasoning_tokens': 0, + 'rejected_prediction_tokens': 0, + 'cached_tokens': 0, + }, + ), + model_name='gpt-4o-2024-08-06', + timestamp=IsDatetime(), + vendor_id='chatcmpl-BSXk1xGHYzbhXgUkSutK08bdoNv5s', + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='final_result', + content='Final result processed.', + tool_call_id=IsStr(), + timestamp=IsDatetime(), + ) + ] + ), + ] + ) + + +@pytest.mark.vcr() +async def test_openai_text_output_function(allow_model_requests: None, openai_api_key: str): + m = OpenAIModel('gpt-4o', provider=OpenAIProvider(api_key=openai_api_key)) + + def upcase(text: str) -> str: + return text.upper() + + agent = Agent(m, output_type=TextOutput(upcase)) + + @agent.tool_plain + async def get_user_country() -> str: + return 'Mexico' + + result = await agent.run('What is the largest city in the user country?') + assert result.output == snapshot('THE LARGEST CITY IN MEXICO IS MEXICO CITY.') + + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='What is the largest city in the user country?', + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[ + ToolCallPart(tool_name='get_user_country', args='{}', tool_call_id='call_J1YabdC7G7kzEZNbbZopwenH') + ], + usage=Usage( + requests=1, + request_tokens=42, + response_tokens=11, + total_tokens=53, + details={ + 'accepted_prediction_tokens': 0, + 'audio_tokens': 0, + 'reasoning_tokens': 0, + 'rejected_prediction_tokens': 0, + 'cached_tokens': 0, + }, + ), + model_name='gpt-4o-2024-08-06', + timestamp=IsDatetime(), + vendor_id='chatcmpl-BgeDFS85bfHosRFEEAvq8reaCPCZ8', + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='get_user_country', + content='Mexico', + tool_call_id='call_J1YabdC7G7kzEZNbbZopwenH', + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[TextPart(content='The largest city in Mexico is Mexico City.')], + usage=Usage( + requests=1, + request_tokens=63, + response_tokens=10, + total_tokens=73, + details={ + 'accepted_prediction_tokens': 0, + 'audio_tokens': 0, + 'reasoning_tokens': 0, + 'rejected_prediction_tokens': 0, + 'cached_tokens': 0, + }, + ), + model_name='gpt-4o-2024-08-06', + timestamp=IsDatetime(), + vendor_id='chatcmpl-BgeDGX9eDyVrEI56aP2vtIHahBzFH', + ), + ] + ) + + +@pytest.mark.vcr() +async def test_openai_json_schema_output(allow_model_requests: None, openai_api_key: str): + m = OpenAIModel('gpt-4o', provider=OpenAIProvider(api_key=openai_api_key)) + + class CityLocation(BaseModel): + """A city and its country.""" + + city: str + country: str + + agent = Agent(m, output_type=JsonSchemaOutput(CityLocation)) + + @agent.tool_plain + async def get_user_country() -> str: + return 'Mexico' + + result = await agent.run('What is the largest city in the user country?') + assert result.output == snapshot(CityLocation(city='Mexico City', country='Mexico')) + + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='What is the largest city in the user country?', + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[ + ToolCallPart(tool_name='get_user_country', args='{}', tool_call_id='call_PkRGedQNRFUzJp2R7dO7avWR') + ], + usage=Usage( + requests=1, + request_tokens=71, + response_tokens=12, + total_tokens=83, + details={ + 'accepted_prediction_tokens': 0, + 'audio_tokens': 0, + 'reasoning_tokens': 0, + 'rejected_prediction_tokens': 0, + 'cached_tokens': 0, + }, + ), + model_name='gpt-4o-2024-08-06', + timestamp=IsDatetime(), + vendor_id='chatcmpl-BSXjyBwGuZrtuuSzNCeaWMpGv2MZ3', + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='get_user_country', + content='Mexico', + tool_call_id='call_PkRGedQNRFUzJp2R7dO7avWR', + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[TextPart(content='{"city":"Mexico City","country":"Mexico"}')], + usage=Usage( + requests=1, + request_tokens=92, + response_tokens=15, + total_tokens=107, + details={ + 'accepted_prediction_tokens': 0, + 'audio_tokens': 0, + 'reasoning_tokens': 0, + 'rejected_prediction_tokens': 0, + 'cached_tokens': 0, + }, + ), + model_name='gpt-4o-2024-08-06', + timestamp=IsDatetime(), + vendor_id='chatcmpl-BSXjzYGu67dhTy5r8KmjJvQ4HhDVO', + ), + ] + ) + + +@pytest.mark.vcr() +async def test_openai_json_schema_output_multiple(allow_model_requests: None, openai_api_key: str): + m = OpenAIModel('gpt-4o', provider=OpenAIProvider(api_key=openai_api_key)) + + class CityLocation(BaseModel): + city: str + country: str + + class CountryLanguage(BaseModel): + country: str + language: str + + agent = Agent(m, output_type=JsonSchemaOutput([CityLocation, CountryLanguage])) + + @agent.tool_plain + async def get_user_country() -> str: + return 'Mexico' + + result = await agent.run('What is the largest city in the user country?') + assert result.output == snapshot(CityLocation(city='Mexico City', country='Mexico')) + + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='What is the largest city in the user country?', + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[ + ToolCallPart(tool_name='get_user_country', args='{}', tool_call_id='call_SIttSeiOistt33Htj4oiHOOX') + ], + usage=Usage( + requests=1, + request_tokens=160, + response_tokens=11, + total_tokens=171, + details={ + 'accepted_prediction_tokens': 0, + 'audio_tokens': 0, + 'reasoning_tokens': 0, + 'rejected_prediction_tokens': 0, + 'cached_tokens': 0, + }, + ), + model_name='gpt-4o-2024-08-06', + timestamp=IsDatetime(), + vendor_id='chatcmpl-Bgg5utuCSXMQ38j0n2qgfdQKcR9VD', + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='get_user_country', + content='Mexico', + tool_call_id='call_SIttSeiOistt33Htj4oiHOOX', + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[ + TextPart( + content='{"result":{"kind":"CityLocation","data":{"city":"Mexico City","country":"Mexico"}}}' + ) + ], + usage=Usage( + requests=1, + request_tokens=181, + response_tokens=25, + total_tokens=206, + details={ + 'accepted_prediction_tokens': 0, + 'audio_tokens': 0, + 'reasoning_tokens': 0, + 'rejected_prediction_tokens': 0, + 'cached_tokens': 0, + }, + ), + model_name='gpt-4o-2024-08-06', + timestamp=IsDatetime(), + vendor_id='chatcmpl-Bgg5vrxUtCDlvgMreoxYxPaKxANmd', + ), + ] + ) + + +@pytest.mark.vcr() +async def test_openai_prompted_json_output(allow_model_requests: None, openai_api_key: str): + m = OpenAIModel('gpt-4o', provider=OpenAIProvider(api_key=openai_api_key)) + + class CityLocation(BaseModel): + city: str + country: str + + agent = Agent(m, output_type=PromptedJsonOutput(CityLocation)) + + @agent.tool_plain + async def get_user_country() -> str: + return 'Mexico' + + result = await agent.run('What is the largest city in the user country?') + assert result.output == snapshot(CityLocation(city='Mexico City', country='Mexico')) + + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='What is the largest city in the user country?', + timestamp=IsDatetime(), + ) + ], + instructions="""\ +Always respond with a JSON object that's compatible with this schema: + +{"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} + +Don't include any text or Markdown fencing before or after.\ +""", + ), + ModelResponse( + parts=[ + ToolCallPart(tool_name='get_user_country', args='{}', tool_call_id='call_s7oT9jaLAsEqTgvxZTmFh0wB') + ], + usage=Usage( + requests=1, + request_tokens=109, + response_tokens=11, + total_tokens=120, + details={ + 'accepted_prediction_tokens': 0, + 'audio_tokens': 0, + 'reasoning_tokens': 0, + 'rejected_prediction_tokens': 0, + 'cached_tokens': 0, + }, + ), + model_name='gpt-4o-2024-08-06', + timestamp=IsDatetime(), + vendor_id='chatcmpl-Bgh27PeOaFW6qmF04qC5uI2H9mviw', + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='get_user_country', + content='Mexico', + tool_call_id='call_s7oT9jaLAsEqTgvxZTmFh0wB', + timestamp=IsDatetime(), + ) + ], + instructions="""\ +Always respond with a JSON object that's compatible with this schema: + +{"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} + +Don't include any text or Markdown fencing before or after.\ +""", + ), + ModelResponse( + parts=[TextPart(content='{"city":"Mexico City","country":"Mexico"}')], + usage=Usage( + requests=1, + request_tokens=130, + response_tokens=11, + total_tokens=141, + details={ + 'accepted_prediction_tokens': 0, + 'audio_tokens': 0, + 'reasoning_tokens': 0, + 'rejected_prediction_tokens': 0, + 'cached_tokens': 0, + }, + ), + model_name='gpt-4o-2024-08-06', + timestamp=IsDatetime(), + vendor_id='chatcmpl-Bgh28advCSFhGHPnzUevVS6g6Uwg0', + ), + ] + ) + + +@pytest.mark.vcr() +async def test_openai_prompted_json_output_multiple(allow_model_requests: None, openai_api_key: str): + m = OpenAIModel('gpt-4o', provider=OpenAIProvider(api_key=openai_api_key)) + + class CityLocation(BaseModel): + city: str + country: str + + class CountryLanguage(BaseModel): + country: str + language: str + + agent = Agent(m, output_type=PromptedJsonOutput([CityLocation, CountryLanguage])) + + @agent.tool_plain + async def get_user_country() -> str: + return 'Mexico' + + result = await agent.run('What is the largest city in the user country?') + assert result.output == snapshot(CityLocation(city='Mexico City', country='Mexico')) + + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='What is the largest city in the user country?', + timestamp=IsDatetime(), + ) + ], + instructions="""\ +Always respond with a JSON object that's compatible with this schema: + +{"type": "object", "properties": {"result": {"anyOf": [{"type": "object", "properties": {"kind": {"type": "string", "const": "CityLocation"}, "data": {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "type": "object"}}, "required": ["kind", "data"], "additionalProperties": false, "title": "CityLocation"}, {"type": "object", "properties": {"kind": {"type": "string", "const": "CountryLanguage"}, "data": {"properties": {"country": {"type": "string"}, "language": {"type": "string"}}, "required": ["country", "language"], "type": "object"}}, "required": ["kind", "data"], "additionalProperties": false, "title": "CountryLanguage"}]}}, "required": ["result"], "additionalProperties": false} + +Don't include any text or Markdown fencing before or after.\ +""", + ), + ModelResponse( + parts=[ + ToolCallPart(tool_name='get_user_country', args='{}', tool_call_id='call_wJD14IyJ4KKVtjCrGyNCHO09') + ], + usage=Usage( + requests=1, + request_tokens=273, + response_tokens=11, + total_tokens=284, + details={ + 'accepted_prediction_tokens': 0, + 'audio_tokens': 0, + 'reasoning_tokens': 0, + 'rejected_prediction_tokens': 0, + 'cached_tokens': 0, + }, + ), + model_name='gpt-4o-2024-08-06', + timestamp=IsDatetime(), + vendor_id='chatcmpl-Bgh2AW2NXGgMc7iS639MJXNRgtatR', + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='get_user_country', + content='Mexico', + tool_call_id='call_wJD14IyJ4KKVtjCrGyNCHO09', + timestamp=IsDatetime(), + ) + ], + instructions="""\ +Always respond with a JSON object that's compatible with this schema: + +{"type": "object", "properties": {"result": {"anyOf": [{"type": "object", "properties": {"kind": {"type": "string", "const": "CityLocation"}, "data": {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "type": "object"}}, "required": ["kind", "data"], "additionalProperties": false, "title": "CityLocation"}, {"type": "object", "properties": {"kind": {"type": "string", "const": "CountryLanguage"}, "data": {"properties": {"country": {"type": "string"}, "language": {"type": "string"}}, "required": ["country", "language"], "type": "object"}}, "required": ["kind", "data"], "additionalProperties": false, "title": "CountryLanguage"}]}}, "required": ["result"], "additionalProperties": false} + +Don't include any text or Markdown fencing before or after.\ +""", + ), + ModelResponse( + parts=[ + TextPart( + content='{"result":{"kind":"CityLocation","data":{"city":"Mexico City","country":"Mexico"}}}' + ) + ], + usage=Usage( + requests=1, + request_tokens=294, + response_tokens=21, + total_tokens=315, + details={ + 'accepted_prediction_tokens': 0, + 'audio_tokens': 0, + 'reasoning_tokens': 0, + 'rejected_prediction_tokens': 0, + 'cached_tokens': 0, + }, + ), + model_name='gpt-4o-2024-08-06', + timestamp=IsDatetime(), + vendor_id='chatcmpl-Bgh2BthuopRnSqCuUgMbBnOqgkDHC', + ), + ] + ) diff --git a/tests/models/test_openai_responses.py b/tests/models/test_openai_responses.py index effce8bb6..fd7155513 100644 --- a/tests/models/test_openai_responses.py +++ b/tests/models/test_openai_responses.py @@ -3,6 +3,7 @@ import pytest from inline_snapshot import snapshot +from pydantic import BaseModel from typing_extensions import TypedDict from pydantic_ai.agent import Agent @@ -20,6 +21,7 @@ UserPromptPart, ) from pydantic_ai.profiles.openai import openai_model_profile +from pydantic_ai.result import JsonSchemaOutput, PromptedJsonOutput, TextOutput, ToolOutput from pydantic_ai.tools import ToolDefinition from pydantic_ai.usage import Usage @@ -515,3 +517,461 @@ async def test_reasoning_model_with_temperature(allow_model_requests: None, open assert result.output == snapshot( 'The capital of Mexico is Mexico City. It serves as the political, cultural, and economic heart of the country and is one of the largest metropolitan areas in the world.' ) + + +@pytest.mark.vcr() +async def test_tool_output(allow_model_requests: None, openai_api_key: str): + m = OpenAIResponsesModel('gpt-4o', provider=OpenAIProvider(api_key=openai_api_key)) + + class CityLocation(BaseModel): + city: str + country: str + + agent = Agent(m, output_type=ToolOutput(CityLocation)) + + @agent.tool_plain + async def get_user_country() -> str: + return 'Mexico' + + result = await agent.run('What is the largest city in the user country?') + assert result.output == snapshot(CityLocation(city='Mexico City', country='Mexico')) + + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='What is the largest city in the user country?', + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[ + TextPart(content=''), + ToolCallPart(tool_name='get_user_country', args='{}', tool_call_id='call_ZWkVhdUjupo528U9dqgFeRkH'), + ], + usage=Usage( + request_tokens=62, + response_tokens=12, + total_tokens=74, + details={'reasoning_tokens': 0, 'cached_tokens': 0}, + ), + model_name='gpt-4o-2024-08-06', + timestamp=IsDatetime(), + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='get_user_country', + content='Mexico', + tool_call_id='call_ZWkVhdUjupo528U9dqgFeRkH', + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[ + TextPart(content=''), + ToolCallPart( + tool_name='final_result', + args='{"city":"Mexico City","country":"Mexico"}', + tool_call_id='call_iFBd0zULhSZRR908DfH73VwN', + ), + ], + usage=Usage( + request_tokens=85, + response_tokens=20, + total_tokens=105, + details={'reasoning_tokens': 0, 'cached_tokens': 0}, + ), + model_name='gpt-4o-2024-08-06', + timestamp=IsDatetime(), + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='final_result', + content='Final result processed.', + tool_call_id='call_iFBd0zULhSZRR908DfH73VwN', + timestamp=IsDatetime(), + ) + ] + ), + ] + ) + + +@pytest.mark.vcr() +async def test_text_output_function(allow_model_requests: None, openai_api_key: str): + m = OpenAIResponsesModel('gpt-4o', provider=OpenAIProvider(api_key=openai_api_key)) + + def upcase(text: str) -> str: + return text.upper() + + agent = Agent(m, output_type=TextOutput(upcase)) + + @agent.tool_plain + async def get_user_country() -> str: + return 'Mexico' + + result = await agent.run('What is the largest city in the user country?') + assert result.output == snapshot('THE LARGEST CITY IN MEXICO IS MEXICO CITY.') + + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='What is the largest city in the user country?', + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[ + TextPart(content=''), + ToolCallPart(tool_name='get_user_country', args='{}', tool_call_id='call_aTJhYjzmixZaVGqwl5gn2Ncr'), + ], + usage=Usage( + request_tokens=36, + response_tokens=12, + total_tokens=48, + details={'reasoning_tokens': 0, 'cached_tokens': 0}, + ), + model_name='gpt-4o-2024-08-06', + timestamp=IsDatetime(), + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='get_user_country', + content='Mexico', + tool_call_id='call_aTJhYjzmixZaVGqwl5gn2Ncr', + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[TextPart(content='The largest city in Mexico is Mexico City.')], + usage=Usage( + request_tokens=59, + response_tokens=11, + total_tokens=70, + details={'reasoning_tokens': 0, 'cached_tokens': 0}, + ), + model_name='gpt-4o-2024-08-06', + timestamp=IsDatetime(), + ), + ] + ) + + +@pytest.mark.vcr() +async def test_json_schema_output(allow_model_requests: None, openai_api_key: str): + m = OpenAIResponsesModel('gpt-4o', provider=OpenAIProvider(api_key=openai_api_key)) + + class CityLocation(BaseModel): + """A city and its country.""" + + city: str + country: str + + agent = Agent(m, output_type=JsonSchemaOutput(CityLocation)) + + @agent.tool_plain + async def get_user_country() -> str: + return 'Mexico' + + result = await agent.run('What is the largest city in the user country?') + assert result.output == snapshot(CityLocation(city='Mexico City', country='Mexico')) + + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='What is the largest city in the user country?', + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[ + TextPart(content=''), + ToolCallPart(tool_name='get_user_country', args='{}', tool_call_id='call_tTAThu8l2S9hNky2krdwijGP'), + ], + usage=Usage( + request_tokens=66, + response_tokens=12, + total_tokens=78, + details={'reasoning_tokens': 0, 'cached_tokens': 0}, + ), + model_name='gpt-4o-2024-08-06', + timestamp=IsDatetime(), + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='get_user_country', + content='Mexico', + tool_call_id='call_tTAThu8l2S9hNky2krdwijGP', + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[TextPart(content='{"city":"Mexico City","country":"Mexico"}')], + usage=Usage( + request_tokens=89, + response_tokens=16, + total_tokens=105, + details={'reasoning_tokens': 0, 'cached_tokens': 0}, + ), + model_name='gpt-4o-2024-08-06', + timestamp=IsDatetime(), + ), + ] + ) + + +@pytest.mark.vcr() +async def test_json_schema_output_multiple(allow_model_requests: None, openai_api_key: str): + m = OpenAIResponsesModel('gpt-4o', provider=OpenAIProvider(api_key=openai_api_key)) + + class CityLocation(BaseModel): + city: str + country: str + + class CountryLanguage(BaseModel): + country: str + language: str + + agent = Agent(m, output_type=JsonSchemaOutput([CityLocation, CountryLanguage])) + + @agent.tool_plain + async def get_user_country() -> str: + return 'Mexico' + + result = await agent.run('What is the largest city in the user country?') + assert result.output == snapshot(CityLocation(city='Mexico City', country='Mexico')) + + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='What is the largest city in the user country?', + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[ + TextPart(content=''), + ToolCallPart(tool_name='get_user_country', args='{}', tool_call_id='call_UaLahjOtaM2tTyYZLxTCbOaP'), + ], + usage=Usage( + request_tokens=153, + response_tokens=12, + total_tokens=165, + details={'reasoning_tokens': 0, 'cached_tokens': 0}, + ), + model_name='gpt-4o-2024-08-06', + timestamp=IsDatetime(), + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='get_user_country', + content='Mexico', + tool_call_id='call_UaLahjOtaM2tTyYZLxTCbOaP', + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[ + TextPart( + content='{"result":{"kind":"CityLocation","data":{"city":"Mexico City","country":"Mexico"}}}' + ) + ], + usage=Usage( + request_tokens=176, + response_tokens=26, + total_tokens=202, + details={'reasoning_tokens': 0, 'cached_tokens': 0}, + ), + model_name='gpt-4o-2024-08-06', + timestamp=IsDatetime(), + ), + ] + ) + + +@pytest.mark.vcr() +async def test_prompted_json_output(allow_model_requests: None, openai_api_key: str): + m = OpenAIResponsesModel('gpt-4o', provider=OpenAIProvider(api_key=openai_api_key)) + + class CityLocation(BaseModel): + city: str + country: str + + agent = Agent(m, output_type=PromptedJsonOutput(CityLocation)) + + @agent.tool_plain + async def get_user_country() -> str: + return 'Mexico' + + result = await agent.run('What is the largest city in the user country?') + assert result.output == snapshot(CityLocation(city='Mexico City', country='Mexico')) + + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='What is the largest city in the user country?', + timestamp=IsDatetime(), + ) + ], + instructions="""\ +Always respond with a JSON object that's compatible with this schema: + +{"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} + +Don't include any text or Markdown fencing before or after.\ +""", + ), + ModelResponse( + parts=[ + TextPart(content=''), + ToolCallPart(tool_name='get_user_country', args='{}', tool_call_id='call_FrlL4M0CbAy8Dhv4VqF1Shom'), + ], + usage=Usage( + request_tokens=107, + response_tokens=12, + total_tokens=119, + details={'reasoning_tokens': 0, 'cached_tokens': 0}, + ), + model_name='gpt-4o-2024-08-06', + timestamp=IsDatetime(), + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='get_user_country', + content='Mexico', + tool_call_id='call_FrlL4M0CbAy8Dhv4VqF1Shom', + timestamp=IsDatetime(), + ) + ], + instructions="""\ +Always respond with a JSON object that's compatible with this schema: + +{"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "CityLocation", "type": "object"} + +Don't include any text or Markdown fencing before or after.\ +""", + ), + ModelResponse( + parts=[TextPart(content='{"city":"Mexico City","country":"Mexico"}')], + usage=Usage( + request_tokens=130, + response_tokens=12, + total_tokens=142, + details={'reasoning_tokens': 0, 'cached_tokens': 0}, + ), + model_name='gpt-4o-2024-08-06', + timestamp=IsDatetime(), + ), + ] + ) + + +@pytest.mark.vcr() +async def test_prompted_json_output_multiple(allow_model_requests: None, openai_api_key: str): + m = OpenAIResponsesModel('gpt-4o', provider=OpenAIProvider(api_key=openai_api_key)) + + class CityLocation(BaseModel): + city: str + country: str + + class CountryLanguage(BaseModel): + country: str + language: str + + agent = Agent(m, output_type=PromptedJsonOutput([CityLocation, CountryLanguage])) + + @agent.tool_plain + async def get_user_country() -> str: + return 'Mexico' + + result = await agent.run('What is the largest city in the user country?') + assert result.output == snapshot(CityLocation(city='Mexico City', country='Mexico')) + + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='What is the largest city in the user country?', + timestamp=IsDatetime(), + ) + ], + instructions="""\ +Always respond with a JSON object that's compatible with this schema: + +{"type": "object", "properties": {"result": {"anyOf": [{"type": "object", "properties": {"kind": {"type": "string", "const": "CityLocation"}, "data": {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "type": "object"}}, "required": ["kind", "data"], "additionalProperties": false, "title": "CityLocation"}, {"type": "object", "properties": {"kind": {"type": "string", "const": "CountryLanguage"}, "data": {"properties": {"country": {"type": "string"}, "language": {"type": "string"}}, "required": ["country", "language"], "type": "object"}}, "required": ["kind", "data"], "additionalProperties": false, "title": "CountryLanguage"}]}}, "required": ["result"], "additionalProperties": false} + +Don't include any text or Markdown fencing before or after.\ +""", + ), + ModelResponse( + parts=[ + TextPart(content=''), + ToolCallPart(tool_name='get_user_country', args='{}', tool_call_id='call_my4OyoVXRT0m7bLWmsxcaCQI'), + ], + usage=Usage( + request_tokens=283, + response_tokens=12, + total_tokens=295, + details={'reasoning_tokens': 0, 'cached_tokens': 0}, + ), + model_name='gpt-4o-2024-08-06', + timestamp=IsDatetime(), + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='get_user_country', + content='Mexico', + tool_call_id='call_my4OyoVXRT0m7bLWmsxcaCQI', + timestamp=IsDatetime(), + ) + ], + instructions="""\ +Always respond with a JSON object that's compatible with this schema: + +{"type": "object", "properties": {"result": {"anyOf": [{"type": "object", "properties": {"kind": {"type": "string", "const": "CityLocation"}, "data": {"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "type": "object"}}, "required": ["kind", "data"], "additionalProperties": false, "title": "CityLocation"}, {"type": "object", "properties": {"kind": {"type": "string", "const": "CountryLanguage"}, "data": {"properties": {"country": {"type": "string"}, "language": {"type": "string"}}, "required": ["country", "language"], "type": "object"}}, "required": ["kind", "data"], "additionalProperties": false, "title": "CountryLanguage"}]}}, "required": ["result"], "additionalProperties": false} + +Don't include any text or Markdown fencing before or after.\ +""", + ), + ModelResponse( + parts=[ + TextPart( + content='{"result":{"kind":"CityLocation","data":{"city":"Mexico City","country":"Mexico"}}}' + ) + ], + usage=Usage( + request_tokens=306, + response_tokens=22, + total_tokens=328, + details={'reasoning_tokens': 0, 'cached_tokens': 0}, + ), + model_name='gpt-4o-2024-08-06', + timestamp=IsDatetime(), + ), + ] + ) diff --git a/tests/test_agent.py b/tests/test_agent.py index d59848155..28e541852 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -13,6 +13,15 @@ from typing_extensions import Self from pydantic_ai import Agent, ModelRetry, RunContext, UnexpectedModelBehavior, UserError, capture_run_messages +from pydantic_ai._output import ( + JsonSchemaOutput, + OutputSpec, + PromptedJsonOutput, + TextOutput, + TextOutputSchema, + ToolOutput, + ToolOutputSchema, +) from pydantic_ai.agent import AgentRunResult from pydantic_ai.messages import ( BinaryContent, @@ -30,7 +39,8 @@ ) from pydantic_ai.models.function import AgentInfo, FunctionModel from pydantic_ai.models.test import TestModel -from pydantic_ai.result import ToolOutput, Usage +from pydantic_ai.profiles import ModelProfile +from pydantic_ai.result import Usage from pydantic_ai.tools import ToolDefinition from .conftest import IsDatetime, IsNow, IsStr, TestEnv @@ -261,7 +271,7 @@ def return_tuple(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: args_json = '{"response": ["foo", "bar"]}' return ModelResponse(parts=[ToolCallPart(info.output_tools[0].name, args_json)]) - agent = Agent(FunctionModel(return_tuple), output_type=tuple[str, str]) + agent = Agent(FunctionModel(return_tuple), output_type=ToolOutput(tuple[str, str])) result = agent.run_sync('Hello') assert result.output == ('foo', 'bar') @@ -353,7 +363,7 @@ def test_response_tuple(): m = TestModel() agent = Agent(m, output_type=tuple[str, str]) - assert agent._output_schema.allow_text_output is False # pyright: ignore[reportPrivateUsage,reportOptionalMemberAccess] + assert isinstance(agent._output_schema, ToolOutputSchema) # pyright: ignore[reportPrivateUsage] result = agent.run_sync('Hello') assert result.output == snapshot(('a', 'a')) @@ -388,10 +398,28 @@ def test_response_tuple(): ) +def upcase(text: str) -> str: + return text.upper() + + @pytest.mark.parametrize( 'input_union_callable', - [lambda: Union[str, Foo], lambda: Union[Foo, str], lambda: str | Foo, lambda: Foo | str, lambda: [Foo, str]], - ids=['Union[str, Foo]', 'Union[Foo, str]', 'str | Foo', 'Foo | str', '[Foo, str]'], + [ + lambda: Union[str, Foo], + lambda: Union[Foo, str], + lambda: str | Foo, + lambda: Foo | str, + lambda: [Foo, str], + lambda: [TextOutput(upcase), ToolOutput(Foo)], + ], + ids=[ + 'Union[str, Foo]', + 'Union[Foo, str]', + 'str | Foo', + 'Foo | str', + '[Foo, str]', + '[TextOutput(upcase), ToolOutput(Foo)]', + ], ) def test_response_union_allow_str(input_union_callable: Callable[[], Any]): try: @@ -410,10 +438,11 @@ def validate_output(ctx: RunContext[None], o: Any) -> Any: got_tool_call_name = ctx.tool_name return o - assert agent._output_schema.allow_text_output is True # pyright: ignore[reportPrivateUsage,reportOptionalMemberAccess] + assert isinstance(agent._output_schema, TextOutputSchema) # pyright: ignore[reportPrivateUsage] result = agent.run_sync('Hello') - assert result.output == snapshot('success (no tool calls)') + assert isinstance(result.output, str) + assert result.output.lower() == snapshot('success (no tool calls)') assert got_tool_call_name == snapshot(None) assert m.last_model_request_parameters is not None @@ -448,6 +477,7 @@ def validate_output(ctx: RunContext[None], o: Any) -> Any: [ pytest.param('OutputType = Union[Foo, Bar]'), pytest.param('OutputType = [Foo, Bar]'), + pytest.param('OutputType = [ToolOutput(Foo), ToolOutput(Bar)]'), pytest.param('OutputType = Foo | Bar', marks=pytest.mark.skipif(sys.version_info < (3, 10), reason='3.10+')), pytest.param( 'OutputType: TypeAlias = Foo | Bar', @@ -846,6 +876,83 @@ def call_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: ) +def test_output_type_text_output_function_with_retry(): + class Weather(BaseModel): + temperature: float + description: str + + def get_weather(ctx: RunContext[None], city: str) -> Weather: + assert ctx is not None + if city != 'Mexico City': + raise ModelRetry('City not found, I only know Mexico City') + return Weather(temperature=28.7, description='sunny') + + def call_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + assert info.output_tools is not None + + if len(messages) == 1: + city = 'New York City' + else: + city = 'Mexico City' + + return ModelResponse(parts=[TextPart(content=city)]) + + agent = Agent(FunctionModel(call_tool), output_type=TextOutput(get_weather)) + result = agent.run_sync('New York City') + assert result.output == snapshot(Weather(temperature=28.7, description='sunny')) + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='New York City', + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[TextPart(content='New York City')], + usage=Usage(requests=1, request_tokens=53, response_tokens=3, total_tokens=56), + model_name='function:call_tool:', + timestamp=IsDatetime(), + ), + ModelRequest( + parts=[ + RetryPromptPart( + content='City not found, I only know Mexico City', + tool_call_id=IsStr(), + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[TextPart(content='Mexico City')], + usage=Usage(requests=1, request_tokens=68, response_tokens=5, total_tokens=73), + model_name='function:call_tool:', + timestamp=IsDatetime(), + ), + ] + ) + + +@pytest.mark.parametrize( + 'output_type', + [[str, str], [str, TextOutput(upcase)], [TextOutput(upcase), TextOutput(upcase)]], +) +def test_output_type_multiple_text_output(output_type: OutputSpec[str]): + with pytest.raises(UserError, match='Only one text output is allowed.'): + Agent('test', output_type=output_type) + + +def test_output_type_text_output_invalid(): + def int_func(x: int) -> str: + return str(int) # pragma: no cover + + with pytest.raises(UserError, match='TextOutput must take a function taking a `str`'): + output_type: TextOutput[str] = TextOutput(int_func) # type: ignore + Agent('test', output_type=output_type) + + def test_output_type_async_function(): class Weather(BaseModel): temperature: float @@ -970,6 +1077,33 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: ) +def test_output_type_text_output_function(): + def say_world(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: + return ModelResponse(parts=[TextPart(content='world')]) + + agent = Agent(FunctionModel(say_world), output_type=TextOutput(upcase)) + result = agent.run_sync('hello') + assert result.output == snapshot('WORLD') + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='hello', + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[TextPart(content='world')], + usage=Usage(requests=1, request_tokens=51, response_tokens=1, total_tokens=52), + model_name='function:say_world:', + timestamp=IsDatetime(), + ), + ] + ) + + def test_output_type_handoff_to_agent(): class Weather(BaseModel): temperature: float @@ -1129,6 +1263,257 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: ) +def test_output_type_prompted_json(): + def return_city_location(_: list[ModelMessage], _info: AgentInfo) -> ModelResponse: + text = CityLocation(city='Mexico City', country='Mexico').model_dump_json() + return ModelResponse(parts=[TextPart(content=text)]) + + m = FunctionModel(return_city_location) + + class CityLocation(BaseModel): + """Description from docstring.""" + + city: str + country: str + + agent = Agent( + m, + output_type=PromptedJsonOutput( + CityLocation, name='City & Country', description='Description from PromptedJsonOutput' + ), + ) + + result = agent.run_sync('What is the capital of Mexico?') + assert result.output == snapshot(CityLocation(city='Mexico City', country='Mexico')) + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='What is the capital of Mexico?', + timestamp=IsDatetime(), + ) + ], + instructions="""\ +Always respond with a JSON object that's compatible with this schema: + +{"properties": {"city": {"type": "string"}, "country": {"type": "string"}}, "required": ["city", "country"], "title": "City & Country", "type": "object", "description": "Description from PromptedJsonOutput. Description from docstring."} + +Don't include any text or Markdown fencing before or after.\ +""", + ), + ModelResponse( + parts=[TextPart(content='{"city":"Mexico City","country":"Mexico"}')], + usage=Usage(requests=1, request_tokens=56, response_tokens=7, total_tokens=63), + model_name='function:return_city_location:', + timestamp=IsDatetime(), + ), + ] + ) + + +def test_output_type_prompted_json_with_defs(): + class Foo(BaseModel): + """Foo description""" + + foo: str + + class Bar(BaseModel): + """Bar description""" + + bar: str + + class Baz(BaseModel): + """Baz description""" + + baz: str + + class FooBar(BaseModel): + """FooBar description""" + + foo: Foo + bar: Bar + + class FooBaz(BaseModel): + """FooBaz description""" + + foo: Foo + baz: Baz + + def return_foo_bar(_: list[ModelMessage], _info: AgentInfo) -> ModelResponse: + text = '{"result": {"kind": "FooBar", "data": {"foo": {"foo": "foo"}, "bar": {"bar": "bar"}}}}' + return ModelResponse(parts=[TextPart(content=text)]) + + m = FunctionModel(return_foo_bar) + + agent = Agent( + m, + output_type=PromptedJsonOutput( + [FooBar, FooBaz], name='FooBar or FooBaz', description='FooBar or FooBaz description' + ), + ) + + result = agent.run_sync('What is foo?') + assert result.output == snapshot(FooBar(foo=Foo(foo='foo'), bar=Bar(bar='bar'))) + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='What is foo?', + timestamp=IsDatetime(), + ) + ], + instructions="""\ +Always respond with a JSON object that's compatible with this schema: + +{"type": "object", "properties": {"result": {"anyOf": [{"type": "object", "properties": {"kind": {"type": "string", "const": "FooBar"}, "data": {"properties": {"foo": {"$ref": "#/$defs/Foo"}, "bar": {"$ref": "#/$defs/Bar"}}, "required": ["foo", "bar"], "type": "object"}}, "required": ["kind", "data"], "additionalProperties": false, "title": "FooBar", "description": "FooBar description"}, {"type": "object", "properties": {"kind": {"type": "string", "const": "FooBaz"}, "data": {"properties": {"foo": {"$ref": "#/$defs/Foo"}, "baz": {"$ref": "#/$defs/Baz"}}, "required": ["foo", "baz"], "type": "object"}}, "required": ["kind", "data"], "additionalProperties": false, "title": "FooBaz", "description": "FooBaz description"}]}}, "required": ["result"], "additionalProperties": false, "$defs": {"Bar": {"description": "Bar description", "properties": {"bar": {"type": "string"}}, "required": ["bar"], "title": "Bar", "type": "object"}, "Foo": {"description": "Foo description", "properties": {"foo": {"type": "string"}}, "required": ["foo"], "title": "Foo", "type": "object"}, "Baz": {"description": "Baz description", "properties": {"baz": {"type": "string"}}, "required": ["baz"], "title": "Baz", "type": "object"}}, "title": "FooBar or FooBaz", "description": "FooBaz description"} + +Don't include any text or Markdown fencing before or after.\ +""", + ), + ModelResponse( + parts=[ + TextPart( + content='{"result": {"kind": "FooBar", "data": {"foo": {"foo": "foo"}, "bar": {"bar": "bar"}}}}' + ) + ], + usage=Usage(requests=1, request_tokens=53, response_tokens=17, total_tokens=70), + model_name='function:return_foo_bar:', + timestamp=IsDatetime(), + ), + ] + ) + + +def test_output_type_json_schema(): + def return_city_location(messages: list[ModelMessage], _info: AgentInfo) -> ModelResponse: + if len(messages) == 1: + text = '{"city": "Mexico City"}' + else: + text = '{"city": "Mexico City", "country": "Mexico"}' + return ModelResponse(parts=[TextPart(content=text)]) + + m = FunctionModel(return_city_location, profile=ModelProfile(output_modes={'tool', 'json_schema'})) + + class CityLocation(BaseModel): + city: str + country: str + + agent = Agent( + m, + output_type=JsonSchemaOutput(CityLocation), + ) + + result = agent.run_sync('What is the capital of Mexico?') + assert result.output == snapshot(CityLocation(city='Mexico City', country='Mexico')) + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='What is the capital of Mexico?', + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[TextPart(content='{"city": "Mexico City"}')], + usage=Usage(requests=1, request_tokens=56, response_tokens=5, total_tokens=61), + model_name='function:return_city_location:', + timestamp=IsDatetime(), + ), + ModelRequest( + parts=[ + RetryPromptPart( + content=[ + { + 'type': 'missing', + 'loc': ('country',), + 'msg': 'Field required', + 'input': {'city': 'Mexico City'}, + } + ], + tool_call_id=IsStr(), + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[TextPart(content='{"city": "Mexico City", "country": "Mexico"}')], + usage=Usage(requests=1, request_tokens=85, response_tokens=12, total_tokens=97), + model_name='function:return_city_location:', + timestamp=IsDatetime(), + ), + ] + ) + + +def test_output_type_prompted_json_function_with_retry(): + class Weather(BaseModel): + temperature: float + description: str + + def get_weather(city: str) -> Weather: + if city != 'Mexico City': + raise ModelRetry('City not found, I only know Mexico City') + return Weather(temperature=28.7, description='sunny') + + def call_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + assert info.output_tools is not None + + if len(messages) == 1: + args_json = '{"city": "New York City"}' + else: + args_json = '{"city": "Mexico City"}' + + return ModelResponse(parts=[TextPart(content=args_json)]) + + agent = Agent(FunctionModel(call_tool), output_type=PromptedJsonOutput(get_weather)) + result = agent.run_sync('New York City') + assert result.output == snapshot(Weather(temperature=28.7, description='sunny')) + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='New York City', + timestamp=IsDatetime(), + ) + ], + instructions="""\ +Always respond with a JSON object that's compatible with this schema: + +{"additionalProperties": false, "properties": {"city": {"type": "string"}}, "required": ["city"], "type": "object", "title": "get_weather"} + +Don't include any text or Markdown fencing before or after.\ +""", + ), + ModelResponse( + parts=[TextPart(content='{"city": "New York City"}')], + usage=Usage(requests=1, request_tokens=53, response_tokens=6, total_tokens=59), + model_name='function:call_tool:', + timestamp=IsDatetime(), + ), + ModelRequest( + parts=[ + RetryPromptPart( + content='City not found, I only know Mexico City', + tool_call_id=IsStr(), + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[TextPart(content='{"city": "Mexico City"}')], + usage=Usage(requests=1, request_tokens=68, response_tokens=11, total_tokens=79), + model_name='function:call_tool:', + timestamp=IsDatetime(), + ), + ] + ) + + def test_run_with_history_new(): m = TestModel() @@ -2486,6 +2871,7 @@ def instructions() -> str: parts=[UserPromptPart(content='Hello', timestamp=IsDatetime())], instructions="""\ You are a helpful assistant. + You are a potato.\ """, ) @@ -2617,3 +3003,13 @@ def foo_tool(foo: Foo) -> int: 'kind': 'request', } ) + + +def test_unsupported_output_mode(): + class Foo(BaseModel): + bar: str + + agent = Agent('test', output_type=JsonSchemaOutput(Foo)) + + with pytest.raises(UserError, match="Output mode 'json_schema' is not among supported modes: 'tool'"): + agent.run_sync('Hello') diff --git a/tests/test_logfire.py b/tests/test_logfire.py index 2d179d193..96de967f9 100644 --- a/tests/test_logfire.py +++ b/tests/test_logfire.py @@ -291,8 +291,10 @@ async def my_ret(x: int) -> str: 'strict': None, } ], - 'allow_text_output': True, + 'output_mode': 'text', 'output_tools': [], + 'output_object': None, + 'allow_text_output': True, } ) ), @@ -472,7 +474,7 @@ async def test_feedback(capfire: CaptureLogfire) -> None: 'gen_ai.operation.name': 'chat', 'gen_ai.system': 'test', 'gen_ai.request.model': 'test', - 'model_request_parameters': '{"function_tools": [], "allow_text_output": true, "output_tools": []}', + 'model_request_parameters': '{"function_tools": [], "output_mode": "text", "output_object": null, "output_tools": [], "allow_text_output": true}', 'logfire.span_type': 'span', 'logfire.msg': 'chat test', 'gen_ai.usage.input_tokens': 51, diff --git a/tests/test_streaming.py b/tests/test_streaming.py index 12d29a9b0..75935d33f 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -13,6 +13,7 @@ from pydantic import BaseModel from pydantic_ai import Agent, UnexpectedModelBehavior, UserError, capture_run_messages +from pydantic_ai._output import PromptedJsonOutput, TextOutput from pydantic_ai.agent import AgentRun from pydantic_ai.messages import ( ModelMessage, @@ -192,6 +193,22 @@ async def test_streamed_text_stream(): ['The ', 'cat ', 'sat ', 'on ', 'the ', 'mat.'] ) + def upcase(text: str) -> str: + return text.upper() + + async with agent.run_stream('Hello', output_type=TextOutput(upcase)) as result: + assert [c async for c in result.stream(debounce_by=None)] == snapshot( + [ + 'THE ', + 'THE CAT ', + 'THE CAT SAT ', + 'THE CAT SAT ON ', + 'THE CAT SAT ON THE ', + 'THE CAT SAT ON THE MAT.', + 'THE CAT SAT ON THE MAT.', + ] + ) + async with agent.run_stream('Hello') as result: assert [c async for c, _is_last in result.stream_structured(debounce_by=None)] == snapshot( [ @@ -921,3 +938,26 @@ def output_validator(data: OutputType | NotOutputType) -> OutputType | NotOutput async for output in stream.stream_output(debounce_by=None): outputs.append(output) assert outputs == [OutputType(value='a (validated)'), OutputType(value='a (validated)')] + + +async def test_stream_output_type_prompted_json(): + class CityLocation(BaseModel): + city: str + country: str | None = None + + m = TestModel(custom_output_text='{"city": "Mexico City", "country": "Mexico"}') + + agent = Agent(m, output_type=PromptedJsonOutput(CityLocation)) + + async with agent.run_stream('') as result: + assert not result.is_complete + assert [c async for c in result.stream(debounce_by=None)] == snapshot( + [ + CityLocation(city='Mexico '), + CityLocation(city='Mexico City'), + CityLocation(city='Mexico City'), + CityLocation(city='Mexico City', country='Mexico'), + CityLocation(city='Mexico City', country='Mexico'), + ] + ) + assert result.is_complete diff --git a/tests/test_tools.py b/tests/test_tools.py index 4c09b31e8..8be113610 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -571,7 +571,7 @@ def test_tool_return_conflict(): # this is also okay Agent('test', tools=[ctx_tool], deps_type=int, output_type=int) # this raises an error - with pytest.raises(UserError, match="Tool name conflicts with result schema name: 'ctx_tool'"): + with pytest.raises(UserError, match="Tool name conflicts with output tool name: 'ctx_tool'"): Agent('test', tools=[ctx_tool], deps_type=int, output_type=ToolOutput(int, name='ctx_tool')) diff --git a/tests/test_utils.py b/tests/test_utils.py index e7d3ddcf3..afd0b7fa9 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -10,7 +10,15 @@ from inline_snapshot import snapshot from pydantic_ai import UserError -from pydantic_ai._utils import UNSET, PeekableAsyncStream, check_object_json_schema, group_by_temporal, run_in_executor +from pydantic_ai._utils import ( + UNSET, + PeekableAsyncStream, + check_object_json_schema, + group_by_temporal, + merge_json_schema_defs, + run_in_executor, + strip_markdown_fences, +) from .models.mock_async_stream import MockAsyncStream @@ -153,3 +161,327 @@ async def test_run_in_executor_with_contextvars() -> None: # show that the old version did not work old_result = asyncio.get_running_loop().run_in_executor(None, ctx_var.get) assert old_result != ctx_var.get() + + +def test_merge_json_schema_defs(): + foo_bar_schema = { + '$defs': { + 'Bar': { + 'description': 'Bar description', + 'properties': {'bar': {'type': 'string'}}, + 'required': ['bar'], + 'title': 'Bar', + 'type': 'object', + }, + 'Foo': { + 'description': 'Foo description', + 'properties': {'foo': {'type': 'string'}}, + 'required': ['foo'], + 'title': 'Foo', + 'type': 'object', + }, + }, + 'properties': {'foo': {'$ref': '#/$defs/Foo'}, 'bar': {'$ref': '#/$defs/Bar'}}, + 'required': ['foo', 'bar'], + 'type': 'object', + 'title': 'FooBar', + } + + foo_bar_baz_schema = { + '$defs': { + 'Baz': { + 'description': 'Baz description', + 'properties': {'baz': {'type': 'string'}}, + 'required': ['baz'], + 'title': 'Baz', + 'type': 'object', + }, + 'Foo': { + 'description': 'Foo description. Note that this is different from the Foo in foo_bar_schema!', + 'properties': {'foo': {'type': 'int'}}, + 'required': ['foo'], + 'title': 'Foo', + 'type': 'object', + }, + 'Bar': { + 'description': 'Bar description', + 'properties': {'bar': {'type': 'string'}}, + 'required': ['bar'], + 'title': 'Bar', + 'type': 'object', + }, + }, + 'properties': {'foo': {'$ref': '#/$defs/Foo'}, 'baz': {'$ref': '#/$defs/Baz'}, 'bar': {'$ref': '#/$defs/Bar'}}, + 'required': ['foo', 'baz', 'bar'], + 'type': 'object', + 'title': 'FooBarBaz', + } + + # A schema with no title that will cause numeric suffixes + no_title_schema = { + '$defs': { + 'Foo': { + 'description': 'Another different Foo', + 'properties': {'foo': {'type': 'boolean'}}, + 'required': ['foo'], + 'title': 'Foo', + 'type': 'object', + }, + 'Bar': { + 'description': 'Another different Bar', + 'properties': {'bar': {'type': 'number'}}, + 'required': ['bar'], + 'title': 'Bar', + 'type': 'object', + }, + }, + 'properties': {'foo': {'$ref': '#/$defs/Foo'}, 'bar': {'$ref': '#/$defs/Bar'}}, + 'required': ['foo', 'bar'], + 'type': 'object', + } + + # Another schema with no title that will cause more numeric suffixes + another_no_title_schema = { + '$defs': { + 'Foo': { + 'description': 'Yet another different Foo', + 'properties': {'foo': {'type': 'array'}}, + 'required': ['foo'], + 'title': 'Foo', + 'type': 'object', + }, + 'Bar': { + 'description': 'Yet another different Bar', + 'properties': {'bar': {'type': 'object'}}, + 'required': ['bar'], + 'title': 'Bar', + 'type': 'object', + }, + }, + 'properties': {'foo': {'$ref': '#/$defs/Foo'}, 'bar': {'$ref': '#/$defs/Bar'}}, + 'required': ['foo', 'bar'], + 'type': 'object', + } + + # Schema with nested properties, array items, prefixItems, and anyOf/oneOf + complex_schema = { + '$defs': { + 'Nested': { + 'description': 'A nested type', + 'properties': {'nested': {'type': 'string'}}, + 'required': ['nested'], + 'title': 'Nested', + 'type': 'object', + }, + 'ArrayItem': { + 'description': 'An array item type', + 'properties': {'item': {'type': 'string'}}, + 'required': ['item'], + 'title': 'ArrayItem', + 'type': 'object', + }, + 'UnionType': { + 'description': 'A union type', + 'properties': {'union': {'type': 'string'}}, + 'required': ['union'], + 'title': 'UnionType', + 'type': 'object', + }, + }, + 'properties': { + 'nested_props': { + 'type': 'object', + 'properties': { + 'deep_nested': {'$ref': '#/$defs/Nested'}, + }, + }, + 'array_with_items': { + 'type': 'array', + 'items': {'$ref': '#/$defs/ArrayItem'}, + }, + 'array_with_prefix': { + 'type': 'array', + 'prefixItems': [ + {'$ref': '#/$defs/ArrayItem'}, + {'$ref': '#/$defs/Nested'}, + ], + }, + 'union_anyOf': { + 'anyOf': [ + {'$ref': '#/$defs/UnionType'}, + {'$ref': '#/$defs/Nested'}, + ], + }, + 'union_oneOf': { + 'oneOf': [ + {'$ref': '#/$defs/UnionType'}, + {'$ref': '#/$defs/ArrayItem'}, + ], + }, + }, + 'type': 'object', + 'title': 'ComplexSchema', + } + + schemas = [foo_bar_schema, foo_bar_baz_schema, no_title_schema, another_no_title_schema, complex_schema] + rewritten_schemas, all_defs = merge_json_schema_defs(schemas) + assert all_defs == snapshot( + { + 'Bar': { + 'description': 'Bar description', + 'properties': {'bar': {'type': 'string'}}, + 'required': ['bar'], + 'title': 'Bar', + 'type': 'object', + }, + 'Foo': { + 'description': 'Foo description', + 'properties': {'foo': {'type': 'string'}}, + 'required': ['foo'], + 'title': 'Foo', + 'type': 'object', + }, + 'Baz': { + 'description': 'Baz description', + 'properties': {'baz': {'type': 'string'}}, + 'required': ['baz'], + 'title': 'Baz', + 'type': 'object', + }, + 'FooBarBaz_Foo_1': { + 'description': 'Foo description. Note that this is different from the Foo in foo_bar_schema!', + 'properties': {'foo': {'type': 'int'}}, + 'required': ['foo'], + 'title': 'Foo', + 'type': 'object', + }, + 'Foo_1': { + 'description': 'Another different Foo', + 'properties': {'foo': {'type': 'boolean'}}, + 'required': ['foo'], + 'title': 'Foo', + 'type': 'object', + }, + 'Bar_1': { + 'description': 'Another different Bar', + 'properties': {'bar': {'type': 'number'}}, + 'required': ['bar'], + 'title': 'Bar', + 'type': 'object', + }, + 'Foo_2': { + 'description': 'Yet another different Foo', + 'properties': {'foo': {'type': 'array'}}, + 'required': ['foo'], + 'title': 'Foo', + 'type': 'object', + }, + 'Bar_2': { + 'description': 'Yet another different Bar', + 'properties': {'bar': {'type': 'object'}}, + 'required': ['bar'], + 'title': 'Bar', + 'type': 'object', + }, + 'Nested': { + 'description': 'A nested type', + 'properties': {'nested': {'type': 'string'}}, + 'required': ['nested'], + 'title': 'Nested', + 'type': 'object', + }, + 'ArrayItem': { + 'description': 'An array item type', + 'properties': {'item': {'type': 'string'}}, + 'required': ['item'], + 'title': 'ArrayItem', + 'type': 'object', + }, + 'UnionType': { + 'description': 'A union type', + 'properties': {'union': {'type': 'string'}}, + 'required': ['union'], + 'title': 'UnionType', + 'type': 'object', + }, + } + ) + assert rewritten_schemas == snapshot( + [ + { + 'properties': {'foo': {'$ref': '#/$defs/Foo'}, 'bar': {'$ref': '#/$defs/Bar'}}, + 'required': ['foo', 'bar'], + 'type': 'object', + 'title': 'FooBar', + }, + { + 'properties': { + 'foo': {'$ref': '#/$defs/FooBarBaz_Foo_1'}, + 'baz': {'$ref': '#/$defs/Baz'}, + 'bar': {'$ref': '#/$defs/Bar'}, + }, + 'required': ['foo', 'baz', 'bar'], + 'type': 'object', + 'title': 'FooBarBaz', + }, + { + 'properties': {'foo': {'$ref': '#/$defs/Foo_1'}, 'bar': {'$ref': '#/$defs/Bar_1'}}, + 'required': ['foo', 'bar'], + 'type': 'object', + }, + { + 'properties': {'foo': {'$ref': '#/$defs/Foo_2'}, 'bar': {'$ref': '#/$defs/Bar_2'}}, + 'required': ['foo', 'bar'], + 'type': 'object', + }, + { + 'properties': { + 'nested_props': { + 'type': 'object', + 'properties': { + 'deep_nested': {'$ref': '#/$defs/Nested'}, + }, + }, + 'array_with_items': { + 'type': 'array', + 'items': {'$ref': '#/$defs/ArrayItem'}, + }, + 'array_with_prefix': { + 'type': 'array', + 'prefixItems': [ + {'$ref': '#/$defs/ArrayItem'}, + {'$ref': '#/$defs/Nested'}, + ], + }, + 'union_anyOf': { + 'anyOf': [ + {'$ref': '#/$defs/UnionType'}, + {'$ref': '#/$defs/Nested'}, + ], + }, + 'union_oneOf': { + 'oneOf': [ + {'$ref': '#/$defs/UnionType'}, + {'$ref': '#/$defs/ArrayItem'}, + ], + }, + }, + 'type': 'object', + 'title': 'ComplexSchema', + }, + ] + ) + + +def test_strip_markdown_fences(): + assert strip_markdown_fences('{"foo": "bar"}') == '{"foo": "bar"}' + assert strip_markdown_fences('```json\n{"foo": "bar"}\n```') == '{"foo": "bar"}' + assert ( + strip_markdown_fences('{"foo": "```json\\n{"foo": "bar"}\\n```"}') + == '{"foo": "```json\\n{"foo": "bar"}\\n```"}' + ) + assert ( + strip_markdown_fences('Here is some beautiful JSON:\n\n```\n{"foo": "bar"}\n``` Nice right?') + == '{"foo": "bar"}' + ) + assert strip_markdown_fences('No JSON to be found') == 'No JSON to be found' diff --git a/tests/typed_agent.py b/tests/typed_agent.py index 180ce2b0d..eaa2c4fa8 100644 --- a/tests/typed_agent.py +++ b/tests/typed_agent.py @@ -1,13 +1,15 @@ """This file is used to test static typing, it's analyzed with pyright and mypy.""" +import re from collections.abc import Awaitable from dataclasses import dataclass +from decimal import Decimal from typing import Callable, TypeAlias, Union from typing_extensions import assert_type from pydantic_ai import Agent, ModelRetry, RunContext, Tool -from pydantic_ai._output import ToolOutput +from pydantic_ai._output import TextOutput, ToolOutput from pydantic_ai.agent import AgentRunResult from pydantic_ai.tools import ToolDefinition @@ -169,21 +171,25 @@ def run_sync3() -> None: assert_type(union_agent2, Agent[None, MyUnion]) -def foobar_ctx(ctx: RunContext[int], x: str, y: int) -> str: - return f'{x} {y}' +def foobar_ctx(ctx: RunContext[int], x: str, y: int) -> Decimal: + return Decimal(x) + y async def foobar_plain(x: int, y: int) -> int: return x * y +def str_to_regex(text: str) -> re.Pattern[str]: + return re.compile(text) + + class MyClass: def my_method(self) -> bool: return True -str_function_agent = Agent(output_type=foobar_ctx) -assert_type(str_function_agent, Agent[None, str]) +decimal_function_agent = Agent(output_type=foobar_ctx) +assert_type(decimal_function_agent, Agent[None, Decimal]) bool_method_agent = Agent(output_type=MyClass().my_method) assert_type(bool_method_agent, Agent[None, bool]) @@ -200,10 +206,12 @@ def my_method(self) -> bool: assert_type(two_scalars_output_agent, Agent[None, int | str]) marker: ToolOutput[bool | tuple[str, int]] = ToolOutput(bool | tuple[str, int]) # type: ignore - complex_output_agent = Agent[None, Foo | Bar | str | int | bool | tuple[str, int]]( - output_type=[Foo, Bar, foobar_ctx, ToolOutput[int](foobar_plain), marker] + complex_output_agent = Agent[None, Foo | Bar | Decimal | int | bool | tuple[str, int] | str | re.Pattern[str]]( + output_type=[str, Foo, Bar, foobar_ctx, ToolOutput[int](foobar_plain), marker, TextOutput(str_to_regex)] + ) + assert_type( + complex_output_agent, Agent[None, Foo | Bar | Decimal | int | bool | tuple[str, int] | str | re.Pattern[str]] ) - assert_type(complex_output_agent, Agent[None, Foo | Bar | str | int | bool | tuple[str, int]]) else: # pyright is able to correctly infer the type here async_int_function_agent = Agent(output_type=foobar_plain) @@ -216,8 +224,12 @@ def my_method(self) -> bool: assert_type(two_scalars_output_agent, Agent[None, int | str]) marker: ToolOutput[bool | tuple[str, int]] = ToolOutput(bool | tuple[str, int]) # type: ignore - complex_output_agent = Agent(output_type=[Foo, Bar, foobar_ctx, ToolOutput(foobar_plain), marker]) - assert_type(complex_output_agent, Agent[None, Foo | Bar | str | int | bool | tuple[str, int]]) + complex_output_agent = Agent( + output_type=[str, Foo, Bar, foobar_ctx, ToolOutput(foobar_plain), marker, TextOutput(str_to_regex)] + ) + assert_type( + complex_output_agent, Agent[None, Foo | Bar | Decimal | int | bool | tuple[str, int] | str | re.Pattern[str]] + ) Tool(foobar_ctx, takes_ctx=True)