Skip to content

Add builtin_tools to Agent #1722

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 21 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,6 @@ repos:
rev: v2.3.0
hooks:
- id: codespell
args: ['--skip', 'tests/models/cassettes/*,docs/a2a/fasta2a.md']
args: ['--skip', 'tests/models/cassettes/*,docs/a2a/fasta2a.md,tests/models/test_groq.py']
additional_dependencies:
- tomli
9 changes: 8 additions & 1 deletion pydantic_ai_slim/pydantic_ai/_agent_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from opentelemetry.trace import Tracer
from typing_extensions import TypeGuard, TypeVar, assert_never

from pydantic_ai.builtin_tools import AbstractBuiltinTool
from pydantic_graph import BaseNode, Graph, GraphRunContext
from pydantic_graph.nodes import End, NodeRunEndT

Expand Down Expand Up @@ -94,6 +95,7 @@ class GraphAgentDeps(Generic[DepsT, OutputDataT]):
output_validators: list[_output.OutputValidator[DepsT, OutputDataT]]

function_tools: dict[str, Tool[DepsT]] = dataclasses.field(repr=False)
builtin_tools: list[AbstractBuiltinTool] = dataclasses.field(repr=False)
mcp_servers: Sequence[MCPServer] = dataclasses.field(repr=False)
default_retries: int

Expand Down Expand Up @@ -266,6 +268,7 @@ async def add_mcp_server_tools(server: MCPServer) -> None:
output_schema = ctx.deps.output_schema
return models.ModelRequestParameters(
function_tools=function_tool_defs,
builtin_tools=ctx.deps.builtin_tools,
allow_text_output=_output.allow_text_output(output_schema),
output_tools=output_schema.tool_defs() if output_schema is not None else [],
)
Expand Down Expand Up @@ -418,7 +421,7 @@ async def stream(
async for _event in stream:
pass

async def _run_stream(
async def _run_stream( # noqa C901
self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]]
) -> AsyncIterator[_messages.HandleResponseEvent]:
if self._events_iterator is None:
Expand All @@ -434,6 +437,10 @@ async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]:
texts.append(part.content)
elif isinstance(part, _messages.ToolCallPart):
tool_calls.append(part)
elif isinstance(part, _messages.ServerToolCallPart):
yield _messages.ServerToolCallEvent(part)
elif isinstance(part, _messages.ServerToolReturnPart):
yield _messages.ServerToolResultEvent(part)
else:
assert_never(part)

Expand Down
8 changes: 7 additions & 1 deletion pydantic_ai_slim/pydantic_ai/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,13 @@ def now_utc() -> datetime:
return datetime.now(tz=timezone.utc)


def guard_tool_call_id(t: _messages.ToolCallPart | _messages.ToolReturnPart | _messages.RetryPromptPart) -> str:
def guard_tool_call_id(
t: _messages.ToolCallPart
| _messages.ToolReturnPart
| _messages.RetryPromptPart
| _messages.ServerToolCallPart
| _messages.ServerToolReturnPart,
) -> str:
"""Type guard that either returns the tool call id or generates a new one if it's None."""
return t.tool_call_id or generate_tool_call_id()

Expand Down
15 changes: 15 additions & 0 deletions pydantic_ai_slim/pydantic_ai/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from pydantic.json_schema import GenerateJsonSchema
from typing_extensions import Literal, Never, Self, TypeIs, TypeVar, deprecated

from pydantic_ai.builtin_tools import AbstractBuiltinTool, WebSearchTool
from pydantic_graph import End, Graph, GraphRun, GraphRunContext
from pydantic_graph._utils import get_event_loop

