Skip to content

Commit 2056539

Browse files
committed
WIP: More output modes
1 parent e290951 commit 2056539

26 files changed

+216
-220
lines changed

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 23 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from contextlib import asynccontextmanager, contextmanager
88
from contextvars import ContextVar
99
from dataclasses import field
10-
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union, cast
10+
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union
1111

1212
from opentelemetry.trace import Tracer
1313
from typing_extensions import TypeGuard, TypeVar, assert_never
@@ -90,7 +90,7 @@ class GraphAgentDeps(Generic[DepsT, OutputDataT]):
9090
end_strategy: EndStrategy
9191
get_instructions: Callable[[RunContext[DepsT]], Awaitable[str | None]]
9292

93-
output_schema: _output.OutputSchema[OutputDataT] | None
93+
output_schema: _output.OutputSchema[OutputDataT]
9494
output_validators: list[_output.OutputValidator[DepsT, OutputDataT]]
9595

9696
function_tools: dict[str, Tool[DepsT]] = dataclasses.field(repr=False)
@@ -264,29 +264,14 @@ async def add_mcp_server_tools(server: MCPServer) -> None:
264264
function_tool_defs = await ctx.deps.prepare_tools(run_context, function_tool_defs) or []
265265

266266
output_schema = ctx.deps.output_schema
267-
model = ctx.deps.model
268-
269-
# TODO: This is horrible
270-
output_mode = None
271-
output_object = None
272-
output_tools = []
273-
require_tool_use = False
274-
if output_schema:
275-
output_mode = output_schema.forced_mode or model.default_output_mode
276-
output_object = output_schema.object_schema.definition
277-
output_tools = output_schema.tool_defs()
278-
require_tool_use = output_mode == 'tool' and output_schema.allow_text_output != 'plain'
279-
280-
supported_modes = model.supported_output_modes
281-
if output_mode not in supported_modes:
282-
raise exceptions.UserError(f"Output mode '{output_mode}' is not among supported modes: {supported_modes}")
267+
assert output_schema.mode is not None # Should have been set in agent._prepare_output_schema
283268

284269
return models.ModelRequestParameters(
285270
function_tools=function_tool_defs,
286-
output_mode=output_mode,
287-
output_object=output_object,
288-
output_tools=output_tools,
289-
require_tool_use=require_tool_use,
271+
output_mode=output_schema.mode,
272+
output_object=output_schema.object_schema.definition if output_schema.object_schema else None,
273+
output_tools=output_schema.tool_defs(),
274+
allow_text_output=output_schema.allow_text_output == 'plain',
290275
)
291276

292277

