diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 4d65bbebb..7bb8821f7 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -261,6 +261,7 @@ async def call_tool( read_timeout_seconds: timedelta | None = None, ) -> types.CallToolResult: """Send a tools/call request.""" + return await self.send_request( types.ClientRequest( types.CallToolRequest( diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 7a8887cd9..ef424e3b3 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -7,207 +7,377 @@ """ import logging +from collections.abc import AsyncGenerator, Awaitable, Callable from contextlib import asynccontextmanager +from dataclasses import dataclass from datetime import timedelta from typing import Any import anyio import httpx -from httpx_sse import EventSource, aconnect_sse +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from httpx_sse import EventSource, ServerSentEvent, aconnect_sse -from mcp.shared.message import SessionMessage +from mcp.shared.message import ClientMessageMetadata, SessionMessage from mcp.types import ( ErrorData, JSONRPCError, JSONRPCMessage, JSONRPCNotification, JSONRPCRequest, + JSONRPCResponse, + RequestId, ) logger = logging.getLogger(__name__) -# Header names -MCP_SESSION_ID_HEADER = "mcp-session-id" -LAST_EVENT_ID_HEADER = "last-event-id" -# Content types -CONTENT_TYPE_JSON = "application/json" -CONTENT_TYPE_SSE = "text/event-stream" +SessionMessageOrError = SessionMessage | Exception +StreamWriter = MemoryObjectSendStream[SessionMessageOrError] +StreamReader = MemoryObjectReceiveStream[SessionMessage] +GetSessionIdCallback = Callable[[], str | None] +MCP_SESSION_ID = "mcp-session-id" +LAST_EVENT_ID = "last-event-id" +CONTENT_TYPE = "content-type" +ACCEPT = "Accept" -@asynccontextmanager -async def streamablehttp_client( - url: str, - headers: dict[str, Any] | None = None, - timeout: timedelta = timedelta(seconds=30), - sse_read_timeout: timedelta = timedelta(seconds=60 * 5), -): - """ - Client transport for StreamableHTTP. - `sse_read_timeout` determines how long (in seconds) the client will wait for a new - event before disconnecting. All other HTTP operations are controlled by `timeout`. +JSON = "application/json" +SSE = "text/event-stream" - Yields: - Tuple of (read_stream, write_stream, terminate_callback) - """ - read_stream_writer, read_stream = anyio.create_memory_object_stream[ - SessionMessage | Exception - ](0) - write_stream, write_stream_reader = anyio.create_memory_object_stream[ - SessionMessage - ](0) +class StreamableHTTPError(Exception): + """Base exception for StreamableHTTP transport errors.""" - async def get_stream(): - """ - Optional GET stream for server-initiated messages + pass + + +class ResumptionError(StreamableHTTPError): + """Raised when resumption request is invalid.""" + + pass + + +@dataclass +class RequestContext: + """Context for a request operation.""" + + client: httpx.AsyncClient + headers: dict[str, str] + session_id: str | None + session_message: SessionMessage + metadata: ClientMessageMetadata | None + read_stream_writer: StreamWriter + sse_read_timeout: timedelta + + +class StreamableHTTPTransport: + """StreamableHTTP client transport implementation.""" + + def __init__( + self, + url: str, + headers: dict[str, Any] | None = None, + timeout: timedelta = timedelta(seconds=30), + sse_read_timeout: timedelta = timedelta(seconds=60 * 5), + ) -> None: + """Initialize the StreamableHTTP transport. + + Args: + url: The endpoint URL. + headers: Optional headers to include in requests. + timeout: HTTP timeout for regular operations. + sse_read_timeout: Timeout for SSE read operations. """ - nonlocal session_id + self.url = url + self.headers = headers or {} + self.timeout = timeout + self.sse_read_timeout = sse_read_timeout + self.session_id: str | None = None + self.request_headers = { + ACCEPT: f"{JSON}, {SSE}", + CONTENT_TYPE: JSON, + **self.headers, + } + + def _update_headers_with_session( + self, base_headers: dict[str, str] + ) -> dict[str, str]: + """Update headers with session ID if available.""" + headers = base_headers.copy() + if self.session_id: + headers[MCP_SESSION_ID] = self.session_id + return headers + + def _is_initialization_request(self, message: JSONRPCMessage) -> bool: + """Check if the message is an initialization request.""" + return ( + isinstance(message.root, JSONRPCRequest) + and message.root.method == "initialize" + ) + + def _is_initialized_notification(self, message: JSONRPCMessage) -> bool: + """Check if the message is an initialized notification.""" + return ( + isinstance(message.root, JSONRPCNotification) + and message.root.method == "notifications/initialized" + ) + + def _maybe_extract_session_id_from_response( + self, + response: httpx.Response, + ) -> None: + """Extract and store session ID from response headers.""" + new_session_id = response.headers.get(MCP_SESSION_ID) + if new_session_id: + self.session_id = new_session_id + logger.info(f"Received session ID: {self.session_id}") + + async def _handle_sse_event( + self, + sse: ServerSentEvent, + read_stream_writer: StreamWriter, + original_request_id: RequestId | None = None, + resumption_callback: Callable[[str], Awaitable[None]] | None = None, + ) -> bool: + """Handle an SSE event, returning True if the response is complete.""" + if sse.event == "message": + try: + message = JSONRPCMessage.model_validate_json(sse.data) + logger.debug(f"SSE message: {message}") + + # If this is a response and we have original_request_id, replace it + if original_request_id is not None and isinstance( + message.root, JSONRPCResponse | JSONRPCError + ): + message.root.id = original_request_id + + session_message = SessionMessage(message) + await read_stream_writer.send(session_message) + + # Call resumption token callback if we have an ID + if sse.id and resumption_callback: + await resumption_callback(sse.id) + + # If this is a response or error return True indicating completion + # Otherwise, return False to continue listening + return isinstance(message.root, JSONRPCResponse | JSONRPCError) + + except Exception as exc: + logger.error(f"Error parsing SSE message: {exc}") + await read_stream_writer.send(exc) + return False + else: + logger.warning(f"Unknown SSE event: {sse.event}") + return False + + async def handle_get_stream( + self, + client: httpx.AsyncClient, + read_stream_writer: StreamWriter, + ) -> None: + """Handle GET stream for server-initiated messages.""" try: - # Only attempt GET if we have a session ID - if not session_id: + if not self.session_id: return - get_headers = request_headers.copy() - get_headers[MCP_SESSION_ID_HEADER] = session_id + headers = self._update_headers_with_session(self.request_headers) async with aconnect_sse( client, "GET", - url, - headers=get_headers, - timeout=httpx.Timeout(timeout.seconds, read=sse_read_timeout.seconds), + self.url, + headers=headers, + timeout=httpx.Timeout( + self.timeout.seconds, read=self.sse_read_timeout.seconds + ), ) as event_source: event_source.response.raise_for_status() logger.debug("GET SSE connection established") async for sse in event_source.aiter_sse(): - if sse.event == "message": - try: - message = JSONRPCMessage.model_validate_json(sse.data) - logger.debug(f"GET message: {message}") - session_message = SessionMessage(message) - await read_stream_writer.send(session_message) - except Exception as exc: - logger.error(f"Error parsing GET message: {exc}") - await read_stream_writer.send(exc) - else: - logger.warning(f"Unknown SSE event from GET: {sse.event}") + await self._handle_sse_event(sse, read_stream_writer) + except Exception as exc: - # GET stream is optional, so don't propagate errors logger.debug(f"GET stream error (non-fatal): {exc}") - async def post_writer(client: httpx.AsyncClient): - nonlocal session_id + async def _handle_resumption_request(self, ctx: RequestContext) -> None: + """Handle a resumption request using GET with SSE.""" + headers = self._update_headers_with_session(ctx.headers) + if ctx.metadata and ctx.metadata.resumption_token: + headers[LAST_EVENT_ID] = ctx.metadata.resumption_token + else: + raise ResumptionError("Resumption request requires a resumption token") + + # Extract original request ID to map responses + original_request_id = None + if isinstance(ctx.session_message.message.root, JSONRPCRequest): + original_request_id = ctx.session_message.message.root.id + + async with aconnect_sse( + ctx.client, + "GET", + self.url, + headers=headers, + timeout=httpx.Timeout( + self.timeout.seconds, read=ctx.sse_read_timeout.seconds + ), + ) as event_source: + event_source.response.raise_for_status() + logger.debug("Resumption GET SSE connection established") + + async for sse in event_source.aiter_sse(): + is_complete = await self._handle_sse_event( + sse, + ctx.read_stream_writer, + original_request_id, + ctx.metadata.on_resumption_token_update if ctx.metadata else None, + ) + if is_complete: + break + + async def _handle_post_request(self, ctx: RequestContext) -> None: + """Handle a POST request with response processing.""" + headers = self._update_headers_with_session(ctx.headers) + message = ctx.session_message.message + is_initialization = self._is_initialization_request(message) + + async with ctx.client.stream( + "POST", + self.url, + json=message.model_dump(by_alias=True, mode="json", exclude_none=True), + headers=headers, + ) as response: + if response.status_code == 202: + logger.debug("Received 202 Accepted") + return + + if response.status_code == 404: + if isinstance(message.root, JSONRPCRequest): + await self._send_session_terminated_error( + ctx.read_stream_writer, + message.root.id, + ) + return + + response.raise_for_status() + if is_initialization: + self._maybe_extract_session_id_from_response(response) + + content_type = response.headers.get(CONTENT_TYPE, "").lower() + + if content_type.startswith(JSON): + await self._handle_json_response(response, ctx.read_stream_writer) + elif content_type.startswith(SSE): + await self._handle_sse_response(response, ctx) + else: + await self._handle_unexpected_content_type( + content_type, + ctx.read_stream_writer, + ) + + async def _handle_json_response( + self, + response: httpx.Response, + read_stream_writer: StreamWriter, + ) -> None: + """Handle JSON response from the server.""" + try: + content = await response.aread() + message = JSONRPCMessage.model_validate_json(content) + session_message = SessionMessage(message) + await read_stream_writer.send(session_message) + except Exception as exc: + logger.error(f"Error parsing JSON response: {exc}") + await read_stream_writer.send(exc) + + async def _handle_sse_response( + self, response: httpx.Response, ctx: RequestContext + ) -> None: + """Handle SSE response from the server.""" + try: + event_source = EventSource(response) + async for sse in event_source.aiter_sse(): + await self._handle_sse_event( + sse, + ctx.read_stream_writer, + resumption_callback=( + ctx.metadata.on_resumption_token_update + if ctx.metadata + else None + ), + ) + except Exception as e: + logger.exception("Error reading SSE stream:") + await ctx.read_stream_writer.send(e) + + async def _handle_unexpected_content_type( + self, + content_type: str, + read_stream_writer: StreamWriter, + ) -> None: + """Handle unexpected content type in response.""" + error_msg = f"Unexpected content type: {content_type}" + logger.error(error_msg) + await read_stream_writer.send(ValueError(error_msg)) + + async def _send_session_terminated_error( + self, + read_stream_writer: StreamWriter, + request_id: RequestId, + ) -> None: + """Send a session terminated error response.""" + jsonrpc_error = JSONRPCError( + jsonrpc="2.0", + id=request_id, + error=ErrorData(code=32600, message="Session terminated"), + ) + session_message = SessionMessage(JSONRPCMessage(jsonrpc_error)) + await read_stream_writer.send(session_message) + + async def post_writer( + self, + client: httpx.AsyncClient, + write_stream_reader: StreamReader, + read_stream_writer: StreamWriter, + write_stream: MemoryObjectSendStream[SessionMessage], + start_get_stream: Callable[[], None], + ) -> None: + """Handle writing requests to the server.""" try: async with write_stream_reader: async for session_message in write_stream_reader: message = session_message.message - # Add session ID to headers if we have one - post_headers = request_headers.copy() - if session_id: - post_headers[MCP_SESSION_ID_HEADER] = session_id + metadata = ( + session_message.metadata + if isinstance(session_message.metadata, ClientMessageMetadata) + else None + ) + + # Check if this is a resumption request + is_resumption = bool(metadata and metadata.resumption_token) logger.debug(f"Sending client message: {message}") - # Handle initial initialization request - is_initialization = ( - isinstance(message.root, JSONRPCRequest) - and message.root.method == "initialize" + # Handle initialized notification + if self._is_initialized_notification(message): + start_get_stream() + + ctx = RequestContext( + client=client, + headers=self.request_headers, + session_id=self.session_id, + session_message=session_message, + metadata=metadata, + read_stream_writer=read_stream_writer, + sse_read_timeout=self.sse_read_timeout, ) - if ( - isinstance(message.root, JSONRPCNotification) - and message.root.method == "notifications/initialized" - ): - tg.start_soon(get_stream) - - async with client.stream( - "POST", - url, - json=message.model_dump( - by_alias=True, mode="json", exclude_none=True - ), - headers=post_headers, - ) as response: - if response.status_code == 202: - logger.debug("Received 202 Accepted") - continue - # Check for 404 (session expired/invalid) - if response.status_code == 404: - if isinstance(message.root, JSONRPCRequest): - jsonrpc_error = JSONRPCError( - jsonrpc="2.0", - id=message.root.id, - error=ErrorData( - code=32600, - message="Session terminated", - ), - ) - session_message = SessionMessage( - JSONRPCMessage(jsonrpc_error) - ) - await read_stream_writer.send(session_message) - continue - response.raise_for_status() - - # Extract session ID from response headers - if is_initialization: - new_session_id = response.headers.get(MCP_SESSION_ID_HEADER) - if new_session_id: - session_id = new_session_id - logger.info(f"Received session ID: {session_id}") - - # Handle different response types - content_type = response.headers.get("content-type", "").lower() - - if content_type.startswith(CONTENT_TYPE_JSON): - try: - content = await response.aread() - json_message = JSONRPCMessage.model_validate_json( - content - ) - session_message = SessionMessage(json_message) - await read_stream_writer.send(session_message) - except Exception as exc: - logger.error(f"Error parsing JSON response: {exc}") - await read_stream_writer.send(exc) - - elif content_type.startswith(CONTENT_TYPE_SSE): - # Parse SSE events from the response - try: - event_source = EventSource(response) - async for sse in event_source.aiter_sse(): - if sse.event == "message": - try: - message = ( - JSONRPCMessage.model_validate_json( - sse.data - ) - ) - session_message = SessionMessage(message) - await read_stream_writer.send( - session_message - ) - except Exception as exc: - logger.exception("Error parsing message") - await read_stream_writer.send(exc) - else: - logger.warning(f"Unknown event: {sse.event}") - - except Exception as e: - logger.exception("Error reading SSE stream:") - await read_stream_writer.send(e) - - else: - # For 202 Accepted with no body - if response.status_code == 202: - logger.debug("Received 202 Accepted") - continue - - error_msg = f"Unexpected content type: {content_type}" - logger.error(error_msg) - await read_stream_writer.send(ValueError(error_msg)) + + if is_resumption: + await self._handle_resumption_request(ctx) + else: + await self._handle_post_request(ctx) except Exception as exc: logger.error(f"Error in post_writer: {exc}") @@ -215,52 +385,98 @@ async def post_writer(client: httpx.AsyncClient): await read_stream_writer.aclose() await write_stream.aclose() - async def terminate_session(): - """ - Terminate the session by sending a DELETE request. - """ - nonlocal session_id - if not session_id: - return # No session to terminate + async def terminate_session(self, client: httpx.AsyncClient) -> None: + """Terminate the session by sending a DELETE request.""" + if not self.session_id: + return try: - delete_headers = request_headers.copy() - delete_headers[MCP_SESSION_ID_HEADER] = session_id - - response = await client.delete( - url, - headers=delete_headers, - ) + headers = self._update_headers_with_session(self.request_headers) + response = await client.delete(self.url, headers=headers) if response.status_code == 405: - # Server doesn't allow client-initiated termination logger.debug("Server does not allow session termination") elif response.status_code != 200: logger.warning(f"Session termination failed: {response.status_code}") except Exception as exc: logger.warning(f"Session termination failed: {exc}") + def get_session_id(self) -> str | None: + """Get the current session ID.""" + return self.session_id + + +@asynccontextmanager +async def streamablehttp_client( + url: str, + headers: dict[str, Any] | None = None, + timeout: timedelta = timedelta(seconds=30), + sse_read_timeout: timedelta = timedelta(seconds=60 * 5), + terminate_on_close: bool = True, +) -> AsyncGenerator[ + tuple[ + MemoryObjectReceiveStream[SessionMessage | Exception], + MemoryObjectSendStream[SessionMessage], + GetSessionIdCallback, + ], + None, +]: + """ + Client transport for StreamableHTTP. + + `sse_read_timeout` determines how long (in seconds) the client will wait for a new + event before disconnecting. All other HTTP operations are controlled by `timeout`. + + Yields: + Tuple containing: + - read_stream: Stream for reading messages from the server + - write_stream: Stream for sending messages to the server + - get_session_id_callback: Function to retrieve the current session ID + """ + transport = StreamableHTTPTransport(url, headers, timeout, sse_read_timeout) + + read_stream_writer, read_stream = anyio.create_memory_object_stream[ + SessionMessage | Exception + ](0) + write_stream, write_stream_reader = anyio.create_memory_object_stream[ + SessionMessage + ](0) + async with anyio.create_task_group() as tg: try: logger.info(f"Connecting to StreamableHTTP endpoint: {url}") - # Set up headers with required Accept header - request_headers = { - "Accept": f"{CONTENT_TYPE_JSON}, {CONTENT_TYPE_SSE}", - "Content-Type": CONTENT_TYPE_JSON, - **(headers or {}), - } - # Track session ID if provided by server - session_id: str | None = None async with httpx.AsyncClient( - headers=request_headers, - timeout=httpx.Timeout(timeout.seconds, read=sse_read_timeout.seconds), + headers=transport.request_headers, + timeout=httpx.Timeout( + transport.timeout.seconds, read=transport.sse_read_timeout.seconds + ), follow_redirects=True, ) as client: - tg.start_soon(post_writer, client) + # Define callbacks that need access to tg + def start_get_stream() -> None: + tg.start_soon( + transport.handle_get_stream, client, read_stream_writer + ) + + tg.start_soon( + transport.post_writer, + client, + write_stream_reader, + read_stream_writer, + write_stream, + start_get_stream, + ) + try: - yield read_stream, write_stream, terminate_session + yield ( + read_stream, + write_stream, + transport.get_session_id, + ) finally: + if transport.session_id and terminate_on_close: + await transport.terminate_session(client) tg.cancel_scope.cancel() finally: await read_stream_writer.aclose() diff --git a/src/mcp/shared/message.py b/src/mcp/shared/message.py index c9341c364..5583f4795 100644 --- a/src/mcp/shared/message.py +++ b/src/mcp/shared/message.py @@ -5,16 +5,24 @@ to support transport-specific features like resumability. """ +from collections.abc import Awaitable, Callable from dataclasses import dataclass from mcp.types import JSONRPCMessage, RequestId +ResumptionToken = str + +ResumptionTokenUpdateCallback = Callable[[ResumptionToken], Awaitable[None]] + @dataclass class ClientMessageMetadata: """Metadata specific to client messages.""" - resumption_token: str | None = None + resumption_token: ResumptionToken | None = None + on_resumption_token_update: Callable[[ResumptionToken], Awaitable[None]] | None = ( + None + ) @dataclass diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index d74c4d066..cce8b1184 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -12,7 +12,7 @@ from typing_extensions import Self from mcp.shared.exceptions import McpError -from mcp.shared.message import ServerMessageMetadata, SessionMessage +from mcp.shared.message import MessageMetadata, ServerMessageMetadata, SessionMessage from mcp.types import ( CancelledNotification, ClientNotification, @@ -213,6 +213,7 @@ async def send_request( request: SendRequestT, result_type: type[ReceiveResultT], request_read_timeout_seconds: timedelta | None = None, + metadata: MessageMetadata = None, ) -> ReceiveResultT: """ Sends a request and wait for a response. Raises an McpError if the @@ -241,7 +242,9 @@ async def send_request( # TODO: Support progress callbacks await self._write_stream.send( - SessionMessage(message=JSONRPCMessage(jsonrpc_request)) + SessionMessage( + message=JSONRPCMessage(jsonrpc_request), metadata=metadata + ) ) # request read timeout takes precedence over session read timeout diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 7331b392b..f64360229 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -23,16 +23,31 @@ from starlette.responses import Response from starlette.routing import Mount +import mcp.types as types from mcp.client.session import ClientSession from mcp.client.streamable_http import streamablehttp_client from mcp.server import Server from mcp.server.streamable_http import ( MCP_SESSION_ID_HEADER, SESSION_ID_PATTERN, + EventCallback, + EventId, + EventMessage, + EventStore, StreamableHTTPServerTransport, + StreamId, ) from mcp.shared.exceptions import McpError -from mcp.types import InitializeResult, TextContent, TextResourceContents, Tool +from mcp.shared.message import ( + ClientMessageMetadata, +) +from mcp.shared.session import RequestResponder +from mcp.types import ( + InitializeResult, + TextContent, + TextResourceContents, + Tool, +) # Test constants SERVER_NAME = "test_streamable_http_server" @@ -49,6 +64,51 @@ } +# Simple in-memory event store for testing +class SimpleEventStore(EventStore): + """Simple in-memory event store for testing.""" + + def __init__(self): + self._events: list[tuple[StreamId, EventId, types.JSONRPCMessage]] = [] + self._event_id_counter = 0 + + async def store_event( + self, stream_id: StreamId, message: types.JSONRPCMessage + ) -> EventId: + """Store an event and return its ID.""" + self._event_id_counter += 1 + event_id = str(self._event_id_counter) + self._events.append((stream_id, event_id, message)) + return event_id + + async def replay_events_after( + self, + last_event_id: EventId, + send_callback: EventCallback, + ) -> StreamId | None: + """Replay events after the specified ID.""" + # Find the index of the last event ID + start_index = None + for i, (_, event_id, _) in enumerate(self._events): + if event_id == last_event_id: + start_index = i + 1 + break + + if start_index is None: + # If event ID not found, start from beginning + start_index = 0 + + stream_id = None + # Replay events + for _, event_id, message in self._events[start_index:]: + await send_callback(EventMessage(message, event_id)) + # Capture the stream ID from the first replayed event + if stream_id is None and len(self._events) > start_index: + stream_id = self._events[start_index][0] + + return stream_id + + # Test server implementation that follows MCP protocol class ServerTest(Server): def __init__(self): @@ -78,25 +138,57 @@ async def handle_list_tools() -> list[Tool]: description="A test tool that sends a notification", inputSchema={"type": "object", "properties": {}}, ), + Tool( + name="long_running_with_checkpoints", + description="A long-running tool that sends periodic notifications", + inputSchema={"type": "object", "properties": {}}, + ), ] @self.call_tool() async def handle_call_tool(name: str, args: dict) -> list[TextContent]: + ctx = self.request_context + # When the tool is called, send a notification to test GET stream if name == "test_tool_with_standalone_notification": - ctx = self.request_context await ctx.session.send_resource_updated( uri=AnyUrl("http://test_resource") ) + return [TextContent(type="text", text=f"Called {name}")] + + elif name == "long_running_with_checkpoints": + # Send notifications that are part of the response stream + # This simulates a long-running tool that sends logs + + await ctx.session.send_log_message( + level="info", + data="Tool started", + logger="tool", + related_request_id=ctx.request_id, # need for stream association + ) + + await anyio.sleep(0.1) + + await ctx.session.send_log_message( + level="info", + data="Tool is almost done", + logger="tool", + related_request_id=ctx.request_id, + ) + + return [TextContent(type="text", text="Completed!")] return [TextContent(type="text", text=f"Called {name}")] -def create_app(is_json_response_enabled=False) -> Starlette: +def create_app( + is_json_response_enabled=False, event_store: EventStore | None = None +) -> Starlette: """Create a Starlette application for testing that matches the example server. Args: is_json_response_enabled: If True, use JSON responses instead of SSE streams. + event_store: Optional event store for testing resumability. """ # Create server instance server = ServerTest() @@ -139,6 +231,7 @@ async def handle_streamable_http(scope, receive, send): http_transport = StreamableHTTPServerTransport( mcp_session_id=new_session_id, is_json_response_enabled=is_json_response_enabled, + event_store=event_store, ) async with http_transport.connect() as streams: @@ -183,15 +276,18 @@ async def run_server(): return app -def run_server(port: int, is_json_response_enabled=False) -> None: +def run_server( + port: int, is_json_response_enabled=False, event_store: EventStore | None = None +) -> None: """Run the test server. Args: port: Port to listen on. is_json_response_enabled: If True, use JSON responses instead of SSE streams. + event_store: Optional event store for testing resumability. """ - app = create_app(is_json_response_enabled) + app = create_app(is_json_response_enabled, event_store) # Configure server config = uvicorn.Config( app=app, @@ -261,6 +357,53 @@ def basic_server(basic_server_port: int) -> Generator[None, None, None]: proc.join(timeout=2) +@pytest.fixture +def event_store() -> SimpleEventStore: + """Create a test event store.""" + return SimpleEventStore() + + +@pytest.fixture +def event_server_port() -> int: + """Find an available port for the event store server.""" + with socket.socket() as s: + s.bind(("127.0.0.1", 0)) + return s.getsockname()[1] + + +@pytest.fixture +def event_server( + event_server_port: int, event_store: SimpleEventStore +) -> Generator[tuple[SimpleEventStore, str], None, None]: + """Start a server with event store enabled.""" + proc = multiprocessing.Process( + target=run_server, + kwargs={"port": event_server_port, "event_store": event_store}, + daemon=True, + ) + proc.start() + + # Wait for server to be running + max_attempts = 20 + attempt = 0 + while attempt < max_attempts: + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.connect(("127.0.0.1", event_server_port)) + break + except ConnectionRefusedError: + time.sleep(0.1) + attempt += 1 + else: + raise RuntimeError(f"Server failed to start after {max_attempts} attempts") + + yield event_store, f"http://127.0.0.1:{event_server_port}" + + # Clean up + proc.kill() + proc.join(timeout=2) + + @pytest.fixture def json_response_server(json_server_port: int) -> Generator[None, None, None]: """Start a server with JSON response enabled.""" @@ -679,7 +822,7 @@ async def test_streamablehttp_client_tool_invocation(initialized_client_session) """Test client tool invocation.""" # First list tools tools = await initialized_client_session.list_tools() - assert len(tools.tools) == 2 + assert len(tools.tools) == 3 assert tools.tools[0].name == "test_tool" # Call the tool @@ -720,7 +863,7 @@ async def test_streamablehttp_client_session_persistence( # Make multiple requests to verify session persistence tools = await session.list_tools() - assert len(tools.tools) == 2 + assert len(tools.tools) == 3 # Read a resource resource = await session.read_resource(uri=AnyUrl("foobar://test-persist")) @@ -751,7 +894,7 @@ async def test_streamablehttp_client_json_response( # Check tool listing tools = await session.list_tools() - assert len(tools.tools) == 2 + assert len(tools.tools) == 3 # Call a tool and verify JSON response handling result = await session.call_tool("test_tool", {}) @@ -813,25 +956,169 @@ async def test_streamablehttp_client_session_termination( ): """Test client session termination functionality.""" + captured_session_id = None + # Create the streamablehttp_client with a custom httpx client to capture headers async with streamablehttp_client(f"{basic_server_url}/mcp") as ( read_stream, write_stream, - terminate_session, + get_session_id, ): async with ClientSession(read_stream, write_stream) as session: # Initialize the session result = await session.initialize() assert isinstance(result, InitializeResult) + captured_session_id = get_session_id() + assert captured_session_id is not None # Make a request to confirm session is working tools = await session.list_tools() - assert len(tools.tools) == 2 + assert len(tools.tools) == 3 + + headers = {} + if captured_session_id: + headers[MCP_SESSION_ID_HEADER] = captured_session_id - # After exiting ClientSession context, explicitly terminate the session - await terminate_session() + async with streamablehttp_client(f"{basic_server_url}/mcp", headers=headers) as ( + read_stream, + write_stream, + _, + ): + async with ClientSession(read_stream, write_stream) as session: + # Attempt to make a request after termination with pytest.raises( McpError, match="Session terminated", ): await session.list_tools() + + +@pytest.mark.anyio +async def test_streamablehttp_client_resumption(event_server): + """Test client session to resume a long running tool.""" + _, server_url = event_server + + # Variables to track the state + captured_session_id = None + captured_resumption_token = None + captured_notifications = [] + tool_started = False + + async def message_handler( + message: RequestResponder[types.ServerRequest, types.ClientResult] + | types.ServerNotification + | Exception, + ) -> None: + if isinstance(message, types.ServerNotification): + captured_notifications.append(message) + # Look for our special notification that indicates the tool is running + if isinstance(message.root, types.LoggingMessageNotification): + if message.root.params.data == "Tool started": + nonlocal tool_started + tool_started = True + + async def on_resumption_token_update(token: str) -> None: + nonlocal captured_resumption_token + captured_resumption_token = token + + # First, start the client session and begin the long-running tool + async with streamablehttp_client(f"{server_url}/mcp", terminate_on_close=False) as ( + read_stream, + write_stream, + get_session_id, + ): + async with ClientSession( + read_stream, write_stream, message_handler=message_handler + ) as session: + # Initialize the session + result = await session.initialize() + assert isinstance(result, InitializeResult) + captured_session_id = get_session_id() + assert captured_session_id is not None + + # Start a long-running tool in a task + async with anyio.create_task_group() as tg: + + async def run_tool(): + metadata = ClientMessageMetadata( + on_resumption_token_update=on_resumption_token_update, + ) + await session.send_request( + types.ClientRequest( + types.CallToolRequest( + method="tools/call", + params=types.CallToolRequestParams( + name="long_running_with_checkpoints", arguments={} + ), + ) + ), + types.CallToolResult, + metadata=metadata, + ) + + tg.start_soon(run_tool) + + # Wait for the tool to start and at least one notification + # and then kill the task group + while not tool_started or not captured_resumption_token: + await anyio.sleep(0.1) + tg.cancel_scope.cancel() + + # Store pre notifications and clear the captured notifications + # for the post-resumption check + captured_notifications_pre = captured_notifications.copy() + captured_notifications = [] + + # Now resume the session with the same mcp-session-id + headers = {} + if captured_session_id: + headers[MCP_SESSION_ID_HEADER] = captured_session_id + + async with streamablehttp_client(f"{server_url}/mcp", headers=headers) as ( + read_stream, + write_stream, + _, + ): + async with ClientSession( + read_stream, write_stream, message_handler=message_handler + ) as session: + # Don't initialize - just use the existing session + + # Resume the tool with the resumption token + assert captured_resumption_token is not None + + metadata = ClientMessageMetadata( + resumption_token=captured_resumption_token, + ) + result = await session.send_request( + types.ClientRequest( + types.CallToolRequest( + method="tools/call", + params=types.CallToolRequestParams( + name="long_running_with_checkpoints", arguments={} + ), + ) + ), + types.CallToolResult, + metadata=metadata, + ) + + # We should get a complete result + assert len(result.content) == 1 + assert result.content[0].type == "text" + assert "Completed" in result.content[0].text + + # We should have received the remaining notifications + assert len(captured_notifications) > 0 + + # Should not have the first notification + # Check that "Tool started" notification isn't repeated when resuming + assert not any( + isinstance(n.root, types.LoggingMessageNotification) + and n.root.params.data == "Tool started" + for n in captured_notifications + ) + # there is no intersection between pre and post notifications + assert not any( + n in captured_notifications_pre for n in captured_notifications + )