Skip to content

Commit 76ee86d

Browse files
committed
WIP: With OutputPart
1 parent c0afc0d commit 76ee86d

26 files changed

+1586
-220
lines changed

pydantic_ai_slim/pydantic_ai/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
)
1313
from .format_prompt import format_as_xml
1414
from .messages import AudioUrl, BinaryContent, DocumentUrl, ImageUrl, VideoUrl
15-
from .result import ToolOutput
15+
from .result import JSONSchemaOutput, ToolOutput
1616
from .tools import RunContext, Tool
1717

1818
__all__ = (
@@ -43,6 +43,7 @@
4343
'RunContext',
4444
# result
4545
'ToolOutput',
46+
'JSONSchemaOutput',
4647
# format_prompt
4748
'format_as_xml',
4849
)

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 56 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
result,
2525
usage as _usage,
2626
)
27-
from .result import OutputDataT, ToolOutput
27+
from .result import OutputDataT
2828
from .settings import ModelSettings, merge_model_settings
2929
from .tools import RunContext, Tool, ToolDefinition
3030

@@ -240,10 +240,29 @@ async def add_mcp_server_tools(server: MCPServer) -> None:
240240
)
241241

242242
output_schema = ctx.deps.output_schema
243+
model = ctx.deps.model
244+
245+
output_mode = None
246+
output_object = None
247+
output_tools = []
248+
allow_text_output = _output.allow_text_output(output_schema)
249+
if output_schema:
250+
output_mode = output_schema.forced_mode or model.default_output_mode
251+
output_object = output_schema.object_schema.definition
252+
output_tools = output_schema.tool_defs()
253+
if output_mode != 'tool':
254+
allow_text_output = False
255+
256+
supported_modes = model.supported_output_modes
257+
if output_mode not in supported_modes:
258+
raise exceptions.UserError(f"Output mode '{output_mode}' is not among supported modes: {supported_modes}")
259+
243260
return models.ModelRequestParameters(
244261
function_tools=function_tool_defs,
245-
allow_text_output=allow_text_output(output_schema),
246-
output_tools=output_schema.tool_defs() if output_schema is not None else [],
262+
output_mode=output_mode,
263+
output_object=output_object,
264+
output_tools=output_tools,
265+
allow_text_output=allow_text_output,
247266
)
248267

249268

@@ -394,20 +413,24 @@ async def stream(
394413
async for _event in stream:
395414
pass
396415

397-
async def _run_stream(
416+
async def _run_stream( # noqa: C901
398417
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
399418
) -> AsyncIterator[_messages.HandleResponseEvent]:
400419
if self._events_iterator is None:
401420
# Ensure that the stream is only run once
402421

403422
async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]:
404423
texts: list[str] = []
424+
outputs: list[str] = []
405425
tool_calls: list[_messages.ToolCallPart] = []
406426
for part in self.model_response.parts:
407427
if isinstance(part, _messages.TextPart):
408428
# ignore empty content for text parts, see #437
409429
if part.content:
410430
texts.append(part.content)
431+
elif isinstance(part, _messages.OutputPart):
432+
if part.content:
433+
outputs.append(part.content)
411434
elif isinstance(part, _messages.ToolCallPart):
412435
tool_calls.append(part)
413436
else:
@@ -420,6 +443,9 @@ async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]:
420443
if tool_calls:
421444
async for event in self._handle_tool_calls(ctx, tool_calls):
422445
yield event
446+
elif outputs: # TODO: Can we have tool calls and structured output? Should we handle both?
447+
# No events are emitted during the handling of structured outputs, so we don't need to yield anything
448+
self._next_node = await self._handle_outputs(ctx, outputs)
423449
elif texts:
424450
# No events are emitted during the handling of text responses, so we don't need to yield anything
425451
self._next_node = await self._handle_text_response(ctx, texts)
@@ -428,7 +454,7 @@ async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]:
428454
# when the model has already returned text along side tool calls
429455
# in this scenario, if text responses are allowed, we return text from the most recent model
430456
# response, if any
431-
if allow_text_output(ctx.deps.output_schema):
457+
if _output.allow_text_output(ctx.deps.output_schema):
432458
for message in reversed(ctx.state.message_history):
433459
if isinstance(message, _messages.ModelResponse):
434460
last_texts = [p.content for p in message.parts if isinstance(p, _messages.TextPart)]
@@ -511,7 +537,7 @@ async def _handle_text_response(
511537
output_schema = ctx.deps.output_schema
512538

513539
text = '\n\n'.join(texts)
514-
if allow_text_output(output_schema):
540+
if _output.allow_text_output(output_schema):
515541
# The following cast is safe because we know `str` is an allowed result type
516542
result_data_input = cast(NodeRunEndT, text)
517543
try:
@@ -533,6 +559,27 @@ async def _handle_text_response(
533559
)
534560
)
535561

562+
async def _handle_outputs(
563+
self,
564+
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
565+
outputs: list[str],
566+
) -> ModelRequestNode[DepsT, NodeRunEndT] | End[result.FinalResult[NodeRunEndT]]:
567+
if len(outputs) != 1:
568+
raise exceptions.UnexpectedModelBehavior('Received multiple structured outputs in a single response')
569+
output_schema = ctx.deps.output_schema
570+
if not output_schema:
571+
raise exceptions.UnexpectedModelBehavior('Must specify a non-str result_type when using structured outputs')
572+
573+
structured_output = outputs[0]
574+
try:
575+
result_data = output_schema.validate(structured_output)
576+
result_data = await _validate_output(result_data, ctx, None)
577+
except _output.ToolRetryError as e:
578+
ctx.state.increment_retries(ctx.deps.max_result_retries)
579+
return ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=[e.tool_retry]))
580+
else:
581+
return self._handle_final_result(ctx, result.FinalResult(result_data, None, None), [])
582+
536583

537584
def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]) -> RunContext[DepsT]:
538585
"""Build a `RunContext` object from the current agent graph run context."""
@@ -773,11 +820,6 @@ async def _validate_output(
773820
return result_data
774821

775822

776-
def allow_text_output(output_schema: _output.OutputSchema[Any] | None) -> bool:
777-
"""Check if the result schema allows text results."""
778-
return output_schema is None or output_schema.allow_text_output
779-
780-
781823
@dataclasses.dataclass
782824
class _RunMessages:
783825
messages: list[_messages.ModelMessage]
@@ -827,7 +869,9 @@ def get_captured_run_messages() -> _RunMessages:
827869

828870

829871
def build_agent_graph(
830-
name: str | None, deps_type: type[DepsT], output_type: type[OutputT] | ToolOutput[OutputT]
872+
name: str | None,
873+
deps_type: type[DepsT],
874+
output_type: _output.OutputType[OutputT],
831875
) -> Graph[GraphAgentState, GraphAgentDeps[DepsT, result.FinalResult[OutputT]], result.FinalResult[OutputT]]:
832876
"""Build the execution [Graph][pydantic_graph.Graph] for a given agent."""
833877
nodes = (

0 commit comments

Comments
 (0)