Skip to content

Add builtin_tools to Agent #1722

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 21 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pydantic_ai_slim/pydantic_ai/_agent_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from opentelemetry.trace import Tracer
from typing_extensions import TypeGuard, TypeVar, assert_never

from pydantic_ai.builtin_tools import AbstractBuiltinTool
from pydantic_graph import BaseNode, Graph, GraphRunContext
from pydantic_graph.nodes import End, NodeRunEndT

Expand Down Expand Up @@ -92,6 +93,7 @@ class GraphAgentDeps(Generic[DepsT, OutputDataT]):
output_validators: list[_output.OutputValidator[DepsT, OutputDataT]]

function_tools: dict[str, Tool[DepsT]] = dataclasses.field(repr=False)
builtin_tools: list[AbstractBuiltinTool] = dataclasses.field(repr=False)
mcp_servers: Sequence[MCPServer] = dataclasses.field(repr=False)
default_retries: int

Expand Down Expand Up @@ -242,6 +244,7 @@ async def add_mcp_server_tools(server: MCPServer) -> None:
output_schema = ctx.deps.output_schema
return models.ModelRequestParameters(
function_tools=function_tool_defs,
builtin_tools=ctx.deps.builtin_tools,
allow_text_output=allow_text_output(output_schema),
output_tools=output_schema.tool_defs() if output_schema is not None else [],
)
Expand Down
15 changes: 15 additions & 0 deletions pydantic_ai_slim/pydantic_ai/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from pydantic.json_schema import GenerateJsonSchema
from typing_extensions import Literal, Never, Self, TypeIs, TypeVar, deprecated

from pydantic_ai.builtin_tools import AbstractBuiltinTool, WebSearchTool
from pydantic_graph import End, Graph, GraphRun, GraphRunContext
from pydantic_graph._utils import get_event_loop

