Skip to content

Commit 688a026

Browse files
authored
Ensure tool call parts with custom argument model validation errors are serializable (#1862)
1 parent 56b98d6 commit 688a026

File tree

3 files changed

+46
-2
lines changed

3 files changed

+46
-2
lines changed

pydantic_ai_slim/pydantic_ai/_output.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -407,7 +407,7 @@ async def process(
407407
if wrap_validation_errors:
408408
m = _messages.RetryPromptPart(
409409
tool_name=tool_call.tool_name,
410-
content=e.errors(include_url=False),
410+
content=e.errors(include_url=False, include_context=False),
411411
tool_call_id=tool_call.tool_call_id,
412412
)
413413
raise ToolRetryError(m) from e

pydantic_ai_slim/pydantic_ai/tools.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -392,7 +392,7 @@ def _on_error(
392392
raise UnexpectedModelBehavior(f'Tool exceeded max retries count of {self.max_retries}') from exc
393393
else:
394394
if isinstance(exc, ValidationError):
395-
content = exc.errors(include_url=False)
395+
content = exc.errors(include_url=False, include_context=False)
396396
else:
397397
content = exc.message
398398
return _messages.RetryPromptPart(

tests/test_agent.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2573,3 +2573,47 @@ def test_agent_repr() -> None:
25732573
assert repr(agent) == snapshot(
25742574
"Agent(model=None, name=None, end_strategy='early', model_settings=None, output_type=<class 'str'>, instrument=None)"
25752575
)
2576+
2577+
2578+
def test_tool_call_with_validation_value_error_serializable():
2579+
def llm(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
2580+
if len(messages) == 1:
2581+
return ModelResponse(parts=[ToolCallPart('foo_tool', {'bar': 0})])
2582+
elif len(messages) == 3:
2583+
return ModelResponse(parts=[ToolCallPart('foo_tool', {'bar': 1})])
2584+
else:
2585+
return ModelResponse(parts=[TextPart('Tool returned 1')])
2586+
2587+
agent = Agent(FunctionModel(llm))
2588+
2589+
class Foo(BaseModel):
2590+
bar: int
2591+
2592+
@field_validator('bar')
2593+
def validate_bar(cls, v: int) -> int:
2594+
if v == 0:
2595+
raise ValueError('bar cannot be 0')
2596+
return v
2597+
2598+
@agent.tool_plain
2599+
def foo_tool(foo: Foo) -> int:
2600+
return foo.bar
2601+
2602+
result = agent.run_sync('Hello')
2603+
assert json.loads(result.all_messages_json())[2] == snapshot(
2604+
{
2605+
'parts': [
2606+
{
2607+
'content': [
2608+
{'type': 'value_error', 'loc': ['bar'], 'msg': 'Value error, bar cannot be 0', 'input': 0}
2609+
],
2610+
'tool_name': 'foo_tool',
2611+
'tool_call_id': IsStr(),
2612+
'timestamp': IsStr(),
2613+
'part_kind': 'retry-prompt',
2614+
}
2615+
],
2616+
'instructions': None,
2617+
'kind': 'request',
2618+
}
2619+
)

0 commit comments

Comments
 (0)