Skip to content

fix: improve misleading warning for progress callback exceptions #775

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 43 additions & 25 deletions src/mcp/shared/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,11 +351,21 @@ async def _receive_loop(self) -> None:
if isinstance(message, Exception):
await self._handle_incoming(message)
elif isinstance(message.message.root, JSONRPCRequest):
validated_request = self._receive_request_type.model_validate(
message.message.root.model_dump(
by_alias=True, mode="json", exclude_none=True
try:
validated_request = self._receive_request_type.model_validate(
message.message.root.model_dump(
by_alias=True, mode="json", exclude_none=True
)
)
)
except Exception as e:
# For other validation errors, log and continue
logging.warning(
"Failed to validate request: %s. Message was: %s",
e,
message.message.root,
)
continue

responder = RequestResponder(
request_id=message.message.root.id,
request_meta=validated_request.root.params.meta
Expand All @@ -379,32 +389,40 @@ async def _receive_loop(self) -> None:
by_alias=True, mode="json", exclude_none=True
)
)
# Handle cancellation notifications
if isinstance(notification.root, CancelledNotification):
cancelled_id = notification.root.params.requestId
if cancelled_id in self._in_flight:
await self._in_flight[cancelled_id].cancel()
else:
# Handle progress notifications callback
if isinstance(notification.root, ProgressNotification):
progress_token = notification.root.params.progressToken
# If there is a progress callback for this token,
# call it with the progress information
if progress_token in self._progress_callbacks:
callback = self._progress_callbacks[progress_token]
except Exception as e:
# For other validation errors, log and continue
logging.warning(
"Failed to validate notification: %s. Message was: %s",
e,
message.message.root,
)
continue
# Handle cancellation notifications
if isinstance(notification.root, CancelledNotification):
cancelled_id = notification.root.params.requestId
if cancelled_id in self._in_flight:
await self._in_flight[cancelled_id].cancel()
else:
# Handle progress notifications callback
if isinstance(notification.root, ProgressNotification):
progress_token = notification.root.params.progressToken
# If there is a progress callback for this token,
# call it with the progress information
if progress_token in self._progress_callbacks:
callback = self._progress_callbacks[progress_token]
try:
await callback(
notification.root.params.progress,
notification.root.params.total,
notification.root.params.message,
)
await self._received_notification(notification)
await self._handle_incoming(notification)
except Exception as e:
# For other validation errors, log and continue
logging.warning(
f"Failed to validate notification: {e}. "
f"Message was: {message.message.root}"
)
except Exception as e:
logging.warning(
"Progress callback raised an exception: %s",
e,
)
await self._received_notification(notification)
await self._handle_incoming(notification)
else: # Response or error
stream = self._response_streams.pop(message.message.root.id, None)
if stream:
Expand Down
80 changes: 80 additions & 0 deletions tests/shared/test_progress_notifications.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any, cast
from unittest.mock import patch

import anyio
import pytest
Expand All @@ -10,12 +11,16 @@
from mcp.server.models import InitializationOptions
from mcp.server.session import ServerSession
from mcp.shared.context import RequestContext
from mcp.shared.memory import create_connected_server_and_client_session
from mcp.shared.progress import progress
from mcp.shared.session import (
BaseSession,
RequestResponder,
SessionMessage,
)
from mcp.types import (
TextContent,
)


@pytest.mark.anyio
Expand Down Expand Up @@ -347,3 +352,78 @@ async def handle_client_message(
assert server_progress_updates[3]["progress"] == 100
assert server_progress_updates[3]["total"] == 100
assert server_progress_updates[3]["message"] == "Processing results..."


@pytest.mark.anyio
async def test_progress_callback_exception_logging():
"""Test that exceptions in progress callbacks are logged and \
don't crash the session."""
# Track logged warnings
logged_warnings = []

def mock_warning(msg, *args):
logged_warnings.append(msg % args if args else msg)

# Create a progress callback that raises an exception
async def failing_progress_callback(
progress: float, total: float | None, message: str | None
) -> None:
raise ValueError("Progress callback failed!")

# Create a server with a tool that sends progress notifications
server = Server(name="TestProgressServer")

@server.call_tool()
async def handle_call_tool(
name: str, arguments: dict | None
) -> list[types.TextContent]:
if name == "progress_tool":
# Send a progress notification
await server.request_context.session.send_progress_notification(
progress_token=server.request_context.request_id,
progress=50.0,
total=100.0,
message="Halfway done",
)
return [types.TextContent(type="text", text="progress_result")]
raise ValueError(f"Unknown tool: {name}")

@server.list_tools()
async def handle_list_tools() -> list[types.Tool]:
return [
types.Tool(
name="progress_tool",
description="A tool that sends progress notifications",
inputSchema={},
)
]

# Test with mocked logging
with patch("mcp.shared.session.logging.warning", side_effect=mock_warning):
async with create_connected_server_and_client_session(server) as client_session:
# Send a request with a failing progress callback
result = await client_session.send_request(
types.ClientRequest(
types.CallToolRequest(
method="tools/call",
params=types.CallToolRequestParams(
name="progress_tool", arguments={}
),
)
),
types.CallToolResult,
progress_callback=failing_progress_callback,
)

# Verify the request completed successfully despite the callback failure
assert len(result.content) == 1
content = result.content[0]
assert isinstance(content, TextContent)
assert content.text == "progress_result"

# Check that a warning was logged for the progress callback exception
assert len(logged_warnings) > 0
assert any(
"Progress callback raised an exception" in warning
for warning in logged_warnings
)
Loading