Expand Down Expand Up @@ -172,6 +173,7 @@ def __init__(
retries: int = 1,
output_retries: int | None = None,
tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (),
builtin_tools: Sequence[Literal['web-search'] | AbstractBuiltinTool] = (),
mcp_servers: Sequence[MCPServer] = (),
defer_model_check: bool = False,
end_strategy: EndStrategy = 'early',
Expand Down Expand Up @@ -200,6 +202,7 @@ def __init__(
result_tool_description: str | None = None,
result_retries: int | None = None,
tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (),
builtin_tools: Sequence[Literal['web-search'] | AbstractBuiltinTool] = (),
mcp_servers: Sequence[MCPServer] = (),
defer_model_check: bool = False,
end_strategy: EndStrategy = 'early',
Expand All @@ -223,6 +226,7 @@ def __init__(
retries: int = 1,
output_retries: int | None = None,
tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (),
builtin_tools: Sequence[Literal['web-search'] | AbstractBuiltinTool] = (),
mcp_servers: Sequence[MCPServer] = (),
defer_model_check: bool = False,
end_strategy: EndStrategy = 'early',
Expand Down Expand Up @@ -251,6 +255,8 @@ def __init__(
output_retries: The maximum number of retries to allow for result validation, defaults to `retries`.
tools: Tools to register with the agent, you can also register tools via the decorators
[`@agent.tool`][pydantic_ai.Agent.tool] and [`@agent.tool_plain`][pydantic_ai.Agent.tool_plain].
builtin_tools: The builtin tools that the agent will use. This depends on the model, as some models may not
support certain tools. On models that don't support certain tools, the tool will be ignored.
mcp_servers: MCP servers to register with the agent. You should register a [`MCPServer`][pydantic_ai.mcp.MCPServer]
for each server you want the agent to connect to.
defer_model_check: by default, if you provide a [named][pydantic_ai.models.KnownModelName] model,
Expand Down Expand Up @@ -334,6 +340,14 @@ def __init__(
self._default_retries = retries
self._max_result_retries = output_retries if output_retries is not None else retries
self._mcp_servers = mcp_servers
self._builtin_tools: list[AbstractBuiltinTool] = []

for tool in builtin_tools:
if tool == 'web-search':
self._builtin_tools.append(WebSearchTool())
else:
self._builtin_tools.append(tool)
Comment on lines +353 to +357
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's easier to not have to handle string on the models, so we already do the transformation here.


for tool in tools:
if isinstance(tool, Tool):
self._register_tool(tool)
Expand Down Expand Up @@ -688,6 +702,7 @@ async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None:
output_schema=output_schema,
output_validators=output_validators,
function_tools=self._function_tools,
builtin_tools=self._builtin_tools,
mcp_servers=self._mcp_servers,
default_retries=self._default_retries,
tracer=tracer,
Expand Down
75 changes: 75 additions & 0 deletions pydantic_ai_slim/pydantic_ai/builtin_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from __future__ import annotations as _annotations

from abc import ABC
from dataclasses import dataclass, field
from typing import Literal

from typing_extensions import TypedDict

__all__ = ('AbstractBuiltinTool', 'WebSearchTool', 'UserLocation')


@dataclass
class AbstractBuiltinTool(ABC):
"""A builtin tool that can be used by an agent.

This class is abstract and cannot be instantiated directly.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think worth including a sentence here explaining how the code execution works to make use of them — something like "these are passed to the model as part of the ModelRequestParameters" or whatever. (Not sure if that's true, haven't gotten there yet ..). But I imagine it helping someone who is trying to figure out how they are different from normal tools.

"""


class UserLocation(TypedDict, total=False):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's easier to handle this in the models if it's a TypedDict, since it matches the type.

"""Allows you to localize search results based on a user's location.

Supported by:
* Anthropic
* OpenAI
"""

city: str
country: str
region: str
timezone: str


@dataclass
class WebSearchTool(AbstractBuiltinTool):
Comment on lines +24 to +25
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From the DX POV, it's nicer for it to be a BaseModel or dataclass.

"""A builtin tool that allows your agent to search the web for information.

The parameters that PydanticAI passes depend on the model, as some parameters may not be supported by certain models.
"""

search_context_size: Literal['low', 'medium', 'high'] = 'medium'
"""The `search_context_size` parameter controls how much context is retrieved from the web to help the tool formulate a response.

Supported by:
* OpenAI
"""

user_location: UserLocation = field(default_factory=UserLocation)
"""The `user_location` parameter allows you to localize search results based on a user's location.

Supported by:
* Anthropic
* OpenAI
"""

blocked_domains: list[str] | None = None
"""If provided, these domains will never appear in results.

With Anthropic, you can only use one of `blocked_domains` or `allowed_domains`, not both.

Supported by:
* Anthropic (https://docs.anthropic.com/en/docs/build-with-claude/tool-use/web-search-tool#domain-filtering)
* Groq (https://console.groq.com/docs/agentic-tooling#search-settings)
* MistralAI
"""

allowed_domains: list[str] | None = None
"""If provided, only these domains will be included in results.

With Anthropic, you can only use one of `blocked_domains` or `allowed_domains`, not both.

Supported by:
* Anthropic (https://docs.anthropic.com/en/docs/build-with-claude/tool-use/web-search-tool#domain-filtering)
* Groq (https://console.groq.com/docs/agentic-tooling#search-settings)
"""
3 changes: 3 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import httpx
from typing_extensions import Literal, TypeAliasType

from pydantic_ai.builtin_tools import AbstractBuiltinTool

from .._parts_manager import ModelResponsePartsManager
from ..exceptions import UserError
from ..messages import ModelMessage, ModelRequest, ModelResponse, ModelResponseStreamEvent
Expand Down Expand Up @@ -261,6 +263,7 @@ class ModelRequestParameters:
"""Configuration for an agent's request to a model, specifically related to tools and output handling."""

function_tools: list[ToolDefinition] = field(default_factory=list)
builtin_tools: list[AbstractBuiltinTool] = field(default_factory=list)
allow_text_output: bool = True
output_tools: list[ToolDefinition] = field(default_factory=list)

Expand Down
33 changes: 33 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from typing_extensions import assert_never

from pydantic_ai.builtin_tools import WebSearchTool
from pydantic_ai.providers import Provider, infer_provider

from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
Expand Down Expand Up @@ -58,6 +59,11 @@
from openai.types.chat.chat_completion_content_part_image_param import ImageURL
from openai.types.chat.chat_completion_content_part_input_audio_param import InputAudio
from openai.types.chat.chat_completion_content_part_param import File, FileFile
from openai.types.chat.completion_create_params import (
WebSearchOptions,
WebSearchOptionsUserLocation,
WebSearchOptionsUserLocationApproximate,
)
from openai.types.responses import ComputerToolParam, FileSearchToolParam, WebSearchToolParam
from openai.types.responses.response_input_param import FunctionCallOutput, Message
from openai.types.shared import ReasoningEffort
Expand Down Expand Up @@ -254,6 +260,7 @@ async def _completions_create(
model_request_parameters: ModelRequestParameters,
) -> chat.ChatCompletion | AsyncStream[ChatCompletionChunk]:
tools = self._get_tools(model_request_parameters)
web_search_options = self._get_web_search_options(model_request_parameters)

# standalone function to make it easier to override
if not tools:
Expand Down Expand Up @@ -288,6 +295,7 @@ async def _completions_create(
logit_bias=model_settings.get('logit_bias', NOT_GIVEN),
reasoning_effort=model_settings.get('openai_reasoning_effort', NOT_GIVEN),
user=model_settings.get('openai_user', NOT_GIVEN),
web_search_options=web_search_options or NOT_GIVEN,
extra_headers=extra_headers,
extra_body=model_settings.get('extra_body'),
)
Expand Down Expand Up @@ -327,6 +335,17 @@ def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[c
tools += [self._map_tool_definition(r) for r in model_request_parameters.output_tools]
return tools

def _get_web_search_options(self, model_request_parameters: ModelRequestParameters) -> WebSearchOptions | None:
for tool in model_request_parameters.builtin_tools:
if isinstance(tool, WebSearchTool):
return WebSearchOptions(
search_context_size=tool.search_context_size,
user_location=WebSearchOptionsUserLocation(
type='approximate',
approximate=WebSearchOptionsUserLocationApproximate(**tool.user_location),
),
)

async def _map_messages(self, messages: list[ModelMessage]) -> list[chat.ChatCompletionMessageParam]:
"""Just maps a `pydantic_ai.Message` to a `openai.types.ChatCompletionMessageParam`."""
openai_messages: list[chat.ChatCompletionMessageParam] = []
Expand Down Expand Up @@ -601,6 +620,7 @@ async def _responses_create(
) -> responses.Response | AsyncStream[responses.ResponseStreamEvent]:
tools = self._get_tools(model_request_parameters)
tools = list(model_settings.get('openai_builtin_tools', [])) + tools
tools = self._get_builtin_tools(model_request_parameters) + tools
Comment on lines 670 to +671
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should deprecate the openai_builtin_tools in this PR.


# standalone function to make it easier to override
if not tools:
Expand Down Expand Up @@ -653,6 +673,19 @@ def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[r
tools += [self._map_tool_definition(r) for r in model_request_parameters.output_tools]
return tools

def _get_builtin_tools(self, model_request_parameters: ModelRequestParameters) -> list[responses.ToolParam]:
tools: list[responses.ToolParam] = []
for tool in model_request_parameters.builtin_tools:
if isinstance(tool, WebSearchTool):
tools.append(
responses.WebSearchToolParam(
type='web_search_preview',
search_context_size=tool.search_context_size,
user_location={'type': 'approximate', **tool.user_location},
)
)
return tools

@staticmethod
def _map_tool_definition(f: ToolDefinition) -> responses.FunctionToolParam:
return {
Expand Down
1 change: 1 addition & 0 deletions tests/models/test_model_request_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ def test_model_request_parameters_are_serializable():
params = ModelRequestParameters(function_tools=[], allow_text_output=False, output_tools=[])
assert TypeAdapter(ModelRequestParameters).dump_python(params) == {
'function_tools': [],
'builtin_tools': [],
'allow_text_output': False,
'output_tools': [],
}
Loading