diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index 8f4a1f512..231f2cf9b 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -137,6 +137,7 @@ def __init__( mcp_session_id: str | None, is_json_response_enabled: bool = False, event_store: EventStore | None = None, + on_session_terminated: Callable[[str], Awaitable[None]] | None = None, ) -> None: """ Initialize a new StreamableHTTP server transport. @@ -149,6 +150,9 @@ def __init__( event_store: Event store for resumability support. If provided, resumability will be enabled, allowing clients to reconnect and resume messages. + on_session_terminated: Optional callback to notify when session is + terminated. Called with the session ID when DELETE + request terminates session. Raises: ValueError: If the session ID contains invalid characters. @@ -163,6 +167,7 @@ def __init__( self.mcp_session_id = mcp_session_id self.is_json_response_enabled = is_json_response_enabled self._event_store = event_store + self._on_session_terminated = on_session_terminated self._request_streams: dict[ RequestId, tuple[ @@ -660,6 +665,13 @@ async def _terminate_session(self) -> None: self._terminated = True logger.info(f"Terminating session: {self.mcp_session_id}") + # Notify the session manager about termination + if self._on_session_terminated and self.mcp_session_id: + try: + await self._on_session_terminated(self.mcp_session_id) + except Exception as e: + logger.warning(f"Error in session termination callback: {e}") + # We need a copy of the keys to avoid modification during iteration request_stream_keys = list(self._request_streams.keys()) diff --git a/src/mcp/server/streamable_http_manager.py b/src/mcp/server/streamable_http_manager.py index e5ef8b4aa..f6abdee61 100644 --- a/src/mcp/server/streamable_http_manager.py +++ b/src/mcp/server/streamable_http_manager.py @@ -117,6 +117,13 @@ async def lifespan(app: Starlette) -> AsyncIterator[None]: # Clear any remaining server instances self._server_instances.clear() + async def _on_session_terminated(self, session_id: str) -> None: + """Callback to clean up terminated sessions from the manager.""" + async with self._session_creation_lock: + if session_id in self._server_instances: + logger.info(f"Removing terminated session from manager: {session_id}") + del self._server_instances[session_id] + async def handle_request( self, scope: Scope, @@ -222,6 +229,7 @@ async def _handle_stateful_request( mcp_session_id=new_session_id, is_json_response_enabled=self.json_response, event_store=self.event_store, # May be None (no resumability) + on_session_terminated=self._on_session_terminated, ) assert http_transport.mcp_session_id is not None @@ -250,9 +258,9 @@ async def run_server( # Handle the HTTP request and return the response await http_transport.handle_request(scope, receive, send) else: - # Invalid session ID + # Invalid session ID response = Response( - "Bad Request: No valid session ID provided", - status_code=HTTPStatus.BAD_REQUEST, + "Not Found: Session has been terminated", + status_code=HTTPStatus.NOT_FOUND, ) await response(scope, receive, send) diff --git a/tests/server/test_streamable_http_manager.py b/tests/server/test_streamable_http_manager.py index 32782e458..bd938b336 100644 --- a/tests/server/test_streamable_http_manager.py +++ b/tests/server/test_streamable_http_manager.py @@ -1,5 +1,7 @@ """Tests for StreamableHTTPSessionManager.""" +import json + import anyio import pytest @@ -70,7 +72,7 @@ async def receive(): return {"type": "http.request", "body": b""} async def send(message): - pass + del message # Suppress unused parameter warning # Should raise error because run() hasn't been called with pytest.raises(RuntimeError) as excinfo: @@ -79,3 +81,98 @@ async def send(message): assert "Task group is not initialized. Make sure to use run()." in str( excinfo.value ) + + +@pytest.mark.anyio +async def test_session_cleanup_on_delete_request(): + """Test sessions are properly cleaned up when DELETE request terminates them.""" + app = Server("test-server") + manager = StreamableHTTPSessionManager(app=app, json_response=True, stateless=False) + + async with manager.run(): + # Create a new session by making a POST request + session_id = None + + # Mock ASGI parameters for POST request (session creation) + post_scope = { + "type": "http", + "method": "POST", + "path": "/test", + "headers": [ + (b"content-type", b"application/json"), + (b"accept", b"application/json, text/event-stream"), + ], + } + + # Mock initialization request + init_request = { + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": { + "protocolVersion": "2024-11-05", + "capabilities": {}, + "clientInfo": {"name": "test-client", "version": "1.0.0"}, + }, + } + + post_body = json.dumps(init_request).encode() + post_request_body_sent = False + + async def post_receive(): + nonlocal post_request_body_sent + if not post_request_body_sent: + post_request_body_sent = True + return {"type": "http.request", "body": post_body} + else: + return {"type": "http.request", "body": b""} + + response_data = {} + + async def post_send(message): + if message["type"] == "http.response.start": + response_data["status"] = message["status"] + response_data["headers"] = dict(message.get("headers", [])) + elif message["type"] == "http.response.body": + response_data["body"] = message.get("body", b"") + + # Make POST request to create session + await manager.handle_request(post_scope, post_receive, post_send) + + # Extract session ID from response headers + session_id = response_data["headers"].get(b"mcp-session-id") + if session_id: + session_id = session_id.decode() + + # Verify session was created + assert session_id is not None + assert session_id in manager._server_instances + + # Now make DELETE request to terminate session + delete_scope = { + "type": "http", + "method": "DELETE", + "path": "/test", + "headers": [(b"mcp-session-id", session_id.encode())], + } + + async def delete_receive(): + return {"type": "http.request", "body": b""} + + delete_response_data = {} + + async def delete_send(message): + if message["type"] == "http.response.start": + delete_response_data["status"] = message["status"] + + # Make DELETE request + await manager.handle_request(delete_scope, delete_receive, delete_send) + + # Verify DELETE request succeeded + assert delete_response_data["status"] == 200 + + # Give a moment for the callback to execute + await anyio.sleep(0.01) + + # Verify session was removed from manager + assert session_id not in manager._server_instances