diff --git a/src/mcp/client/sse.py b/src/mcp/client/sse.py index ff04d2f96..fcdc7b2df 100644 --- a/src/mcp/client/sse.py +++ b/src/mcp/client/sse.py @@ -1,7 +1,7 @@ import logging from contextlib import asynccontextmanager from typing import Any -from urllib.parse import urljoin, urlparse +from urllib.parse import urljoin, urlparse, urlunparse import anyio import httpx @@ -19,6 +19,31 @@ def remove_request_params(url: str) -> str: return urljoin(url, urlparse(url).path) +def custom_url_join(base_url: str, endpoint: str) -> str: + """ + Custom URL join function to handle the case where the endpoint is relative + to a base URL. This function ensures that the base URL and endpoint are + combined correctly, even if the endpoint is not a full URL. + """ + # Parse the base URL + parsed_base = urlparse(base_url) + + # Get the path prefix (e.g., '/weather') + path_prefix = "/".join(parsed_base.path.split("/")[:-1]) + + # Remove any leading slash from the endpoint + clean_endpoint = endpoint.lstrip("/") + + # Create the new path by joining prefix and endpoint + new_path = f"{path_prefix}/{clean_endpoint}" + + # Create a new parsed URL with the updated path + parsed_new = parsed_base._replace(path=new_path) + + # Convert back to a string URL + return urlunparse(parsed_new) + + @asynccontextmanager async def sse_client( url: str, @@ -44,6 +69,15 @@ async def sse_client( async with anyio.create_task_group() as tg: try: logger.info(f"Connecting to SSE endpoint: {remove_request_params(url)}") + + # extract MCP server name from URL, for example: https://mcp.example.com/weather/sse + # will extract 'weather' + path_tokens = urlparse(url).path.split("/") + optional_mcp_server_name = ( + path_tokens[1:-1] if len(path_tokens) > 2 else None + ) + logger.debug(f"MCP Server name (optional): {optional_mcp_server_name}") + async with httpx.AsyncClient(headers=headers) as client: async with aconnect_sse( client, @@ -62,7 +96,12 @@ async def sse_reader( logger.debug(f"Received SSE event: {sse.event}") match sse.event: case "endpoint": - endpoint_url = urljoin(url, sse.data) + if optional_mcp_server_name: + endpoint_url = custom_url_join( + url, sse.data + ) + else: + endpoint_url = urljoin(url, sse.data) logger.info( f"Received endpoint URL: {endpoint_url}" ) diff --git a/tests/client/test_sse.py b/tests/client/test_sse.py new file mode 100644 index 000000000..5b3104fda --- /dev/null +++ b/tests/client/test_sse.py @@ -0,0 +1,30 @@ +import pytest + +from mcp.client.sse import custom_url_join + + +@pytest.mark.parametrize( + "base_url,endpoint,expected", + [ + # Additional test cases to verify behavior with different URL structures + ( + "https://mcp.example.com/weather/sse", + "/messages/?session_id=616df71373444d76bd566df4377c9629", + "https://mcp.example.com/weather/messages/?session_id=616df71373444d76bd566df4377c9629", + ), + ( + "https://mcp.example.com/weather/clarksburg/sse", + "/messages/?session_id=616df71373444d76bd566df4377c9629", + "https://mcp.example.com/weather/clarksburg/messages/?session_id=616df71373444d76bd566df4377c9629", + ), + ( + "https://mcp.example.com/sse", + "/messages/?session_id=616df71373444d76bd566df4377c9629", + "https://mcp.example.com/messages/?session_id=616df71373444d76bd566df4377c9629", + ), + ], +) +def test_custom_url_join(base_url, endpoint, expected): + """Test the custom_url_join function with messages endpoint and session ID.""" + result = custom_url_join(base_url, endpoint) + assert result == expected