Skip to content

Commit b892fc8

Browse files
committed
WIP: Remove OutputPart, work around allow_text_output instead
1 parent 76ee86d commit b892fc8

22 files changed

+132
-237
lines changed

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 13 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -245,13 +245,12 @@ async def add_mcp_server_tools(server: MCPServer) -> None:
245245
output_mode = None
246246
output_object = None
247247
output_tools = []
248-
allow_text_output = _output.allow_text_output(output_schema)
248+
require_tool_use = False
249249
if output_schema:
250250
output_mode = output_schema.forced_mode or model.default_output_mode
251251
output_object = output_schema.object_schema.definition
252252
output_tools = output_schema.tool_defs()
253-
if output_mode != 'tool':
254-
allow_text_output = False
253+
require_tool_use = output_mode == 'tool' and not output_schema.allow_plain_text_output
255254

256255
supported_modes = model.supported_output_modes
257256
if output_mode not in supported_modes:
@@ -262,7 +261,7 @@ async def add_mcp_server_tools(server: MCPServer) -> None:
262261
output_mode=output_mode,
263262
output_object=output_object,
264263
output_tools=output_tools,
265-
allow_text_output=allow_text_output,
264+
require_tool_use=require_tool_use,
266265
)
267266

268267

@@ -413,24 +412,20 @@ async def stream(
413412
async for _event in stream:
414413
pass
415414

416-
async def _run_stream( # noqa: C901
415+
async def _run_stream(
417416
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
418417
) -> AsyncIterator[_messages.HandleResponseEvent]:
419418
if self._events_iterator is None:
420419
# Ensure that the stream is only run once
421420

422421
async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]:
423422
texts: list[str] = []
424-
outputs: list[str] = []
425423
tool_calls: list[_messages.ToolCallPart] = []
426424
for part in self.model_response.parts:
427425
if isinstance(part, _messages.TextPart):
428426
# ignore empty content for text parts, see #437
429427
if part.content:
430428
texts.append(part.content)
431-
elif isinstance(part, _messages.OutputPart):
432-
if part.content:
433-
outputs.append(part.content)
434429
elif isinstance(part, _messages.ToolCallPart):
435430
tool_calls.append(part)
436431
else:
@@ -443,9 +438,6 @@ async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]:
443438
if tool_calls:
444439
async for event in self._handle_tool_calls(ctx, tool_calls):
445440
yield event
446-
elif outputs: # TODO: Can we have tool calls and structured output? Should we handle both?
447-
# No events are emitted during the handling of structured outputs, so we don't need to yield anything
448-
self._next_node = await self._handle_outputs(ctx, outputs)
449441
elif texts:
450442
# No events are emitted during the handling of text responses, so we don't need to yield anything
451443
self._next_node = await self._handle_text_response(ctx, texts)
@@ -537,42 +529,18 @@ async def _handle_text_response(
537529
output_schema = ctx.deps.output_schema
538530

539531
text = '\n\n'.join(texts)
540-
if _output.allow_text_output(output_schema):
541-
# The following cast is safe because we know `str` is an allowed result type
542-
result_data_input = cast(NodeRunEndT, text)
543-
try:
544-
result_data = await _validate_output(result_data_input, ctx, None)
545-
except _output.ToolRetryError as e:
546-
ctx.state.increment_retries(ctx.deps.max_result_retries)
547-
return ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=[e.tool_retry]))
532+
try:
533+
if output_schema is None or output_schema.allow_plain_text_output:
534+
# The following cast is safe because we know `str` is an allowed result type
535+
result_data = cast(NodeRunEndT, text)
536+
elif output_schema.allow_json_text_output:
537+
result_data = output_schema.validate(text)
548538
else:
549-
return self._handle_final_result(ctx, result.FinalResult(result_data, None, None), [])
550-
else:
551-
ctx.state.increment_retries(ctx.deps.max_result_retries)
552-
return ModelRequestNode[DepsT, NodeRunEndT](
553-
_messages.ModelRequest(
554-
parts=[
555-
_messages.RetryPromptPart(
556-
content='Plain text responses are not permitted, please include your response in a tool call',
557-
)
558-
]
539+
m = _messages.RetryPromptPart(
540+
content='Plain text responses are not permitted, please include your response in a tool call',
559541
)
560-
)
542+
raise _output.ToolRetryError(m)
561543

562-
async def _handle_outputs(
563-
self,
564-
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
565-
outputs: list[str],
566-
) -> ModelRequestNode[DepsT, NodeRunEndT] | End[result.FinalResult[NodeRunEndT]]:
567-
if len(outputs) != 1:
568-
raise exceptions.UnexpectedModelBehavior('Received multiple structured outputs in a single response')
569-
output_schema = ctx.deps.output_schema
570-
if not output_schema:
571-
raise exceptions.UnexpectedModelBehavior('Must specify a non-str result_type when using structured outputs')
572-
573-
structured_output = outputs[0]
574-
try:
575-
result_data = output_schema.validate(structured_output)
576544
result_data = await _validate_output(result_data, ctx, None)
577545
except _output.ToolRetryError as e:
578546
ctx.state.increment_retries(ctx.deps.max_result_retries)

pydantic_ai_slim/pydantic_ai/_output.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,8 @@ class OutputSchema(Generic[OutputDataT]):
220220
forced_mode: OutputMode | None
221221
object_schema: OutputObjectSchema[OutputDataT]
222222
tools: dict[str, OutputTool[OutputDataT]]
223-
allow_text_output: bool # TODO: Verify structured output works correctly with string as a union member
223+
allow_plain_text_output: bool
224+
allow_json_text_output: bool # TODO: Turn into allowed_text_output: Literal['plain', 'json'] | None
224225

225226
@classmethod
226227
def build(
@@ -235,11 +236,14 @@ def build(
235236
return None
236237

237238
forced_mode = None
239+
allow_json_text_output = True
240+
allow_plain_text_output = False
238241
tool_output_type = None
239-
allow_text_output = False
240242
if isinstance(output_type, ToolOutput):
241-
# do we need to error on conflicts here? (DavidM): If this is internal maybe doesn't matter, if public, use overloads
242243
forced_mode = 'tool'
244+
allow_json_text_output = False
245+
246+
# do we need to error on conflicts here? (DavidM): If this is internal maybe doesn't matter, if public, use overloads
243247
name = output_type.name
244248
description = output_type.description
245249
output_type_ = output_type.output_type
@@ -255,12 +259,15 @@ def build(
255259
name = output_type.name
256260
description = output_type.description
257261
output_type_ = output_type.output_type
258-
else:
262+
elif output_type_other_than_str := extract_str_from_union(output_type):
263+
forced_mode = 'tool'
259264
output_type_ = output_type
260265

261-
if output_type_other_than_str := extract_str_from_union(output_type):
262-
allow_text_output = True
263-
tool_output_type = output_type_other_than_str.value
266+
allow_json_text_output = False
267+
allow_plain_text_output = True
268+
tool_output_type = output_type_other_than_str.value
269+
else:
270+
output_type_ = output_type
264271

265272
output_object_schema = OutputObjectSchema(
266273
output_type=output_type_, name=name, description=description, strict=strict
@@ -292,7 +299,8 @@ def build(
292299
forced_mode=forced_mode,
293300
object_schema=output_object_schema,
294301
tools=tools,
295-
allow_text_output=allow_text_output,
302+
allow_plain_text_output=allow_plain_text_output,
303+
allow_json_text_output=allow_json_text_output,
296304
)
297305

298306
def find_named_tool(
@@ -341,8 +349,7 @@ def validate(
341349

342350

343351
def allow_text_output(output_schema: OutputSchema[Any] | None) -> bool:
344-
"""Check if the result schema allows text results."""
345-
return output_schema is None or output_schema.allow_text_output
352+
return output_schema is None or output_schema.allow_plain_text_output or output_schema.allow_json_text_output
346353

347354

348355
@dataclass

pydantic_ai_slim/pydantic_ai/agent.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -997,8 +997,6 @@ async def stream_to_final(
997997
elif isinstance(new_part, _messages.ToolCallPart) and output_schema:
998998
for call, _ in output_schema.find_tool([new_part]):
999999
return FinalResult(s, call.tool_name, call.tool_call_id)
1000-
elif isinstance(new_part, _messages.OutputPart) and output_schema:
1001-
return FinalResult(s, None, None)
10021000
return None
10031001

10041002
final_result_details = await stream_to_final(streamed_response)

pydantic_ai_slim/pydantic_ai/messages.py

Lines changed: 2 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -494,21 +494,6 @@ def has_content(self) -> bool:
494494
return bool(self.content)
495495

496496

497-
@dataclass
498-
class OutputPart:
499-
"""An output response from a model."""
500-
501-
content: str
502-
"""The output content of the response as a JSON-serialized string."""
503-
504-
part_kind: Literal['output'] = 'output'
505-
"""Part type identifier, this is available on all parts as a discriminator."""
506-
507-
def has_content(self) -> bool:
508-
"""Return `True` if the output content is non-empty."""
509-
return bool(self.content)
510-
511-
512497
@dataclass
513498
class ToolCallPart:
514499
"""A tool call from a model."""
@@ -563,7 +548,7 @@ def has_content(self) -> bool:
563548
return bool(self.args)
564549

565550

566-
ModelResponsePart = Annotated[Union[TextPart, OutputPart, ToolCallPart], pydantic.Discriminator('part_kind')]
551+
ModelResponsePart = Annotated[Union[TextPart, ToolCallPart], pydantic.Discriminator('part_kind')]
567552
"""A message part returned by a model."""
568553

569554

@@ -659,33 +644,6 @@ def apply(self, part: ModelResponsePart) -> TextPart:
659644
return replace(part, content=part.content + self.content_delta)
660645

661646

662-
@dataclass
663-
class OutputPartDelta:
664-
"""A partial update (delta) for a `OutputPart` to append new structured output content."""
665-
666-
content_delta: str
667-
"""The incremental structured output content to add to the existing `OutputPart` content."""
668-
669-
part_delta_kind: Literal['output'] = 'output'
670-
"""Part delta type identifier, used as a discriminator."""
671-
672-
def apply(self, part: ModelResponsePart) -> OutputPart:
673-
"""Apply this structured output delta to an existing `OutputPart`.
674-
675-
Args:
676-
part: The existing model response part, which must be a `OutputPart`.
677-
678-
Returns:
679-
A new `OutputPart` with updated structured output content.
680-
681-
Raises:
682-
ValueError: If `part` is not a `OutputPart`.
683-
"""
684-
if not isinstance(part, OutputPart):
685-
raise ValueError('Cannot apply OutputPartDeltas to non-OutputParts')
686-
return replace(part, content=part.content + self.content_delta)
687-
688-
689647
@dataclass
690648
class ToolCallPartDelta:
691649
"""A partial update (delta) for a `ToolCallPart` to modify tool name, arguments, or tool call ID."""
@@ -801,9 +759,7 @@ def _apply_to_part(self, part: ToolCallPart) -> ToolCallPart:
801759
return part
802760

803761

804-
ModelResponsePartDelta = Annotated[
805-
Union[TextPartDelta, OutputPartDelta, ToolCallPartDelta], pydantic.Discriminator('part_delta_kind')
806-
]
762+
ModelResponsePartDelta = Annotated[Union[TextPartDelta, ToolCallPartDelta], pydantic.Discriminator('part_delta_kind')]
807763
"""A partial update (delta) for any model response part."""
808764

809765

pydantic_ai_slim/pydantic_ai/models/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ class ModelRequestParameters:
266266
output_mode: OutputMode | None
267267
output_object: OutputObjectDefinition | None
268268
output_tools: list[ToolDefinition]
269-
allow_text_output: bool
269+
require_tool_use: bool
270270

271271

272272
class Model(ABC):

pydantic_ai_slim/pydantic_ai/models/anthropic.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
ModelResponse,
2222
ModelResponsePart,
2323
ModelResponseStreamEvent,
24-
OutputPart,
2524
RetryPromptPart,
2625
SystemPromptPart,
2726
TextPart,
@@ -213,7 +212,7 @@ async def _messages_create(
213212
if not tools:
214213
tool_choice = None
215214
else:
216-
if not model_request_parameters.allow_text_output:
215+
if model_request_parameters.require_tool_use:
217216
tool_choice = {'type': 'any'}
218217
else:
219218
tool_choice = {'type': 'auto'}
@@ -321,7 +320,7 @@ async def _map_message(self, messages: list[ModelMessage]) -> tuple[str, list[Me
321320
elif isinstance(m, ModelResponse):
322321
assistant_content_params: list[TextBlockParam | ToolUseBlockParam] = []
323322
for response_part in m.parts:
324-
if isinstance(response_part, (TextPart, OutputPart)):
323+
if isinstance(response_part, TextPart):
325324
assistant_content_params.append(TextBlockParam(text=response_part.content, type='text'))
326325
else:
327326
tool_use_block_param = ToolUseBlockParam(

pydantic_ai_slim/pydantic_ai/models/bedrock.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,7 @@ async def _messages_create(
304304
support_tools_choice = self.model_name.startswith(('anthropic', 'us.anthropic'))
305305
if not tools or not support_tools_choice:
306306
tool_choice: ToolChoiceTypeDef = {}
307-
elif not model_request_parameters.allow_text_output:
307+
elif model_request_parameters.require_tool_use:
308308
tool_choice = {'any': {}}
309309
else:
310310
tool_choice = {'auto': {}}

pydantic_ai_slim/pydantic_ai/models/cohere.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
ModelRequest,
1414
ModelResponse,
1515
ModelResponsePart,
16-
OutputPart,
1716
RetryPromptPart,
1817
SystemPromptPart,
1918
TextPart,
@@ -206,7 +205,7 @@ def _map_messages(self, messages: list[ModelMessage]) -> list[ChatMessageV2]:
206205
texts: list[str] = []
207206
tool_calls: list[ToolCallV2] = []
208207
for item in message.parts:
209-
if isinstance(item, (TextPart, OutputPart)):
208+
if isinstance(item, TextPart):
210209
texts.append(item.content)
211210
elif isinstance(item, ToolCallPart):
212211
tool_calls.append(self._map_tool_call(item))

pydantic_ai_slim/pydantic_ai/models/function.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
ModelRequest,
2222
ModelResponse,
2323
ModelResponseStreamEvent,
24-
OutputPart,
2524
RetryPromptPart,
2625
SystemPromptPart,
2726
TextPart,
@@ -92,7 +91,7 @@ async def request(
9291
) -> ModelResponse:
9392
agent_info = AgentInfo(
9493
model_request_parameters.function_tools,
95-
model_request_parameters.allow_text_output,
94+
not model_request_parameters.require_tool_use,
9695
model_request_parameters.output_tools,
9796
model_settings,
9897
)
@@ -121,7 +120,7 @@ async def request_stream(
121120
) -> AsyncIterator[StreamedResponse]:
122121
agent_info = AgentInfo(
123122
model_request_parameters.function_tools,
124-
model_request_parameters.allow_text_output,
123+
not model_request_parameters.require_tool_use,
125124
model_request_parameters.output_tools,
126125
model_settings,
127126
)
@@ -267,7 +266,7 @@ def _estimate_usage(messages: Iterable[ModelMessage]) -> usage.Usage:
267266
assert_never(part)
268267
elif isinstance(message, ModelResponse):
269268
for part in message.parts:
270-
if isinstance(part, (TextPart, OutputPart)):
269+
if isinstance(part, TextPart):
271270
response_tokens += _estimate_string_tokens(part.content)
272271
elif isinstance(part, ToolCallPart):
273272
call = part

pydantic_ai_slim/pydantic_ai/models/gemini.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
ModelResponse,
2929
ModelResponsePart,
3030
ModelResponseStreamEvent,
31-
OutputPart,
3231
RetryPromptPart,
3332
SystemPromptPart,
3433
TextPart,
@@ -183,7 +182,7 @@ def _customize_output_object_def(o: OutputObjectDefinition):
183182
if model_request_parameters.output_object
184183
else None,
185184
output_tools=[_customize_tool_def(tool) for tool in model_request_parameters.output_tools],
186-
allow_text_output=model_request_parameters.allow_text_output,
185+
require_tool_use=model_request_parameters.require_tool_use,
187186
)
188187

189188
@property
@@ -205,7 +204,7 @@ def _get_tools(self, model_request_parameters: ModelRequestParameters) -> _Gemin
205204
def _get_tool_config(
206205
self, model_request_parameters: ModelRequestParameters, tools: _GeminiTools | None
207206
) -> _GeminiToolConfig | None:
208-
if model_request_parameters.allow_text_output:
207+
if not model_request_parameters.require_tool_use:
209208
return None
210209
elif tools:
211210
return _tool_config([t['name'] for t in tools['function_declarations']])
@@ -559,7 +558,7 @@ def _content_model_response(m: ModelResponse) -> _GeminiContent:
559558
for item in m.parts:
560559
if isinstance(item, ToolCallPart):
561560
parts.append(_function_call_part_from_call(item))
562-
elif isinstance(item, (TextPart, OutputPart)):
561+
elif isinstance(item, TextPart):
563562
if item.content:
564563
parts.append(_GeminiTextPart(text=item.content))
565564
else:

0 commit comments

Comments
 (0)