From fde04eb85e1487180baa56e271de93a02dac3ed4 Mon Sep 17 00:00:00 2001 From: Christian Glessner Date: Mon, 5 May 2025 13:56:17 +0000 Subject: [PATCH 1/2] Fix SSE server transport to support absolute endpoints This change fixes the endpoint URL handling in the SSE server transport to support both relative and absolute URLs. Some clients like Copilot Studio require absolute URLs. This change aligns with the TypeScript SDK's support for absolute endpoint URLs as in https://github.com/modelcontextprotocol/typescript-sdk/blob/main/src/server/sse.ts The PR: 1. Removes unnecessary URL quoting which would break absolute URLs 2. Adds comprehensive tests for both relative and absolute URL endpoints --- src/mcp/server/sse.py | 3 +- tests/shared/test_sse.py | 269 +++++++++++++++++++++++++++++++++++++-- 2 files changed, 260 insertions(+), 12 deletions(-) diff --git a/src/mcp/server/sse.py b/src/mcp/server/sse.py index cc41a80d6..e03765141 100644 --- a/src/mcp/server/sse.py +++ b/src/mcp/server/sse.py @@ -40,7 +40,6 @@ async def handle_sse(request): import logging from contextlib import asynccontextmanager from typing import Any -from urllib.parse import quote from uuid import UUID, uuid4 import anyio @@ -100,7 +99,7 @@ async def connect_sse(self, scope: Scope, receive: Receive, send: Send): write_stream, write_stream_reader = anyio.create_memory_object_stream(0) session_id = uuid4() - session_uri = f"{quote(self._endpoint)}?session_id={session_id.hex}" + session_uri = f"{self._endpoint}?session_id={session_id.hex}" self._read_stream_writers[session_id] = read_stream_writer logger.debug(f"Created new session with ID: {session_id}") diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index 4558bb88c..a284d47b8 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -119,6 +119,54 @@ def run_server(server_port: int) -> None: time.sleep(0.5) +def make_server_app_with_endpoint(endpoint: str) -> Starlette: + """Create test Starlette app with SSE transport using the specified endpoint""" + sse = SseServerTransport(endpoint) + server = ServerTest() + + async def handle_sse(request: Request) -> Response: + async with sse.connect_sse( + request.scope, request.receive, request._send + ) as streams: + await server.run( + streams[0], streams[1], server.create_initialization_options() + ) + return Response() + + # For absolute URLs, we route all paths + if endpoint.startswith(("http://", "https://")): + route_path = "/sse" + mount_path = "/" + else: + route_path = "/sse" + mount_path = endpoint + + app = Starlette( + routes=[ + Route(route_path, endpoint=handle_sse), + Mount(mount_path, app=sse.handle_post_message), + ] + ) + + return app + + +def run_server_with_endpoint(server_port: int, endpoint: str) -> None: + app = make_server_app_with_endpoint(endpoint) + server = uvicorn.Server( + config=uvicorn.Config( + app=app, host="127.0.0.1", port=server_port, log_level="error" + ) + ) + print(f"starting server on {server_port} with endpoint {endpoint}") + server.run() + + # Give server time to start + while not server.started: + print("waiting for server to start") + time.sleep(0.5) + + @pytest.fixture() def server(server_port: int) -> Generator[None, None, None]: proc = multiprocessing.Process( @@ -159,6 +207,129 @@ async def http_client(server, server_url) -> AsyncGenerator[httpx.AsyncClient, N yield client +@pytest.fixture() +def server_with_relative_endpoint(server_port: int) -> Generator[None, None, None]: + """Setup a server with a relative endpoint path""" + proc = multiprocessing.Process( + target=run_server_with_endpoint, + kwargs={"server_port": server_port, "endpoint": "/messages/"}, + daemon=True, + ) + print("starting process with relative endpoint") + proc.start() + + # Wait for server to be running + max_attempts = 20 + attempt = 0 + print("waiting for server to start") + while attempt < max_attempts: + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.connect(("127.0.0.1", server_port)) + break + except ConnectionRefusedError: + time.sleep(0.1) + attempt += 1 + else: + raise RuntimeError(f"Server failed to start after {max_attempts} attempts") + + yield + + print("killing server") + # Signal the server to stop + proc.kill() + proc.join(timeout=2) + if proc.is_alive(): + print("server process failed to terminate") + + +@pytest.fixture() +def server_with_absolute_endpoint( + server_port: int, server_url: str +) -> Generator[None, None, None]: + """Setup a server with an absolute endpoint URL""" + absolute_endpoint = f"{server_url}/messages/" + proc = multiprocessing.Process( + target=run_server_with_endpoint, + kwargs={"server_port": server_port, "endpoint": absolute_endpoint}, + daemon=True, + ) + print(f"starting process with absolute endpoint: {absolute_endpoint}") + proc.start() + + # Wait for server to be running + max_attempts = 20 + attempt = 0 + print("waiting for server to start") + while attempt < max_attempts: + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.connect(("127.0.0.1", server_port)) + break + except ConnectionRefusedError: + time.sleep(0.1) + attempt += 1 + else: + raise RuntimeError(f"Server failed to start after {max_attempts} attempts") + + yield + + print("killing server") + # Signal the server to stop + proc.kill() + proc.join(timeout=2) + if proc.is_alive(): + print("server process failed to terminate") + + +@pytest.fixture() +async def http_client_with_relative_endpoint( + server_with_relative_endpoint, server_url +) -> AsyncGenerator[httpx.AsyncClient, None]: + """Create test client for server with relative endpoint""" + async with httpx.AsyncClient(base_url=server_url) as client: + yield client + + +@pytest.fixture() +async def http_client_with_absolute_endpoint( + server_with_absolute_endpoint, server_url +) -> AsyncGenerator[httpx.AsyncClient, None]: + """Create test client for server with absolute endpoint""" + async with httpx.AsyncClient(base_url=server_url) as client: + yield client + + +@pytest.fixture +async def initialized_sse_client_session( + server, server_url: str +) -> AsyncGenerator[ClientSession, None]: + async with sse_client(server_url + "/sse", sse_read_timeout=0.5) as streams: + async with ClientSession(*streams) as session: + await session.initialize() + yield session + + +@pytest.fixture +async def initialized_sse_client_session_with_relative_endpoint( + server_with_relative_endpoint, server_url: str +) -> AsyncGenerator[ClientSession, None]: + async with sse_client(server_url + "/sse", sse_read_timeout=0.5) as streams: + async with ClientSession(*streams) as session: + await session.initialize() + yield session + + +@pytest.fixture +async def initialized_sse_client_session_with_absolute_endpoint( + server_with_absolute_endpoint, server_url: str +) -> AsyncGenerator[ClientSession, None]: + async with sse_client(server_url + "/sse", sse_read_timeout=0.5) as streams: + async with ClientSession(*streams) as session: + await session.initialize() + yield session + + # Tests @pytest.mark.anyio async def test_raw_sse_connection(http_client: httpx.AsyncClient) -> None: @@ -202,16 +373,6 @@ async def test_sse_client_basic_connection(server: None, server_url: str) -> Non assert isinstance(ping_result, EmptyResult) -@pytest.fixture -async def initialized_sse_client_session( - server, server_url: str -) -> AsyncGenerator[ClientSession, None]: - async with sse_client(server_url + "/sse", sse_read_timeout=0.5) as streams: - async with ClientSession(*streams) as session: - await session.initialize() - yield session - - @pytest.mark.anyio async def test_sse_client_happy_request_and_response( initialized_sse_client_session: ClientSession, @@ -252,3 +413,91 @@ async def test_sse_client_timeout( return pytest.fail("the client should have timed out and returned an error already") + + +@pytest.mark.anyio +async def test_raw_sse_connection_with_relative_endpoint(http_client_with_relative_endpoint: httpx.AsyncClient) -> None: + """Test the SSE connection establishment with a relative endpoint URL.""" + async with anyio.create_task_group(): + + async def connection_test() -> None: + async with http_client_with_relative_endpoint.stream("GET", "/sse") as response: + assert response.status_code == 200 + assert ( + response.headers["content-type"] + == "text/event-stream; charset=utf-8" + ) + + line_number = 0 + async for line in response.aiter_lines(): + if line_number == 0: + assert line == "event: endpoint" + elif line_number == 1: + assert line.startswith("data: /messages/?session_id=") + # Verify it's a relative URL + endpoint_data = line.removeprefix("data: ") + assert not endpoint_data.startswith(("http://", "https://")) + assert endpoint_data.startswith("/messages/?session_id=") + else: + return + line_number += 1 + + # Add timeout to prevent test from hanging if it fails + with anyio.fail_after(3): + await connection_test() + + +@pytest.mark.anyio +async def test_raw_sse_connection_with_absolute_endpoint(http_client_with_absolute_endpoint: httpx.AsyncClient) -> None: + """Test the SSE connection establishment with an absolute endpoint URL.""" + async with anyio.create_task_group(): + + async def connection_test() -> None: + async with http_client_with_absolute_endpoint.stream("GET", "/sse") as response: + assert response.status_code == 200 + assert ( + response.headers["content-type"] + == "text/event-stream; charset=utf-8" + ) + + line_number = 0 + async for line in response.aiter_lines(): + if line_number == 0: + assert line == "event: endpoint" + elif line_number == 1: + # Verify it's an absolute URL + assert line.startswith("data: http://") + assert "/messages/?session_id=" in line + else: + return + line_number += 1 + + # Add timeout to prevent test from hanging if it fails + with anyio.fail_after(3): + await connection_test() + + +@pytest.mark.anyio +async def test_sse_client_with_relative_endpoint( + initialized_sse_client_session_with_relative_endpoint: ClientSession, +) -> None: + """Test that a client session works properly with a relative endpoint.""" + session = initialized_sse_client_session_with_relative_endpoint + # Test basic functionality + response = await session.read_resource(uri=AnyUrl("foobar://should-work")) + assert len(response.contents) == 1 + assert isinstance(response.contents[0], TextResourceContents) + assert response.contents[0].text == "Read should-work" + + +@pytest.mark.anyio +async def test_sse_client_with_absolute_endpoint( + initialized_sse_client_session_with_absolute_endpoint: ClientSession, +) -> None: + """Test that a client session works properly with an absolute endpoint.""" + session = initialized_sse_client_session_with_absolute_endpoint + # Test basic functionality + response = await session.read_resource(uri=AnyUrl("foobar://should-work")) + assert len(response.contents) == 1 + assert isinstance(response.contents[0], TextResourceContents) + assert response.contents[0].text == "Read should-work" From 4f63efe3557695eb685c2b503662d4c010bf5299 Mon Sep 17 00:00:00 2001 From: Christian Glessner Date: Mon, 5 May 2025 15:08:16 +0000 Subject: [PATCH 2/2] Fix linting issues --- tests/shared/test_sse.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index a284d47b8..ce71104ee 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -416,12 +416,16 @@ async def test_sse_client_timeout( @pytest.mark.anyio -async def test_raw_sse_connection_with_relative_endpoint(http_client_with_relative_endpoint: httpx.AsyncClient) -> None: +async def test_raw_sse_connection_with_relative_endpoint( + http_client_with_relative_endpoint: httpx.AsyncClient, +) -> None: """Test the SSE connection establishment with a relative endpoint URL.""" async with anyio.create_task_group(): async def connection_test() -> None: - async with http_client_with_relative_endpoint.stream("GET", "/sse") as response: + async with http_client_with_relative_endpoint.stream( + "GET", "/sse" + ) as response: assert response.status_code == 200 assert ( response.headers["content-type"] @@ -448,12 +452,16 @@ async def connection_test() -> None: @pytest.mark.anyio -async def test_raw_sse_connection_with_absolute_endpoint(http_client_with_absolute_endpoint: httpx.AsyncClient) -> None: +async def test_raw_sse_connection_with_absolute_endpoint( + http_client_with_absolute_endpoint: httpx.AsyncClient, +) -> None: """Test the SSE connection establishment with an absolute endpoint URL.""" async with anyio.create_task_group(): async def connection_test() -> None: - async with http_client_with_absolute_endpoint.stream("GET", "/sse") as response: + async with http_client_with_absolute_endpoint.stream( + "GET", "/sse" + ) as response: assert response.status_code == 200 assert ( response.headers["content-type"]