diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 7bb8821f7..1e8ab2042 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -168,7 +168,11 @@ async def send_ping(self) -> types.EmptyResult: ) async def send_progress_notification( - self, progress_token: str | int, progress: float, total: float | None = None + self, + progress_token: str | int, + progress: float, + total: float | None = None, + message: str | None = None, ) -> None: """Send a progress notification.""" await self.send_notification( @@ -179,6 +183,7 @@ async def send_progress_notification( progressToken=progress_token, progress=progress, total=total, + message=message, ), ), ) diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index c31f29d4c..8929eb6fd 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -952,13 +952,14 @@ def request_context(self) -> RequestContext[ServerSessionT, LifespanContextT]: return self._request_context async def report_progress( - self, progress: float, total: float | None = None + self, progress: float, total: float | None = None, message: str | None = None ) -> None: """Report progress for the current operation. Args: progress: Current progress value e.g. 24 total: Optional total value e.g. 100 + message: Optional message e.g. Starting render... """ progress_token = ( @@ -971,7 +972,10 @@ async def report_progress( return await self.request_context.session.send_progress_notification( - progress_token=progress_token, progress=progress, total=total + progress_token=progress_token, + progress=progress, + total=total, + message=message, ) async def read_resource(self, uri: str | AnyUrl) -> Iterable[ReadResourceContents]: diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 4b97b33da..876aef817 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -37,7 +37,8 @@ async def handle_list_resource_templates() -> list[types.ResourceTemplate]: 3. Define notification handlers if needed: @server.progress_notification() async def handle_progress( - progress_token: str | int, progress: float, total: float | None + progress_token: str | int, progress: float, total: float | None, + message: str | None ) -> None: # Implementation @@ -427,13 +428,18 @@ async def handler(req: types.CallToolRequest): def progress_notification(self): def decorator( - func: Callable[[str | int, float, float | None], Awaitable[None]], + func: Callable[ + [str | int, float, float | None, str | None], Awaitable[None] + ], ): logger.debug("Registering handler for ProgressNotification") async def handler(req: types.ProgressNotification): await func( - req.params.progressToken, req.params.progress, req.params.total + req.params.progressToken, + req.params.progress, + req.params.total, + req.params.message, ) self.notification_handlers[types.ProgressNotification] = handler diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index f4e72eac1..4f97c6cd6 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -282,6 +282,7 @@ async def send_progress_notification( progress_token: str | int, progress: float, total: float | None = None, + message: str | None = None, related_request_id: str | None = None, ) -> None: """Send a progress notification.""" @@ -293,6 +294,7 @@ async def send_progress_notification( progressToken=progress_token, progress=progress, total=total, + message=message, ), ) ), diff --git a/src/mcp/shared/progress.py b/src/mcp/shared/progress.py index 52e0017d0..856a8d3b6 100644 --- a/src/mcp/shared/progress.py +++ b/src/mcp/shared/progress.py @@ -43,11 +43,11 @@ class ProgressContext( total: float | None current: float = field(default=0.0, init=False) - async def progress(self, amount: float) -> None: + async def progress(self, amount: float, message: str | None = None) -> None: self.current += amount await self.session.send_progress_notification( - self.progress_token, self.current, total=self.total + self.progress_token, self.current, total=self.total, message=message ) diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index c390386a9..19728e0ec 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -401,7 +401,11 @@ async def _received_notification(self, notification: ReceiveNotificationT) -> No """ async def send_progress_notification( - self, progress_token: str | int, progress: float, total: float | None = None + self, + progress_token: str | int, + progress: float, + total: float | None = None, + message: str | None = None, ) -> None: """ Sends a progress notification for a request that is currently being diff --git a/src/mcp/types.py b/src/mcp/types.py index 6ab7fba5c..e01929b8a 100644 --- a/src/mcp/types.py +++ b/src/mcp/types.py @@ -337,6 +337,11 @@ class ProgressNotificationParams(NotificationParams): total is unknown. """ total: float | None = None + """ + Message related to progress. This should provide relevant human readable + progress information. + """ + message: str | None = None """Total number of items to process (or total progress required), if known.""" model_config = ConfigDict(extra="allow") diff --git a/tests/issues/test_176_progress_token.py b/tests/issues/test_176_progress_token.py index 7f9131a1e..4ad22f294 100644 --- a/tests/issues/test_176_progress_token.py +++ b/tests/issues/test_176_progress_token.py @@ -39,11 +39,11 @@ async def test_progress_token_zero_first_call(): mock_session.send_progress_notification.call_count == 3 ), "All progress notifications should be sent" mock_session.send_progress_notification.assert_any_call( - progress_token=0, progress=0.0, total=10.0 + progress_token=0, progress=0.0, total=10.0, message=None ) mock_session.send_progress_notification.assert_any_call( - progress_token=0, progress=5.0, total=10.0 + progress_token=0, progress=5.0, total=10.0, message=None ) mock_session.send_progress_notification.assert_any_call( - progress_token=0, progress=10.0, total=10.0 + progress_token=0, progress=10.0, total=10.0, message=None ) diff --git a/tests/shared/test_progress_notifications.py b/tests/shared/test_progress_notifications.py new file mode 100644 index 000000000..1e0409e14 --- /dev/null +++ b/tests/shared/test_progress_notifications.py @@ -0,0 +1,349 @@ +from typing import Any, cast + +import anyio +import pytest + +import mcp.types as types +from mcp.client.session import ClientSession +from mcp.server import Server +from mcp.server.lowlevel import NotificationOptions +from mcp.server.models import InitializationOptions +from mcp.server.session import ServerSession +from mcp.shared.context import RequestContext +from mcp.shared.progress import progress +from mcp.shared.session import ( + BaseSession, + RequestResponder, + SessionMessage, +) + + +@pytest.mark.anyio +async def test_bidirectional_progress_notifications(): + """Test that both client and server can send progress notifications.""" + # Create memory streams for client/server + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ + SessionMessage + ](5) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ + SessionMessage + ](5) + + # Run a server session so we can send progress updates in tool + async def run_server(): + # Create a server session + async with ServerSession( + client_to_server_receive, + server_to_client_send, + InitializationOptions( + server_name="ProgressTestServer", + server_version="0.1.0", + capabilities=server.get_capabilities(NotificationOptions(), {}), + ), + ) as server_session: + global serv_sesh + + serv_sesh = server_session + async for message in server_session.incoming_messages: + try: + await server._handle_message(message, server_session, ()) + except Exception as e: + raise e + + # Track progress updates + server_progress_updates = [] + client_progress_updates = [] + + # Progress tokens + server_progress_token = "server_token_123" + client_progress_token = "client_token_456" + + # Create a server with progress capability + server = Server(name="ProgressTestServer") + + # Register progress handler + @server.progress_notification() + async def handle_progress( + progress_token: str | int, + progress: float, + total: float | None, + message: str | None, + ): + server_progress_updates.append( + { + "token": progress_token, + "progress": progress, + "total": total, + "message": message, + } + ) + + # Register list tool handler + @server.list_tools() + async def handle_list_tools() -> list[types.Tool]: + return [ + types.Tool( + name="test_tool", + description="A tool that sends progress notifications list: + # Make sure we received a progress token + if name == "test_tool": + if arguments and "_meta" in arguments: + progressToken = arguments["_meta"]["progressToken"] + + if not progressToken: + raise ValueError("Empty progress token received") + + if progressToken != client_progress_token: + raise ValueError("Server sending back incorrect progressToken") + + # Send progress notifications + await serv_sesh.send_progress_notification( + progress_token=progressToken, + progress=0.25, + total=1.0, + message="Server progress 25%", + ) + + await serv_sesh.send_progress_notification( + progress_token=progressToken, + progress=0.5, + total=1.0, + message="Server progress 50%", + ) + + await serv_sesh.send_progress_notification( + progress_token=progressToken, + progress=1.0, + total=1.0, + message="Server progress 100%", + ) + + else: + raise ValueError("Progress token not sent.") + + return ["Tool executed successfully"] + + raise ValueError(f"Unknown tool: {name}") + + # Client message handler to store progress notifications + async def handle_client_message( + message: RequestResponder[types.ServerRequest, types.ClientResult] + | types.ServerNotification + | Exception, + ) -> None: + if isinstance(message, Exception): + raise message + + if isinstance(message, types.ServerNotification): + if isinstance(message.root, types.ProgressNotification): + params = message.root.params + client_progress_updates.append( + { + "token": params.progressToken, + "progress": params.progress, + "total": params.total, + "message": params.message, + } + ) + + # Test using client + async with ( + ClientSession( + server_to_client_receive, + client_to_server_send, + message_handler=handle_client_message, + ) as client_session, + anyio.create_task_group() as tg, + ): + # Start the server in a background task + tg.start_soon(run_server) + + # Initialize the client connection + await client_session.initialize() + + # Call list_tools with progress token + await client_session.list_tools() + + # Call test_tool with progress token + await client_session.call_tool( + "test_tool", {"_meta": {"progressToken": client_progress_token}} + ) + + # Send progress notifications from client to server + await client_session.send_progress_notification( + progress_token=server_progress_token, + progress=0.33, + total=1.0, + message="Client progress 33%", + ) + + await client_session.send_progress_notification( + progress_token=server_progress_token, + progress=0.66, + total=1.0, + message="Client progress 66%", + ) + + await client_session.send_progress_notification( + progress_token=server_progress_token, + progress=1.0, + total=1.0, + message="Client progress 100%", + ) + + # Wait and exit + await anyio.sleep(0.5) + tg.cancel_scope.cancel() + + # Verify client received progress updates from server + assert len(client_progress_updates) == 3 + assert client_progress_updates[0]["token"] == client_progress_token + assert client_progress_updates[0]["progress"] == 0.25 + assert client_progress_updates[0]["message"] == "Server progress 25%" + assert client_progress_updates[2]["progress"] == 1.0 + + # Verify server received progress updates from client + assert len(server_progress_updates) == 3 + assert server_progress_updates[0]["token"] == server_progress_token + assert server_progress_updates[0]["progress"] == 0.33 + assert server_progress_updates[0]["message"] == "Client progress 33%" + assert server_progress_updates[2]["progress"] == 1.0 + + +@pytest.mark.anyio +async def test_progress_context_manager(): + """Test client using progress context manager for sending progress notifications.""" + # Create memory streams for client/server + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ + SessionMessage + ](5) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ + SessionMessage + ](5) + + # Track progress updates + server_progress_updates = [] + + server = Server(name="ProgressContextTestServer") + + # Register progress handler + @server.progress_notification() + async def handle_progress( + progress_token: str | int, + progress: float, + total: float | None, + message: str | None, + ): + server_progress_updates.append( + { + "token": progress_token, + "progress": progress, + "total": total, + "message": message, + } + ) + + # Run server session to receive progress updates + async def run_server(): + # Create a server session + async with ServerSession( + client_to_server_receive, + server_to_client_send, + InitializationOptions( + server_name="ProgressContextTestServer", + server_version="0.1.0", + capabilities=server.get_capabilities(NotificationOptions(), {}), + ), + ) as server_session: + async for message in server_session.incoming_messages: + try: + await server._handle_message(message, server_session, ()) + except Exception as e: + raise e + + # Client message handler + async def handle_client_message( + message: RequestResponder[types.ServerRequest, types.ClientResult] + | types.ServerNotification + | Exception, + ) -> None: + if isinstance(message, Exception): + raise message + + # run client session + async with ( + ClientSession( + server_to_client_receive, + client_to_server_send, + message_handler=handle_client_message, + ) as client_session, + anyio.create_task_group() as tg, + ): + tg.start_soon(run_server) + + await client_session.initialize() + + progress_token = "client_token_456" + + # Create request context + meta = types.RequestParams.Meta(progressToken=progress_token) + request_context = RequestContext( + request_id="test-request", + session=client_session, + meta=meta, + lifespan_context=None, + ) + + # cast for type checker + typed_context = cast( + RequestContext[ + BaseSession[Any, Any, Any, Any, Any], + Any, + ], + request_context, + ) + + # Utilize progress context manager + with progress(typed_context, total=100) as p: + await p.progress(10, message="Loading configuration...") + await p.progress(30, message="Connecting to database...") + await p.progress(40, message="Fetching data...") + await p.progress(20, message="Processing results...") + + # Wait for all messages to be processed + await anyio.sleep(0.5) + tg.cancel_scope.cancel() + + # Verify progress updates were received by server + assert len(server_progress_updates) == 4 + + # first update + assert server_progress_updates[0]["token"] == progress_token + assert server_progress_updates[0]["progress"] == 10 + assert server_progress_updates[0]["total"] == 100 + assert server_progress_updates[0]["message"] == "Loading configuration..." + + # second update + assert server_progress_updates[1]["token"] == progress_token + assert server_progress_updates[1]["progress"] == 40 + assert server_progress_updates[1]["total"] == 100 + assert server_progress_updates[1]["message"] == "Connecting to database..." + + # third update + assert server_progress_updates[2]["token"] == progress_token + assert server_progress_updates[2]["progress"] == 80 + assert server_progress_updates[2]["total"] == 100 + assert server_progress_updates[2]["message"] == "Fetching data..." + + # final update + assert server_progress_updates[3]["token"] == progress_token + assert server_progress_updates[3]["progress"] == 100 + assert server_progress_updates[3]["total"] == 100 + assert server_progress_updates[3]["message"] == "Processing results..."