Skip to content

fix: Clean up sessions from manager when terminated via DELETE request #791

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions src/mcp/server/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ def __init__(
mcp_session_id: str | None,
is_json_response_enabled: bool = False,
event_store: EventStore | None = None,
on_session_terminated: Callable[[str], Awaitable[None]] | None = None,
) -> None:
"""
Initialize a new StreamableHTTP server transport.
Expand All @@ -149,6 +150,9 @@ def __init__(
event_store: Event store for resumability support. If provided,
resumability will be enabled, allowing clients to
reconnect and resume messages.
on_session_terminated: Optional callback to notify when session is
terminated. Called with the session ID when DELETE
request terminates session.

Raises:
ValueError: If the session ID contains invalid characters.
Expand All @@ -163,6 +167,7 @@ def __init__(
self.mcp_session_id = mcp_session_id
self.is_json_response_enabled = is_json_response_enabled
self._event_store = event_store
self._on_session_terminated = on_session_terminated
self._request_streams: dict[
RequestId,
tuple[
Expand Down Expand Up @@ -660,6 +665,13 @@ async def _terminate_session(self) -> None:
self._terminated = True
logger.info(f"Terminating session: {self.mcp_session_id}")

# Notify the session manager about termination
if self._on_session_terminated and self.mcp_session_id:
try:
await self._on_session_terminated(self.mcp_session_id)
except Exception as e:
logger.warning(f"Error in session termination callback: {e}")

# We need a copy of the keys to avoid modification during iteration
request_stream_keys = list(self._request_streams.keys())

Expand Down
14 changes: 11 additions & 3 deletions src/mcp/server/streamable_http_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,13 @@ async def lifespan(app: Starlette) -> AsyncIterator[None]:
# Clear any remaining server instances
self._server_instances.clear()

async def _on_session_terminated(self, session_id: str) -> None:
"""Callback to clean up terminated sessions from the manager."""
async with self._session_creation_lock:
if session_id in self._server_instances:
logger.info(f"Removing terminated session from manager: {session_id}")
del self._server_instances[session_id]

async def handle_request(
self,
scope: Scope,
Expand Down Expand Up @@ -222,6 +229,7 @@ async def _handle_stateful_request(
mcp_session_id=new_session_id,
is_json_response_enabled=self.json_response,
event_store=self.event_store, # May be None (no resumability)
on_session_terminated=self._on_session_terminated,
)

assert http_transport.mcp_session_id is not None
Expand Down Expand Up @@ -250,9 +258,9 @@ async def run_server(
# Handle the HTTP request and return the response
await http_transport.handle_request(scope, receive, send)
else:
# Invalid session ID
# Invalid session ID
response = Response(
"Bad Request: No valid session ID provided",
status_code=HTTPStatus.BAD_REQUEST,
"Not Found: Session has been terminated",
status_code=HTTPStatus.NOT_FOUND,
)
await response(scope, receive, send)
99 changes: 98 additions & 1 deletion tests/server/test_streamable_http_manager.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Tests for StreamableHTTPSessionManager."""

import json

import anyio
import pytest

Expand Down Expand Up @@ -70,7 +72,7 @@ async def receive():
return {"type": "http.request", "body": b""}

async def send(message):
pass
del message # Suppress unused parameter warning

# Should raise error because run() hasn't been called
with pytest.raises(RuntimeError) as excinfo:
Expand All @@ -79,3 +81,98 @@ async def send(message):
assert "Task group is not initialized. Make sure to use run()." in str(
excinfo.value
)


@pytest.mark.anyio
async def test_session_cleanup_on_delete_request():
"""Test sessions are properly cleaned up when DELETE request terminates them."""
app = Server("test-server")
manager = StreamableHTTPSessionManager(app=app, json_response=True, stateless=False)

async with manager.run():
# Create a new session by making a POST request
session_id = None

# Mock ASGI parameters for POST request (session creation)
post_scope = {
"type": "http",
"method": "POST",
"path": "/test",
"headers": [
(b"content-type", b"application/json"),
(b"accept", b"application/json, text/event-stream"),
],
}

# Mock initialization request
init_request = {
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": {
"protocolVersion": "2024-11-05",
"capabilities": {},
"clientInfo": {"name": "test-client", "version": "1.0.0"},
},
}

post_body = json.dumps(init_request).encode()
post_request_body_sent = False

async def post_receive():
nonlocal post_request_body_sent
if not post_request_body_sent:
post_request_body_sent = True
return {"type": "http.request", "body": post_body}
else:
return {"type": "http.request", "body": b""}

response_data = {}

async def post_send(message):
if message["type"] == "http.response.start":
response_data["status"] = message["status"]
response_data["headers"] = dict(message.get("headers", []))
elif message["type"] == "http.response.body":
response_data["body"] = message.get("body", b"")

# Make POST request to create session
await manager.handle_request(post_scope, post_receive, post_send)

# Extract session ID from response headers
session_id = response_data["headers"].get(b"mcp-session-id")
if session_id:
session_id = session_id.decode()

# Verify session was created
assert session_id is not None
assert session_id in manager._server_instances

# Now make DELETE request to terminate session
delete_scope = {
"type": "http",
"method": "DELETE",
"path": "/test",
"headers": [(b"mcp-session-id", session_id.encode())],
}

async def delete_receive():
return {"type": "http.request", "body": b""}

delete_response_data = {}

async def delete_send(message):
if message["type"] == "http.response.start":
delete_response_data["status"] = message["status"]

# Make DELETE request
await manager.handle_request(delete_scope, delete_receive, delete_send)

# Verify DELETE request succeeded
assert delete_response_data["status"] == 200

# Give a moment for the callback to execute
await anyio.sleep(0.01)

# Verify session was removed from manager
assert session_id not in manager._server_instances
Loading