13
13
from typing_inspection import typing_objects
14
14
from typing_inspection .introspection import is_union_origin
15
15
16
+ from pydantic_ai .profiles import ModelProfile
17
+
16
18
from . import _function_schema , _utils , messages as _messages
17
19
from .exceptions import ModelRetry
18
20
from .tools import AgentDepsT , GenerateToolJsonSchema , ObjectJsonSchema , RunContext , ToolDefinition
@@ -208,7 +210,7 @@ def __init__(
208
210
)
209
211
210
212
# TODO: Add `json_object` for old OpenAI models, or rename `json_schema` to `json` and choose automatically, relying on Pydantic validation
211
- type OutputMode = Literal ['tool' , 'json_schema' , 'manual_json' ]
213
+ type OutputMode = Literal ['text' , ' tool' , 'tool_or_text ' , 'json_schema' , 'manual_json' ]
212
214
213
215
214
216
@dataclass
@@ -218,50 +220,46 @@ class OutputSchema(Generic[OutputDataT]):
218
220
Similar to `Tool` but for the final output of running an agent.
219
221
"""
220
222
221
- forced_mode : OutputMode | None
222
- object_schema : OutputObjectSchema [OutputDataT ] | OutputUnionSchema [OutputDataT ]
223
- tools : dict [str , OutputTool [OutputDataT ]]
224
- allow_text_output : Literal ['plain' , 'json' ] | None = None
223
+ mode : OutputMode | None
224
+ object_schema : OutputObjectSchema [OutputDataT ] | OutputUnionSchema [OutputDataT ] | None = None
225
+ tools : dict [str , OutputTool [OutputDataT ]] = field (default_factory = dict )
225
226
226
227
@classmethod
227
228
def build (
228
229
cls : type [OutputSchema [OutputDataT ]],
229
230
output_type : OutputType [OutputDataT ],
230
- name : str | None = None ,
231
- description : str | None = None ,
232
- strict : bool | None = None ,
233
- ) -> OutputSchema [OutputDataT ] | None :
231
+ name : str | None ,
232
+ description : str | None ,
233
+ ) -> OutputSchema [OutputDataT ]:
234
234
"""Build an OutputSchema dataclass from an output type."""
235
235
if output_type is str :
236
- return None
236
+ return cls ( mode = 'text' )
237
237
238
- forced_mode : OutputMode | None = None
239
- allow_text_output : Literal ['plain' , 'json' ] | None = 'plain'
238
+ mode : OutputMode | None = None
240
239
tools : dict [str , OutputTool [OutputDataT ]] = {}
240
+ strict : bool | None = None
241
241
242
242
output_types : Sequence [OutputTypeOrFunction [OutputDataT ]]
243
243
if isinstance (output_type , JSONSchemaOutput ):
244
- forced_mode = 'json_schema'
244
+ mode = 'json_schema'
245
245
output_types = output_type .output_types
246
246
name = output_type .name # TODO: If not set, use method arg?
247
247
description = output_type .description
248
248
strict = output_type .strict
249
- allow_text_output = 'json'
250
249
elif isinstance (output_type , ManualJSONOutput ):
251
- forced_mode = 'manual_json'
250
+ mode = 'manual_json'
252
251
output_types = output_type .output_types
253
252
name = output_type .name
254
253
description = output_type .description
255
- allow_text_output = 'json'
256
254
else :
257
- # TODO: We can't always force tool mode here, because some models may not support tools but will work with manual_json
258
255
output_types_or_tool_outputs = flatten_output_types (output_type )
259
256
260
257
if str in output_types_or_tool_outputs :
261
- forced_mode = 'tool'
262
- allow_text_output = 'plain'
263
- # TODO: What if str is the only item, e.g. `output_type=[str]`
264
- output_types_or_tool_outputs = [t for t in output_types_or_tool_outputs if t is not str ]
258
+ if len (output_types_or_tool_outputs ) == 1 :
259
+ return cls (mode = 'text' )
260
+ else :
261
+ mode = 'tool_or_text'
262
+ output_types_or_tool_outputs = [t for t in output_types_or_tool_outputs if t is not str ]
265
263
266
264
multiple = len (output_types_or_tool_outputs ) > 1
267
265
@@ -275,7 +273,9 @@ def build(
275
273
tool_description = None
276
274
tool_strict = None
277
275
if isinstance (output_type_or_tool_output , ToolOutput ):
278
- forced_mode = 'tool'
276
+ if mode is None :
277
+ mode = 'tool'
278
+
279
279
tool_output = output_type_or_tool_output
280
280
output_type = tool_output .output_type
281
281
# do we need to error on conflicts here? (DavidM): If this is internal maybe doesn't matter, if public, use overloads
@@ -307,7 +307,6 @@ def build(
307
307
output_types .append (output_type )
308
308
309
309
output_types = flatten_output_types (output_types )
310
-
311
310
if len (output_types ) > 1 :
312
311
output_object_schema = OutputUnionSchema (
313
312
output_types = output_types , name = name , description = description , strict = strict
@@ -318,12 +317,30 @@ def build(
318
317
)
319
318
320
319
return cls (
321
- forced_mode = forced_mode ,
320
+ mode = mode ,
322
321
object_schema = output_object_schema ,
323
322
tools = tools ,
324
- allow_text_output = allow_text_output ,
325
323
)
326
324
325
+ @property
326
+ def allow_text_output (self ) -> Literal ['plain' , 'json' , False ]:
327
+ """Whether the model allows text output."""
328
+ if self .mode in ('text' , 'tool_or_text' ):
329
+ return 'plain'
330
+ elif self .mode in ('json_schema' , 'manual_json' ):
331
+ return 'json'
332
+ else : # tool-only mode
333
+ return False
334
+
335
+ def is_mode_supported (self , profile : ModelProfile ) -> bool :
336
+ """Whether the model supports the output mode."""
337
+ mode = self .mode
338
+ if mode in ('text' , 'manual_json' ):
339
+ return True
340
+ if self .mode == 'tool_or_text' :
341
+ mode = 'tool'
342
+ return mode in profile .output_modes
343
+
327
344
def find_named_tool (
328
345
self , parts : Iterable [_messages .ModelResponsePart ], tool_name : str
329
346
) -> tuple [_messages .ToolCallPart , OutputTool [OutputDataT ]] | None :
@@ -369,16 +386,18 @@ async def process(
369
386
Returns:
370
387
Either the validated output data (left) or a retry message (right).
371
388
"""
389
+ assert self .allow_text_output is not False
390
+
391
+ if self .allow_text_output == 'plain' :
392
+ return cast (OutputDataT , data )
393
+
394
+ assert self .object_schema is not None
395
+
372
396
return await self .object_schema .process (
373
397
data , run_context , allow_partial = allow_partial , wrap_validation_errors = wrap_validation_errors
374
398
)
375
399
376
400
377
- def allow_text_output (output_schema : OutputSchema [Any ] | None ) -> bool :
378
- # TODO: Add plain/json argument?
379
- return output_schema is None or output_schema .allow_text_output is not None
380
-
381
-
382
401
@dataclass
383
402
class OutputObjectDefinition :
384
403
name : str
@@ -389,6 +408,7 @@ class OutputObjectDefinition:
389
408
@property
390
409
def manual_json_instructions (self ) -> str :
391
410
"""Get instructions for model to output manual JSON matching the schema."""
411
+ # TODO: Move to ModelProfile so it can be tweaked
392
412
description = ': ' .join ([v for v in [self .name , self .description ] if v ])
393
413
return DEFAULT_MANUAL_JSON_PROMPT .format (schema = json .dumps (self .json_schema ), description = description )
394
414
0 commit comments