diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index 3282baae6..e5b6c3acc 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -49,7 +49,7 @@ from mcp.server.stdio import stdio_server from mcp.server.streamable_http import EventStore from mcp.server.streamable_http_manager import StreamableHTTPSessionManager -from mcp.shared.context import LifespanContextT, RequestContext +from mcp.shared.context import LifespanContextT, RequestContext, RequestT from mcp.types import ( AnyFunction, EmbeddedResource, @@ -124,9 +124,11 @@ class Settings(BaseSettings, Generic[LifespanResultT]): def lifespan_wrapper( app: FastMCP, lifespan: Callable[[FastMCP], AbstractAsyncContextManager[LifespanResultT]], -) -> Callable[[MCPServer[LifespanResultT]], AbstractAsyncContextManager[object]]: +) -> Callable[ + [MCPServer[LifespanResultT, Request]], AbstractAsyncContextManager[object] +]: @asynccontextmanager - async def wrap(s: MCPServer[LifespanResultT]) -> AsyncIterator[object]: + async def wrap(s: MCPServer[LifespanResultT, Request]) -> AsyncIterator[object]: async with lifespan(app) as context: yield context @@ -260,7 +262,7 @@ async def list_tools(self) -> list[MCPTool]: for info in tools ] - def get_context(self) -> Context[ServerSession, object]: + def get_context(self) -> Context[ServerSession, object, Request]: """ Returns a Context object. Note that the context will only be valid during a request; outside a request, most methods will error. @@ -893,7 +895,7 @@ def _convert_to_content( return [TextContent(type="text", text=result)] -class Context(BaseModel, Generic[ServerSessionT, LifespanContextT]): +class Context(BaseModel, Generic[ServerSessionT, LifespanContextT, RequestT]): """Context object providing access to MCP capabilities. This provides a cleaner interface to MCP's RequestContext functionality. @@ -927,13 +929,15 @@ def my_tool(x: int, ctx: Context) -> str: The context is optional - tools that don't need it can omit the parameter. """ - _request_context: RequestContext[ServerSessionT, LifespanContextT] | None + _request_context: RequestContext[ServerSessionT, LifespanContextT, RequestT] | None _fastmcp: FastMCP | None def __init__( self, *, - request_context: RequestContext[ServerSessionT, LifespanContextT] | None = None, + request_context: ( + RequestContext[ServerSessionT, LifespanContextT, RequestT] | None + ) = None, fastmcp: FastMCP | None = None, **kwargs: Any, ): @@ -949,7 +953,9 @@ def fastmcp(self) -> FastMCP: return self._fastmcp @property - def request_context(self) -> RequestContext[ServerSessionT, LifespanContextT]: + def request_context( + self, + ) -> RequestContext[ServerSessionT, LifespanContextT, RequestT]: """Access to the underlying request context.""" if self._request_context is None: raise ValueError("Context is not available outside of a request") diff --git a/src/mcp/server/fastmcp/tools/base.py b/src/mcp/server/fastmcp/tools/base.py index 01fedcdc9..f32eb15bd 100644 --- a/src/mcp/server/fastmcp/tools/base.py +++ b/src/mcp/server/fastmcp/tools/base.py @@ -14,7 +14,7 @@ if TYPE_CHECKING: from mcp.server.fastmcp.server import Context from mcp.server.session import ServerSessionT - from mcp.shared.context import LifespanContextT + from mcp.shared.context import LifespanContextT, RequestT class Tool(BaseModel): @@ -85,7 +85,7 @@ def from_function( async def run( self, arguments: dict[str, Any], - context: Context[ServerSessionT, LifespanContextT] | None = None, + context: Context[ServerSessionT, LifespanContextT, RequestT] | None = None, ) -> Any: """Run the tool with arguments.""" try: diff --git a/src/mcp/server/fastmcp/tools/tool_manager.py b/src/mcp/server/fastmcp/tools/tool_manager.py index 6ec4fd151..153249379 100644 --- a/src/mcp/server/fastmcp/tools/tool_manager.py +++ b/src/mcp/server/fastmcp/tools/tool_manager.py @@ -6,7 +6,7 @@ from mcp.server.fastmcp.exceptions import ToolError from mcp.server.fastmcp.tools.base import Tool from mcp.server.fastmcp.utilities.logging import get_logger -from mcp.shared.context import LifespanContextT +from mcp.shared.context import LifespanContextT, RequestT from mcp.types import ToolAnnotations if TYPE_CHECKING: @@ -65,7 +65,7 @@ async def call_tool( self, name: str, arguments: dict[str, Any], - context: Context[ServerSessionT, LifespanContextT] | None = None, + context: Context[ServerSessionT, LifespanContextT, RequestT] | None = None, ) -> Any: """Call a tool by name with arguments.""" tool = self.get_tool(name) diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 876aef817..b98e3dd1a 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -72,11 +72,12 @@ async def main(): import warnings from collections.abc import AsyncIterator, Awaitable, Callable, Iterable from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager -from typing import Any, Generic, TypeVar +from typing import Any, Generic import anyio from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from pydantic import AnyUrl +from typing_extensions import TypeVar import mcp.types as types from mcp.server.lowlevel.helper_types import ReadResourceContents @@ -85,15 +86,16 @@ async def main(): from mcp.server.stdio import stdio_server as stdio_server from mcp.shared.context import RequestContext from mcp.shared.exceptions import McpError -from mcp.shared.message import SessionMessage +from mcp.shared.message import ServerMessageMetadata, SessionMessage from mcp.shared.session import RequestResponder logger = logging.getLogger(__name__) LifespanResultT = TypeVar("LifespanResultT") +RequestT = TypeVar("RequestT", default=Any) # This will be properly typed in each Server instance's context -request_ctx: contextvars.ContextVar[RequestContext[ServerSession, Any]] = ( +request_ctx: contextvars.ContextVar[RequestContext[ServerSession, Any, Any]] = ( contextvars.ContextVar("request_ctx") ) @@ -111,7 +113,7 @@ def __init__( @asynccontextmanager -async def lifespan(server: Server[LifespanResultT]) -> AsyncIterator[object]: +async def lifespan(server: Server[LifespanResultT, RequestT]) -> AsyncIterator[object]: """Default lifespan context manager that does nothing. Args: @@ -123,14 +125,15 @@ async def lifespan(server: Server[LifespanResultT]) -> AsyncIterator[object]: yield {} -class Server(Generic[LifespanResultT]): +class Server(Generic[LifespanResultT, RequestT]): def __init__( self, name: str, version: str | None = None, instructions: str | None = None, lifespan: Callable[ - [Server[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT] + [Server[LifespanResultT, RequestT]], + AbstractAsyncContextManager[LifespanResultT], ] = lifespan, ): self.name = name @@ -215,7 +218,9 @@ def get_capabilities( ) @property - def request_context(self) -> RequestContext[ServerSession, LifespanResultT]: + def request_context( + self, + ) -> RequestContext[ServerSession, LifespanResultT, RequestT]: """If called outside of a request context, this will raise a LookupError.""" return request_ctx.get() @@ -555,6 +560,13 @@ async def _handle_request( token = None try: + # Extract request context from message metadata + request_data = None + if message.message_metadata is not None and isinstance( + message.message_metadata, ServerMessageMetadata + ): + request_data = message.message_metadata.request_context + # Set our global state that can be retrieved via # app.get_request_context() token = request_ctx.set( @@ -563,6 +575,7 @@ async def _handle_request( message.request_meta, session, lifespan_context, + request=request_data, ) ) response = await handler(req) diff --git a/src/mcp/server/sse.py b/src/mcp/server/sse.py index bae2bbf52..192c1290b 100644 --- a/src/mcp/server/sse.py +++ b/src/mcp/server/sse.py @@ -52,7 +52,7 @@ async def handle_sse(request): from starlette.types import Receive, Scope, Send import mcp.types as types -from mcp.shared.message import SessionMessage +from mcp.shared.message import ServerMessageMetadata, SessionMessage logger = logging.getLogger(__name__) @@ -203,7 +203,9 @@ async def handle_post_message( await writer.send(err) return - session_message = SessionMessage(message) + # Pass the ASGI scope for framework-agnostic access to request data + metadata = ServerMessageMetadata(request_context=request) + session_message = SessionMessage(message, metadata=metadata) logger.debug(f"Sending session message to writer: {session_message}") response = Response("Accepted", status_code=202) await response(scope, receive, send) diff --git a/src/mcp/server/streamable_http_manager.py b/src/mcp/server/streamable_http_manager.py index e5ef8b4aa..8188c2f3b 100644 --- a/src/mcp/server/streamable_http_manager.py +++ b/src/mcp/server/streamable_http_manager.py @@ -56,7 +56,7 @@ class StreamableHTTPSessionManager: def __init__( self, - app: MCPServer[Any], + app: MCPServer[Any, Any], event_store: EventStore | None = None, json_response: bool = False, stateless: bool = False, diff --git a/src/mcp/shared/context.py b/src/mcp/shared/context.py index ae85d3a19..f3006e7d5 100644 --- a/src/mcp/shared/context.py +++ b/src/mcp/shared/context.py @@ -8,11 +8,13 @@ SessionT = TypeVar("SessionT", bound=BaseSession[Any, Any, Any, Any, Any]) LifespanContextT = TypeVar("LifespanContextT") +RequestT = TypeVar("RequestT", default=Any) @dataclass -class RequestContext(Generic[SessionT, LifespanContextT]): +class RequestContext(Generic[SessionT, LifespanContextT, RequestT]): request_id: RequestId meta: RequestParams.Meta | None session: SessionT lifespan_context: LifespanContextT + request: RequestT | None = None diff --git a/src/mcp/shared/message.py b/src/mcp/shared/message.py index 5583f4795..6b0233714 100644 --- a/src/mcp/shared/message.py +++ b/src/mcp/shared/message.py @@ -30,6 +30,8 @@ class ServerMessageMetadata: """Metadata specific to server messages.""" related_request_id: RequestId | None = None + # Request-specific context (e.g., headers, auth info) + request_context: object | None = None MessageMetadata = ClientMessageMetadata | ServerMessageMetadata | None diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 90b4eb27c..791c0b138 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -80,10 +80,12 @@ def __init__( ReceiveNotificationT ]""", on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any], + message_metadata: MessageMetadata = None, ) -> None: self.request_id = request_id self.request_meta = request_meta self.request = request + self.message_metadata = message_metadata self._session = session self._completed = False self._cancel_scope = anyio.CancelScope() @@ -364,6 +366,7 @@ async def _receive_loop(self) -> None: request=validated_request, session=self, on_complete=lambda r: self._in_flight.pop(r.request_id, None), + message_metadata=message.metadata, ) self._in_flight[responder.request_id] = responder diff --git a/tests/server/fastmcp/test_integration.py b/tests/server/fastmcp/test_integration.py index 79285ecb1..761e810bc 100644 --- a/tests/server/fastmcp/test_integration.py +++ b/tests/server/fastmcp/test_integration.py @@ -5,14 +5,18 @@ including with and without authentication. """ +import json import multiprocessing import socket import time from collections.abc import Generator +from typing import Any import pytest import uvicorn from pydantic import AnyUrl +from starlette.applications import Starlette +from starlette.requests import Request import mcp.types as types from mcp.client.session import ClientSession @@ -20,6 +24,7 @@ from mcp.client.streamable_http import streamablehttp_client from mcp.server.fastmcp import FastMCP from mcp.server.fastmcp.resources import FunctionResource +from mcp.server.fastmcp.server import Context from mcp.shared.context import RequestContext from mcp.types import ( CreateMessageRequestParams, @@ -78,8 +83,6 @@ def stateless_http_server_url(stateless_http_server_port: int) -> str: # Create a function to make the FastMCP server app def make_fastmcp_app(): """Create a FastMCP server without auth settings.""" - from starlette.applications import Starlette - mcp = FastMCP(name="NoAuthServer") # Add a simple tool @@ -88,7 +91,7 @@ def echo(message: str) -> str: return f"Echo: {message}" # Create the SSE app - app: Starlette = mcp.sse_app() + app = mcp.sse_app() return mcp, app @@ -198,17 +201,14 @@ def complex_prompt(user_query: str, context: str = "general") -> str: def make_everything_fastmcp_app(): """Create a comprehensive FastMCP server with SSE transport.""" - from starlette.applications import Starlette - mcp = make_everything_fastmcp() # Create the SSE app - app: Starlette = mcp.sse_app() + app = mcp.sse_app() return mcp, app def make_fastmcp_streamable_http_app(): """Create a FastMCP server with StreamableHTTP transport.""" - from starlette.applications import Starlette mcp = FastMCP(name="NoAuthServer") @@ -225,8 +225,6 @@ def echo(message: str) -> str: def make_everything_fastmcp_streamable_http_app(): """Create a comprehensive FastMCP server with StreamableHTTP transport.""" - from starlette.applications import Starlette - # Create a new instance with different name for HTTP transport mcp = make_everything_fastmcp() # We can't change the name after creation, so we'll use the same name @@ -237,7 +235,6 @@ def make_everything_fastmcp_streamable_http_app(): def make_fastmcp_stateless_http_app(): """Create a FastMCP server with stateless StreamableHTTP transport.""" - from starlette.applications import Starlette mcp = FastMCP(name="StatelessServer", stateless_http=True) @@ -435,6 +432,174 @@ async def test_fastmcp_without_auth(server: None, server_url: str) -> None: assert tool_result.content[0].text == "Echo: hello" +def make_fastmcp_with_context_app(): + """Create a FastMCP server that can access request context.""" + + mcp = FastMCP(name="ContextServer") + + # Tool that echoes request headers + @mcp.tool(description="Echo request headers from context") + def echo_headers(ctx: Context[Any, Any, Request]) -> str: + """Returns the request headers as JSON.""" + headers_info = {} + if ctx.request_context.request: + # Now the type system knows request is a Starlette Request object + headers_info = dict(ctx.request_context.request.headers) + return json.dumps(headers_info) + + # Tool that returns full request context + @mcp.tool(description="Echo request context with custom data") + def echo_context(custom_request_id: str, ctx: Context[Any, Any, Request]) -> str: + """Returns request context including headers and custom data.""" + context_data = { + "custom_request_id": custom_request_id, + "headers": {}, + "method": None, + "path": None, + } + if ctx.request_context.request: + request = ctx.request_context.request + context_data["headers"] = dict(request.headers) + context_data["method"] = request.method + context_data["path"] = request.url.path + return json.dumps(context_data) + + # Create the SSE app + app = mcp.sse_app() + return mcp, app + + +def run_context_server(server_port: int) -> None: + """Run the context-aware FastMCP server.""" + _, app = make_fastmcp_with_context_app() + server = uvicorn.Server( + config=uvicorn.Config( + app=app, host="127.0.0.1", port=server_port, log_level="error" + ) + ) + print(f"Starting context server on port {server_port}") + server.run() + + +@pytest.fixture() +def context_aware_server(server_port: int) -> Generator[None, None, None]: + """Start the context-aware server in a separate process.""" + proc = multiprocessing.Process( + target=run_context_server, args=(server_port,), daemon=True + ) + print("Starting context-aware server process") + proc.start() + + # Wait for server to be running + max_attempts = 20 + attempt = 0 + print("Waiting for context-aware server to start") + while attempt < max_attempts: + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.connect(("127.0.0.1", server_port)) + break + except ConnectionRefusedError: + time.sleep(0.1) + attempt += 1 + else: + raise RuntimeError( + f"Context server failed to start after {max_attempts} attempts" + ) + + yield + + print("Killing context-aware server") + proc.kill() + proc.join(timeout=2) + if proc.is_alive(): + print("Context server process failed to terminate") + + +@pytest.mark.anyio +async def test_fast_mcp_with_request_context( + context_aware_server: None, server_url: str +) -> None: + """Test that FastMCP properly propagates request context to tools.""" + # Test with custom headers + custom_headers = { + "Authorization": "Bearer fastmcp-test-token", + "X-Custom-Header": "fastmcp-value", + "X-Request-Id": "req-123", + } + + async with sse_client(server_url + "/sse", headers=custom_headers) as streams: + async with ClientSession(*streams) as session: + # Initialize the session + result = await session.initialize() + assert isinstance(result, InitializeResult) + assert result.serverInfo.name == "ContextServer" + + # Test 1: Call tool that echoes headers + headers_result = await session.call_tool("echo_headers", {}) + assert len(headers_result.content) == 1 + assert isinstance(headers_result.content[0], TextContent) + + headers_data = json.loads(headers_result.content[0].text) + assert headers_data.get("authorization") == "Bearer fastmcp-test-token" + assert headers_data.get("x-custom-header") == "fastmcp-value" + assert headers_data.get("x-request-id") == "req-123" + + # Test 2: Call tool that returns full context + context_result = await session.call_tool( + "echo_context", {"custom_request_id": "test-123"} + ) + assert len(context_result.content) == 1 + assert isinstance(context_result.content[0], TextContent) + + context_data = json.loads(context_result.content[0].text) + assert context_data["custom_request_id"] == "test-123" + assert ( + context_data["headers"].get("authorization") + == "Bearer fastmcp-test-token" + ) + assert context_data["method"] == "POST" # + + +@pytest.mark.anyio +async def test_fast_mcp_request_context_isolation( + context_aware_server: None, server_url: str +) -> None: + """Test that request contexts are isolated between different FastMCP clients.""" + contexts = [] + + # Create multiple clients with different headers + for i in range(3): + headers = { + "Authorization": f"Bearer token-{i}", + "X-Request-Id": f"fastmcp-req-{i}", + "X-Custom-Value": f"value-{i}", + } + + async with sse_client(server_url + "/sse", headers=headers) as streams: + async with ClientSession(*streams) as session: + await session.initialize() + + # Call the tool that returns context + tool_result = await session.call_tool( + "echo_context", {"custom_request_id": f"test-req-{i}"} + ) + + # Parse and store the result + assert len(tool_result.content) == 1 + assert isinstance(tool_result.content[0], TextContent) + context_data = json.loads(tool_result.content[0].text) + contexts.append(context_data) + + # Verify each request had its own isolated context + assert len(contexts) == 3 + for i, ctx in enumerate(contexts): + assert ctx["custom_request_id"] == f"test-req-{i}" + assert ctx["headers"].get("authorization") == f"Bearer token-{i}" + assert ctx["headers"].get("x-request-id") == f"fastmcp-req-{i}" + assert ctx["headers"].get("x-custom-value") == f"value-{i}" + + @pytest.mark.anyio async def test_fastmcp_streamable_http( streamable_http_server: None, http_server_url: str diff --git a/tests/server/fastmcp/test_tool_manager.py b/tests/server/fastmcp/test_tool_manager.py index 203a7172b..b45c7ac38 100644 --- a/tests/server/fastmcp/test_tool_manager.py +++ b/tests/server/fastmcp/test_tool_manager.py @@ -9,7 +9,7 @@ from mcp.server.fastmcp.tools import Tool, ToolManager from mcp.server.fastmcp.utilities.func_metadata import ArgModelBase, FuncMetadata from mcp.server.session import ServerSessionT -from mcp.shared.context import LifespanContextT +from mcp.shared.context import LifespanContextT, RequestT from mcp.types import ToolAnnotations @@ -347,7 +347,7 @@ def tool_without_context(x: int) -> str: assert tool.context_kwarg is None def tool_with_parametrized_context( - x: int, ctx: Context[ServerSessionT, LifespanContextT] + x: int, ctx: Context[ServerSessionT, LifespanContextT, RequestT] ) -> str: return str(x) diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index e55983e01..78bbbb235 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -1,3 +1,4 @@ +import json import multiprocessing import socket import time @@ -318,3 +319,187 @@ async def test_sse_client_basic_connection_mounted_app( # Test ping ping_result = await session.send_ping() assert isinstance(ping_result, EmptyResult) + + +# Test server with request context that returns headers in the response +class RequestContextServer(Server[object, Request]): + def __init__(self): + super().__init__("request_context_server") + + @self.call_tool() + async def handle_call_tool(name: str, args: dict) -> list[TextContent]: + headers_info = {} + context = self.request_context + if context.request: + headers_info = dict(context.request.headers) + + if name == "echo_headers": + return [TextContent(type="text", text=json.dumps(headers_info))] + elif name == "echo_context": + context_data = { + "request_id": args.get("request_id"), + "headers": headers_info, + } + return [TextContent(type="text", text=json.dumps(context_data))] + + return [TextContent(type="text", text=f"Called {name}")] + + @self.list_tools() + async def handle_list_tools() -> list[Tool]: + return [ + Tool( + name="echo_headers", + description="Echoes request headers", + inputSchema={"type": "object", "properties": {}}, + ), + Tool( + name="echo_context", + description="Echoes request context", + inputSchema={ + "type": "object", + "properties": {"request_id": {"type": "string"}}, + "required": ["request_id"], + }, + ), + ] + + +def run_context_server(server_port: int) -> None: + """Run a server that captures request context""" + sse = SseServerTransport("/messages/") + context_server = RequestContextServer() + + async def handle_sse(request: Request) -> Response: + async with sse.connect_sse( + request.scope, request.receive, request._send + ) as streams: + await context_server.run( + streams[0], streams[1], context_server.create_initialization_options() + ) + return Response() + + app = Starlette( + routes=[ + Route("/sse", endpoint=handle_sse), + Mount("/messages/", app=sse.handle_post_message), + ] + ) + + server = uvicorn.Server( + config=uvicorn.Config( + app=app, host="127.0.0.1", port=server_port, log_level="error" + ) + ) + print(f"starting context server on {server_port}") + server.run() + + +@pytest.fixture() +def context_server(server_port: int) -> Generator[None, None, None]: + """Fixture that provides a server with request context capture""" + proc = multiprocessing.Process( + target=run_context_server, kwargs={"server_port": server_port}, daemon=True + ) + print("starting context server process") + proc.start() + + # Wait for server to be running + max_attempts = 20 + attempt = 0 + print("waiting for context server to start") + while attempt < max_attempts: + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.connect(("127.0.0.1", server_port)) + break + except ConnectionRefusedError: + time.sleep(0.1) + attempt += 1 + else: + raise RuntimeError( + f"Context server failed to start after {max_attempts} attempts" + ) + + yield + + print("killing context server") + proc.kill() + proc.join(timeout=2) + if proc.is_alive(): + print("context server process failed to terminate") + + +@pytest.mark.anyio +async def test_request_context_propagation( + context_server: None, server_url: str +) -> None: + """Test that request context is properly propagated through SSE transport.""" + # Test with custom headers + custom_headers = { + "Authorization": "Bearer test-token", + "X-Custom-Header": "test-value", + "X-Trace-Id": "trace-123", + } + + async with sse_client(server_url + "/sse", headers=custom_headers) as ( + read_stream, + write_stream, + ): + async with ClientSession(read_stream, write_stream) as session: + # Initialize the session + result = await session.initialize() + assert isinstance(result, InitializeResult) + + # Call the tool that echoes headers back + tool_result = await session.call_tool("echo_headers", {}) + + # Parse the JSON response + + assert len(tool_result.content) == 1 + headers_data = json.loads( + tool_result.content[0].text + if tool_result.content[0].type == "text" + else "{}" + ) + + # Verify headers were propagated + assert headers_data.get("authorization") == "Bearer test-token" + assert headers_data.get("x-custom-header") == "test-value" + assert headers_data.get("x-trace-id") == "trace-123" + + +@pytest.mark.anyio +async def test_request_context_isolation(context_server: None, server_url: str) -> None: + """Test that request contexts are isolated between different SSE clients.""" + contexts = [] + + # Create multiple clients with different headers + for i in range(3): + headers = {"X-Request-Id": f"request-{i}", "X-Custom-Value": f"value-{i}"} + + async with sse_client(server_url + "/sse", headers=headers) as ( + read_stream, + write_stream, + ): + async with ClientSession(read_stream, write_stream) as session: + await session.initialize() + + # Call the tool that echoes context + tool_result = await session.call_tool( + "echo_context", {"request_id": f"request-{i}"} + ) + + assert len(tool_result.content) == 1 + context_data = json.loads( + tool_result.content[0].text + if tool_result.content[0].type == "text" + else "{}" + ) + contexts.append(context_data) + + # Verify each request had its own context + assert len(contexts) == 3 + for i, ctx in enumerate(contexts): + assert ctx["request_id"] == f"request-{i}" + assert ctx["headers"].get("x-request-id") == f"request-{i}" + assert ctx["headers"].get("x-custom-value") == f"value-{i}"