Expand Down Expand Up @@ -174,6 +175,7 @@ def __init__(
retries: int = 1,
output_retries: int | None = None,
tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (),
builtin_tools: Sequence[Literal['web-search'] | AbstractBuiltinTool] = (),
prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = None,
mcp_servers: Sequence[MCPServer] = (),
defer_model_check: bool = False,
Expand Down Expand Up @@ -203,6 +205,7 @@ def __init__(
result_tool_description: str | None = None,
result_retries: int | None = None,
tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (),
builtin_tools: Sequence[Literal['web-search'] | AbstractBuiltinTool] = (),
prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = None,
mcp_servers: Sequence[MCPServer] = (),
defer_model_check: bool = False,
Expand All @@ -227,6 +230,7 @@ def __init__(
retries: int = 1,
output_retries: int | None = None,
tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (),
builtin_tools: Sequence[Literal['web-search'] | AbstractBuiltinTool] = (),
prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = None,
mcp_servers: Sequence[MCPServer] = (),
defer_model_check: bool = False,
Expand Down Expand Up @@ -256,6 +260,8 @@ def __init__(
output_retries: The maximum number of retries to allow for result validation, defaults to `retries`.
tools: Tools to register with the agent, you can also register tools via the decorators
[`@agent.tool`][pydantic_ai.Agent.tool] and [`@agent.tool_plain`][pydantic_ai.Agent.tool_plain].
builtin_tools: The builtin tools that the agent will use. This depends on the model, as some models may not
support certain tools. On models that don't support certain tools, the tool will be ignored.
prepare_tools: custom method to prepare the tool definition of all tools for each step.
This is useful if you want to customize the definition of multiple tools or you want to register
a subset of tools for a given step. See [`ToolsPrepareFunc`][pydantic_ai.tools.ToolsPrepareFunc]
Expand Down Expand Up @@ -342,6 +348,14 @@ def __init__(
self._default_retries = retries
self._max_result_retries = output_retries if output_retries is not None else retries
self._mcp_servers = mcp_servers
self._builtin_tools: list[AbstractBuiltinTool] = []

for tool in builtin_tools:
if tool == 'web-search':
self._builtin_tools.append(WebSearchTool())
else:
self._builtin_tools.append(tool)
Comment on lines +353 to +357
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's easier to not have to handle string on the models, so we already do the transformation here.


self._prepare_tools = prepare_tools
for tool in tools:
if isinstance(tool, Tool):
Expand Down Expand Up @@ -691,6 +705,7 @@ async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None:
output_schema=output_schema,
output_validators=output_validators,
function_tools=self._function_tools,
builtin_tools=self._builtin_tools,
mcp_servers=self._mcp_servers,
default_retries=self._default_retries,
tracer=tracer,
Expand Down
95 changes: 95 additions & 0 deletions pydantic_ai_slim/pydantic_ai/builtin_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from __future__ import annotations as _annotations

from abc import ABC
from dataclasses import dataclass
from typing import Any, Literal

from typing_extensions import TypedDict

__all__ = ('AbstractBuiltinTool', 'WebSearchTool', 'UserLocation')


@dataclass
class AbstractBuiltinTool(ABC):
"""A builtin tool that can be used by an agent.

This class is abstract and cannot be instantiated directly.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think worth including a sentence here explaining how the code execution works to make use of them — something like "these are passed to the model as part of the ModelRequestParameters" or whatever. (Not sure if that's true, haven't gotten there yet ..). But I imagine it helping someone who is trying to figure out how they are different from normal tools.


The builtin tools are passed to the model as part of the `ModelRequestParameters`.
"""

def handle_custom_tool_definition(self, model: str) -> Any: ...


@dataclass
class WebSearchTool(AbstractBuiltinTool):
Comment on lines +24 to +25
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From the DX POV, it's nicer for it to be a BaseModel or dataclass.

"""A builtin tool that allows your agent to search the web for information.

The parameters that PydanticAI passes depend on the model, as some parameters may not be supported by certain models.
"""

search_context_size: Literal['low', 'medium', 'high'] = 'medium'
"""The `search_context_size` parameter controls how much context is retrieved from the web to help the tool formulate a response.

Supported by:
* OpenAI
"""

user_location: UserLocation | None = None
"""The `user_location` parameter allows you to localize search results based on a user's location.

Supported by:
* Anthropic
* OpenAI
"""

blocked_domains: list[str] | None = None
"""If provided, these domains will never appear in results.

With Anthropic, you can only use one of `blocked_domains` or `allowed_domains`, not both.

Supported by:
* Anthropic (https://docs.anthropic.com/en/docs/build-with-claude/tool-use/web-search-tool#domain-filtering)
* Groq (https://console.groq.com/docs/agentic-tooling#search-settings)
* MistralAI
"""

allowed_domains: list[str] | None = None
"""If provided, only these domains will be included in results.

With Anthropic, you can only use one of `blocked_domains` or `allowed_domains`, not both.

Supported by:
* Anthropic (https://docs.anthropic.com/en/docs/build-with-claude/tool-use/web-search-tool#domain-filtering)
* Groq (https://console.groq.com/docs/agentic-tooling#search-settings)
"""

max_uses: int | None = None
"""If provided, the tool will stop searching the web after the given number of uses.

Supported by:
* Anthropic
"""


class UserLocation(TypedDict, total=False):
"""Allows you to localize search results based on a user's location.

Supported by:
* Anthropic
* OpenAI
"""

city: str
country: str
region: str
timezone: str


class CodeExecutionTool(AbstractBuiltinTool):
"""A builtin tool that allows your agent to execute code.

Supported by:
* Anthropic
* OpenAI
"""
80 changes: 69 additions & 11 deletions pydantic_ai_slim/pydantic_ai/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,8 +349,8 @@ def otel_event(self, settings: InstrumentationSettings) -> Event:


@dataclass(repr=False)
class ToolReturnPart:
"""A tool return message, this encodes the result of running a tool."""
class BaseToolReturnPart:
"""Base class for tool return parts."""

tool_name: str
"""The name of the "tool" was called."""
Expand All @@ -364,9 +364,6 @@ class ToolReturnPart:
timestamp: datetime = field(default_factory=_now_utc)
"""The timestamp, when the tool returned."""

part_kind: Literal['tool-return'] = 'tool-return'
"""Part type identifier, this is available on all parts as a discriminator."""

def model_response_str(self) -> str:
"""Return a string representation of the content for the model."""
if isinstance(self.content, str):
Expand All @@ -388,9 +385,29 @@ def otel_event(self, _settings: InstrumentationSettings) -> Event:
body={'content': self.content, 'role': 'tool', 'id': self.tool_call_id, 'name': self.tool_name},
)

def has_content(self) -> bool:
"""Return `True` if the tool return has content."""
return self.content is not None

__repr__ = _utils.dataclasses_no_defaults_repr


@dataclass(repr=False)
class ToolReturnPart(BaseToolReturnPart):
"""A tool return message, this encodes the result of running a tool."""

part_kind: Literal['tool-return'] = 'tool-return'
"""Part type identifier, this is available on all parts as a discriminator."""


@dataclass(repr=False)
class ServerToolReturnPart(BaseToolReturnPart):
"""A tool return message from a server tool."""

part_kind: Literal['server-tool-return'] = 'server-tool-return'
"""Part type identifier, this is available on all parts as a discriminator."""


error_details_ta = pydantic.TypeAdapter(list[pydantic_core.ErrorDetails], config=pydantic.ConfigDict(defer_build=True))


Expand Down Expand Up @@ -503,7 +520,7 @@ def has_content(self) -> bool:


@dataclass(repr=False)
class ToolCallPart:
class BaseToolCallPart:
"""A tool call from a model."""

tool_name: str
Expand All @@ -521,9 +538,6 @@ class ToolCallPart:
In case the tool call id is not provided by the model, PydanticAI will generate a random one.
"""

part_kind: Literal['tool-call'] = 'tool-call'
"""Part type identifier, this is available on all parts as a discriminator."""

def args_as_dict(self) -> dict[str, Any]:
"""Return the arguments as a Python dictionary.

Expand Down Expand Up @@ -560,7 +574,28 @@ def has_content(self) -> bool:
__repr__ = _utils.dataclasses_no_defaults_repr


ModelResponsePart = Annotated[Union[TextPart, ToolCallPart], pydantic.Discriminator('part_kind')]
@dataclass(repr=False)
class ToolCallPart(BaseToolCallPart):
"""A tool call from a model."""

part_kind: Literal['tool-call'] = 'tool-call'
"""Part type identifier, this is available on all parts as a discriminator."""


@dataclass(repr=False)
class ServerToolCallPart(BaseToolCallPart):
"""A tool call from a server tool."""

model_name: str | None = None
"""The name of the model that generated the response."""

part_kind: Literal['server-tool-call'] = 'server-tool-call'
"""Part type identifier, this is available on all parts as a discriminator."""


ModelResponsePart = Annotated[
Union[TextPart, ToolCallPart, ServerToolCallPart, ServerToolReturnPart], pydantic.Discriminator('part_kind')
]
"""A message part returned by a model."""


Expand Down Expand Up @@ -883,6 +918,29 @@ class FunctionToolResultEvent:
__repr__ = _utils.dataclasses_no_defaults_repr


@dataclass(repr=False)
class ServerToolCallEvent:
"""An event indicating the start to a call to a server tool."""

part: ServerToolCallPart
"""The server tool call to make."""

event_kind: Literal['server_tool_call'] = 'server_tool_call'
"""Event type identifier, used as a discriminator."""


@dataclass(repr=False)
class ServerToolResultEvent:
"""An event indicating the result of a server tool call."""

result: ServerToolReturnPart
"""The result of the call to the server tool."""

event_kind: Literal['server_tool_result'] = 'server_tool_result'
"""Event type identifier, used as a discriminator."""


HandleResponseEvent = Annotated[
Union[FunctionToolCallEvent, FunctionToolResultEvent], pydantic.Discriminator('event_kind')
Union[FunctionToolCallEvent, FunctionToolResultEvent, ServerToolCallEvent, ServerToolResultEvent],
pydantic.Discriminator('event_kind'),
]
2 changes: 2 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import httpx
from typing_extensions import Literal, TypeAliasType

from pydantic_ai.builtin_tools import AbstractBuiltinTool
from pydantic_ai.profiles import DEFAULT_PROFILE, ModelProfile, ModelProfileSpec

from .._parts_manager import ModelResponsePartsManager
Expand Down Expand Up @@ -292,6 +293,7 @@ class ModelRequestParameters:
"""Configuration for an agent's request to a model, specifically related to tools and output handling."""

function_tools: list[ToolDefinition] = field(default_factory=list)
builtin_tools: list[AbstractBuiltinTool] = field(default_factory=list)
allow_text_output: bool = True
output_tools: list[ToolDefinition] = field(default_factory=list)

Expand Down
Loading
Loading