Skip to content

fix for SSE URL handling when a server name is specified #597

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 41 additions & 2 deletions src/mcp/client/sse.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
from contextlib import asynccontextmanager
from typing import Any
from urllib.parse import urljoin, urlparse
from urllib.parse import urljoin, urlparse, urlunparse

import anyio
import httpx
Expand All @@ -19,6 +19,31 @@ def remove_request_params(url: str) -> str:
return urljoin(url, urlparse(url).path)


def custom_url_join(base_url: str, endpoint: str) -> str:
"""
Custom URL join function to handle the case where the endpoint is relative
to a base URL. This function ensures that the base URL and endpoint are
combined correctly, even if the endpoint is not a full URL.
"""
# Parse the base URL
parsed_base = urlparse(base_url)

# Get the path prefix (e.g., '/weather')
path_prefix = "/".join(parsed_base.path.split("/")[:-1])

# Remove any leading slash from the endpoint
clean_endpoint = endpoint.lstrip("/")

# Create the new path by joining prefix and endpoint
new_path = f"{path_prefix}/{clean_endpoint}"

# Create a new parsed URL with the updated path
parsed_new = parsed_base._replace(path=new_path)

# Convert back to a string URL
return urlunparse(parsed_new)


@asynccontextmanager
async def sse_client(
url: str,
Expand All @@ -44,6 +69,15 @@ async def sse_client(
async with anyio.create_task_group() as tg:
try:
logger.info(f"Connecting to SSE endpoint: {remove_request_params(url)}")

# extract MCP server name from URL, for example: https://mcp.example.com/weather/sse
# will extract 'weather'
path_tokens = urlparse(url).path.split("/")
optional_mcp_server_name = (
path_tokens[1:-1] if len(path_tokens) > 2 else None
)
logger.debug(f"MCP Server name (optional): {optional_mcp_server_name}")

async with httpx.AsyncClient(headers=headers) as client:
async with aconnect_sse(
client,
Expand All @@ -62,7 +96,12 @@ async def sse_reader(
logger.debug(f"Received SSE event: {sse.event}")
match sse.event:
case "endpoint":
endpoint_url = urljoin(url, sse.data)
if optional_mcp_server_name:
endpoint_url = custom_url_join(
url, sse.data
)
else:
endpoint_url = urljoin(url, sse.data)
logger.info(
f"Received endpoint URL: {endpoint_url}"
)
Expand Down
30 changes: 30 additions & 0 deletions tests/client/test_sse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import pytest

from mcp.client.sse import custom_url_join


@pytest.mark.parametrize(
"base_url,endpoint,expected",
[
# Additional test cases to verify behavior with different URL structures
(
"https://mcp.example.com/weather/sse",
"/messages/?session_id=616df71373444d76bd566df4377c9629",
"https://mcp.example.com/weather/messages/?session_id=616df71373444d76bd566df4377c9629",
),
(
"https://mcp.example.com/weather/clarksburg/sse",
"/messages/?session_id=616df71373444d76bd566df4377c9629",
"https://mcp.example.com/weather/clarksburg/messages/?session_id=616df71373444d76bd566df4377c9629",
),
(
"https://mcp.example.com/sse",
"/messages/?session_id=616df71373444d76bd566df4377c9629",
"https://mcp.example.com/messages/?session_id=616df71373444d76bd566df4377c9629",
),
],
)
def test_custom_url_join(base_url, endpoint, expected):
"""Test the custom_url_join function with messages endpoint and session ID."""
result = custom_url_join(base_url, endpoint)
assert result == expected
Loading