diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index fe90716e2..276f6582d 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -17,14 +17,14 @@ class SamplingFnT(Protocol): async def __call__( self, - context: RequestContext["ClientSession", Any], + context: RequestContext["ClientSession", Any, Any], params: types.CreateMessageRequestParams, ) -> types.CreateMessageResult | types.ErrorData: ... class ListRootsFnT(Protocol): async def __call__( - self, context: RequestContext["ClientSession", Any] + self, context: RequestContext["ClientSession", Any, Any] ) -> types.ListRootsResult | types.ErrorData: ... @@ -53,7 +53,7 @@ async def _default_message_handler( async def _default_sampling_callback( - context: RequestContext["ClientSession", Any], + context: RequestContext["ClientSession", Any, Any], params: types.CreateMessageRequestParams, ) -> types.CreateMessageResult | types.ErrorData: return types.ErrorData( @@ -63,7 +63,7 @@ async def _default_sampling_callback( async def _default_list_roots_callback( - context: RequestContext["ClientSession", Any], + context: RequestContext["ClientSession", Any, Any], ) -> types.ListRootsResult | types.ErrorData: return types.ErrorData( code=types.INVALID_REQUEST, @@ -367,7 +367,7 @@ async def send_roots_list_changed(self) -> None: async def _received_request( self, responder: RequestResponder[types.ServerRequest, types.ClientResult] ) -> None: - ctx = RequestContext[ClientSession, Any]( + ctx = RequestContext[ClientSession, Any, Any]( request_id=responder.request_id, meta=responder.request_meta, session=self, diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index 3282baae6..2d15d8145 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -124,9 +124,11 @@ class Settings(BaseSettings, Generic[LifespanResultT]): def lifespan_wrapper( app: FastMCP, lifespan: Callable[[FastMCP], AbstractAsyncContextManager[LifespanResultT]], -) -> Callable[[MCPServer[LifespanResultT]], AbstractAsyncContextManager[object]]: +) -> Callable[ + [MCPServer[LifespanResultT, Request]], AbstractAsyncContextManager[object] +]: @asynccontextmanager - async def wrap(s: MCPServer[LifespanResultT]) -> AsyncIterator[object]: + async def wrap(s: MCPServer[LifespanResultT, Request]) -> AsyncIterator[object]: async with lifespan(app) as context: yield context @@ -684,6 +686,7 @@ async def handle_sse(scope: Scope, receive: Receive, send: Send): streams[0], streams[1], self._mcp_server.create_initialization_options(), + request=request, ) return Response() @@ -927,13 +930,14 @@ def my_tool(x: int, ctx: Context) -> str: The context is optional - tools that don't need it can omit the parameter. """ - _request_context: RequestContext[ServerSessionT, LifespanContextT] | None + _request_context: RequestContext[ServerSessionT, LifespanContextT, Request] | None _fastmcp: FastMCP | None def __init__( self, *, - request_context: RequestContext[ServerSessionT, LifespanContextT] | None = None, + request_context: RequestContext[ServerSessionT, LifespanContextT, Request] + | None = None, fastmcp: FastMCP | None = None, **kwargs: Any, ): @@ -949,7 +953,9 @@ def fastmcp(self) -> FastMCP: return self._fastmcp @property - def request_context(self) -> RequestContext[ServerSessionT, LifespanContextT]: + def request_context( + self, + ) -> RequestContext[ServerSessionT, LifespanContextT, Request]: """Access to the underlying request context.""" if self._request_context is None: raise ValueError("Context is not available outside of a request") diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 876aef817..2b1a5b7cf 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -83,7 +83,7 @@ async def main(): from mcp.server.models import InitializationOptions from mcp.server.session import ServerSession from mcp.server.stdio import stdio_server as stdio_server -from mcp.shared.context import RequestContext +from mcp.shared.context import RequestContext, RequestT from mcp.shared.exceptions import McpError from mcp.shared.message import SessionMessage from mcp.shared.session import RequestResponder @@ -93,7 +93,7 @@ async def main(): LifespanResultT = TypeVar("LifespanResultT") # This will be properly typed in each Server instance's context -request_ctx: contextvars.ContextVar[RequestContext[ServerSession, Any]] = ( +request_ctx: contextvars.ContextVar[RequestContext[ServerSession, Any, Any]] = ( contextvars.ContextVar("request_ctx") ) @@ -111,7 +111,7 @@ def __init__( @asynccontextmanager -async def lifespan(server: Server[LifespanResultT]) -> AsyncIterator[object]: +async def lifespan(server: Server[LifespanResultT, RequestT]) -> AsyncIterator[object]: """Default lifespan context manager that does nothing. Args: @@ -123,14 +123,15 @@ async def lifespan(server: Server[LifespanResultT]) -> AsyncIterator[object]: yield {} -class Server(Generic[LifespanResultT]): +class Server(Generic[LifespanResultT, RequestT]): def __init__( self, name: str, version: str | None = None, instructions: str | None = None, lifespan: Callable[ - [Server[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT] + [Server[LifespanResultT, RequestT]], + AbstractAsyncContextManager[LifespanResultT], ] = lifespan, ): self.name = name @@ -215,7 +216,9 @@ def get_capabilities( ) @property - def request_context(self) -> RequestContext[ServerSession, LifespanResultT]: + def request_context( + self, + ) -> RequestContext[ServerSession, LifespanResultT, RequestT]: """If called outside of a request context, this will raise a LookupError.""" return request_ctx.get() @@ -486,6 +489,7 @@ async def run( # but also make tracing exceptions much easier during testing and when using # in-process servers. raise_exceptions: bool = False, + request: RequestT | None = None, # When True, the server is stateless and # clients can perform initialization with any node. The client must still follow # the initialization lifecycle, but can do so with any available node @@ -513,6 +517,7 @@ async def run( session, lifespan_context, raise_exceptions, + request, ) async def _handle_message( @@ -523,6 +528,7 @@ async def _handle_message( session: ServerSession, lifespan_context: LifespanResultT, raise_exceptions: bool = False, + request: RequestT | None = None, ): with warnings.catch_warnings(record=True) as w: # TODO(Marcelo): We should be checking if message is Exception here. @@ -532,7 +538,12 @@ async def _handle_message( ): with responder: await self._handle_request( - message, req, session, lifespan_context, raise_exceptions + message, + req, + session, + lifespan_context, + raise_exceptions, + request, ) case types.ClientNotification(root=notify): await self._handle_notification(notify) @@ -547,6 +558,7 @@ async def _handle_request( session: ServerSession, lifespan_context: LifespanResultT, raise_exceptions: bool, + request: RequestT | None, ): logger.info(f"Processing request of type {type(req).__name__}") if type(req) in self.request_handlers: @@ -563,6 +575,7 @@ async def _handle_request( message.request_meta, session, lifespan_context, + request=request, ) ) response = await handler(req) diff --git a/src/mcp/shared/context.py b/src/mcp/shared/context.py index ae85d3a19..2b6735655 100644 --- a/src/mcp/shared/context.py +++ b/src/mcp/shared/context.py @@ -8,11 +8,13 @@ SessionT = TypeVar("SessionT", bound=BaseSession[Any, Any, Any, Any, Any]) LifespanContextT = TypeVar("LifespanContextT") +RequestT = TypeVar("RequestT") @dataclass -class RequestContext(Generic[SessionT, LifespanContextT]): +class RequestContext(Generic[SessionT, LifespanContextT, RequestT]): request_id: RequestId meta: RequestParams.Meta | None session: SessionT lifespan_context: LifespanContextT + request: RequestT | None = None diff --git a/src/mcp/shared/memory.py b/src/mcp/shared/memory.py index b53f8dd63..7d2a73777 100644 --- a/src/mcp/shared/memory.py +++ b/src/mcp/shared/memory.py @@ -60,7 +60,7 @@ async def create_client_server_memory_streams() -> ( @asynccontextmanager async def create_connected_server_and_client_session( - server: Server[Any], + server: Server[Any, Any], read_timeout_seconds: timedelta | None = None, sampling_callback: SamplingFnT | None = None, list_roots_callback: ListRootsFnT | None = None, diff --git a/src/mcp/shared/progress.py b/src/mcp/shared/progress.py index 856a8d3b6..35092b5be 100644 --- a/src/mcp/shared/progress.py +++ b/src/mcp/shared/progress.py @@ -1,7 +1,7 @@ from collections.abc import Generator from contextlib import contextmanager from dataclasses import dataclass, field -from typing import Generic +from typing import Any, Generic from pydantic import BaseModel @@ -62,6 +62,7 @@ def progress( ReceiveNotificationT, ], LifespanContextT, + Any, ], total: float | None = None, ) -> Generator[ diff --git a/tests/client/test_list_roots_callback.py b/tests/client/test_list_roots_callback.py index f5b598218..6f3d3576e 100644 --- a/tests/client/test_list_roots_callback.py +++ b/tests/client/test_list_roots_callback.py @@ -30,7 +30,7 @@ async def test_list_roots_callback(): ) async def list_roots_callback( - context: RequestContext[ClientSession, None], + context: RequestContext[ClientSession, None, None], ) -> ListRootsResult: return callback_return diff --git a/tests/client/test_sampling_callback.py b/tests/client/test_sampling_callback.py index ba586d4a8..ee09cc083 100644 --- a/tests/client/test_sampling_callback.py +++ b/tests/client/test_sampling_callback.py @@ -29,7 +29,7 @@ async def test_sampling_callback(): ) async def sampling_callback( - context: RequestContext[ClientSession, None], + context: RequestContext[ClientSession, None, None], params: CreateMessageRequestParams, ) -> CreateMessageResult: return callback_return