From c4ec56de7743d0f358a88d275eac3a0c3cb05542 Mon Sep 17 00:00:00 2001 From: Lorenzo C Date: Sat, 24 May 2025 22:27:20 -0300 Subject: [PATCH 1/2] Client sampling and roots capabilities set to None if not implemented --- src/mcp/client/session.py | 19 +++-- src/mcp/types.py | 2 +- tests/client/test_session.py | 161 +++++++++++++++++++++++++++++++++++ 3 files changed, 175 insertions(+), 7 deletions(-) diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index fe90716e2..4cb1713c4 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -116,13 +116,20 @@ def __init__( self._message_handler = message_handler or _default_message_handler async def initialize(self) -> types.InitializeResult: - sampling = types.SamplingCapability() - roots = types.RootsCapability( - # TODO: Should this be based on whether we - # _will_ send notifications, or only whether - # they're supported? - listChanged=True, + sampling = ( + types.SamplingCapability() + if self._sampling_callback is not _default_sampling_callback + else None ) + if self._list_roots_callback is _default_list_roots_callback: + roots = None + else: + roots = types.RootsCapability( + # TODO: Should this be based on whether we + # _will_ send notifications, or only whether + # they're supported? + listChanged=True, + ) result = await self.send_request( types.ClientRequest( diff --git a/src/mcp/types.py b/src/mcp/types.py index 465fc6ee6..b281cf9bd 100644 --- a/src/mcp/types.py +++ b/src/mcp/types.py @@ -214,7 +214,7 @@ class RootsCapability(BaseModel): class SamplingCapability(BaseModel): - """Capability for logging operations.""" + """Capability for sampling operations.""" model_config = ConfigDict(extra="allow") diff --git a/tests/client/test_session.py b/tests/client/test_session.py index cad89f217..f48955420 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -1,8 +1,11 @@ +from typing import Any + import anyio import pytest import mcp.types as types from mcp.client.session import DEFAULT_CLIENT_INFO, ClientSession +from mcp.shared.context import RequestContext from mcp.shared.message import SessionMessage from mcp.shared.session import RequestResponder from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS @@ -380,3 +383,161 @@ async def mock_server(): # Should raise RuntimeError for unsupported version with pytest.raises(RuntimeError, match="Unsupported protocol version"): await session.initialize() + + +@pytest.mark.anyio +async def test_client_capabilities_default(): + """Test that client capabilities are properly set with default callbacks""" + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ + SessionMessage + ](1) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ + SessionMessage + ](1) + + received_capabilities = None + + async def mock_server(): + nonlocal received_capabilities + + session_message = await client_to_server_receive.receive() + jsonrpc_request = session_message.message + assert isinstance(jsonrpc_request.root, JSONRPCRequest) + request = ClientRequest.model_validate( + jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) + ) + assert isinstance(request.root, InitializeRequest) + received_capabilities = request.root.params.capabilities + + result = ServerResult( + InitializeResult( + protocolVersion=LATEST_PROTOCOL_VERSION, + capabilities=ServerCapabilities(), + serverInfo=Implementation(name="mock-server", version="0.1.0"), + ) + ) + + async with server_to_client_send: + await server_to_client_send.send( + SessionMessage( + JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.root.id, + result=result.model_dump( + by_alias=True, mode="json", exclude_none=True + ), + ) + ) + ) + ) + # Receive initialized notification + await client_to_server_receive.receive() + + async with ( + ClientSession( + server_to_client_receive, + client_to_server_send, + ) as session, + anyio.create_task_group() as tg, + client_to_server_send, + client_to_server_receive, + server_to_client_send, + server_to_client_receive, + ): + tg.start_soon(mock_server) + await session.initialize() + + # Assert that capabilities are properly set with defaults + assert received_capabilities is not None + assert received_capabilities.sampling is None # No custom sampling callback + assert received_capabilities.roots is None # No custom list_roots callback + + +@pytest.mark.anyio +async def test_client_capabilities_with_custom_callbacks(): + """Test that client capabilities are properly set with custom callbacks""" + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ + SessionMessage + ](1) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ + SessionMessage + ](1) + + received_capabilities = None + + async def custom_sampling_callback( + context: RequestContext["ClientSession", Any], + params: types.CreateMessageRequestParams, + ) -> types.CreateMessageResult | types.ErrorData: + return types.CreateMessageResult( + role="assistant", + content=types.TextContent(type="text", text="test"), + model="test-model", + ) + + async def custom_list_roots_callback( + context: RequestContext["ClientSession", Any], + ) -> types.ListRootsResult | types.ErrorData: + return types.ListRootsResult(roots=[]) + + async def mock_server(): + nonlocal received_capabilities + + session_message = await client_to_server_receive.receive() + jsonrpc_request = session_message.message + assert isinstance(jsonrpc_request.root, JSONRPCRequest) + request = ClientRequest.model_validate( + jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) + ) + assert isinstance(request.root, InitializeRequest) + received_capabilities = request.root.params.capabilities + + result = ServerResult( + InitializeResult( + protocolVersion=LATEST_PROTOCOL_VERSION, + capabilities=ServerCapabilities(), + serverInfo=Implementation(name="mock-server", version="0.1.0"), + ) + ) + + async with server_to_client_send: + await server_to_client_send.send( + SessionMessage( + JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.root.id, + result=result.model_dump( + by_alias=True, mode="json", exclude_none=True + ), + ) + ) + ) + ) + # Receive initialized notification + await client_to_server_receive.receive() + + async with ( + ClientSession( + server_to_client_receive, + client_to_server_send, + sampling_callback=custom_sampling_callback, + list_roots_callback=custom_list_roots_callback, + ) as session, + anyio.create_task_group() as tg, + client_to_server_send, + client_to_server_receive, + server_to_client_send, + server_to_client_receive, + ): + tg.start_soon(mock_server) + await session.initialize() + + # Assert that capabilities are properly set with custom callbacks + assert received_capabilities is not None + assert received_capabilities.sampling is not None # Custom sampling callback provided + assert isinstance(received_capabilities.sampling, types.SamplingCapability) + assert received_capabilities.roots is not None # Custom list_roots callback provided + assert isinstance(received_capabilities.roots, types.RootsCapability) + assert received_capabilities.roots.listChanged is True # Should be True for custom callback From 8afaefd851e4d87c5dab51e5be5fc70cc89e4aa1 Mon Sep 17 00:00:00 2001 From: Lorenzo C Date: Sat, 24 May 2025 22:30:04 -0300 Subject: [PATCH 2/2] refactor: apply formatting --- src/mcp/client/session_group.py | 1 - tests/client/test_session.py | 14 ++++++++++---- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/src/mcp/client/session_group.py b/src/mcp/client/session_group.py index a430533b3..a77dc7a1e 100644 --- a/src/mcp/client/session_group.py +++ b/src/mcp/client/session_group.py @@ -154,7 +154,6 @@ async def __aexit__( for exit_stack in self._session_exit_stacks.values(): tg.start_soon(exit_stack.aclose) - @property def sessions(self) -> list[mcp.ClientSession]: """Returns the list of sessions being managed.""" diff --git a/tests/client/test_session.py b/tests/client/test_session.py index f48955420..72b4413d2 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -451,7 +451,7 @@ async def mock_server(): # Assert that capabilities are properly set with defaults assert received_capabilities is not None assert received_capabilities.sampling is None # No custom sampling callback - assert received_capabilities.roots is None # No custom list_roots callback + assert received_capabilities.roots is None # No custom list_roots callback @pytest.mark.anyio @@ -536,8 +536,14 @@ async def mock_server(): # Assert that capabilities are properly set with custom callbacks assert received_capabilities is not None - assert received_capabilities.sampling is not None # Custom sampling callback provided + assert ( + received_capabilities.sampling is not None + ) # Custom sampling callback provided assert isinstance(received_capabilities.sampling, types.SamplingCapability) - assert received_capabilities.roots is not None # Custom list_roots callback provided + assert ( + received_capabilities.roots is not None + ) # Custom list_roots callback provided assert isinstance(received_capabilities.roots, types.RootsCapability) - assert received_capabilities.roots.listChanged is True # Should be True for custom callback + assert ( + received_capabilities.roots.listChanged is True + ) # Should be True for custom callback