-
Notifications
You must be signed in to change notification settings - Fork 902
Add builtin_tools
to Agent
#1722
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 2 commits
57e568b
97ab44b
e3dda9d
3ad6d38
0b43f65
fa7fd11
32324fa
f33e568
13d7433
ac85205
c93633f
3a8b640
360de87
cb4e539
4e3769a
ebb536f
c8bb611
6bcc1a8
97ff651
1d47e1e
5f89444
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
from __future__ import annotations as _annotations | ||
|
||
from abc import ABC | ||
from dataclasses import dataclass, field | ||
from typing import Literal | ||
|
||
from typing_extensions import TypedDict | ||
|
||
__all__ = ('AbstractBuiltinTool', 'WebSearchTool', 'UserLocation') | ||
|
||
|
||
@dataclass | ||
class AbstractBuiltinTool(ABC): | ||
"""A builtin tool that can be used by an agent. | ||
|
||
This class is abstract and cannot be instantiated directly. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think worth including a sentence here explaining how the code execution works to make use of them — something like "these are passed to the model as part of the ModelRequestParameters" or whatever. (Not sure if that's true, haven't gotten there yet ..). But I imagine it helping someone who is trying to figure out how they are different from normal tools. |
||
""" | ||
|
||
|
||
class UserLocation(TypedDict, total=False): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's easier to handle this in the models if it's a |
||
"""Allows you to localize search results based on a user's location. | ||
|
||
Supported by: | ||
* Anthropic | ||
* OpenAI | ||
""" | ||
|
||
city: str | ||
country: str | ||
region: str | ||
timezone: str | ||
|
||
|
||
@dataclass | ||
class WebSearchTool(AbstractBuiltinTool): | ||
Comment on lines
+24
to
+25
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. From the DX POV, it's nicer for it to be a |
||
"""A builtin tool that allows your agent to search the web for information. | ||
|
||
The parameters that PydanticAI passes depend on the model, as some parameters may not be supported by certain models. | ||
""" | ||
|
||
search_context_size: Literal['low', 'medium', 'high'] = 'medium' | ||
"""The `search_context_size` parameter controls how much context is retrieved from the web to help the tool formulate a response. | ||
|
||
Supported by: | ||
* OpenAI | ||
""" | ||
|
||
user_location: UserLocation = field(default_factory=UserLocation) | ||
"""The `user_location` parameter allows you to localize search results based on a user's location. | ||
|
||
Supported by: | ||
* Anthropic | ||
* OpenAI | ||
""" | ||
|
||
blocked_domains: list[str] | None = None | ||
"""If provided, these domains will never appear in results. | ||
|
||
With Anthropic, you can only use one of `blocked_domains` or `allowed_domains`, not both. | ||
|
||
Supported by: | ||
* Anthropic (https://docs.anthropic.com/en/docs/build-with-claude/tool-use/web-search-tool#domain-filtering) | ||
* Groq (https://console.groq.com/docs/agentic-tooling#search-settings) | ||
* MistralAI | ||
""" | ||
|
||
allowed_domains: list[str] | None = None | ||
"""If provided, only these domains will be included in results. | ||
|
||
With Anthropic, you can only use one of `blocked_domains` or `allowed_domains`, not both. | ||
|
||
Supported by: | ||
* Anthropic (https://docs.anthropic.com/en/docs/build-with-claude/tool-use/web-search-tool#domain-filtering) | ||
* Groq (https://console.groq.com/docs/agentic-tooling#search-settings) | ||
""" |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,6 +11,7 @@ | |
|
||
from typing_extensions import assert_never | ||
|
||
from pydantic_ai.builtin_tools import WebSearchTool | ||
from pydantic_ai.providers import Provider, infer_provider | ||
|
||
from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage | ||
|
@@ -58,6 +59,11 @@ | |
from openai.types.chat.chat_completion_content_part_image_param import ImageURL | ||
from openai.types.chat.chat_completion_content_part_input_audio_param import InputAudio | ||
from openai.types.chat.chat_completion_content_part_param import File, FileFile | ||
from openai.types.chat.completion_create_params import ( | ||
WebSearchOptions, | ||
WebSearchOptionsUserLocation, | ||
WebSearchOptionsUserLocationApproximate, | ||
) | ||
from openai.types.responses import ComputerToolParam, FileSearchToolParam, WebSearchToolParam | ||
from openai.types.responses.response_input_param import FunctionCallOutput, Message | ||
from openai.types.shared import ReasoningEffort | ||
|
@@ -254,6 +260,7 @@ async def _completions_create( | |
model_request_parameters: ModelRequestParameters, | ||
) -> chat.ChatCompletion | AsyncStream[ChatCompletionChunk]: | ||
tools = self._get_tools(model_request_parameters) | ||
web_search_options = self._get_web_search_options(model_request_parameters) | ||
|
||
# standalone function to make it easier to override | ||
if not tools: | ||
|
@@ -288,6 +295,7 @@ async def _completions_create( | |
logit_bias=model_settings.get('logit_bias', NOT_GIVEN), | ||
reasoning_effort=model_settings.get('openai_reasoning_effort', NOT_GIVEN), | ||
user=model_settings.get('openai_user', NOT_GIVEN), | ||
web_search_options=web_search_options or NOT_GIVEN, | ||
extra_headers=extra_headers, | ||
extra_body=model_settings.get('extra_body'), | ||
) | ||
|
@@ -327,6 +335,17 @@ def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[c | |
tools += [self._map_tool_definition(r) for r in model_request_parameters.output_tools] | ||
return tools | ||
|
||
def _get_web_search_options(self, model_request_parameters: ModelRequestParameters) -> WebSearchOptions | None: | ||
for tool in model_request_parameters.builtin_tools: | ||
if isinstance(tool, WebSearchTool): | ||
return WebSearchOptions( | ||
search_context_size=tool.search_context_size, | ||
user_location=WebSearchOptionsUserLocation( | ||
type='approximate', | ||
approximate=WebSearchOptionsUserLocationApproximate(**tool.user_location), | ||
), | ||
) | ||
|
||
async def _map_messages(self, messages: list[ModelMessage]) -> list[chat.ChatCompletionMessageParam]: | ||
"""Just maps a `pydantic_ai.Message` to a `openai.types.ChatCompletionMessageParam`.""" | ||
openai_messages: list[chat.ChatCompletionMessageParam] = [] | ||
|
@@ -601,6 +620,7 @@ async def _responses_create( | |
) -> responses.Response | AsyncStream[responses.ResponseStreamEvent]: | ||
tools = self._get_tools(model_request_parameters) | ||
tools = list(model_settings.get('openai_builtin_tools', [])) + tools | ||
tools = self._get_builtin_tools(model_request_parameters) + tools | ||
Comment on lines
670
to
+671
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should deprecate the |
||
|
||
# standalone function to make it easier to override | ||
if not tools: | ||
|
@@ -653,6 +673,19 @@ def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[r | |
tools += [self._map_tool_definition(r) for r in model_request_parameters.output_tools] | ||
return tools | ||
|
||
def _get_builtin_tools(self, model_request_parameters: ModelRequestParameters) -> list[responses.ToolParam]: | ||
tools: list[responses.ToolParam] = [] | ||
for tool in model_request_parameters.builtin_tools: | ||
if isinstance(tool, WebSearchTool): | ||
tools.append( | ||
responses.WebSearchToolParam( | ||
type='web_search_preview', | ||
search_context_size=tool.search_context_size, | ||
user_location={'type': 'approximate', **tool.user_location}, | ||
) | ||
) | ||
return tools | ||
|
||
@staticmethod | ||
def _map_tool_definition(f: ToolDefinition) -> responses.FunctionToolParam: | ||
return { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's easier to not have to handle string on the models, so we already do the transformation here.