24
24
result ,
25
25
usage as _usage ,
26
26
)
27
- from .result import OutputDataT , ToolOutput
27
+ from .result import OutputDataT
28
28
from .settings import ModelSettings , merge_model_settings
29
29
from .tools import RunContext , Tool , ToolDefinition
30
30
@@ -240,10 +240,29 @@ async def add_mcp_server_tools(server: MCPServer) -> None:
240
240
)
241
241
242
242
output_schema = ctx .deps .output_schema
243
+ model = ctx .deps .model
244
+
245
+ output_mode = None
246
+ output_object = None
247
+ output_tools = []
248
+ allow_text_output = _output .allow_text_output (output_schema )
249
+ if output_schema :
250
+ output_mode = output_schema .forced_mode or model .default_output_mode
251
+ output_object = output_schema .object_schema .definition
252
+ output_tools = output_schema .tool_defs ()
253
+ if output_mode != 'tool' :
254
+ allow_text_output = False
255
+
256
+ supported_modes = model .supported_output_modes
257
+ if output_mode not in supported_modes :
258
+ raise exceptions .UserError (f"Output mode '{ output_mode } ' is not among supported modes: { supported_modes } " )
259
+
243
260
return models .ModelRequestParameters (
244
261
function_tools = function_tool_defs ,
245
- allow_text_output = allow_text_output (output_schema ),
246
- output_tools = output_schema .tool_defs () if output_schema is not None else [],
262
+ output_mode = output_mode ,
263
+ output_object = output_object ,
264
+ output_tools = output_tools ,
265
+ allow_text_output = allow_text_output ,
247
266
)
248
267
249
268
@@ -394,20 +413,24 @@ async def stream(
394
413
async for _event in stream :
395
414
pass
396
415
397
- async def _run_stream (
416
+ async def _run_stream ( # noqa: C901
398
417
self , ctx : GraphRunContext [GraphAgentState , GraphAgentDeps [DepsT , NodeRunEndT ]]
399
418
) -> AsyncIterator [_messages .HandleResponseEvent ]:
400
419
if self ._events_iterator is None :
401
420
# Ensure that the stream is only run once
402
421
403
422
async def _run_stream () -> AsyncIterator [_messages .HandleResponseEvent ]:
404
423
texts : list [str ] = []
424
+ outputs : list [str ] = []
405
425
tool_calls : list [_messages .ToolCallPart ] = []
406
426
for part in self .model_response .parts :
407
427
if isinstance (part , _messages .TextPart ):
408
428
# ignore empty content for text parts, see #437
409
429
if part .content :
410
430
texts .append (part .content )
431
+ elif isinstance (part , _messages .OutputPart ):
432
+ if part .content :
433
+ outputs .append (part .content )
411
434
elif isinstance (part , _messages .ToolCallPart ):
412
435
tool_calls .append (part )
413
436
else :
@@ -420,6 +443,9 @@ async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]:
420
443
if tool_calls :
421
444
async for event in self ._handle_tool_calls (ctx , tool_calls ):
422
445
yield event
446
+ elif outputs : # TODO: Can we have tool calls and structured output? Should we handle both?
447
+ # No events are emitted during the handling of structured outputs, so we don't need to yield anything
448
+ self ._next_node = await self ._handle_outputs (ctx , outputs )
423
449
elif texts :
424
450
# No events are emitted during the handling of text responses, so we don't need to yield anything
425
451
self ._next_node = await self ._handle_text_response (ctx , texts )
@@ -428,7 +454,7 @@ async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]:
428
454
# when the model has already returned text along side tool calls
429
455
# in this scenario, if text responses are allowed, we return text from the most recent model
430
456
# response, if any
431
- if allow_text_output (ctx .deps .output_schema ):
457
+ if _output . allow_text_output (ctx .deps .output_schema ):
432
458
for message in reversed (ctx .state .message_history ):
433
459
if isinstance (message , _messages .ModelResponse ):
434
460
last_texts = [p .content for p in message .parts if isinstance (p , _messages .TextPart )]
@@ -511,7 +537,7 @@ async def _handle_text_response(
511
537
output_schema = ctx .deps .output_schema
512
538
513
539
text = '\n \n ' .join (texts )
514
- if allow_text_output (output_schema ):
540
+ if _output . allow_text_output (output_schema ):
515
541
# The following cast is safe because we know `str` is an allowed result type
516
542
result_data_input = cast (NodeRunEndT , text )
517
543
try :
@@ -533,6 +559,27 @@ async def _handle_text_response(
533
559
)
534
560
)
535
561
562
+ async def _handle_outputs (
563
+ self ,
564
+ ctx : GraphRunContext [GraphAgentState , GraphAgentDeps [DepsT , NodeRunEndT ]],
565
+ outputs : list [str ],
566
+ ) -> ModelRequestNode [DepsT , NodeRunEndT ] | End [result .FinalResult [NodeRunEndT ]]:
567
+ if len (outputs ) != 1 :
568
+ raise exceptions .UnexpectedModelBehavior ('Received multiple structured outputs in a single response' )
569
+ output_schema = ctx .deps .output_schema
570
+ if not output_schema :
571
+ raise exceptions .UnexpectedModelBehavior ('Must specify a non-str result_type when using structured outputs' )
572
+
573
+ structured_output = outputs [0 ]
574
+ try :
575
+ result_data = output_schema .validate (structured_output )
576
+ result_data = await _validate_output (result_data , ctx , None )
577
+ except _output .ToolRetryError as e :
578
+ ctx .state .increment_retries (ctx .deps .max_result_retries )
579
+ return ModelRequestNode [DepsT , NodeRunEndT ](_messages .ModelRequest (parts = [e .tool_retry ]))
580
+ else :
581
+ return self ._handle_final_result (ctx , result .FinalResult (result_data , None , None ), [])
582
+
536
583
537
584
def build_run_context (ctx : GraphRunContext [GraphAgentState , GraphAgentDeps [DepsT , Any ]]) -> RunContext [DepsT ]:
538
585
"""Build a `RunContext` object from the current agent graph run context."""
@@ -773,11 +820,6 @@ async def _validate_output(
773
820
return result_data
774
821
775
822
776
- def allow_text_output (output_schema : _output .OutputSchema [Any ] | None ) -> bool :
777
- """Check if the result schema allows text results."""
778
- return output_schema is None or output_schema .allow_text_output
779
-
780
-
781
823
@dataclasses .dataclass
782
824
class _RunMessages :
783
825
messages : list [_messages .ModelMessage ]
@@ -827,7 +869,9 @@ def get_captured_run_messages() -> _RunMessages:
827
869
828
870
829
871
def build_agent_graph (
830
- name : str | None , deps_type : type [DepsT ], output_type : type [OutputT ] | ToolOutput [OutputT ]
872
+ name : str | None ,
873
+ deps_type : type [DepsT ],
874
+ output_type : _output .OutputType [OutputT ],
831
875
) -> Graph [GraphAgentState , GraphAgentDeps [DepsT , result .FinalResult [OutputT ]], result .FinalResult [OutputT ]]:
832
876
"""Build the execution [Graph][pydantic_graph.Graph] for a given agent."""
833
877
nodes = (
0 commit comments