diff --git a/src/mcp/server/sse.py b/src/mcp/server/sse.py index cc41a80d6..e03765141 100644 --- a/src/mcp/server/sse.py +++ b/src/mcp/server/sse.py @@ -40,7 +40,6 @@ async def handle_sse(request): import logging from contextlib import asynccontextmanager from typing import Any -from urllib.parse import quote from uuid import UUID, uuid4 import anyio @@ -100,7 +99,7 @@ async def connect_sse(self, scope: Scope, receive: Receive, send: Send): write_stream, write_stream_reader = anyio.create_memory_object_stream(0) session_id = uuid4() - session_uri = f"{quote(self._endpoint)}?session_id={session_id.hex}" + session_uri = f"{self._endpoint}?session_id={session_id.hex}" self._read_stream_writers[session_id] = read_stream_writer logger.debug(f"Created new session with ID: {session_id}") diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index 4558bb88c..ce71104ee 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -119,6 +119,54 @@ def run_server(server_port: int) -> None: time.sleep(0.5) +def make_server_app_with_endpoint(endpoint: str) -> Starlette: + """Create test Starlette app with SSE transport using the specified endpoint""" + sse = SseServerTransport(endpoint) + server = ServerTest() + + async def handle_sse(request: Request) -> Response: + async with sse.connect_sse( + request.scope, request.receive, request._send + ) as streams: + await server.run( + streams[0], streams[1], server.create_initialization_options() + ) + return Response() + + # For absolute URLs, we route all paths + if endpoint.startswith(("http://", "https://")): + route_path = "/sse" + mount_path = "/" + else: + route_path = "/sse" + mount_path = endpoint + + app = Starlette( + routes=[ + Route(route_path, endpoint=handle_sse), + Mount(mount_path, app=sse.handle_post_message), + ] + ) + + return app + + +def run_server_with_endpoint(server_port: int, endpoint: str) -> None: + app = make_server_app_with_endpoint(endpoint) + server = uvicorn.Server( + config=uvicorn.Config( + app=app, host="127.0.0.1", port=server_port, log_level="error" + ) + ) + print(f"starting server on {server_port} with endpoint {endpoint}") + server.run() + + # Give server time to start + while not server.started: + print("waiting for server to start") + time.sleep(0.5) + + @pytest.fixture() def server(server_port: int) -> Generator[None, None, None]: proc = multiprocessing.Process( @@ -159,6 +207,129 @@ async def http_client(server, server_url) -> AsyncGenerator[httpx.AsyncClient, N yield client +@pytest.fixture() +def server_with_relative_endpoint(server_port: int) -> Generator[None, None, None]: + """Setup a server with a relative endpoint path""" + proc = multiprocessing.Process( + target=run_server_with_endpoint, + kwargs={"server_port": server_port, "endpoint": "/messages/"}, + daemon=True, + ) + print("starting process with relative endpoint") + proc.start() + + # Wait for server to be running + max_attempts = 20 + attempt = 0 + print("waiting for server to start") + while attempt < max_attempts: + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.connect(("127.0.0.1", server_port)) + break + except ConnectionRefusedError: + time.sleep(0.1) + attempt += 1 + else: + raise RuntimeError(f"Server failed to start after {max_attempts} attempts") + + yield + + print("killing server") + # Signal the server to stop + proc.kill() + proc.join(timeout=2) + if proc.is_alive(): + print("server process failed to terminate") + + +@pytest.fixture() +def server_with_absolute_endpoint( + server_port: int, server_url: str +) -> Generator[None, None, None]: + """Setup a server with an absolute endpoint URL""" + absolute_endpoint = f"{server_url}/messages/" + proc = multiprocessing.Process( + target=run_server_with_endpoint, + kwargs={"server_port": server_port, "endpoint": absolute_endpoint}, + daemon=True, + ) + print(f"starting process with absolute endpoint: {absolute_endpoint}") + proc.start() + + # Wait for server to be running + max_attempts = 20 + attempt = 0 + print("waiting for server to start") + while attempt < max_attempts: + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.connect(("127.0.0.1", server_port)) + break + except ConnectionRefusedError: + time.sleep(0.1) + attempt += 1 + else: + raise RuntimeError(f"Server failed to start after {max_attempts} attempts") + + yield + + print("killing server") + # Signal the server to stop + proc.kill() + proc.join(timeout=2) + if proc.is_alive(): + print("server process failed to terminate") + + +@pytest.fixture() +async def http_client_with_relative_endpoint( + server_with_relative_endpoint, server_url +) -> AsyncGenerator[httpx.AsyncClient, None]: + """Create test client for server with relative endpoint""" + async with httpx.AsyncClient(base_url=server_url) as client: + yield client + + +@pytest.fixture() +async def http_client_with_absolute_endpoint( + server_with_absolute_endpoint, server_url +) -> AsyncGenerator[httpx.AsyncClient, None]: + """Create test client for server with absolute endpoint""" + async with httpx.AsyncClient(base_url=server_url) as client: + yield client + + +@pytest.fixture +async def initialized_sse_client_session( + server, server_url: str +) -> AsyncGenerator[ClientSession, None]: + async with sse_client(server_url + "/sse", sse_read_timeout=0.5) as streams: + async with ClientSession(*streams) as session: + await session.initialize() + yield session + + +@pytest.fixture +async def initialized_sse_client_session_with_relative_endpoint( + server_with_relative_endpoint, server_url: str +) -> AsyncGenerator[ClientSession, None]: + async with sse_client(server_url + "/sse", sse_read_timeout=0.5) as streams: + async with ClientSession(*streams) as session: + await session.initialize() + yield session + + +@pytest.fixture +async def initialized_sse_client_session_with_absolute_endpoint( + server_with_absolute_endpoint, server_url: str +) -> AsyncGenerator[ClientSession, None]: + async with sse_client(server_url + "/sse", sse_read_timeout=0.5) as streams: + async with ClientSession(*streams) as session: + await session.initialize() + yield session + + # Tests @pytest.mark.anyio async def test_raw_sse_connection(http_client: httpx.AsyncClient) -> None: @@ -202,16 +373,6 @@ async def test_sse_client_basic_connection(server: None, server_url: str) -> Non assert isinstance(ping_result, EmptyResult) -@pytest.fixture -async def initialized_sse_client_session( - server, server_url: str -) -> AsyncGenerator[ClientSession, None]: - async with sse_client(server_url + "/sse", sse_read_timeout=0.5) as streams: - async with ClientSession(*streams) as session: - await session.initialize() - yield session - - @pytest.mark.anyio async def test_sse_client_happy_request_and_response( initialized_sse_client_session: ClientSession, @@ -252,3 +413,99 @@ async def test_sse_client_timeout( return pytest.fail("the client should have timed out and returned an error already") + + +@pytest.mark.anyio +async def test_raw_sse_connection_with_relative_endpoint( + http_client_with_relative_endpoint: httpx.AsyncClient, +) -> None: + """Test the SSE connection establishment with a relative endpoint URL.""" + async with anyio.create_task_group(): + + async def connection_test() -> None: + async with http_client_with_relative_endpoint.stream( + "GET", "/sse" + ) as response: + assert response.status_code == 200 + assert ( + response.headers["content-type"] + == "text/event-stream; charset=utf-8" + ) + + line_number = 0 + async for line in response.aiter_lines(): + if line_number == 0: + assert line == "event: endpoint" + elif line_number == 1: + assert line.startswith("data: /messages/?session_id=") + # Verify it's a relative URL + endpoint_data = line.removeprefix("data: ") + assert not endpoint_data.startswith(("http://", "https://")) + assert endpoint_data.startswith("/messages/?session_id=") + else: + return + line_number += 1 + + # Add timeout to prevent test from hanging if it fails + with anyio.fail_after(3): + await connection_test() + + +@pytest.mark.anyio +async def test_raw_sse_connection_with_absolute_endpoint( + http_client_with_absolute_endpoint: httpx.AsyncClient, +) -> None: + """Test the SSE connection establishment with an absolute endpoint URL.""" + async with anyio.create_task_group(): + + async def connection_test() -> None: + async with http_client_with_absolute_endpoint.stream( + "GET", "/sse" + ) as response: + assert response.status_code == 200 + assert ( + response.headers["content-type"] + == "text/event-stream; charset=utf-8" + ) + + line_number = 0 + async for line in response.aiter_lines(): + if line_number == 0: + assert line == "event: endpoint" + elif line_number == 1: + # Verify it's an absolute URL + assert line.startswith("data: http://") + assert "/messages/?session_id=" in line + else: + return + line_number += 1 + + # Add timeout to prevent test from hanging if it fails + with anyio.fail_after(3): + await connection_test() + + +@pytest.mark.anyio +async def test_sse_client_with_relative_endpoint( + initialized_sse_client_session_with_relative_endpoint: ClientSession, +) -> None: + """Test that a client session works properly with a relative endpoint.""" + session = initialized_sse_client_session_with_relative_endpoint + # Test basic functionality + response = await session.read_resource(uri=AnyUrl("foobar://should-work")) + assert len(response.contents) == 1 + assert isinstance(response.contents[0], TextResourceContents) + assert response.contents[0].text == "Read should-work" + + +@pytest.mark.anyio +async def test_sse_client_with_absolute_endpoint( + initialized_sse_client_session_with_absolute_endpoint: ClientSession, +) -> None: + """Test that a client session works properly with an absolute endpoint.""" + session = initialized_sse_client_session_with_absolute_endpoint + # Test basic functionality + response = await session.read_resource(uri=AnyUrl("foobar://should-work")) + assert len(response.contents) == 1 + assert isinstance(response.contents[0], TextResourceContents) + assert response.contents[0].text == "Read should-work"