From 71a9418ae20394a1347ad0f7dd563f3444309565 Mon Sep 17 00:00:00 2001 From: Akshey D Date: Tue, 1 Apr 2025 23:09:21 -0400 Subject: [PATCH 1/9] add message type --- src/mcp/client/session.py | 7 ++++++- src/mcp/server/session.py | 7 ++++++- src/mcp/shared/progress.py | 5 +++-- src/mcp/shared/session.py | 6 +++++- src/mcp/types.py | 5 +++++ 5 files changed, 25 insertions(+), 5 deletions(-) diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 65d5e11e2..9ac1b9798 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -163,7 +163,11 @@ async def send_ping(self) -> types.EmptyResult: ) async def send_progress_notification( - self, progress_token: str | int, progress: float, total: float | None = None + self, + progress_token: str | int, + progress: float, + total: float | None = None, + message: str | None = None, ) -> None: """Send a progress notification.""" await self.send_notification( @@ -174,6 +178,7 @@ async def send_progress_notification( progressToken=progress_token, progress=progress, total=total, + message=message, ), ), ) diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index 568ecd4b9..9d850dc26 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -261,7 +261,11 @@ async def send_ping(self) -> types.EmptyResult: ) async def send_progress_notification( - self, progress_token: str | int, progress: float, total: float | None = None + self, + progress_token: str | int, + progress: float, + total: float | None = None, + message: str | None = None, ) -> None: """Send a progress notification.""" await self.send_notification( @@ -272,6 +276,7 @@ async def send_progress_notification( progressToken=progress_token, progress=progress, total=total, + message=message, ), ) ) diff --git a/src/mcp/shared/progress.py b/src/mcp/shared/progress.py index 52e0017d0..2326fdef7 100644 --- a/src/mcp/shared/progress.py +++ b/src/mcp/shared/progress.py @@ -42,12 +42,13 @@ class ProgressContext( progress_token: ProgressToken total: float | None current: float = field(default=0.0, init=False) + message: str | None async def progress(self, amount: float) -> None: self.current += amount await self.session.send_progress_notification( - self.progress_token, self.current, total=self.total + self.progress_token, self.current, total=self.total, message=self.message ) @@ -77,7 +78,7 @@ def progress( if ctx.meta is None or ctx.meta.progressToken is None: raise ValueError("No progress token provided") - progress_ctx = ProgressContext(ctx.session, ctx.meta.progressToken, total) + progress_ctx = ProgressContext(ctx.session, ctx.meta.progressToken, total, None) try: yield progress_ctx finally: diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 05fd3ce37..56ce24eba 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -377,7 +377,11 @@ async def _received_notification(self, notification: ReceiveNotificationT) -> No """ async def send_progress_notification( - self, progress_token: str | int, progress: float, total: float | None = None + self, + progress_token: str | int, + progress: float, + total: float | None = None, + message: str | None = None, ) -> None: """ Sends a progress notification for a request that is currently being diff --git a/src/mcp/types.py b/src/mcp/types.py index bd71d51f0..d0d7eef90 100644 --- a/src/mcp/types.py +++ b/src/mcp/types.py @@ -337,6 +337,11 @@ class ProgressNotificationParams(NotificationParams): total is unknown. """ total: float | None = None + """ + Message related to progress. This should provide relevant human readble + progress information. + """ + message: str | None = None """Total number of items to process (or total progress required), if known.""" model_config = ConfigDict(extra="allow") From 80ad86285d46afefea52796fac8737b050fc991e Mon Sep 17 00:00:00 2001 From: Akshey D Date: Sat, 5 Apr 2025 16:14:14 -0400 Subject: [PATCH 2/9] more type and test case updates --- src/mcp/server/fastmcp/server.py | 8 +- src/mcp/server/lowlevel/server.py | 12 +- src/mcp/types.py | 2 +- tests/issues/test_176_progress_token.py | 6 +- tests/shared/test_progress_notifications.py | 214 ++++++++++++++++++++ 5 files changed, 233 insertions(+), 9 deletions(-) create mode 100644 tests/shared/test_progress_notifications.py diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index bf0ce880a..84827137b 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -621,13 +621,14 @@ def request_context(self) -> RequestContext[ServerSessionT, LifespanContextT]: return self._request_context async def report_progress( - self, progress: float, total: float | None = None + self, progress: float, total: float | None = None, message: str | None = None ) -> None: """Report progress for the current operation. Args: progress: Current progress value e.g. 24 total: Optional total value e.g. 100 + message: Optional message e.g. Starting render... """ progress_token = ( @@ -640,7 +641,10 @@ async def report_progress( return await self.request_context.session.send_progress_notification( - progress_token=progress_token, progress=progress, total=total + progress_token=progress_token, + progress=progress, + total=total, + message=message, ) async def read_resource(self, uri: str | AnyUrl) -> Iterable[ReadResourceContents]: diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index dbaff3051..9cbc9b03b 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -37,7 +37,8 @@ async def handle_list_resource_templates() -> list[types.ResourceTemplate]: 3. Define notification handlers if needed: @server.progress_notification() async def handle_progress( - progress_token: str | int, progress: float, total: float | None + progress_token: str | int, progress: float, total: float | None, + message: str | None ) -> None: # Implementation @@ -426,13 +427,18 @@ async def handler(req: types.CallToolRequest): def progress_notification(self): def decorator( - func: Callable[[str | int, float, float | None], Awaitable[None]], + func: Callable[ + [str | int, float, float | None, str | None], Awaitable[None] + ], ): logger.debug("Registering handler for ProgressNotification") async def handler(req: types.ProgressNotification): await func( - req.params.progressToken, req.params.progress, req.params.total + req.params.progressToken, + req.params.progress, + req.params.total, + req.params.message, ) self.notification_handlers[types.ProgressNotification] = handler diff --git a/src/mcp/types.py b/src/mcp/types.py index d0d7eef90..72cf0e197 100644 --- a/src/mcp/types.py +++ b/src/mcp/types.py @@ -338,7 +338,7 @@ class ProgressNotificationParams(NotificationParams): """ total: float | None = None """ - Message related to progress. This should provide relevant human readble + Message related to progress. This should provide relevant human readable progress information. """ message: str | None = None diff --git a/tests/issues/test_176_progress_token.py b/tests/issues/test_176_progress_token.py index 7f9131a1e..4ad22f294 100644 --- a/tests/issues/test_176_progress_token.py +++ b/tests/issues/test_176_progress_token.py @@ -39,11 +39,11 @@ async def test_progress_token_zero_first_call(): mock_session.send_progress_notification.call_count == 3 ), "All progress notifications should be sent" mock_session.send_progress_notification.assert_any_call( - progress_token=0, progress=0.0, total=10.0 + progress_token=0, progress=0.0, total=10.0, message=None ) mock_session.send_progress_notification.assert_any_call( - progress_token=0, progress=5.0, total=10.0 + progress_token=0, progress=5.0, total=10.0, message=None ) mock_session.send_progress_notification.assert_any_call( - progress_token=0, progress=10.0, total=10.0 + progress_token=0, progress=10.0, total=10.0, message=None ) diff --git a/tests/shared/test_progress_notifications.py b/tests/shared/test_progress_notifications.py new file mode 100644 index 000000000..de3130660 --- /dev/null +++ b/tests/shared/test_progress_notifications.py @@ -0,0 +1,214 @@ +import anyio +import pytest + +import mcp.types as types +from mcp.client.session import ClientSession +from mcp.server import Server +from mcp.server.lowlevel import NotificationOptions +from mcp.server.models import InitializationOptions +from mcp.server.session import ServerSession +from mcp.shared.session import RequestResponder +from mcp.types import ( + JSONRPCMessage, +) + + +@pytest.mark.anyio +async def test_bidirectional_progress_notifications(): + """Test that both client and server can send progress notifications.""" + # Create memory streams for client/server + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ + JSONRPCMessage + ](5) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ + JSONRPCMessage + ](5) + + # Run a server session so we can send progress updates in tool + async def run_server(): + # Create a server session + async with ServerSession( + client_to_server_receive, + server_to_client_send, + InitializationOptions( + server_name="ProgressTestServer", + server_version="0.1.0", + capabilities=server.get_capabilities(NotificationOptions(), {}), + ), + ) as server_session: + global serv_sesh + + serv_sesh = server_session + async for message in server_session.incoming_messages: + try: + await server._handle_message(message, server_session, ()) + except Exception as e: + raise e + + # Track progress updates + server_progress_updates = [] + client_progress_updates = [] + + # Progress tokens + server_progress_token = "server_token_123" + client_progress_token = "client_token_456" + + # Create a server with progress capability + server = Server(name="ProgressTestServer") + + # Register progress handler + @server.progress_notification() + async def handle_progress( + progress_token: str | int, + progress: float, + total: float | None, + message: str | None, + ): + server_progress_updates.append( + { + "token": progress_token, + "progress": progress, + "total": total, + "message": message, + } + ) + + # Register list tool handler + @server.list_tools() + async def handle_list_tools() -> list[types.Tool]: + return [ + types.Tool( + name="test_tool", + description="A tool that sends progress notifications list: + # Make sure we received a progress token + if name == "test_tool": + if arguments and "_meta" in arguments: + progressToken = arguments["_meta"]["progressToken"] + + if not progressToken: + raise ValueError("Empty progress token received") + + if progressToken != client_progress_token: + raise ValueError("Server sending back incorrect progressToken") + + # Send progress notifications + await serv_sesh.send_progress_notification( + progress_token=progressToken, + progress=0.25, + total=1.0, + message="Server progress 25%", + ) + await anyio.sleep(0.2) + + await serv_sesh.send_progress_notification( + progress_token=progressToken, + progress=0.5, + total=1.0, + message="Server progress 50%", + ) + await anyio.sleep(0.2) + + await serv_sesh.send_progress_notification( + progress_token=progressToken, + progress=1.0, + total=1.0, + message="Server progress 100%", + ) + + else: + raise ValueError("Progress token not sent.") + + return ["Tool executed successfully"] + + raise ValueError(f"Unknown tool: {name}") + + # Client message handler to store progress notifications + async def handle_client_message( + message: RequestResponder[types.ServerRequest, types.ClientResult] + | types.ServerNotification + | Exception, + ) -> None: + if isinstance(message, Exception): + raise message + + if isinstance(message, types.ServerNotification): + if isinstance(message.root, types.ProgressNotification): + params = message.root.params + client_progress_updates.append( + { + "token": params.progressToken, + "progress": params.progress, + "total": params.total, + "message": params.message, + } + ) + + # Test using client + async with ( + ClientSession( + server_to_client_receive, + client_to_server_send, + message_handler=handle_client_message, + ) as client_session, + anyio.create_task_group() as tg, + ): + # Start the server in a background task + tg.start_soon(run_server) + + # Initialize the client connection + await client_session.initialize() + + # Call list_tools with progress token + await client_session.list_tools() + + # Call test_tool with progress token + await client_session.call_tool( + "test_tool", {"_meta": {"progressToken": client_progress_token}} + ) + + # Send progress notifications from client to server + await client_session.send_progress_notification( + progress_token=server_progress_token, + progress=0.33, + total=1.0, + message="Client progress 33%", + ) + + await client_session.send_progress_notification( + progress_token=server_progress_token, + progress=0.66, + total=1.0, + message="Client progress 66%", + ) + + await client_session.send_progress_notification( + progress_token=server_progress_token, + progress=1.0, + total=1.0, + message="Client progress 100%", + ) + + # Wait and exit + await anyio.sleep(1.0) + tg.cancel_scope.cancel() + + # Verify client received progress updates from server + assert len(client_progress_updates) == 3 + assert client_progress_updates[0]["token"] == client_progress_token + assert client_progress_updates[0]["progress"] == 0.25 + assert client_progress_updates[0]["message"] == "Server progress 25%" + assert client_progress_updates[2]["progress"] == 1.0 + + # Verify server received progress updates from client + assert len(server_progress_updates) == 3 + assert server_progress_updates[0]["token"] == server_progress_token + assert server_progress_updates[0]["progress"] == 0.33 + assert server_progress_updates[0]["message"] == "Client progress 33%" + assert server_progress_updates[2]["progress"] == 1.0 From aa92533f4ac3a846bba8fcc7ee3a6446f5ab827c Mon Sep 17 00:00:00 2001 From: Akshey D <131929364+aksheyd@users.noreply.github.com> Date: Wed, 7 May 2025 13:04:37 -0400 Subject: [PATCH 3/9] remove message from context manager Co-authored-by: ihrpr --- src/mcp/shared/progress.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mcp/shared/progress.py b/src/mcp/shared/progress.py index 2326fdef7..0322735e5 100644 --- a/src/mcp/shared/progress.py +++ b/src/mcp/shared/progress.py @@ -78,7 +78,7 @@ def progress( if ctx.meta is None or ctx.meta.progressToken is None: raise ValueError("No progress token provided") - progress_ctx = ProgressContext(ctx.session, ctx.meta.progressToken, total, None) + progress_ctx = ProgressContext(ctx.session, ctx.meta.progressToken, total) try: yield progress_ctx finally: From 380e2c3cc993a0e34f26f7979ecf87fa0dc5d053 Mon Sep 17 00:00:00 2001 From: Akshey D <131929364+aksheyd@users.noreply.github.com> Date: Wed, 7 May 2025 13:05:05 -0400 Subject: [PATCH 4/9] set default for message in progress() method Co-authored-by: ihrpr --- src/mcp/shared/progress.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mcp/shared/progress.py b/src/mcp/shared/progress.py index 0322735e5..594e41e81 100644 --- a/src/mcp/shared/progress.py +++ b/src/mcp/shared/progress.py @@ -44,7 +44,7 @@ class ProgressContext( current: float = field(default=0.0, init=False) message: str | None - async def progress(self, amount: float) -> None: + async def progress(self, amount: float, message: str | None = None) -> None: self.current += amount await self.session.send_progress_notification( From 115a9bf9d718fab731ddb2fa9c04e89fe1c910ee Mon Sep 17 00:00:00 2001 From: Akshey D <131929364+aksheyd@users.noreply.github.com> Date: Wed, 7 May 2025 13:05:20 -0400 Subject: [PATCH 5/9] remove message from context Co-authored-by: ihrpr --- src/mcp/shared/progress.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/mcp/shared/progress.py b/src/mcp/shared/progress.py index 594e41e81..50bf33cc1 100644 --- a/src/mcp/shared/progress.py +++ b/src/mcp/shared/progress.py @@ -42,7 +42,6 @@ class ProgressContext( progress_token: ProgressToken total: float | None current: float = field(default=0.0, init=False) - message: str | None async def progress(self, amount: float, message: str | None = None) -> None: self.current += amount From fcdad5836cc4ba5cb25b2e37121ed1b08739fab5 Mon Sep 17 00:00:00 2001 From: Akshey D <131929364+aksheyd@users.noreply.github.com> Date: Wed, 7 May 2025 13:06:17 -0400 Subject: [PATCH 6/9] use progress() method's message param not context Co-authored-by: ihrpr --- src/mcp/shared/progress.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mcp/shared/progress.py b/src/mcp/shared/progress.py index 50bf33cc1..856a8d3b6 100644 --- a/src/mcp/shared/progress.py +++ b/src/mcp/shared/progress.py @@ -47,7 +47,7 @@ async def progress(self, amount: float, message: str | None = None) -> None: self.current += amount await self.session.send_progress_notification( - self.progress_token, self.current, total=self.total, message=self.message + self.progress_token, self.current, total=self.total, message=message ) From 198f96a744897a2e1e874261558b263d3e88808d Mon Sep 17 00:00:00 2001 From: Akshey D Date: Wed, 7 May 2025 13:16:21 -0400 Subject: [PATCH 7/9] resolve pyright failure --- tests/shared/test_progress_notifications.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/tests/shared/test_progress_notifications.py b/tests/shared/test_progress_notifications.py index de3130660..78d31d897 100644 --- a/tests/shared/test_progress_notifications.py +++ b/tests/shared/test_progress_notifications.py @@ -7,10 +7,7 @@ from mcp.server.lowlevel import NotificationOptions from mcp.server.models import InitializationOptions from mcp.server.session import ServerSession -from mcp.shared.session import RequestResponder -from mcp.types import ( - JSONRPCMessage, -) +from mcp.shared.session import RequestResponder, SessionMessage @pytest.mark.anyio @@ -18,10 +15,10 @@ async def test_bidirectional_progress_notifications(): """Test that both client and server can send progress notifications.""" # Create memory streams for client/server server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ - JSONRPCMessage + SessionMessage ](5) client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ - JSONRPCMessage + SessionMessage ](5) # Run a server session so we can send progress updates in tool From 97d39022f308ca2ecc07bab28adbc5eb61dbfa32 Mon Sep 17 00:00:00 2001 From: Akshey D Date: Wed, 7 May 2025 23:28:17 -0400 Subject: [PATCH 8/9] test progress context manager --- tests/shared/test_progress_notifications.py | 149 +++++++++++++++++++- 1 file changed, 148 insertions(+), 1 deletion(-) diff --git a/tests/shared/test_progress_notifications.py b/tests/shared/test_progress_notifications.py index 78d31d897..a2a7bc0d4 100644 --- a/tests/shared/test_progress_notifications.py +++ b/tests/shared/test_progress_notifications.py @@ -1,3 +1,5 @@ +from typing import Any, cast + import anyio import pytest @@ -7,7 +9,13 @@ from mcp.server.lowlevel import NotificationOptions from mcp.server.models import InitializationOptions from mcp.server.session import ServerSession -from mcp.shared.session import RequestResponder, SessionMessage +from mcp.shared.context import RequestContext +from mcp.shared.progress import progress +from mcp.shared.session import ( + BaseSession, + RequestResponder, + SessionMessage, +) @pytest.mark.anyio @@ -209,3 +217,142 @@ async def handle_client_message( assert server_progress_updates[0]["progress"] == 0.33 assert server_progress_updates[0]["message"] == "Client progress 33%" assert server_progress_updates[2]["progress"] == 1.0 + + +@pytest.mark.anyio +async def test_progress_context_manager(): + """Test client using progress context manager for sending progress notifications.""" + # Create memory streams for client/server + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ + SessionMessage + ](5) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ + SessionMessage + ](5) + + # Track progress updates + server_progress_updates = [] + + server = Server(name="ProgressContextTestServer") + + # Register progress handler + @server.progress_notification() + async def handle_progress( + progress_token: str | int, + progress: float, + total: float | None, + message: str | None, + ): + server_progress_updates.append( + { + "token": progress_token, + "progress": progress, + "total": total, + "message": message, + } + ) + + # Run server session to receive progress updates + async def run_server(): + # Create a server session + async with ServerSession( + client_to_server_receive, + server_to_client_send, + InitializationOptions( + server_name="ProgressContextTestServer", + server_version="0.1.0", + capabilities=server.get_capabilities(NotificationOptions(), {}), + ), + ) as server_session: + async for message in server_session.incoming_messages: + try: + await server._handle_message(message, server_session, ()) + except Exception as e: + raise e + + # Client message handler + async def handle_client_message( + message: RequestResponder[types.ServerRequest, types.ClientResult] + | types.ServerNotification + | Exception, + ) -> None: + if isinstance(message, Exception): + raise message + + # run client session + async with ( + ClientSession( + server_to_client_receive, + client_to_server_send, + message_handler=handle_client_message, + ) as client_session, + anyio.create_task_group() as tg, + ): + tg.start_soon(run_server) + + await client_session.initialize() + + progress_token = "client_token_456" + + # Create request context + meta = types.RequestParams.Meta(progressToken=progress_token) + request_context = RequestContext( + request_id="test-request", + session=client_session, + meta=meta, + lifespan_context=None, + ) + + # cast for type checker + typed_context = cast( + RequestContext[ + BaseSession[Any, Any, Any, Any, Any], + Any, + ], + request_context, + ) + + # Utilize progress context manager + with progress(typed_context, total=100) as p: + await p.progress(10, message="Loading configuration...") + await anyio.sleep(0.1) + + await p.progress(30, message="Connecting to database...") + await anyio.sleep(0.1) + + await p.progress(40, message="Fetching data...") + await anyio.sleep(0.1) + + await p.progress(20, message="Processing results...") + await anyio.sleep(0.1) + + # Wait for all messages to be processed + await anyio.sleep(0.5) + tg.cancel_scope.cancel() + + # Verify progress updates were received by server + assert len(server_progress_updates) == 4 + + # first update + assert server_progress_updates[0]["token"] == progress_token + assert server_progress_updates[0]["progress"] == 10 + assert server_progress_updates[0]["total"] == 100 + assert server_progress_updates[0]["message"] == "Loading configuration..." + + # second update + assert server_progress_updates[1]["token"] == progress_token + assert server_progress_updates[1]["progress"] == 40 + assert server_progress_updates[1]["total"] == 100 + assert server_progress_updates[1]["message"] == "Connecting to database..." + + # third update + assert server_progress_updates[2]["token"] == progress_token + assert server_progress_updates[2]["progress"] == 80 + assert server_progress_updates[2]["total"] == 100 + assert server_progress_updates[2]["message"] == "Fetching data..." + + # final update + assert server_progress_updates[3]["token"] == progress_token + assert server_progress_updates[3]["progress"] == 100 + assert server_progress_updates[3]["total"] == 100 + assert server_progress_updates[3]["message"] == "Processing results..." From e4a71087398ca6217cc07b1272c3856f11230bfb Mon Sep 17 00:00:00 2001 From: Akshey D Date: Thu, 8 May 2025 08:31:28 -0400 Subject: [PATCH 9/9] remove unnecessary sleeps --- tests/shared/test_progress_notifications.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/tests/shared/test_progress_notifications.py b/tests/shared/test_progress_notifications.py index a2a7bc0d4..1e0409e14 100644 --- a/tests/shared/test_progress_notifications.py +++ b/tests/shared/test_progress_notifications.py @@ -110,7 +110,6 @@ async def handle_call_tool(name: str, arguments: dict | None) -> list: total=1.0, message="Server progress 25%", ) - await anyio.sleep(0.2) await serv_sesh.send_progress_notification( progress_token=progressToken, @@ -118,7 +117,6 @@ async def handle_call_tool(name: str, arguments: dict | None) -> list: total=1.0, message="Server progress 50%", ) - await anyio.sleep(0.2) await serv_sesh.send_progress_notification( progress_token=progressToken, @@ -201,7 +199,7 @@ async def handle_client_message( ) # Wait and exit - await anyio.sleep(1.0) + await anyio.sleep(0.5) tg.cancel_scope.cancel() # Verify client received progress updates from server @@ -315,16 +313,9 @@ async def handle_client_message( # Utilize progress context manager with progress(typed_context, total=100) as p: await p.progress(10, message="Loading configuration...") - await anyio.sleep(0.1) - await p.progress(30, message="Connecting to database...") - await anyio.sleep(0.1) - await p.progress(40, message="Fetching data...") - await anyio.sleep(0.1) - await p.progress(20, message="Processing results...") - await anyio.sleep(0.1) # Wait for all messages to be processed await anyio.sleep(0.5)