Skip to content

Support structured and manual JSON output_type modes in addition to tool calls #1628

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion pydantic_ai_slim/pydantic_ai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
)
from .format_prompt import format_as_xml
from .messages import AudioUrl, BinaryContent, DocumentUrl, ImageUrl, VideoUrl
from .result import ToolOutput
from .result import JSONSchemaOutput, ManualJSONOutput, ToolOutput
from .tools import RunContext, Tool

__all__ = (
Expand Down Expand Up @@ -43,6 +43,8 @@
'RunContext',
# result
'ToolOutput',
'JSONSchemaOutput',
'ManualJSONOutput',
# format_prompt
'format_as_xml',
)
Expand Down
50 changes: 26 additions & 24 deletions pydantic_ai_slim/pydantic_ai/_agent_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from contextlib import asynccontextmanager, contextmanager
from contextvars import ContextVar
from dataclasses import field
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union, cast
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union

from opentelemetry.trace import Tracer
from typing_extensions import TypeGuard, TypeVar, assert_never
Expand Down Expand Up @@ -90,7 +90,7 @@ class GraphAgentDeps(Generic[DepsT, OutputDataT]):
end_strategy: EndStrategy
get_instructions: Callable[[RunContext[DepsT]], Awaitable[str | None]]

output_schema: _output.OutputSchema[OutputDataT] | None
output_schema: _output.OutputSchema[OutputDataT]
output_validators: list[_output.OutputValidator[DepsT, OutputDataT]]

function_tools: dict[str, Tool[DepsT]] = dataclasses.field(repr=False)
Expand Down Expand Up @@ -264,10 +264,14 @@ async def add_mcp_server_tools(server: MCPServer) -> None:
function_tool_defs = await ctx.deps.prepare_tools(run_context, function_tool_defs) or []

output_schema = ctx.deps.output_schema
assert output_schema.mode is not None # Should have been set in agent._prepare_output_schema

return models.ModelRequestParameters(
function_tools=function_tool_defs,
allow_text_output=_output.allow_text_output(output_schema),
output_tools=output_schema.tool_defs() if output_schema is not None else [],
output_mode=output_schema.mode,
output_object=output_schema.object_schema.definition if output_schema.object_schema else None,
output_tools=output_schema.tool_defs(),
allow_text_output=output_schema.allow_text_output == 'plain',
)


Expand Down Expand Up @@ -452,7 +456,7 @@ async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]:
# when the model has already returned text along side tool calls
# in this scenario, if text responses are allowed, we return text from the most recent model
# response, if any
if _output.allow_text_output(ctx.deps.output_schema):
if ctx.deps.output_schema.allow_text_output:
for message in reversed(ctx.state.message_history):
if isinstance(message, _messages.ModelResponse):
last_texts = [p.content for p in message.parts if isinstance(p, _messages.TextPart)]
Expand All @@ -478,19 +482,18 @@ async def _handle_tool_calls(
# first, look for the output tool call
final_result: result.FinalResult[NodeRunEndT] | None = None
parts: list[_messages.ModelRequestPart] = []
if output_schema is not None:
for call, output_tool in output_schema.find_tool(tool_calls):
try:
result_data = await output_tool.process(call, run_context)
result_data = await _validate_output(result_data, ctx, call)
except _output.ToolRetryError as e:
# TODO: Should only increment retry stuff once per node execution, not for each tool call
# Also, should increment the tool-specific retry count rather than the run retry count
ctx.state.increment_retries(ctx.deps.max_result_retries, e)
parts.append(e.tool_retry)
else:
final_result = result.FinalResult(result_data, call.tool_name, call.tool_call_id)
break
for call, output_tool in output_schema.find_tool(tool_calls):
try:
result_data = await output_tool.process(call, run_context)
result_data = await _validate_output(result_data, ctx, call)
except _output.ToolRetryError as e:
# TODO: Should only increment retry stuff once per node execution, not for each tool call
# Also, should increment the tool-specific retry count rather than the run retry count
ctx.state.increment_retries(ctx.deps.max_result_retries, e)
parts.append(e.tool_retry)
else:
final_result = result.FinalResult(result_data, call.tool_name, call.tool_call_id)
break

# Then build the other request parts based on end strategy
tool_responses: list[_messages.ModelRequestPart] = self._tool_responses
Expand Down Expand Up @@ -536,9 +539,9 @@ async def _handle_text_response(

text = '\n\n'.join(texts)
try:
if _output.allow_text_output(output_schema):
# The following cast is safe because we know `str` is an allowed result type
result_data = cast(NodeRunEndT, text)
if output_schema.allow_text_output:
run_context = build_run_context(ctx)
result_data = await output_schema.process(text, run_context)
else:
m = _messages.RetryPromptPart(
content='Plain text responses are not permitted, please include your response in a tool call',
Expand Down Expand Up @@ -637,7 +640,7 @@ async def process_function_tools( # noqa C901
yield event
call_index_to_event_id[len(calls_to_run)] = event.call_id
calls_to_run.append((mcp_tool, call))
elif output_schema is not None and call.tool_name in output_schema.tools:
elif call.tool_name in output_schema.tools:
# if tool_name is in output_schema, it means we found a output tool but an error occurred in
# validation, we don't add another part here
if output_tool_name is not None:
Expand Down Expand Up @@ -766,8 +769,7 @@ def _unknown_tool(
) -> _messages.RetryPromptPart:
ctx.state.increment_retries(ctx.deps.max_result_retries)
tool_names = list(ctx.deps.function_tools.keys())
if output_schema := ctx.deps.output_schema:
tool_names.extend(output_schema.tool_names())
tool_names.extend(ctx.deps.output_schema.tool_names())

if tool_names:
msg = f'Available tools: {", ".join(tool_names)}'
Expand Down
Loading
Loading