Skip to content

Client sampling and roots capabilities set to None if not implemented #802

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions src/mcp/client/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,13 +116,20 @@ def __init__(
self._message_handler = message_handler or _default_message_handler

async def initialize(self) -> types.InitializeResult:
sampling = types.SamplingCapability()
roots = types.RootsCapability(
# TODO: Should this be based on whether we
# _will_ send notifications, or only whether
# they're supported?
listChanged=True,
sampling = (
types.SamplingCapability()
if self._sampling_callback is not _default_sampling_callback
else None
)
if self._list_roots_callback is _default_list_roots_callback:
roots = None
else:
roots = types.RootsCapability(
# TODO: Should this be based on whether we
# _will_ send notifications, or only whether
# they're supported?
listChanged=True,
)

result = await self.send_request(
types.ClientRequest(
Expand Down
1 change: 0 additions & 1 deletion src/mcp/client/session_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,6 @@ async def __aexit__(
for exit_stack in self._session_exit_stacks.values():
tg.start_soon(exit_stack.aclose)


@property
def sessions(self) -> list[mcp.ClientSession]:
"""Returns the list of sessions being managed."""
Expand Down
2 changes: 1 addition & 1 deletion src/mcp/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ class RootsCapability(BaseModel):


class SamplingCapability(BaseModel):
"""Capability for logging operations."""
"""Capability for sampling operations."""

model_config = ConfigDict(extra="allow")

Expand Down
167 changes: 167 additions & 0 deletions tests/client/test_session.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from typing import Any

import anyio
import pytest

import mcp.types as types
from mcp.client.session import DEFAULT_CLIENT_INFO, ClientSession
from mcp.shared.context import RequestContext
from mcp.shared.message import SessionMessage
from mcp.shared.session import RequestResponder
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
Expand Down Expand Up @@ -380,3 +383,167 @@ async def mock_server():
# Should raise RuntimeError for unsupported version
with pytest.raises(RuntimeError, match="Unsupported protocol version"):
await session.initialize()


@pytest.mark.anyio
async def test_client_capabilities_default():
"""Test that client capabilities are properly set with default callbacks"""
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
SessionMessage
](1)
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
SessionMessage
](1)

received_capabilities = None

async def mock_server():
nonlocal received_capabilities

session_message = await client_to_server_receive.receive()
jsonrpc_request = session_message.message
assert isinstance(jsonrpc_request.root, JSONRPCRequest)
request = ClientRequest.model_validate(
jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True)
)
assert isinstance(request.root, InitializeRequest)
received_capabilities = request.root.params.capabilities

result = ServerResult(
InitializeResult(
protocolVersion=LATEST_PROTOCOL_VERSION,
capabilities=ServerCapabilities(),
serverInfo=Implementation(name="mock-server", version="0.1.0"),
)
)

async with server_to_client_send:
await server_to_client_send.send(
SessionMessage(
JSONRPCMessage(
JSONRPCResponse(
jsonrpc="2.0",
id=jsonrpc_request.root.id,
result=result.model_dump(
by_alias=True, mode="json", exclude_none=True
),
)
)
)
)
# Receive initialized notification
await client_to_server_receive.receive()

async with (
ClientSession(
server_to_client_receive,
client_to_server_send,
) as session,
anyio.create_task_group() as tg,
client_to_server_send,
client_to_server_receive,
server_to_client_send,
server_to_client_receive,
):
tg.start_soon(mock_server)
await session.initialize()

# Assert that capabilities are properly set with defaults
assert received_capabilities is not None
assert received_capabilities.sampling is None # No custom sampling callback
assert received_capabilities.roots is None # No custom list_roots callback


@pytest.mark.anyio
async def test_client_capabilities_with_custom_callbacks():
"""Test that client capabilities are properly set with custom callbacks"""
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
SessionMessage
](1)
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
SessionMessage
](1)

received_capabilities = None

async def custom_sampling_callback(
context: RequestContext["ClientSession", Any],
params: types.CreateMessageRequestParams,
) -> types.CreateMessageResult | types.ErrorData:
return types.CreateMessageResult(
role="assistant",
content=types.TextContent(type="text", text="test"),
model="test-model",
)

async def custom_list_roots_callback(
context: RequestContext["ClientSession", Any],
) -> types.ListRootsResult | types.ErrorData:
return types.ListRootsResult(roots=[])

async def mock_server():
nonlocal received_capabilities

session_message = await client_to_server_receive.receive()
jsonrpc_request = session_message.message
assert isinstance(jsonrpc_request.root, JSONRPCRequest)
request = ClientRequest.model_validate(
jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True)
)
assert isinstance(request.root, InitializeRequest)
received_capabilities = request.root.params.capabilities

result = ServerResult(
InitializeResult(
protocolVersion=LATEST_PROTOCOL_VERSION,
capabilities=ServerCapabilities(),
serverInfo=Implementation(name="mock-server", version="0.1.0"),
)
)

async with server_to_client_send:
await server_to_client_send.send(
SessionMessage(
JSONRPCMessage(
JSONRPCResponse(
jsonrpc="2.0",
id=jsonrpc_request.root.id,
result=result.model_dump(
by_alias=True, mode="json", exclude_none=True
),
)
)
)
)
# Receive initialized notification
await client_to_server_receive.receive()

async with (
ClientSession(
server_to_client_receive,
client_to_server_send,
sampling_callback=custom_sampling_callback,
list_roots_callback=custom_list_roots_callback,
) as session,
anyio.create_task_group() as tg,
client_to_server_send,
client_to_server_receive,
server_to_client_send,
server_to_client_receive,
):
tg.start_soon(mock_server)
await session.initialize()

# Assert that capabilities are properly set with custom callbacks
assert received_capabilities is not None
assert (
received_capabilities.sampling is not None
) # Custom sampling callback provided
assert isinstance(received_capabilities.sampling, types.SamplingCapability)
assert (
received_capabilities.roots is not None
) # Custom list_roots callback provided
assert isinstance(received_capabilities.roots, types.RootsCapability)
assert (
received_capabilities.roots.listChanged is True
) # Should be True for custom callback
Loading