diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index 8f4a1f512..a94cc2834 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -397,7 +397,8 @@ async def _handle_post_request( await response(scope, receive, send) # Process the message after sending the response - session_message = SessionMessage(message) + metadata = ServerMessageMetadata(request_context=request) + session_message = SessionMessage(message, metadata=metadata) await writer.send(session_message) return @@ -412,7 +413,8 @@ async def _handle_post_request( if self.is_json_response_enabled: # Process the message - session_message = SessionMessage(message) + metadata = ServerMessageMetadata(request_context=request) + session_message = SessionMessage(message, metadata=metadata) await writer.send(session_message) try: # Process messages from the request-specific stream @@ -511,7 +513,8 @@ async def sse_writer(): async with anyio.create_task_group() as tg: tg.start_soon(response, scope, receive, send) # Then send the message to be processed by the server - session_message = SessionMessage(message) + metadata = ServerMessageMetadata(request_context=request) + session_message = SessionMessage(message, metadata=metadata) await writer.send(session_message) except Exception: logger.exception("SSE response error") diff --git a/tests/server/fastmcp/test_integration.py b/tests/server/fastmcp/test_integration.py index 761e810bc..121492bc6 100644 --- a/tests/server/fastmcp/test_integration.py +++ b/tests/server/fastmcp/test_integration.py @@ -24,7 +24,6 @@ from mcp.client.streamable_http import streamablehttp_client from mcp.server.fastmcp import FastMCP from mcp.server.fastmcp.resources import FunctionResource -from mcp.server.fastmcp.server import Context from mcp.shared.context import RequestContext from mcp.types import ( CreateMessageRequestParams, @@ -196,6 +195,33 @@ def complex_prompt(user_query: str, context: str = "general") -> str: # Since FastMCP doesn't support system messages in the same way return f"Context: {context}. Query: {user_query}" + # Tool that echoes request headers from context + @mcp.tool(description="Echo request headers from context") + def echo_headers(ctx: Context[Any, Any, Request]) -> str: + """Returns the request headers as JSON.""" + headers_info = {} + if ctx.request_context.request: + # Now the type system knows request is a Starlette Request object + headers_info = dict(ctx.request_context.request.headers) + return json.dumps(headers_info) + + # Tool that returns full request context + @mcp.tool(description="Echo request context with custom data") + def echo_context(custom_request_id: str, ctx: Context[Any, Any, Request]) -> str: + """Returns request context including headers and custom data.""" + context_data = { + "custom_request_id": custom_request_id, + "headers": {}, + "method": None, + "path": None, + } + if ctx.request_context.request: + request = ctx.request_context.request + context_data["headers"] = dict(request.headers) + context_data["method"] = request.method + context_data["path"] = request.url.path + return json.dumps(context_data) + return mcp @@ -432,174 +458,6 @@ async def test_fastmcp_without_auth(server: None, server_url: str) -> None: assert tool_result.content[0].text == "Echo: hello" -def make_fastmcp_with_context_app(): - """Create a FastMCP server that can access request context.""" - - mcp = FastMCP(name="ContextServer") - - # Tool that echoes request headers - @mcp.tool(description="Echo request headers from context") - def echo_headers(ctx: Context[Any, Any, Request]) -> str: - """Returns the request headers as JSON.""" - headers_info = {} - if ctx.request_context.request: - # Now the type system knows request is a Starlette Request object - headers_info = dict(ctx.request_context.request.headers) - return json.dumps(headers_info) - - # Tool that returns full request context - @mcp.tool(description="Echo request context with custom data") - def echo_context(custom_request_id: str, ctx: Context[Any, Any, Request]) -> str: - """Returns request context including headers and custom data.""" - context_data = { - "custom_request_id": custom_request_id, - "headers": {}, - "method": None, - "path": None, - } - if ctx.request_context.request: - request = ctx.request_context.request - context_data["headers"] = dict(request.headers) - context_data["method"] = request.method - context_data["path"] = request.url.path - return json.dumps(context_data) - - # Create the SSE app - app = mcp.sse_app() - return mcp, app - - -def run_context_server(server_port: int) -> None: - """Run the context-aware FastMCP server.""" - _, app = make_fastmcp_with_context_app() - server = uvicorn.Server( - config=uvicorn.Config( - app=app, host="127.0.0.1", port=server_port, log_level="error" - ) - ) - print(f"Starting context server on port {server_port}") - server.run() - - -@pytest.fixture() -def context_aware_server(server_port: int) -> Generator[None, None, None]: - """Start the context-aware server in a separate process.""" - proc = multiprocessing.Process( - target=run_context_server, args=(server_port,), daemon=True - ) - print("Starting context-aware server process") - proc.start() - - # Wait for server to be running - max_attempts = 20 - attempt = 0 - print("Waiting for context-aware server to start") - while attempt < max_attempts: - try: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.connect(("127.0.0.1", server_port)) - break - except ConnectionRefusedError: - time.sleep(0.1) - attempt += 1 - else: - raise RuntimeError( - f"Context server failed to start after {max_attempts} attempts" - ) - - yield - - print("Killing context-aware server") - proc.kill() - proc.join(timeout=2) - if proc.is_alive(): - print("Context server process failed to terminate") - - -@pytest.mark.anyio -async def test_fast_mcp_with_request_context( - context_aware_server: None, server_url: str -) -> None: - """Test that FastMCP properly propagates request context to tools.""" - # Test with custom headers - custom_headers = { - "Authorization": "Bearer fastmcp-test-token", - "X-Custom-Header": "fastmcp-value", - "X-Request-Id": "req-123", - } - - async with sse_client(server_url + "/sse", headers=custom_headers) as streams: - async with ClientSession(*streams) as session: - # Initialize the session - result = await session.initialize() - assert isinstance(result, InitializeResult) - assert result.serverInfo.name == "ContextServer" - - # Test 1: Call tool that echoes headers - headers_result = await session.call_tool("echo_headers", {}) - assert len(headers_result.content) == 1 - assert isinstance(headers_result.content[0], TextContent) - - headers_data = json.loads(headers_result.content[0].text) - assert headers_data.get("authorization") == "Bearer fastmcp-test-token" - assert headers_data.get("x-custom-header") == "fastmcp-value" - assert headers_data.get("x-request-id") == "req-123" - - # Test 2: Call tool that returns full context - context_result = await session.call_tool( - "echo_context", {"custom_request_id": "test-123"} - ) - assert len(context_result.content) == 1 - assert isinstance(context_result.content[0], TextContent) - - context_data = json.loads(context_result.content[0].text) - assert context_data["custom_request_id"] == "test-123" - assert ( - context_data["headers"].get("authorization") - == "Bearer fastmcp-test-token" - ) - assert context_data["method"] == "POST" # - - -@pytest.mark.anyio -async def test_fast_mcp_request_context_isolation( - context_aware_server: None, server_url: str -) -> None: - """Test that request contexts are isolated between different FastMCP clients.""" - contexts = [] - - # Create multiple clients with different headers - for i in range(3): - headers = { - "Authorization": f"Bearer token-{i}", - "X-Request-Id": f"fastmcp-req-{i}", - "X-Custom-Value": f"value-{i}", - } - - async with sse_client(server_url + "/sse", headers=headers) as streams: - async with ClientSession(*streams) as session: - await session.initialize() - - # Call the tool that returns context - tool_result = await session.call_tool( - "echo_context", {"custom_request_id": f"test-req-{i}"} - ) - - # Parse and store the result - assert len(tool_result.content) == 1 - assert isinstance(tool_result.content[0], TextContent) - context_data = json.loads(tool_result.content[0].text) - contexts.append(context_data) - - # Verify each request had its own isolated context - assert len(contexts) == 3 - for i, ctx in enumerate(contexts): - assert ctx["custom_request_id"] == f"test-req-{i}" - assert ctx["headers"].get("authorization") == f"Bearer token-{i}" - assert ctx["headers"].get("x-request-id") == f"fastmcp-req-{i}" - assert ctx["headers"].get("x-custom-value") == f"value-{i}" - - @pytest.mark.anyio async def test_fastmcp_streamable_http( streamable_http_server: None, http_server_url: str @@ -967,6 +825,30 @@ async def progress_callback( assert isinstance(complex_result, GetPromptResult) assert len(complex_result.messages) >= 1 + # Test request context propagation (only works when headers are available) + + headers_result = await session.call_tool("echo_headers", {}) + assert len(headers_result.content) == 1 + assert isinstance(headers_result.content[0], TextContent) + + # If we got headers, verify they exist + headers_data = json.loads(headers_result.content[0].text) + # The headers depend on the transport and test setup + print(f"Received headers: {headers_data}") + + # Test 6: Call tool that returns full context + context_result = await session.call_tool( + "echo_context", {"custom_request_id": "test-123"} + ) + assert len(context_result.content) == 1 + assert isinstance(context_result.content[0], TextContent) + + context_data = json.loads(context_result.content[0].text) + assert context_data["custom_request_id"] == "test-123" + # The method should be POST for most transports + if context_data["method"]: + assert context_data["method"] == "POST" + async def sampling_callback( context: RequestContext[ClientSession, None], diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index f1c7ef809..5cf346e1a 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -4,6 +4,7 @@ Contains tests for both server and client sides of the StreamableHTTP transport. """ +import json import multiprocessing import socket import time @@ -17,6 +18,7 @@ import uvicorn from pydantic import AnyUrl from starlette.applications import Starlette +from starlette.requests import Request from starlette.routing import Mount import mcp.types as types @@ -1223,3 +1225,203 @@ async def sampling_callback( captured_message_params.messages[0].content.text == "Server needs client sampling" ) + + +# Context-aware server implementation for testing request context propagation +class ContextAwareServerTest(Server): + def __init__(self): + super().__init__("ContextAwareServer") + + @self.list_tools() + async def handle_list_tools() -> list[Tool]: + return [ + Tool( + name="echo_headers", + description="Echo request headers from context", + inputSchema={"type": "object", "properties": {}}, + ), + Tool( + name="echo_context", + description="Echo request context with custom data", + inputSchema={ + "type": "object", + "properties": { + "request_id": {"type": "string"}, + }, + "required": ["request_id"], + }, + ), + ] + + @self.call_tool() + async def handle_call_tool(name: str, args: dict) -> list[TextContent]: + ctx = self.request_context + + if name == "echo_headers": + # Access the request object from context + headers_info = {} + if ctx.request and isinstance(ctx.request, Request): + headers_info = dict(ctx.request.headers) + return [ + TextContent( + type="text", + text=json.dumps(headers_info), + ) + ] + + elif name == "echo_context": + # Return full context information + context_data = { + "request_id": args.get("request_id"), + "headers": {}, + "method": None, + "path": None, + } + if ctx.request and isinstance(ctx.request, Request): + request = ctx.request + context_data["headers"] = dict(request.headers) + context_data["method"] = request.method + context_data["path"] = request.url.path + return [ + TextContent( + type="text", + text=json.dumps(context_data), + ) + ] + + return [TextContent(type="text", text=f"Unknown tool: {name}")] + + +# Server runner for context-aware testing +def run_context_aware_server(port: int): + """Run the context-aware test server.""" + server = ContextAwareServerTest() + + session_manager = StreamableHTTPSessionManager( + app=server, + event_store=None, + json_response=False, + ) + + app = Starlette( + debug=True, + routes=[ + Mount("/mcp", app=session_manager.handle_request), + ], + lifespan=lambda app: session_manager.run(), + ) + + server_instance = uvicorn.Server( + config=uvicorn.Config( + app=app, + host="127.0.0.1", + port=port, + log_level="error", + ) + ) + server_instance.run() + + +@pytest.fixture +def context_aware_server(basic_server_port: int) -> Generator[None, None, None]: + """Start the context-aware server in a separate process.""" + proc = multiprocessing.Process( + target=run_context_aware_server, args=(basic_server_port,), daemon=True + ) + proc.start() + + # Wait for server to be running + max_attempts = 20 + attempt = 0 + while attempt < max_attempts: + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.connect(("127.0.0.1", basic_server_port)) + break + except ConnectionRefusedError: + time.sleep(0.1) + attempt += 1 + else: + raise RuntimeError( + f"Context-aware server failed to start after {max_attempts} attempts" + ) + + yield + + proc.kill() + proc.join(timeout=2) + if proc.is_alive(): + print("Context-aware server process failed to terminate") + + +@pytest.mark.anyio +async def test_streamablehttp_request_context_propagation( + context_aware_server: None, basic_server_url: str +) -> None: + """Test that request context is properly propagated through StreamableHTTP.""" + custom_headers = { + "Authorization": "Bearer test-token", + "X-Custom-Header": "test-value", + "X-Trace-Id": "trace-123", + } + + async with streamablehttp_client( + f"{basic_server_url}/mcp", headers=custom_headers + ) as (read_stream, write_stream, _): + async with ClientSession(read_stream, write_stream) as session: + result = await session.initialize() + assert isinstance(result, InitializeResult) + assert result.serverInfo.name == "ContextAwareServer" + + # Call the tool that echoes headers back + tool_result = await session.call_tool("echo_headers", {}) + + # Parse the JSON response + assert len(tool_result.content) == 1 + assert isinstance(tool_result.content[0], TextContent) + headers_data = json.loads(tool_result.content[0].text) + + # Verify headers were propagated + assert headers_data.get("authorization") == "Bearer test-token" + assert headers_data.get("x-custom-header") == "test-value" + assert headers_data.get("x-trace-id") == "trace-123" + + +@pytest.mark.anyio +async def test_streamablehttp_request_context_isolation( + context_aware_server: None, basic_server_url: str +) -> None: + """Test that request contexts are isolated between StreamableHTTP clients.""" + contexts = [] + + # Create multiple clients with different headers + for i in range(3): + headers = { + "X-Request-Id": f"request-{i}", + "X-Custom-Value": f"value-{i}", + "Authorization": f"Bearer token-{i}", + } + + async with streamablehttp_client( + f"{basic_server_url}/mcp", headers=headers + ) as (read_stream, write_stream, _): + async with ClientSession(read_stream, write_stream) as session: + await session.initialize() + + # Call the tool that echoes context + tool_result = await session.call_tool( + "echo_context", {"request_id": f"request-{i}"} + ) + + assert len(tool_result.content) == 1 + assert isinstance(tool_result.content[0], TextContent) + context_data = json.loads(tool_result.content[0].text) + contexts.append(context_data) + + # Verify each request had its own context + assert len(contexts) == 3 + for i, ctx in enumerate(contexts): + assert ctx["request_id"] == f"request-{i}" + assert ctx["headers"].get("x-request-id") == f"request-{i}" + assert ctx["headers"].get("x-custom-value") == f"value-{i}" + assert ctx["headers"].get("authorization") == f"Bearer token-{i}"