From fddc22d0c4f5429be3a03f8690cb5057497f458e Mon Sep 17 00:00:00 2001 From: andrewmjc Date: Tue, 8 Apr 2025 09:03:33 -0600 Subject: [PATCH 1/2] headers in handshake --- src/mcp/client/websocket.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/mcp/client/websocket.py b/src/mcp/client/websocket.py index 2c2ed38b9..0aa61dda9 100644 --- a/src/mcp/client/websocket.py +++ b/src/mcp/client/websocket.py @@ -2,6 +2,7 @@ import logging from collections.abc import AsyncGenerator from contextlib import asynccontextmanager +from typing import Any import anyio from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream @@ -17,6 +18,7 @@ @asynccontextmanager async def websocket_client( url: str, + headers: dict[str, Any] | None = None, ) -> AsyncGenerator[ tuple[ MemoryObjectReceiveStream[types.JSONRPCMessage | Exception], @@ -48,7 +50,7 @@ async def websocket_client( write_stream, write_stream_reader = anyio.create_memory_object_stream(0) # Connect using websockets, requesting the "mcp" subprotocol - async with ws_connect(url, subprotocols=[Subprotocol("mcp")]) as ws: + async with ws_connect(url, subprotocols=[Subprotocol("mcp")], additional_headers=headers) as ws: async def ws_reader(): """ From 7f7c29ce8bb513cb6494f11c3b55c9fc466ab854 Mon Sep 17 00:00:00 2001 From: andrewmjc Date: Tue, 8 Apr 2025 10:07:44 -0600 Subject: [PATCH 2/2] format --- src/mcp/client/websocket.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/mcp/client/websocket.py b/src/mcp/client/websocket.py index 0aa61dda9..6c0f30128 100644 --- a/src/mcp/client/websocket.py +++ b/src/mcp/client/websocket.py @@ -50,7 +50,9 @@ async def websocket_client( write_stream, write_stream_reader = anyio.create_memory_object_stream(0) # Connect using websockets, requesting the "mcp" subprotocol - async with ws_connect(url, subprotocols=[Subprotocol("mcp")], additional_headers=headers) as ws: + async with ws_connect( + url, subprotocols=[Subprotocol("mcp")], additional_headers=headers + ) as ws: async def ws_reader(): """