@@ -471,7 +456,7 @@ async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]:
471456
# when the model has already returned text along side tool calls
472457
# in this scenario, if text responses are allowed, we return text from the most recent model
473458
# response, if any
474-
if _output.allow_text_output(ctx.deps.output_schema):
459+
if ctx.deps.output_schema.allow_text_output:
475460
for message in reversed(ctx.state.message_history):
476461
if isinstance(message, _messages.ModelResponse):
477462
last_texts = [p.content for p in message.parts if isinstance(p, _messages.TextPart)]
@@ -497,19 +482,18 @@ async def _handle_tool_calls(
497482
# first, look for the output tool call
498483
final_result: result.FinalResult[NodeRunEndT] | None = None
499484
parts: list[_messages.ModelRequestPart] = []
500-
if output_schema is not None:
501-
for call, output_tool in output_schema.find_tool(tool_calls):
502-
try:
503-
result_data = await output_tool.process(call, run_context)
504-
result_data = await _validate_output(result_data, ctx, call)
505-
except _output.ToolRetryError as e:
506-
# TODO: Should only increment retry stuff once per node execution, not for each tool call
507-
# Also, should increment the tool-specific retry count rather than the run retry count
508-
ctx.state.increment_retries(ctx.deps.max_result_retries, e)
509-
parts.append(e.tool_retry)
510-
else:
511-
final_result = result.FinalResult(result_data, call.tool_name, call.tool_call_id)
512-
break
485+
for call, output_tool in output_schema.find_tool(tool_calls):
486+
try:
487+
result_data = await output_tool.process(call, run_context)
488+
result_data = await _validate_output(result_data, ctx, call)
489+
except _output.ToolRetryError as e:
490+
# TODO: Should only increment retry stuff once per node execution, not for each tool call
491+
# Also, should increment the tool-specific retry count rather than the run retry count
492+
ctx.state.increment_retries(ctx.deps.max_result_retries, e)
493+
parts.append(e.tool_retry)
494+
else:
495+
final_result = result.FinalResult(result_data, call.tool_name, call.tool_call_id)
496+
break
513497

514498
# Then build the other request parts based on end strategy
515499
tool_responses: list[_messages.ModelRequestPart] = self._tool_responses
@@ -555,10 +539,7 @@ async def _handle_text_response(
555539

556540
text = '\n\n'.join(texts)
557541
try:
558-
if output_schema is None or output_schema.allow_text_output == 'plain':
559-
# The following cast is safe because we know `str` is an allowed result type
560-
result_data = cast(NodeRunEndT, text)
561-
elif output_schema.allow_text_output == 'json':
542+
if output_schema.allow_text_output:
562543
run_context = build_run_context(ctx)
563544
result_data = await output_schema.process(text, run_context)
564545
else:
@@ -659,7 +640,7 @@ async def process_function_tools( # noqa C901
659640
yield event
660641
call_index_to_event_id[len(calls_to_run)] = event.call_id
661642
calls_to_run.append((mcp_tool, call))
662-
elif output_schema is not None and call.tool_name in output_schema.tools:
643+
elif call.tool_name in output_schema.tools:
663644
# if tool_name is in output_schema, it means we found a output tool but an error occurred in
664645
# validation, we don't add another part here
665646
if output_tool_name is not None:
@@ -788,8 +769,7 @@ def _unknown_tool(
788769
) -> _messages.RetryPromptPart:
789770
ctx.state.increment_retries(ctx.deps.max_result_retries)
790771
tool_names = list(ctx.deps.function_tools.keys())
791-
if output_schema := ctx.deps.output_schema:
792-
tool_names.extend(output_schema.tool_names())
772+
tool_names.extend(ctx.deps.output_schema.tool_names())
793773

794774
if tool_names:
795775
msg = f'Available tools: {", ".join(tool_names)}'

pydantic_ai_slim/pydantic_ai/_output.py

Lines changed: 50 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
from typing_inspection import typing_objects
1414
from typing_inspection.introspection import is_union_origin
1515

16+
from pydantic_ai.profiles import ModelProfile
17+
1618
from . import _function_schema, _utils, messages as _messages
1719
from .exceptions import ModelRetry
1820
from .tools import AgentDepsT, GenerateToolJsonSchema, ObjectJsonSchema, RunContext, ToolDefinition
@@ -208,7 +210,7 @@ def __init__(
208210
)
209211

210212
# TODO: Add `json_object` for old OpenAI models, or rename `json_schema` to `json` and choose automatically, relying on Pydantic validation
211-
type OutputMode = Literal['tool', 'json_schema', 'manual_json']
213+
type OutputMode = Literal['text', 'tool', 'tool_or_text', 'json_schema', 'manual_json']
212214

213215

214216
@dataclass
@@ -218,50 +220,46 @@ class OutputSchema(Generic[OutputDataT]):
218220
Similar to `Tool` but for the final output of running an agent.
219221
"""
220222

221-
forced_mode: OutputMode | None
222-
object_schema: OutputObjectSchema[OutputDataT] | OutputUnionSchema[OutputDataT]
223-
tools: dict[str, OutputTool[OutputDataT]]
224-
allow_text_output: Literal['plain', 'json'] | None = None
223+
mode: OutputMode | None
224+
object_schema: OutputObjectSchema[OutputDataT] | OutputUnionSchema[OutputDataT] | None = None
225+
tools: dict[str, OutputTool[OutputDataT]] = field(default_factory=dict)
225226

226227
@classmethod
227228
def build(
228229
cls: type[OutputSchema[OutputDataT]],
229230
output_type: OutputType[OutputDataT],
230-
name: str | None = None,
231-
description: str | None = None,
232-
strict: bool | None = None,
233-
) -> OutputSchema[OutputDataT] | None:
231+
name: str | None,
232+
description: str | None,
233+
) -> OutputSchema[OutputDataT]:
234234
"""Build an OutputSchema dataclass from an output type."""
235235
if output_type is str:
236-
return None
236+
return cls(mode='text')
237237

238-
forced_mode: OutputMode | None = None
239-
allow_text_output: Literal['plain', 'json'] | None = 'plain'
238+
mode: OutputMode | None = None
240239
tools: dict[str, OutputTool[OutputDataT]] = {}
240+
strict: bool | None = None
241241

242242
output_types: Sequence[OutputTypeOrFunction[OutputDataT]]
243243
if isinstance(output_type, JSONSchemaOutput):
244-
forced_mode = 'json_schema'
244+
mode = 'json_schema'
245245
output_types = output_type.output_types
246246
name = output_type.name # TODO: If not set, use method arg?
247247
description = output_type.description
248248
strict = output_type.strict
249-
allow_text_output = 'json'
250249
elif isinstance(output_type, ManualJSONOutput):
251-
forced_mode = 'manual_json'
250+
mode = 'manual_json'
252251
output_types = output_type.output_types
253252
name = output_type.name
254253
description = output_type.description
255-
allow_text_output = 'json'
256254
else:
257-
# TODO: We can't always force tool mode here, because some models may not support tools but will work with manual_json
258255
output_types_or_tool_outputs = flatten_output_types(output_type)
259256

260257
if str in output_types_or_tool_outputs:
261-
forced_mode = 'tool'
262-
allow_text_output = 'plain'
263-
# TODO: What if str is the only item, e.g. `output_type=[str]`
264-
output_types_or_tool_outputs = [t for t in output_types_or_tool_outputs if t is not str]
258+
if len(output_types_or_tool_outputs) == 1:
259+
return cls(mode='text')
260+
else:
261+
mode = 'tool_or_text'
262+
output_types_or_tool_outputs = [t for t in output_types_or_tool_outputs if t is not str]
265263

266264
multiple = len(output_types_or_tool_outputs) > 1
267265

@@ -275,7 +273,9 @@ def build(
275273
tool_description = None
276274
tool_strict = None
277275
if isinstance(output_type_or_tool_output, ToolOutput):
278-
forced_mode = 'tool'
276+
if mode is None:
277+
mode = 'tool'
278+
279279
tool_output = output_type_or_tool_output
280280
output_type = tool_output.output_type
281281
# do we need to error on conflicts here? (DavidM): If this is internal maybe doesn't matter, if public, use overloads
@@ -307,7 +307,6 @@ def build(
307307
output_types.append(output_type)
308308

309309
output_types = flatten_output_types(output_types)
310-
311310
if len(output_types) > 1:
312311
output_object_schema = OutputUnionSchema(
313312
output_types=output_types, name=name, description=description, strict=strict
@@ -318,12 +317,30 @@ def build(
318317
)
319318

320319
return cls(
321-
forced_mode=forced_mode,
320+
mode=mode,
322321
object_schema=output_object_schema,
323322
tools=tools,
324-
allow_text_output=allow_text_output,
325323
)
326324

325+
@property
326+
def allow_text_output(self) -> Literal['plain', 'json', False]:
327+
"""Whether the model allows text output."""
328+
if self.mode in ('text', 'tool_or_text'):
329+
return 'plain'
330+
elif self.mode in ('json_schema', 'manual_json'):
331+
return 'json'
332+
else: # tool-only mode
333+
return False
334+
335+
def is_mode_supported(self, profile: ModelProfile) -> bool:
336+
"""Whether the model supports the output mode."""
337+
mode = self.mode
338+
if mode in ('text', 'manual_json'):
339+
return True
340+
if self.mode == 'tool_or_text':
341+
mode = 'tool'
342+
return mode in profile.output_modes
343+
327344
def find_named_tool(
328345
self, parts: Iterable[_messages.ModelResponsePart], tool_name: str
329346
) -> tuple[_messages.ToolCallPart, OutputTool[OutputDataT]] | None:
@@ -369,16 +386,18 @@ async def process(
369386
Returns:
370387
Either the validated output data (left) or a retry message (right).
371388
"""
389+
assert self.allow_text_output is not False
390+
391+
if self.allow_text_output == 'plain':
392+
return cast(OutputDataT, data)
393+
394+
assert self.object_schema is not None
395+
372396
return await self.object_schema.process(
373397
data, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors
374398
)
375399

376400

377-
def allow_text_output(output_schema: OutputSchema[Any] | None) -> bool:
378-
# TODO: Add plain/json argument?
379-
return output_schema is None or output_schema.allow_text_output is not None
380-
381-
382401
@dataclass
383402
class OutputObjectDefinition:
384403
name: str
@@ -389,6 +408,7 @@ class OutputObjectDefinition:
389408
@property
390409
def manual_json_instructions(self) -> str:
391410
"""Get instructions for model to output manual JSON matching the schema."""
411+
# TODO: Move to ModelProfile so it can be tweaked
392412
description = ': '.join([v for v in [self.name, self.description] if v])
393413
return DEFAULT_MANUAL_JSON_PROMPT.format(schema=json.dumps(self.json_schema), description=description)
394414

0 commit comments

Comments
 (0)