diff --git a/docs/mcp/client.md b/docs/mcp/client.md index 4b643d493..b1a5994d8 100644 --- a/docs/mcp/client.md +++ b/docs/mcp/client.md @@ -18,24 +18,23 @@ pip/uv-add "pydantic-ai-slim[mcp]" PydanticAI comes with two ways to connect to MCP servers: -- [`MCPServerHTTP`][pydantic_ai.mcp.MCPServerHTTP] which connects to an MCP server using the [HTTP SSE](https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#http-with-sse) transport +- [`MCPServerHTTP`][pydantic_ai.mcp.MCPServerHTTP] which connects to an MCP server using the [Streamable HTTP transport](https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#streamable-http) transport - [`MCPServerStdio`][pydantic_ai.mcp.MCPServerStdio] which runs the server as a subprocess and connects to it using the [stdio](https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#stdio) transport Examples of both are shown below; [mcp-run-python](run-python.md) is used as the MCP server in both examples. -### SSE Client +### HTTP Client -[`MCPServerHTTP`][pydantic_ai.mcp.MCPServerHTTP] connects over HTTP using the [HTTP + Server Sent Events transport](https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#http-with-sse) to a server. +[`MCPServerHTTP`][pydantic_ai.mcp.MCPServerHTTP] connects over HTTP using the [Streamable HTTP transport](https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#streamable-http) to a server. !!! note [`MCPServerHTTP`][pydantic_ai.mcp.MCPServerHTTP] requires an MCP server to be running and accepting HTTP connections before calling [`agent.run_mcp_servers()`][pydantic_ai.Agent.run_mcp_servers]. Running the server is not managed by PydanticAI. -The name "HTTP" is used since this implemented will be adapted in future to use the new -[Streamable HTTP](https://github.com/modelcontextprotocol/specification/pull/206) currently in development. +The StreamableHTTP Transport is able to connect to both stateless HTTP and older Server Sent Events (SSE) servers. -Before creating the SSE client, we need to run the server (docs [here](run-python.md)): +Before creating the HTTP client, we need to run the server (docs [here](run-python.md)): -```bash {title="terminal (run sse server)"} +```bash {title="terminal (run http server)"} deno run \ -N -R=node_modules -W=node_modules --node-modules-dir=auto \ jsr:@pydantic/mcp-run-python sse @@ -56,7 +55,7 @@ async def main(): #> There are 9,208 days between January 1, 2000, and March 18, 2025. ``` -1. Define the MCP server with the URL used to connect. +1. Define the MCP server with the URL used to connect. This will typically end in `/mcp` for HTTP servers and `/sse` for SSE. 2. Create an agent with the MCP server attached. 3. Create a client session to connect to the server. diff --git a/pydantic_ai_slim/pydantic_ai/mcp.py b/pydantic_ai_slim/pydantic_ai/mcp.py index 3be1146ad..d4402f125 100644 --- a/pydantic_ai_slim/pydantic_ai/mcp.py +++ b/pydantic_ai_slim/pydantic_ai/mcp.py @@ -2,20 +2,22 @@ import base64 import json +import warnings from abc import ABC, abstractmethod from collections.abc import AsyncIterator, Sequence from contextlib import AsyncExitStack, asynccontextmanager from dataclasses import dataclass +from datetime import timedelta from pathlib import Path from types import TracebackType from typing import Any from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from mcp.shared.message import SessionMessage from mcp.types import ( BlobResourceContents, EmbeddedResource, ImageContent, - JSONRPCMessage, LoggingLevel, TextContent, TextResourceContents, @@ -28,8 +30,8 @@ try: from mcp.client.session import ClientSession - from mcp.client.sse import sse_client from mcp.client.stdio import StdioServerParameters, stdio_client + from mcp.client.streamable_http import streamablehttp_client except ImportError as _import_error: raise ImportError( 'Please install the `mcp` package to use the MCP server, ' @@ -55,8 +57,8 @@ class MCPServer(ABC): """ _client: ClientSession - _read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception] - _write_stream: MemoryObjectSendStream[JSONRPCMessage] + _read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] + _write_stream: MemoryObjectSendStream[SessionMessage] _exit_stack: AsyncExitStack @abstractmethod @@ -64,10 +66,7 @@ class MCPServer(ABC): async def client_streams( self, ) -> AsyncIterator[ - tuple[ - MemoryObjectReceiveStream[JSONRPCMessage | Exception], - MemoryObjectSendStream[JSONRPCMessage], - ] + tuple[MemoryObjectReceiveStream[SessionMessage | Exception], MemoryObjectSendStream[SessionMessage]] ]: """Create the streams for the MCP server.""" raise NotImplementedError('MCP Server subclasses must implement this method.') @@ -256,10 +255,7 @@ async def main(): async def client_streams( self, ) -> AsyncIterator[ - tuple[ - MemoryObjectReceiveStream[JSONRPCMessage | Exception], - MemoryObjectSendStream[JSONRPCMessage], - ] + tuple[MemoryObjectReceiveStream[SessionMessage | Exception], MemoryObjectSendStream[SessionMessage]] ]: server = StdioServerParameters(command=self.command, args=list(self.args), env=self.env, cwd=self.cwd) async with stdio_client(server=server) as (read_stream, write_stream): @@ -276,11 +272,11 @@ def __repr__(self) -> str: class MCPServerHTTP(MCPServer): """An MCP server that connects over streamable HTTP connections. - This class implements the SSE transport from the MCP specification. - See for more information. + This class implements the Streamable HTTP transport from the MCP specification. + See for more information. - The name "HTTP" is used since this implemented will be adapted in future to use the new - [Streamable HTTP](https://github.com/modelcontextprotocol/specification/pull/206) currently in development. + The Streamable HTTP transport is intended to replace the SSE transport from the previous protocol, but it is fully + backwards compatible with SSE-based servers. !!! note Using this class as an async context manager will create a new pool of HTTP connections to connect @@ -291,7 +287,7 @@ class MCPServerHTTP(MCPServer): from pydantic_ai import Agent from pydantic_ai.mcp import MCPServerHTTP - server = MCPServerHTTP('http://localhost:3001/sse') # (1)! + server = MCPServerHTTP('http://localhost:3001/mcp') # (1)! agent = Agent('openai:gpt-4o', mcp_servers=[server]) async def main(): @@ -304,27 +300,27 @@ async def main(): """ url: str - """The URL of the SSE endpoint on the MCP server. + """The URL of the SSE or MCP endpoint on the MCP server. - For example for a server running locally, this might be `http://localhost:3001/sse`. + For example for a server running locally, this might be `http://localhost:3001/mcp`. """ headers: dict[str, Any] | None = None - """Optional HTTP headers to be sent with each request to the SSE endpoint. + """Optional HTTP headers to be sent with each request to the endpoint. These headers will be passed directly to the underlying `httpx.AsyncClient`. Useful for authentication, custom headers, or other HTTP-specific configurations. """ - timeout: float = 5 - """Initial connection timeout in seconds for establishing the SSE connection. + timeout: timedelta | float = timedelta(seconds=5) + """Initial connection timeout as a timedelta for establishing the connection. This timeout applies to the initial connection setup and handshake. If the connection cannot be established within this time, the operation will fail. """ - sse_read_timeout: float = 60 * 5 - """Maximum time in seconds to wait for new SSE messages before timing out. + sse_read_timeout: timedelta | float = timedelta(minutes=5) + """Maximum time as a timedelta to wait for new SSE messages before timing out. This timeout applies to the long-lived SSE connection after it's established. If no new messages are received within this time, the connection will be considered stale @@ -346,21 +342,48 @@ async def main(): For example, if `tool_prefix='foo'`, then a tool named `bar` will be registered as `foo_bar` """ + def __post_init__(self): + if not isinstance(self.timeout, timedelta): + warnings.warn( + 'Passing timeout as a float has been deprecated, please use a timedelta instead.', + DeprecationWarning, + stacklevel=2, + ) + self.timeout = timedelta(seconds=self.timeout) + + if not isinstance(self.sse_read_timeout, timedelta): + warnings.warn( + 'Passing sse_read_timeout as a float has been deprecated, please use a timedelta instead.', + DeprecationWarning, + stacklevel=2, + ) + self.sse_read_timeout = timedelta(seconds=self.sse_read_timeout) + @asynccontextmanager async def client_streams( self, ) -> AsyncIterator[ - tuple[ - MemoryObjectReceiveStream[JSONRPCMessage | Exception], - MemoryObjectSendStream[JSONRPCMessage], - ] + tuple[MemoryObjectReceiveStream[SessionMessage | Exception], MemoryObjectSendStream[SessionMessage]] ]: # pragma: no cover - async with sse_client( - url=self.url, - headers=self.headers, - timeout=self.timeout, - sse_read_timeout=self.sse_read_timeout, - ) as (read_stream, write_stream): + if not isinstance(self.timeout, timedelta): + warnings.warn( + 'Passing timeout as a float has been deprecated, please use a timedelta instead.', + DeprecationWarning, + stacklevel=2, + ) + self.timeout = timedelta(seconds=self.timeout) + + if not isinstance(self.sse_read_timeout, timedelta): + warnings.warn( + 'Passing sse_read_timeout as a float has been deprecated, please use a timedelta instead.', + DeprecationWarning, + stacklevel=2, + ) + self.sse_read_timeout = timedelta(seconds=self.sse_read_timeout) + + async with streamablehttp_client( + url=self.url, headers=self.headers, timeout=self.timeout, sse_read_timeout=self.sse_read_timeout + ) as (read_stream, write_stream, _): yield read_stream, write_stream def _get_log_level(self) -> LoggingLevel | None: diff --git a/pydantic_ai_slim/pyproject.toml b/pydantic_ai_slim/pyproject.toml index 15d0bb838..e83c01cc2 100644 --- a/pydantic_ai_slim/pyproject.toml +++ b/pydantic_ai_slim/pyproject.toml @@ -75,7 +75,7 @@ tavily = ["tavily-python>=0.5.0"] # CLI cli = ["rich>=13", "prompt-toolkit>=3", "argcomplete>=3.5.0"] # MCP -mcp = ["mcp>=1.6.0; python_version >= '3.10'"] +mcp = ["mcp>=1.8.0; python_version >= '3.10'"] # Evals evals = ["pydantic-evals=={{ version }}"] # A2A diff --git a/tests/test_mcp.py b/tests/test_mcp.py index 735f76e55..ab3d9b9db 100644 --- a/tests/test_mcp.py +++ b/tests/test_mcp.py @@ -1,6 +1,7 @@ """Tests for the MCP (Model Context Protocol) server implementation.""" import re +from datetime import timedelta from pathlib import Path import pytest @@ -70,25 +71,41 @@ async def test_stdio_server_with_cwd(): assert len(tools) == 10 -def test_sse_server(): - sse_server = MCPServerHTTP(url='http://localhost:8000/sse') - assert sse_server.url == 'http://localhost:8000/sse' - assert sse_server._get_log_level() is None # pyright: ignore[reportPrivateUsage] +def test_http_server(): + http_server = MCPServerHTTP(url='http://localhost:8000/sse') + assert http_server.url == 'http://localhost:8000/sse' + assert http_server._get_log_level() is None # pyright: ignore[reportPrivateUsage] -def test_sse_server_with_header_and_timeout(): - sse_server = MCPServerHTTP( +def test_http_server_with_header_and_timeout(): + http_server = MCPServerHTTP( url='http://localhost:8000/sse', headers={'my-custom-header': 'my-header-value'}, - timeout=10, - sse_read_timeout=100, + timeout=timedelta(seconds=10), + sse_read_timeout=timedelta(seconds=100), log_level='info', ) - assert sse_server.url == 'http://localhost:8000/sse' - assert sse_server.headers is not None and sse_server.headers['my-custom-header'] == 'my-header-value' - assert sse_server.timeout == 10 - assert sse_server.sse_read_timeout == 100 - assert sse_server._get_log_level() == 'info' # pyright: ignore[reportPrivateUsage] + assert http_server.url == 'http://localhost:8000/sse' + assert http_server.headers is not None and http_server.headers['my-custom-header'] == 'my-header-value' + assert http_server.timeout == timedelta(seconds=10) + assert http_server.sse_read_timeout == timedelta(seconds=100) + assert http_server._get_log_level() == 'info' # pyright: ignore[reportPrivateUsage] + + +def test_http_server_with_deprecated_arguments(): + with pytest.warns(DeprecationWarning): + http_server = MCPServerHTTP( + url='http://localhost:8000/sse', + headers={'my-custom-header': 'my-header-value'}, + timeout=10, + sse_read_timeout=100, + log_level='info', + ) + assert http_server.url == 'http://localhost:8000/sse' + assert http_server.headers is not None and http_server.headers['my-custom-header'] == 'my-header-value' + assert http_server.timeout == timedelta(seconds=10) + assert http_server.sse_read_timeout == timedelta(seconds=100) + assert http_server._get_log_level() == 'info' # pyright: ignore[reportPrivateUsage] @pytest.mark.vcr() diff --git a/uv.lock b/uv.lock index 3fb6c45b1..9646a5025 100644 --- a/uv.lock +++ b/uv.lock @@ -1748,7 +1748,7 @@ wheels = [ [[package]] name = "mcp" -version = "1.6.0" +version = "1.9.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "anyio", marker = "python_full_version >= '3.10'" }, @@ -1756,13 +1756,14 @@ dependencies = [ { name = "httpx-sse", marker = "python_full_version >= '3.10'" }, { name = "pydantic", marker = "python_full_version >= '3.10'" }, { name = "pydantic-settings", marker = "python_full_version >= '3.10'" }, + { name = "python-multipart", marker = "python_full_version >= '3.10'" }, { name = "sse-starlette", marker = "python_full_version >= '3.10'" }, { name = "starlette", marker = "python_full_version >= '3.10'" }, - { name = "uvicorn", marker = "python_full_version >= '3.10'" }, + { name = "uvicorn", marker = "python_full_version >= '3.10' and sys_platform != 'emscripten'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/95/d2/f587cb965a56e992634bebc8611c5b579af912b74e04eb9164bd49527d21/mcp-1.6.0.tar.gz", hash = "sha256:d9324876de2c5637369f43161cd71eebfd803df5a95e46225cab8d280e366723", size = 200031, upload-time = "2025-03-27T16:46:32.336Z" } +sdist = { url = "https://files.pythonhosted.org/packages/bc/8d/0f4468582e9e97b0a24604b585c651dfd2144300ecffd1c06a680f5c8861/mcp-1.9.0.tar.gz", hash = "sha256:905d8d208baf7e3e71d70c82803b89112e321581bcd2530f9de0fe4103d28749", size = 281432, upload-time = "2025-05-15T18:51:06.615Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/10/30/20a7f33b0b884a9d14dd3aa94ff1ac9da1479fe2ad66dd9e2736075d2506/mcp-1.6.0-py3-none-any.whl", hash = "sha256:7bd24c6ea042dbec44c754f100984d186620d8b841ec30f1b19eda9b93a634d0", size = 76077, upload-time = "2025-03-27T16:46:29.919Z" }, + { url = "https://files.pythonhosted.org/packages/a5/d5/22e36c95c83c80eb47c83f231095419cf57cf5cca5416f1c960032074c78/mcp-1.9.0-py3-none-any.whl", hash = "sha256:9dfb89c8c56f742da10a5910a1f64b0d2ac2c3ed2bd572ddb1cfab7f35957178", size = 125082, upload-time = "2025-05-15T18:51:04.916Z" }, ] [package.optional-dependencies] @@ -3084,7 +3085,7 @@ requires-dist = [ { name = "groq", marker = "extra == 'groq'", specifier = ">=0.15.0" }, { name = "httpx", specifier = ">=0.27" }, { name = "logfire", marker = "extra == 'logfire'", specifier = ">=3.11.0" }, - { name = "mcp", marker = "python_full_version >= '3.10' and extra == 'mcp'", specifier = ">=1.6.0" }, + { name = "mcp", marker = "python_full_version >= '3.10' and extra == 'mcp'", specifier = ">=1.8.0" }, { name = "mistralai", marker = "extra == 'mistral'", specifier = ">=1.2.5" }, { name = "openai", marker = "extra == 'openai'", specifier = ">=1.75.0" }, { name = "opentelemetry-api", specifier = ">=1.28.0" },