From 120c5df2473cc20190931eff07546b5af755f966 Mon Sep 17 00:00:00 2001 From: ihrpr Date: Thu, 15 May 2025 08:30:32 +0100 Subject: [PATCH 1/9] add tests --- src/mcp/server/session.py | 3 + tests/server/fastmcp/test_integration.py | 710 ++++++++++++++++++++++- 2 files changed, 712 insertions(+), 1 deletion(-) diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index 4f97c6cd6..9177a1609 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -286,6 +286,9 @@ async def send_progress_notification( related_request_id: str | None = None, ) -> None: """Send a progress notification.""" + print( + f"Sending progress notification: {progress_token}, {progress}, {total}, {message}" + ) await self.send_notification( types.ServerNotification( types.ProgressNotification( diff --git a/tests/server/fastmcp/test_integration.py b/tests/server/fastmcp/test_integration.py index 67911e9e7..cc17a6d2b 100644 --- a/tests/server/fastmcp/test_integration.py +++ b/tests/server/fastmcp/test_integration.py @@ -10,6 +10,7 @@ import time from collections.abc import Generator +import anyio import pytest import uvicorn @@ -17,7 +18,9 @@ from mcp.client.sse import sse_client from mcp.client.streamable_http import streamablehttp_client from mcp.server.fastmcp import FastMCP -from mcp.types import InitializeResult, TextContent +import mcp.types as types +from mcp.types import InitializeResult, TextContent, TextResourceContents +from pydantic import AnyUrl @pytest.fixture @@ -80,6 +83,125 @@ def echo(message: str) -> str: return mcp, app +def make_comprehensive_fastmcp() -> FastMCP: + """Create a FastMCP server with all features enabled for testing.""" + from mcp.server.fastmcp import Context + + mcp = FastMCP(name="AllFeaturesServer") + + # Tool with context for logging and progress + @mcp.tool(description="A tool that demonstrates logging and progress") + async def tool_with_context(message: str, ctx: Context, steps: int = 3) -> str: + await ctx.info(f"Starting processing of '{message}' with {steps} steps") + + # Send progress notifications + for i in range(steps): + progress_value = (i + 1) / steps + await ctx.report_progress( + progress=progress_value, + total=1.0, + message=f"Processing step {i + 1} of {steps}", + ) + await ctx.debug(f"Completed step {i + 1}") + + return f"Processed '{message}' in {steps} steps" + + # Simple tool for basic functionality + @mcp.tool(description="A simple echo tool") + def echo(message: str) -> str: + return f"Echo: {message}" + + # Tool with sampling capability + @mcp.tool(description="A tool that uses sampling to generate content") + async def sampling_tool(prompt: str, ctx: Context) -> str: + from mcp.types import SamplingMessage, TextContent + + await ctx.info(f"Requesting sampling for prompt: {prompt}") + + # Request sampling from the client + result = await ctx.session.create_message( + messages=[ + SamplingMessage( + role="user", content=TextContent(type="text", text=prompt) + ) + ], + max_tokens=100, + temperature=0.7, + ) + + await ctx.info(f"Received sampling result from model: {result.model}") + # Handle different content types safely + if hasattr(result.content, "text"): + return f"Sampling result: {result.content.text[:100]}..." + else: + return f"Sampling result: {str(result.content)[:100]}..." + + # Tool that sends notifications and logging + @mcp.tool(description="A tool that demonstrates notifications and logging") + async def notification_tool(message: str, ctx: Context) -> str: + # Send different log levels + await ctx.debug("Debug: Starting notification tool") + await ctx.info(f"Info: Processing message '{message}'") + await ctx.warning("Warning: This is a test warning") + + # Send resource change notifications + await ctx.session.send_resource_list_changed() + await ctx.session.send_tool_list_changed() + + await ctx.info("Completed notification tool successfully") + return f"Sent notifications and logs for: {message}" + + # Resource - static + from pydantic import AnyUrl + + from mcp.server.fastmcp.resources import FunctionResource + + def get_static_info() -> str: + return "This is static resource content" + + static_resource = FunctionResource( + uri=AnyUrl("resource://static/info"), + name="Static Info", + description="Static information resource", + fn=get_static_info, + ) + mcp.add_resource(static_resource) + + # Resource - dynamic function + @mcp.resource("resource://dynamic/{category}") + def dynamic_resource(category: str) -> str: + return f"Dynamic resource content for category: {category}" + + # Resource template + @mcp.resource("resource://template/{id}/data") + def template_resource(id: str) -> str: + return f"Template resource data for ID: {id}" + + # Prompt - simple + @mcp.prompt(description="A simple prompt") + def simple_prompt(topic: str) -> str: + return f"Tell me about {topic}" + + # Prompt - complex with multiple messages + @mcp.prompt(description="Complex prompt with context") + def complex_prompt(user_query: str, context: str = "general") -> str: + # For simplicity, return a single string that incorporates the context + # Since FastMCP doesn't support system messages in the same way + return f"Context: {context}. Query: {user_query}" + + return mcp + + +def make_comprehensive_fastmcp_app(): + """Create a comprehensive FastMCP server with SSE transport.""" + from starlette.applications import Starlette + + mcp = make_comprehensive_fastmcp() + # Create the SSE app + app: Starlette = mcp.sse_app() + return mcp, app + + def make_fastmcp_streamable_http_app(): """Create a FastMCP server with StreamableHTTP transport.""" from starlette.applications import Starlette @@ -97,6 +219,18 @@ def echo(message: str) -> str: return mcp, app +def make_comprehensive_fastmcp_streamable_http_app(): + """Create a comprehensive FastMCP server with StreamableHTTP transport.""" + from starlette.applications import Starlette + + # Create a new instance with different name for HTTP transport + mcp = make_comprehensive_fastmcp() + # We can't change the name after creation, so we'll use the same name + # Create the StreamableHTTP app + app: Starlette = mcp.streamable_http_app() + return mcp, app + + def make_fastmcp_stateless_http_app(): """Create a FastMCP server with stateless StreamableHTTP transport.""" from starlette.applications import Starlette @@ -126,6 +260,18 @@ def run_server(server_port: int) -> None: server.run() +def run_comprehensive_server(server_port: int) -> None: + """Run the comprehensive server with all features.""" + _, app = make_comprehensive_fastmcp_app() + server = uvicorn.Server( + config=uvicorn.Config( + app=app, host="127.0.0.1", port=server_port, log_level="error" + ) + ) + print(f"Starting comprehensive server on port {server_port}") + server.run() + + def run_streamable_http_server(server_port: int) -> None: """Run the StreamableHTTP server.""" _, app = make_fastmcp_streamable_http_app() @@ -138,6 +284,18 @@ def run_streamable_http_server(server_port: int) -> None: server.run() +def run_comprehensive_streamable_http_server(server_port: int) -> None: + """Run the comprehensive StreamableHTTP server with all features.""" + _, app = make_comprehensive_fastmcp_streamable_http_app() + server = uvicorn.Server( + config=uvicorn.Config( + app=app, host="127.0.0.1", port=server_port, log_level="error" + ) + ) + print(f"Starting comprehensive StreamableHTTP server on port {server_port}") + server.run() + + def run_stateless_http_server(server_port: int) -> None: """Run the stateless StreamableHTTP server.""" _, app = make_fastmcp_stateless_http_app() @@ -323,3 +481,553 @@ async def test_fastmcp_stateless_streamable_http( assert len(tool_result.content) == 1 assert isinstance(tool_result.content[0], TextContent) assert tool_result.content[0].text == f"Echo: test_{i}" + + +# Fixtures for comprehensive servers +@pytest.fixture +def comprehensive_server_port() -> int: + """Get a free port for testing the comprehensive server.""" + with socket.socket() as s: + s.bind(("127.0.0.1", 0)) + return s.getsockname()[1] + + +@pytest.fixture +def comprehensive_server_url(comprehensive_server_port: int) -> str: + """Get the comprehensive server URL for testing.""" + return f"http://127.0.0.1:{comprehensive_server_port}" + + +@pytest.fixture +def comprehensive_http_server_port() -> int: + """Get a free port for testing the comprehensive StreamableHTTP server.""" + with socket.socket() as s: + s.bind(("127.0.0.1", 0)) + return s.getsockname()[1] + + +@pytest.fixture +def comprehensive_http_server_url(comprehensive_http_server_port: int) -> str: + """Get the comprehensive StreamableHTTP server URL for testing.""" + return f"http://127.0.0.1:{comprehensive_http_server_port}" + + +@pytest.fixture() +def comprehensive_server(comprehensive_server_port: int) -> Generator[None, None, None]: + """Start the comprehensive server in a separate process and clean up after.""" + proc = multiprocessing.Process( + target=run_comprehensive_server, args=(comprehensive_server_port,), daemon=True + ) + print("Starting comprehensive server process") + proc.start() + + # Wait for server to be running + max_attempts = 20 + attempt = 0 + print("Waiting for comprehensive server to start") + while attempt < max_attempts: + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.connect(("127.0.0.1", comprehensive_server_port)) + break + except ConnectionRefusedError: + time.sleep(0.1) + attempt += 1 + else: + raise RuntimeError( + f"Comprehensive server failed to start after {max_attempts} attempts" + ) + + yield + + print("Killing comprehensive server") + proc.kill() + proc.join(timeout=2) + if proc.is_alive(): + print("Comprehensive server process failed to terminate") + + +@pytest.fixture() +def comprehensive_streamable_http_server( + comprehensive_http_server_port: int, +) -> Generator[None, None, None]: + """Start the comprehensive StreamableHTTP server in a separate process.""" + proc = multiprocessing.Process( + target=run_comprehensive_streamable_http_server, + args=(comprehensive_http_server_port,), + daemon=True, + ) + print("Starting comprehensive StreamableHTTP server process") + proc.start() + + # Wait for server to be running + max_attempts = 20 + attempt = 0 + print("Waiting for comprehensive StreamableHTTP server to start") + while attempt < max_attempts: + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.connect(("127.0.0.1", comprehensive_http_server_port)) + break + except ConnectionRefusedError: + time.sleep(0.1) + attempt += 1 + else: + raise RuntimeError( + f"Comprehensive StreamableHTTP server failed to start after " + f"{max_attempts} attempts" + ) + + yield + + print("Killing comprehensive StreamableHTTP server") + proc.kill() + proc.join(timeout=2) + if proc.is_alive(): + print("Comprehensive StreamableHTTP server process failed to terminate") + + +class NotificationCollector: + def __init__(self): + self.progress_notifications: list = [] + self.log_messages: list = [] + self.resource_notifications: list = [] + self.tool_notifications: list = [] + + async def handle_progress(self, params) -> None: + self.progress_notifications.append(params) + + async def handle_log(self, params) -> None: + self.log_messages.append(params) + + async def handle_resource_list_changed(self, params) -> None: + self.resource_notifications.append(params) + + async def handle_tool_list_changed(self, params) -> None: + self.tool_notifications.append(params) + + async def handle_generic_notification(self, message) -> None: + # Check if this is a ServerNotification + if isinstance(message, types.ServerNotification): + # Check the specific notification type + if isinstance(message.root, types.ProgressNotification): + await self.handle_progress(message.root.params) + elif isinstance(message.root, types.LoggingMessageNotification): + await self.handle_log(message.root.params) + elif isinstance(message.root, types.ResourceListChangedNotification): + await self.handle_resource_list_changed(message.root.params) + elif isinstance(message.root, types.ToolListChangedNotification): + await self.handle_tool_list_changed(message.root.params) + + +@pytest.mark.anyio +async def test_fastmcp_all_features_sse( + comprehensive_server: None, comprehensive_server_url: str +) -> None: + """Test all MCP features work correctly with SSE transport.""" + from mcp.types import ( + GetPromptResult, + ReadResourceResult, + CreateMessageResult, + CreateMessageRequestParams, + SamplingMessage, + TextContent, + ) + from mcp.shared.context import RequestContext + + # Create notification collector + collector = NotificationCollector() + + # Create a sampling callback that simulates an LLM + async def sampling_callback( + context: RequestContext[ClientSession, None], + params: CreateMessageRequestParams, + ) -> CreateMessageResult: + # Simulate LLM response based on the input + if params.messages and isinstance(params.messages[0].content, TextContent): + input_text = params.messages[0].content.text + else: + input_text = "No input" + response_text = f"This is a simulated LLM response to: {input_text}" + + return CreateMessageResult( + role="assistant", + content=TextContent(type="text", text=response_text), + model="test-llm-model", + stopReason="endTurn", + ) + + # Connect to the server with callbacks + async with sse_client(comprehensive_server_url + "/sse") as streams: + # Set up message handler to capture notifications + async def message_handler(message): + print(f"Received message: {message}") + await collector.handle_generic_notification(message) + if isinstance(message, Exception): + raise message + + async with ClientSession( + *streams, + sampling_callback=sampling_callback, + message_handler=message_handler, + ) as session: + # Test initialization + result = await session.initialize() + assert isinstance(result, InitializeResult) + assert result.serverInfo.name == "AllFeaturesServer" + + # Check server features are reported + assert result.capabilities.prompts is not None + assert result.capabilities.resources is not None + assert result.capabilities.tools is not None + # Note: logging capability may be None if no tools use context logging + + # Test tools + # 1. Simple echo tool + tool_result = await session.call_tool("echo", {"message": "hello"}) + assert len(tool_result.content) == 1 + assert isinstance(tool_result.content[0], TextContent) + assert tool_result.content[0].text == "Echo: hello" + + # 2. Tool with context (logging and progress) + # Test with progress token to capture progress notifications + tool_result = await session.call_tool( + "tool_with_context", + { + "message": "test", + "steps": 3, + "_meta": {"progressToken": "sse_test_token"}, + }, + ) + assert len(tool_result.content) == 1 + assert len(collector.progress_notifications) > 0 + assert isinstance(tool_result.content[0], TextContent) + assert "Processed 'test' in 3 steps" in tool_result.content[0].text + + # Verify we received log messages from the tool + # Note: Progress notifications require special handling in the MCP client + # that's not implemented by default, so we focus on testing logging + assert len(collector.log_messages) > 0 + + # 3. Test sampling tool + sampling_result = await session.call_tool( + "sampling_tool", {"prompt": "What is the meaning of life?"} + ) + assert len(sampling_result.content) == 1 + assert isinstance(sampling_result.content[0], TextContent) + assert "Sampling result:" in sampling_result.content[0].text + assert "This is a simulated LLM response" in sampling_result.content[0].text + + # Give time for log messages + await anyio.sleep(0.1) + + # Verify we received log messages from the sampling tool + assert len(collector.log_messages) > 0 + assert any( + "Requesting sampling for prompt" in msg.data + for msg in collector.log_messages + ) + assert any( + "Received sampling result from model" in msg.data + for msg in collector.log_messages + ) + + # 4. Test notification tool + notification_result = await session.call_tool( + "notification_tool", {"message": "test_notifications"} + ) + assert len(notification_result.content) == 1 + assert isinstance(notification_result.content[0], TextContent) + assert "Sent notifications and logs" in notification_result.content[0].text + + # Give time for notifications + await anyio.sleep(0.1) + + # Verify we received various notification types + assert len(collector.log_messages) > 3 # Should have logs from both tools + assert len(collector.resource_notifications) > 0 + assert len(collector.tool_notifications) > 0 + + # Check that we got different log levels + log_levels = [msg.level for msg in collector.log_messages] + assert "debug" in log_levels + assert "info" in log_levels + assert "warning" in log_levels + + # Test resources + # 1. Static resource + resources = await session.list_resources() + # Try using string comparison since AnyUrl might not match directly + static_resource = next( + ( + r + for r in resources.resources + if str(r.uri) == "resource://static/info" + ), + None, + ) + assert static_resource is not None + assert static_resource.name == "Static Info" + + static_content = await session.read_resource( + AnyUrl("resource://static/info") + ) + assert isinstance(static_content, ReadResourceResult) + assert len(static_content.contents) == 1 + assert isinstance(static_content.contents[0], TextResourceContents) + assert static_content.contents[0].text == "This is static resource content" + + # 2. Dynamic resource + dynamic_content = await session.read_resource( + AnyUrl("resource://dynamic/test") + ) + assert isinstance(dynamic_content, ReadResourceResult) + assert len(dynamic_content.contents) == 1 + assert isinstance(dynamic_content.contents[0], TextResourceContents) + assert ( + "Dynamic resource content for category: test" + in dynamic_content.contents[0].text + ) + + # 3. Template resource + template_content = await session.read_resource( + AnyUrl("resource://template/123/data") + ) + assert isinstance(template_content, ReadResourceResult) + assert len(template_content.contents) == 1 + assert isinstance(template_content.contents[0], TextResourceContents) + assert ( + "Template resource data for ID: 123" + in template_content.contents[0].text + ) + + # Test prompts + # 1. Simple prompt + prompts = await session.list_prompts() + simple_prompt = next( + (p for p in prompts.prompts if p.name == "simple_prompt"), None + ) + assert simple_prompt is not None + + prompt_result = await session.get_prompt("simple_prompt", {"topic": "AI"}) + assert isinstance(prompt_result, GetPromptResult) + assert len(prompt_result.messages) >= 1 + # The actual message structure depends on the prompt implementation + + # 2. Complex prompt + complex_prompt = next( + (p for p in prompts.prompts if p.name == "complex_prompt"), None + ) + assert complex_prompt is not None + + complex_result = await session.get_prompt( + "complex_prompt", {"user_query": "What is AI?", "context": "technical"} + ) + assert isinstance(complex_result, GetPromptResult) + assert len(complex_result.messages) >= 1 + + +@pytest.mark.anyio +async def test_fastmcp_all_features_streamable_http( + comprehensive_streamable_http_server: None, comprehensive_http_server_url: str +) -> None: + """Test all MCP features work correctly with StreamableHTTP transport.""" + from mcp.types import ( + GetPromptResult, + ReadResourceResult, + CreateMessageResult, + CreateMessageRequestParams, + SamplingMessage, + TextContent, + ) + from mcp.shared.context import RequestContext + + # Create notification collector + collector = NotificationCollector() + + # Create a sampling callback that simulates an LLM + async def sampling_callback( + context: RequestContext[ClientSession, None], + params: CreateMessageRequestParams, + ) -> CreateMessageResult: + # Simulate LLM response + if params.messages and isinstance(params.messages[0].content, TextContent): + input_text = params.messages[0].content.text + else: + input_text = "No input" + response_text = f"This is a simulated LLM response to: {input_text}" + + return CreateMessageResult( + role="assistant", + content=TextContent(type="text", text=response_text), + model="test-llm-model-http", + stopReason="endTurn", + ) + + # Connect to the server using StreamableHTTP + async with streamablehttp_client(comprehensive_http_server_url + "/mcp") as ( + read_stream, + write_stream, + _, + ): + # Set up message handler to capture notifications + async def message_handler(message): + print(f"Received message: {message}") + await collector.handle_generic_notification(message) + if isinstance(message, Exception): + raise message + + async with ClientSession( + read_stream, + write_stream, + sampling_callback=sampling_callback, + message_handler=message_handler, + ) as session: + # Test initialization + result = await session.initialize() + assert isinstance(result, InitializeResult) + assert result.serverInfo.name == "AllFeaturesServer" + + # Check server features are reported + assert result.capabilities.prompts is not None + assert result.capabilities.resources is not None + assert result.capabilities.tools is not None + # Note: logging capability may be None if no tools use context logging + + # Test tools + # 1. Simple echo tool + tool_result = await session.call_tool("echo", {"message": "hello"}) + assert len(tool_result.content) == 1 + assert isinstance(tool_result.content[0], TextContent) + assert tool_result.content[0].text == "Echo: hello" + + # 2. Tool with context (logging and progress) + # Test with progress token to capture progress notifications + await session.call_tool( + "tool_with_context", + { + "message": "http_test", + "steps": 2, + }, + ) + + # Verify we received progress notifications + assert len(collector.progress_notifications) > 0 + assert any( + p.progressToken == "http_test_token" + for p in collector.progress_notifications + ) + + # 3. Test sampling tool + sampling_result = await session.call_tool( + "sampling_tool", {"prompt": "Explain quantum computing"} + ) + assert len(sampling_result.content) == 1 + assert isinstance(sampling_result.content[0], TextContent) + assert "Sampling result:" in sampling_result.content[0].text + assert "This is a simulated LLM response" in sampling_result.content[0].text + + # Give time for log messages + await anyio.sleep(0.1) + + # Verify we received log messages + assert len(collector.log_messages) > 0 + assert any( + "Requesting sampling for prompt" in msg.data + for msg in collector.log_messages + ) + + # Test resources + # 1. Static resource + resources = await session.list_resources() + # Try using string comparison since AnyUrl might not match directly + static_resource = next( + ( + r + for r in resources.resources + if str(r.uri) == "resource://static/info" + ), + None, + ) + assert static_resource is not None + assert static_resource.name == "Static Info" + + static_content = await session.read_resource( + AnyUrl("resource://static/info") + ) + assert isinstance(static_content, ReadResourceResult) + assert len(static_content.contents) == 1 + assert isinstance(static_content.contents[0], TextResourceContents) + assert static_content.contents[0].text == "This is static resource content" + + # 2. Dynamic resource + dynamic_content = await session.read_resource( + AnyUrl("resource://dynamic/http") + ) + assert isinstance(dynamic_content, ReadResourceResult) + assert len(dynamic_content.contents) == 1 + assert isinstance(dynamic_content.contents[0], TextResourceContents) + assert ( + "Dynamic resource content for category: http" + in dynamic_content.contents[0].text + ) + + # 3. Template resource + template_content = await session.read_resource( + AnyUrl("resource://template/456/data") + ) + assert isinstance(template_content, ReadResourceResult) + assert len(template_content.contents) == 1 + assert isinstance(template_content.contents[0], TextResourceContents) + assert ( + "Template resource data for ID: 456" + in template_content.contents[0].text + ) + + # Test prompts + # 1. Simple prompt + prompts = await session.list_prompts() + simple_prompt = next( + (p for p in prompts.prompts if p.name == "simple_prompt"), None + ) + assert simple_prompt is not None + + prompt_result = await session.get_prompt("simple_prompt", {"topic": "HTTP"}) + assert isinstance(prompt_result, GetPromptResult) + assert len(prompt_result.messages) >= 1 + # The actual message structure depends on the prompt implementation + + # 2. Complex prompt + complex_prompt = next( + (p for p in prompts.prompts if p.name == "complex_prompt"), None + ) + assert complex_prompt is not None + + complex_result = await session.get_prompt( + "complex_prompt", {"user_query": "What is HTTP?", "context": "web"} + ) + assert isinstance(complex_result, GetPromptResult) + assert len(complex_result.messages) >= 1 + + # Test that all features work in sequence (integration test) + # This tests that the different transport doesn't affect feature interaction + for i in range(3): + # Call tool + tool_result = await session.call_tool( + "echo", {"message": f"iteration_{i}"} + ) + assert isinstance(tool_result.content[0], TextContent) + assert f"iteration_{i}" in tool_result.content[0].text + + # Read resource + resource_result = await session.read_resource( + AnyUrl(f"resource://dynamic/{i}") + ) + assert isinstance(resource_result.contents[0], TextResourceContents) + assert f"category: {i}" in resource_result.contents[0].text + + # Get prompt + prompt_result = await session.get_prompt( + "simple_prompt", {"topic": f"topic_{i}"} + ) + assert len(prompt_result.messages) >= 1 From 52e066aed0ec7b57aa3b86fcc2fb69b9f0feee8d Mon Sep 17 00:00:00 2001 From: ihrpr Date: Thu, 15 May 2025 15:22:04 +0100 Subject: [PATCH 2/9] add progress notificaiton to the client --- src/mcp/client/session.py | 11 ++- src/mcp/server/fastmcp/server.py | 1 - src/mcp/shared/session.py | 52 ++++++++++++-- tests/server/fastmcp/test_integration.py | 91 +++++++++++++++--------- 4 files changed, 111 insertions(+), 44 deletions(-) diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 15e8809c1..9ab1c7b87 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -8,7 +8,7 @@ import mcp.types as types from mcp.shared.context import RequestContext from mcp.shared.message import SessionMessage -from mcp.shared.session import BaseSession, RequestResponder +from mcp.shared.session import BaseSession, ProgressCallbackFnT, RequestResponder from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS DEFAULT_CLIENT_INFO = types.Implementation(name="mcp", version="0.1.0") @@ -270,18 +270,23 @@ async def call_tool( name: str, arguments: dict[str, Any] | None = None, read_timeout_seconds: timedelta | None = None, + progress_callback: ProgressCallbackFnT | None = None, ) -> types.CallToolResult: - """Send a tools/call request.""" + """Send a tools/call request with optional progress callback support.""" return await self.send_request( types.ClientRequest( types.CallToolRequest( method="tools/call", - params=types.CallToolRequestParams(name=name, arguments=arguments), + params=types.CallToolRequestParams( + name=name, + arguments=arguments, + ), ) ), types.CallToolResult, request_read_timeout_seconds=read_timeout_seconds, + progress_callback=progress_callback, ) async def list_prompts(self, cursor: str | None = None) -> types.ListPromptsResult: diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index 09896dc6d..21c31b0b3 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -963,7 +963,6 @@ async def report_progress( total: Optional total value e.g. 100 message: Optional message e.g. Starting render... """ - progress_token = ( self.request_context.meta.progressToken if self.request_context.meta diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 19728e0ec..62801e3da 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -3,7 +3,7 @@ from contextlib import AsyncExitStack from datetime import timedelta from types import TracebackType -from typing import Any, Generic, TypeVar +from typing import Any, Generic, Protocol, TypeVar import anyio import httpx @@ -24,6 +24,7 @@ JSONRPCNotification, JSONRPCRequest, JSONRPCResponse, + ProgressNotification, RequestParams, ServerNotification, ServerRequest, @@ -42,6 +43,22 @@ RequestId = str | int +class ProgressCallbackFnT(Protocol): + """Protocol for progress notification callbacks.""" + + def __call__( + self, progress: float, total: float | None, message: str | None + ) -> None: + """Called when progress updates are received. + + Args: + progress: Current progress value + total: Total progress value (if known), None if indeterminate + message: Optional progress message + """ + ... + + class RequestResponder(Generic[ReceiveRequestT, SendResultT]): """Handles responding to MCP requests and manages request lifecycle. @@ -169,6 +186,7 @@ class BaseSession( ] _request_id: int _in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]] + _progress_callbacks: dict[RequestId, ProgressCallbackFnT] def __init__( self, @@ -187,6 +205,7 @@ def __init__( self._receive_notification_type = receive_notification_type self._session_read_timeout_seconds = read_timeout_seconds self._in_flight = {} + self._progress_callbacks = {} self._exit_stack = AsyncExitStack() async def __aenter__(self) -> Self: @@ -214,6 +233,7 @@ async def send_request( result_type: type[ReceiveResultT], request_read_timeout_seconds: timedelta | None = None, metadata: MessageMetadata = None, + progress_callback: ProgressCallbackFnT | None = None, ) -> ReceiveResultT: """ Sends a request and wait for a response. Raises an McpError if the @@ -231,15 +251,25 @@ async def send_request( ](1) self._response_streams[request_id] = response_stream + # Set up progress token if progress callback is provided + request_data = request.model_dump(by_alias=True, mode="json", exclude_none=True) + if progress_callback is not None: + # Use request_id as progress token + if "params" not in request_data: + request_data["params"] = {} + if "_meta" not in request_data["params"]: + request_data["params"]["_meta"] = {} + request_data["params"]["_meta"]["progressToken"] = request_id + # Store the callback for this request + self._progress_callbacks[request_id] = progress_callback + try: jsonrpc_request = JSONRPCRequest( jsonrpc="2.0", id=request_id, - **request.model_dump(by_alias=True, mode="json", exclude_none=True), + **request_data, ) - # TODO: Support progress callbacks - await self._write_stream.send( SessionMessage( message=JSONRPCMessage(jsonrpc_request), metadata=metadata @@ -275,6 +305,7 @@ async def send_request( finally: self._response_streams.pop(request_id, None) + self._progress_callbacks.pop(request_id, None) await response_stream.aclose() await response_stream_reader.aclose() @@ -333,7 +364,6 @@ async def _receive_loop(self) -> None: by_alias=True, mode="json", exclude_none=True ) ) - responder = RequestResponder( request_id=message.message.root.id, request_meta=validated_request.root.params.meta @@ -362,6 +392,18 @@ async def _receive_loop(self) -> None: cancelled_id = notification.root.params.requestId if cancelled_id in self._in_flight: await self._in_flight[cancelled_id].cancel() + # Handle progress notifications + elif isinstance(notification.root, ProgressNotification): + progress_token = notification.root.params.progressToken + # If there is a progress callback for this token, + # call it with the progress information + if progress_token in self._progress_callbacks: + callback = self._progress_callbacks[progress_token] + callback( + notification.root.params.progress, + notification.root.params.total, + notification.root.params.message, + ) else: await self._received_notification(notification) await self._handle_incoming(notification) diff --git a/tests/server/fastmcp/test_integration.py b/tests/server/fastmcp/test_integration.py index cc17a6d2b..3d15a5c96 100644 --- a/tests/server/fastmcp/test_integration.py +++ b/tests/server/fastmcp/test_integration.py @@ -13,14 +13,23 @@ import anyio import pytest import uvicorn +from pydantic import AnyUrl +import mcp.types as types from mcp.client.session import ClientSession from mcp.client.sse import sse_client from mcp.client.streamable_http import streamablehttp_client from mcp.server.fastmcp import FastMCP -import mcp.types as types -from mcp.types import InitializeResult, TextContent, TextResourceContents -from pydantic import AnyUrl +from mcp.shared.context import RequestContext +from mcp.types import ( + CreateMessageRequestParams, + CreateMessageResult, + GetPromptResult, + InitializeResult, + ReadResourceResult, + TextContent, + TextResourceContents, +) @pytest.fixture @@ -625,15 +634,6 @@ async def test_fastmcp_all_features_sse( comprehensive_server: None, comprehensive_server_url: str ) -> None: """Test all MCP features work correctly with SSE transport.""" - from mcp.types import ( - GetPromptResult, - ReadResourceResult, - CreateMessageResult, - CreateMessageRequestParams, - SamplingMessage, - TextContent, - ) - from mcp.shared.context import RequestContext # Create notification collector collector = NotificationCollector() @@ -690,20 +690,38 @@ async def message_handler(message): assert tool_result.content[0].text == "Echo: hello" # 2. Tool with context (logging and progress) - # Test with progress token to capture progress notifications + # Test progress callback functionality + progress_updates = [] + + def progress_callback( + progress: float, total: float | None, message: str | None + ) -> None: + """Collect progress updates for testing.""" + progress_updates.append((progress, total, message)) + print(f"Progress: {progress}/{total} - {message}") + + params = { + "message": "test", + "steps": 3, + } tool_result = await session.call_tool( "tool_with_context", - { - "message": "test", - "steps": 3, - "_meta": {"progressToken": "sse_test_token"}, - }, + params, + progress_callback=progress_callback, ) assert len(tool_result.content) == 1 - assert len(collector.progress_notifications) > 0 assert isinstance(tool_result.content[0], TextContent) assert "Processed 'test' in 3 steps" in tool_result.content[0].text + # Verify progress callback was called + assert len(progress_updates) == 3 + for i, (progress, total, message) in enumerate(progress_updates): + expected_progress = (i + 1) / 3 + assert abs(progress - expected_progress) < 0.01 + assert total == 1.0 + assert message is not None + assert f"step {i + 1} of 3" in message + # Verify we received log messages from the tool # Note: Progress notifications require special handling in the MCP client # that's not implemented by default, so we focus on testing logging @@ -832,15 +850,6 @@ async def test_fastmcp_all_features_streamable_http( comprehensive_streamable_http_server: None, comprehensive_http_server_url: str ) -> None: """Test all MCP features work correctly with StreamableHTTP transport.""" - from mcp.types import ( - GetPromptResult, - ReadResourceResult, - CreateMessageResult, - CreateMessageRequestParams, - SamplingMessage, - TextContent, - ) - from mcp.shared.context import RequestContext # Create notification collector collector = NotificationCollector() @@ -902,21 +911,33 @@ async def message_handler(message): assert tool_result.content[0].text == "Echo: hello" # 2. Tool with context (logging and progress) - # Test with progress token to capture progress notifications + # Test progress callback functionality over HTTP + progress_updates_http = [] + + def progress_callback_http( + progress: float, total: float | None, message: str | None + ) -> None: + """Collect progress updates for HTTP testing.""" + progress_updates_http.append((progress, total, message)) + print(f"HTTP Progress: {progress}/{total} - {message}") + await session.call_tool( "tool_with_context", { "message": "http_test", "steps": 2, }, + progress_callback=progress_callback_http, ) - # Verify we received progress notifications - assert len(collector.progress_notifications) > 0 - assert any( - p.progressToken == "http_test_token" - for p in collector.progress_notifications - ) + # Verify progress callback was called over HTTP + assert len(progress_updates_http) == 2 + for i, (progress, total, message) in enumerate(progress_updates_http): + expected_progress = (i + 1) / 2 + assert abs(progress - expected_progress) < 0.01 + assert total == 1.0 + assert message is not None + assert f"step {i + 1} of 2" in message # 3. Test sampling tool sampling_result = await session.call_tool( From bf08774bc997d183ca4067447addaa12d9c957bf Mon Sep 17 00:00:00 2001 From: ihrpr Date: Thu, 15 May 2025 15:31:44 +0100 Subject: [PATCH 3/9] rename --- src/mcp/client/session.py | 4 ++-- src/mcp/server/session.py | 3 --- src/mcp/shared/session.py | 16 ++++------------ 3 files changed, 6 insertions(+), 17 deletions(-) diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 9ab1c7b87..c714c44bb 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -8,7 +8,7 @@ import mcp.types as types from mcp.shared.context import RequestContext from mcp.shared.message import SessionMessage -from mcp.shared.session import BaseSession, ProgressCallbackFnT, RequestResponder +from mcp.shared.session import BaseSession, ProgressFnT, RequestResponder from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS DEFAULT_CLIENT_INFO = types.Implementation(name="mcp", version="0.1.0") @@ -270,7 +270,7 @@ async def call_tool( name: str, arguments: dict[str, Any] | None = None, read_timeout_seconds: timedelta | None = None, - progress_callback: ProgressCallbackFnT | None = None, + progress_callback: ProgressFnT | None = None, ) -> types.CallToolResult: """Send a tools/call request with optional progress callback support.""" diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index 9177a1609..4f97c6cd6 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -286,9 +286,6 @@ async def send_progress_notification( related_request_id: str | None = None, ) -> None: """Send a progress notification.""" - print( - f"Sending progress notification: {progress_token}, {progress}, {total}, {message}" - ) await self.send_notification( types.ServerNotification( types.ProgressNotification( diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 62801e3da..7fc08d679 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -43,20 +43,12 @@ RequestId = str | int -class ProgressCallbackFnT(Protocol): +class ProgressFnT(Protocol): """Protocol for progress notification callbacks.""" def __call__( self, progress: float, total: float | None, message: str | None - ) -> None: - """Called when progress updates are received. - - Args: - progress: Current progress value - total: Total progress value (if known), None if indeterminate - message: Optional progress message - """ - ... + ) -> None: ... class RequestResponder(Generic[ReceiveRequestT, SendResultT]): @@ -186,7 +178,7 @@ class BaseSession( ] _request_id: int _in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]] - _progress_callbacks: dict[RequestId, ProgressCallbackFnT] + _progress_callbacks: dict[RequestId, ProgressFnT] def __init__( self, @@ -233,7 +225,7 @@ async def send_request( result_type: type[ReceiveResultT], request_read_timeout_seconds: timedelta | None = None, metadata: MessageMetadata = None, - progress_callback: ProgressCallbackFnT | None = None, + progress_callback: ProgressFnT | None = None, ) -> ReceiveResultT: """ Sends a request and wait for a response. Raises an McpError if the From 7356a3f1589428823e50151a4be4b7e21b67a3ec Mon Sep 17 00:00:00 2001 From: ihrpr Date: Thu, 15 May 2025 15:37:58 +0100 Subject: [PATCH 4/9] name fixes --- tests/server/fastmcp/test_integration.py | 26 ++++++++++++------------ 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/tests/server/fastmcp/test_integration.py b/tests/server/fastmcp/test_integration.py index 3d15a5c96..813db3118 100644 --- a/tests/server/fastmcp/test_integration.py +++ b/tests/server/fastmcp/test_integration.py @@ -92,15 +92,15 @@ def echo(message: str) -> str: return mcp, app -def make_comprehensive_fastmcp() -> FastMCP: +def make_everything_fastmcp() -> FastMCP: """Create a FastMCP server with all features enabled for testing.""" from mcp.server.fastmcp import Context - mcp = FastMCP(name="AllFeaturesServer") + mcp = FastMCP(name="EverythingServer") # Tool with context for logging and progress @mcp.tool(description="A tool that demonstrates logging and progress") - async def tool_with_context(message: str, ctx: Context, steps: int = 3) -> str: + async def tool_with_progress(message: str, ctx: Context, steps: int = 3) -> str: await ctx.info(f"Starting processing of '{message}' with {steps} steps") # Send progress notifications @@ -139,8 +139,8 @@ async def sampling_tool(prompt: str, ctx: Context) -> str: ) await ctx.info(f"Received sampling result from model: {result.model}") - # Handle different content types safely - if hasattr(result.content, "text"): + # Handle different content types + if result.content.type == "text": return f"Sampling result: {result.content.text[:100]}..." else: return f"Sampling result: {str(result.content)[:100]}..." @@ -201,11 +201,11 @@ def complex_prompt(user_query: str, context: str = "general") -> str: return mcp -def make_comprehensive_fastmcp_app(): +def make_everything_fastmcp_app(): """Create a comprehensive FastMCP server with SSE transport.""" from starlette.applications import Starlette - mcp = make_comprehensive_fastmcp() + mcp = make_everything_fastmcp() # Create the SSE app app: Starlette = mcp.sse_app() return mcp, app @@ -228,12 +228,12 @@ def echo(message: str) -> str: return mcp, app -def make_comprehensive_fastmcp_streamable_http_app(): +def make_everything_fastmcp_streamable_http_app(): """Create a comprehensive FastMCP server with StreamableHTTP transport.""" from starlette.applications import Starlette # Create a new instance with different name for HTTP transport - mcp = make_comprehensive_fastmcp() + mcp = make_everything_fastmcp() # We can't change the name after creation, so we'll use the same name # Create the StreamableHTTP app app: Starlette = mcp.streamable_http_app() @@ -271,7 +271,7 @@ def run_server(server_port: int) -> None: def run_comprehensive_server(server_port: int) -> None: """Run the comprehensive server with all features.""" - _, app = make_comprehensive_fastmcp_app() + _, app = make_everything_fastmcp_app() server = uvicorn.Server( config=uvicorn.Config( app=app, host="127.0.0.1", port=server_port, log_level="error" @@ -295,7 +295,7 @@ def run_streamable_http_server(server_port: int) -> None: def run_comprehensive_streamable_http_server(server_port: int) -> None: """Run the comprehensive StreamableHTTP server with all features.""" - _, app = make_comprehensive_fastmcp_streamable_http_app() + _, app = make_everything_fastmcp_streamable_http_app() server = uvicorn.Server( config=uvicorn.Config( app=app, host="127.0.0.1", port=server_port, log_level="error" @@ -705,7 +705,7 @@ def progress_callback( "steps": 3, } tool_result = await session.call_tool( - "tool_with_context", + "tool_with_progress", params, progress_callback=progress_callback, ) @@ -922,7 +922,7 @@ def progress_callback_http( print(f"HTTP Progress: {progress}/{total} - {message}") await session.call_tool( - "tool_with_context", + "tool_with_progress", { "message": "http_test", "steps": 2, From d8fc0e7ba34a10c35f6e99f56d82062646ea272c Mon Sep 17 00:00:00 2001 From: ihrpr Date: Thu, 15 May 2025 15:47:19 +0100 Subject: [PATCH 5/9] names --- tests/server/fastmcp/test_integration.py | 38 ++++++++++++------------ 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/tests/server/fastmcp/test_integration.py b/tests/server/fastmcp/test_integration.py index 813db3118..5239a648b 100644 --- a/tests/server/fastmcp/test_integration.py +++ b/tests/server/fastmcp/test_integration.py @@ -269,7 +269,7 @@ def run_server(server_port: int) -> None: server.run() -def run_comprehensive_server(server_port: int) -> None: +def run_everything_server(server_port: int) -> None: """Run the comprehensive server with all features.""" _, app = make_everything_fastmcp_app() server = uvicorn.Server( @@ -494,7 +494,7 @@ async def test_fastmcp_stateless_streamable_http( # Fixtures for comprehensive servers @pytest.fixture -def comprehensive_server_port() -> int: +def everything_server_port() -> int: """Get a free port for testing the comprehensive server.""" with socket.socket() as s: s.bind(("127.0.0.1", 0)) @@ -502,13 +502,13 @@ def comprehensive_server_port() -> int: @pytest.fixture -def comprehensive_server_url(comprehensive_server_port: int) -> str: +def everything_server_url(everything_server_port: int) -> str: """Get the comprehensive server URL for testing.""" - return f"http://127.0.0.1:{comprehensive_server_port}" + return f"http://127.0.0.1:{everything_server_port}" @pytest.fixture -def comprehensive_http_server_port() -> int: +def everything_http_server_port() -> int: """Get a free port for testing the comprehensive StreamableHTTP server.""" with socket.socket() as s: s.bind(("127.0.0.1", 0)) @@ -516,16 +516,16 @@ def comprehensive_http_server_port() -> int: @pytest.fixture -def comprehensive_http_server_url(comprehensive_http_server_port: int) -> str: +def everything_http_server_url(everything_http_server_port: int) -> str: """Get the comprehensive StreamableHTTP server URL for testing.""" - return f"http://127.0.0.1:{comprehensive_http_server_port}" + return f"http://127.0.0.1:{everything_http_server_port}" @pytest.fixture() -def comprehensive_server(comprehensive_server_port: int) -> Generator[None, None, None]: +def everything_server(everything_server_port: int) -> Generator[None, None, None]: """Start the comprehensive server in a separate process and clean up after.""" proc = multiprocessing.Process( - target=run_comprehensive_server, args=(comprehensive_server_port,), daemon=True + target=run_everything_server, args=(everything_server_port,), daemon=True ) print("Starting comprehensive server process") proc.start() @@ -537,7 +537,7 @@ def comprehensive_server(comprehensive_server_port: int) -> Generator[None, None while attempt < max_attempts: try: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.connect(("127.0.0.1", comprehensive_server_port)) + s.connect(("127.0.0.1", everything_server_port)) break except ConnectionRefusedError: time.sleep(0.1) @@ -558,12 +558,12 @@ def comprehensive_server(comprehensive_server_port: int) -> Generator[None, None @pytest.fixture() def comprehensive_streamable_http_server( - comprehensive_http_server_port: int, + everything_http_server_port: int, ) -> Generator[None, None, None]: """Start the comprehensive StreamableHTTP server in a separate process.""" proc = multiprocessing.Process( target=run_comprehensive_streamable_http_server, - args=(comprehensive_http_server_port,), + args=(everything_http_server_port,), daemon=True, ) print("Starting comprehensive StreamableHTTP server process") @@ -576,7 +576,7 @@ def comprehensive_streamable_http_server( while attempt < max_attempts: try: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.connect(("127.0.0.1", comprehensive_http_server_port)) + s.connect(("127.0.0.1", everything_http_server_port)) break except ConnectionRefusedError: time.sleep(0.1) @@ -631,7 +631,7 @@ async def handle_generic_notification(self, message) -> None: @pytest.mark.anyio async def test_fastmcp_all_features_sse( - comprehensive_server: None, comprehensive_server_url: str + everything_server: None, everything_server_url: str ) -> None: """Test all MCP features work correctly with SSE transport.""" @@ -658,7 +658,7 @@ async def sampling_callback( ) # Connect to the server with callbacks - async with sse_client(comprehensive_server_url + "/sse") as streams: + async with sse_client(everything_server_url + "/sse") as streams: # Set up message handler to capture notifications async def message_handler(message): print(f"Received message: {message}") @@ -674,7 +674,7 @@ async def message_handler(message): # Test initialization result = await session.initialize() assert isinstance(result, InitializeResult) - assert result.serverInfo.name == "AllFeaturesServer" + assert result.serverInfo.name == "EverythingServer" # Check server features are reported assert result.capabilities.prompts is not None @@ -847,7 +847,7 @@ def progress_callback( @pytest.mark.anyio async def test_fastmcp_all_features_streamable_http( - comprehensive_streamable_http_server: None, comprehensive_http_server_url: str + comprehensive_streamable_http_server: None, everything_http_server_url: str ) -> None: """Test all MCP features work correctly with StreamableHTTP transport.""" @@ -874,7 +874,7 @@ async def sampling_callback( ) # Connect to the server using StreamableHTTP - async with streamablehttp_client(comprehensive_http_server_url + "/mcp") as ( + async with streamablehttp_client(everything_http_server_url + "/mcp") as ( read_stream, write_stream, _, @@ -895,7 +895,7 @@ async def message_handler(message): # Test initialization result = await session.initialize() assert isinstance(result, InitializeResult) - assert result.serverInfo.name == "AllFeaturesServer" + assert result.serverInfo.name == "EverythingServer" # Check server features are reported assert result.capabilities.prompts is not None From 7e090f532e6bcfbcb8d2b0c39254e4812db94997 Mon Sep 17 00:00:00 2001 From: ihrpr Date: Thu, 15 May 2025 16:16:25 +0100 Subject: [PATCH 6/9] refactor --- tests/server/fastmcp/test_integration.py | 570 ++++++++--------------- 1 file changed, 202 insertions(+), 368 deletions(-) diff --git a/tests/server/fastmcp/test_integration.py b/tests/server/fastmcp/test_integration.py index 5239a648b..e7e314e72 100644 --- a/tests/server/fastmcp/test_integration.py +++ b/tests/server/fastmcp/test_integration.py @@ -629,6 +629,204 @@ async def handle_generic_notification(self, message) -> None: await self.handle_tool_list_changed(message.root.params) +async def call_all_mcp_features( + session: ClientSession, collector: NotificationCollector +) -> None: + """ + Test all MCP features using the provided session. + + Args: + session: The MCP client session to test with + collector: Notification collector for capturing server notifications + """ + # Test initialization + result = await session.initialize() + assert isinstance(result, InitializeResult) + assert result.serverInfo.name == "EverythingServer" + + # Check server features are reported + assert result.capabilities.prompts is not None + assert result.capabilities.resources is not None + assert result.capabilities.tools is not None + # Note: logging capability may be None if no tools use context logging + + # Test tools + # 1. Simple echo tool + tool_result = await session.call_tool("echo", {"message": "hello"}) + assert len(tool_result.content) == 1 + assert isinstance(tool_result.content[0], TextContent) + assert tool_result.content[0].text == "Echo: hello" + + # 2. Tool with context (logging and progress) + # Test progress callback functionality + progress_updates = [] + + def progress_callback( + progress: float, total: float | None, message: str | None + ) -> None: + """Collect progress updates for testing.""" + progress_updates.append((progress, total, message)) + print(f"Progress: {progress}/{total} - {message}") + + test_message = "test" + steps = 3 + params = { + "message": test_message, + "steps": steps, + } + tool_result = await session.call_tool( + "tool_with_progress", + params, + progress_callback=progress_callback, + ) + assert len(tool_result.content) == 1 + assert isinstance(tool_result.content[0], TextContent) + assert f"Processed '{test_message}' in {steps} steps" in tool_result.content[0].text + + # Verify progress callback was called + assert len(progress_updates) == steps + for i, (progress, total, message) in enumerate(progress_updates): + expected_progress = (i + 1) / steps + assert abs(progress - expected_progress) < 0.01 + assert total == 1.0 + assert message is not None + assert f"step {i + 1} of {steps}" in message + + # Verify we received log messages from the tool + # Note: Progress notifications require special handling in the MCP client + # that's not implemented by default, so we focus on testing logging + assert len(collector.log_messages) > 0 + + # 3. Test sampling tool + prompt = "What is the meaning of life?" + sampling_result = await session.call_tool("sampling_tool", {"prompt": prompt}) + assert len(sampling_result.content) == 1 + assert isinstance(sampling_result.content[0], TextContent) + assert "Sampling result:" in sampling_result.content[0].text + assert "This is a simulated LLM response" in sampling_result.content[0].text + + # Verify we received log messages from the sampling tool + assert len(collector.log_messages) > 0 + assert any( + "Requesting sampling for prompt" in msg.data for msg in collector.log_messages + ) + assert any( + "Received sampling result from model" in msg.data + for msg in collector.log_messages + ) + + # 4. Test notification tool + notification_message = "test_notifications" + notification_result = await session.call_tool( + "notification_tool", {"message": notification_message} + ) + assert len(notification_result.content) == 1 + assert isinstance(notification_result.content[0], TextContent) + assert "Sent notifications and logs" in notification_result.content[0].text + + # Verify we received various notification types + assert len(collector.log_messages) > 3 # Should have logs from both tools + assert len(collector.resource_notifications) > 0 + assert len(collector.tool_notifications) > 0 + + # Check that we got different log levels + log_levels = [msg.level for msg in collector.log_messages] + assert "debug" in log_levels + assert "info" in log_levels + assert "warning" in log_levels + + # Test resources + # 1. Static resource + resources = await session.list_resources() + # Try using string comparison since AnyUrl might not match directly + static_resource = next( + (r for r in resources.resources if str(r.uri) == "resource://static/info"), + None, + ) + assert static_resource is not None + assert static_resource.name == "Static Info" + + static_content = await session.read_resource(AnyUrl("resource://static/info")) + assert isinstance(static_content, ReadResourceResult) + assert len(static_content.contents) == 1 + assert isinstance(static_content.contents[0], TextResourceContents) + assert static_content.contents[0].text == "This is static resource content" + + # 2. Dynamic resource + resource_category = "test" + dynamic_content = await session.read_resource( + AnyUrl(f"resource://dynamic/{resource_category}") + ) + assert isinstance(dynamic_content, ReadResourceResult) + assert len(dynamic_content.contents) == 1 + assert isinstance(dynamic_content.contents[0], TextResourceContents) + assert ( + f"Dynamic resource content for category: {resource_category}" + in dynamic_content.contents[0].text + ) + + # 3. Template resource + resource_id = "456" + template_content = await session.read_resource( + AnyUrl(f"resource://template/{resource_id}/data") + ) + assert isinstance(template_content, ReadResourceResult) + assert len(template_content.contents) == 1 + assert isinstance(template_content.contents[0], TextResourceContents) + assert ( + f"Template resource data for ID: {resource_id}" + in template_content.contents[0].text + ) + + # Test prompts + # 1. Simple prompt + prompts = await session.list_prompts() + simple_prompt = next( + (p for p in prompts.prompts if p.name == "simple_prompt"), None + ) + assert simple_prompt is not None + + prompt_topic = "AI" + prompt_result = await session.get_prompt("simple_prompt", {"topic": prompt_topic}) + assert isinstance(prompt_result, GetPromptResult) + assert len(prompt_result.messages) >= 1 + # The actual message structure depends on the prompt implementation + + # 2. Complex prompt + complex_prompt = next( + (p for p in prompts.prompts if p.name == "complex_prompt"), None + ) + assert complex_prompt is not None + + query = "What is AI?" + context = "technical" + complex_result = await session.get_prompt( + "complex_prompt", {"user_query": query, "context": context} + ) + assert isinstance(complex_result, GetPromptResult) + assert len(complex_result.messages) >= 1 + + +async def sampling_callback( + context: RequestContext[ClientSession, None], + params: CreateMessageRequestParams, +) -> CreateMessageResult: + # Simulate LLM response based on the input + if params.messages and isinstance(params.messages[0].content, TextContent): + input_text = params.messages[0].content.text + else: + input_text = "No input" + response_text = f"This is a simulated LLM response to: {input_text}" + + model_name = "test-llm-model" + return CreateMessageResult( + role="assistant", + content=TextContent(type="text", text=response_text), + model=model_name, + stopReason="endTurn", + ) + + @pytest.mark.anyio async def test_fastmcp_all_features_sse( everything_server: None, everything_server_url: str @@ -639,23 +837,6 @@ async def test_fastmcp_all_features_sse( collector = NotificationCollector() # Create a sampling callback that simulates an LLM - async def sampling_callback( - context: RequestContext[ClientSession, None], - params: CreateMessageRequestParams, - ) -> CreateMessageResult: - # Simulate LLM response based on the input - if params.messages and isinstance(params.messages[0].content, TextContent): - input_text = params.messages[0].content.text - else: - input_text = "No input" - response_text = f"This is a simulated LLM response to: {input_text}" - - return CreateMessageResult( - role="assistant", - content=TextContent(type="text", text=response_text), - model="test-llm-model", - stopReason="endTurn", - ) # Connect to the server with callbacks async with sse_client(everything_server_url + "/sse") as streams: @@ -671,178 +852,8 @@ async def message_handler(message): sampling_callback=sampling_callback, message_handler=message_handler, ) as session: - # Test initialization - result = await session.initialize() - assert isinstance(result, InitializeResult) - assert result.serverInfo.name == "EverythingServer" - - # Check server features are reported - assert result.capabilities.prompts is not None - assert result.capabilities.resources is not None - assert result.capabilities.tools is not None - # Note: logging capability may be None if no tools use context logging - - # Test tools - # 1. Simple echo tool - tool_result = await session.call_tool("echo", {"message": "hello"}) - assert len(tool_result.content) == 1 - assert isinstance(tool_result.content[0], TextContent) - assert tool_result.content[0].text == "Echo: hello" - - # 2. Tool with context (logging and progress) - # Test progress callback functionality - progress_updates = [] - - def progress_callback( - progress: float, total: float | None, message: str | None - ) -> None: - """Collect progress updates for testing.""" - progress_updates.append((progress, total, message)) - print(f"Progress: {progress}/{total} - {message}") - - params = { - "message": "test", - "steps": 3, - } - tool_result = await session.call_tool( - "tool_with_progress", - params, - progress_callback=progress_callback, - ) - assert len(tool_result.content) == 1 - assert isinstance(tool_result.content[0], TextContent) - assert "Processed 'test' in 3 steps" in tool_result.content[0].text - - # Verify progress callback was called - assert len(progress_updates) == 3 - for i, (progress, total, message) in enumerate(progress_updates): - expected_progress = (i + 1) / 3 - assert abs(progress - expected_progress) < 0.01 - assert total == 1.0 - assert message is not None - assert f"step {i + 1} of 3" in message - - # Verify we received log messages from the tool - # Note: Progress notifications require special handling in the MCP client - # that's not implemented by default, so we focus on testing logging - assert len(collector.log_messages) > 0 - - # 3. Test sampling tool - sampling_result = await session.call_tool( - "sampling_tool", {"prompt": "What is the meaning of life?"} - ) - assert len(sampling_result.content) == 1 - assert isinstance(sampling_result.content[0], TextContent) - assert "Sampling result:" in sampling_result.content[0].text - assert "This is a simulated LLM response" in sampling_result.content[0].text - - # Give time for log messages - await anyio.sleep(0.1) - - # Verify we received log messages from the sampling tool - assert len(collector.log_messages) > 0 - assert any( - "Requesting sampling for prompt" in msg.data - for msg in collector.log_messages - ) - assert any( - "Received sampling result from model" in msg.data - for msg in collector.log_messages - ) - - # 4. Test notification tool - notification_result = await session.call_tool( - "notification_tool", {"message": "test_notifications"} - ) - assert len(notification_result.content) == 1 - assert isinstance(notification_result.content[0], TextContent) - assert "Sent notifications and logs" in notification_result.content[0].text - - # Give time for notifications - await anyio.sleep(0.1) - - # Verify we received various notification types - assert len(collector.log_messages) > 3 # Should have logs from both tools - assert len(collector.resource_notifications) > 0 - assert len(collector.tool_notifications) > 0 - - # Check that we got different log levels - log_levels = [msg.level for msg in collector.log_messages] - assert "debug" in log_levels - assert "info" in log_levels - assert "warning" in log_levels - - # Test resources - # 1. Static resource - resources = await session.list_resources() - # Try using string comparison since AnyUrl might not match directly - static_resource = next( - ( - r - for r in resources.resources - if str(r.uri) == "resource://static/info" - ), - None, - ) - assert static_resource is not None - assert static_resource.name == "Static Info" - - static_content = await session.read_resource( - AnyUrl("resource://static/info") - ) - assert isinstance(static_content, ReadResourceResult) - assert len(static_content.contents) == 1 - assert isinstance(static_content.contents[0], TextResourceContents) - assert static_content.contents[0].text == "This is static resource content" - - # 2. Dynamic resource - dynamic_content = await session.read_resource( - AnyUrl("resource://dynamic/test") - ) - assert isinstance(dynamic_content, ReadResourceResult) - assert len(dynamic_content.contents) == 1 - assert isinstance(dynamic_content.contents[0], TextResourceContents) - assert ( - "Dynamic resource content for category: test" - in dynamic_content.contents[0].text - ) - - # 3. Template resource - template_content = await session.read_resource( - AnyUrl("resource://template/123/data") - ) - assert isinstance(template_content, ReadResourceResult) - assert len(template_content.contents) == 1 - assert isinstance(template_content.contents[0], TextResourceContents) - assert ( - "Template resource data for ID: 123" - in template_content.contents[0].text - ) - - # Test prompts - # 1. Simple prompt - prompts = await session.list_prompts() - simple_prompt = next( - (p for p in prompts.prompts if p.name == "simple_prompt"), None - ) - assert simple_prompt is not None - - prompt_result = await session.get_prompt("simple_prompt", {"topic": "AI"}) - assert isinstance(prompt_result, GetPromptResult) - assert len(prompt_result.messages) >= 1 - # The actual message structure depends on the prompt implementation - - # 2. Complex prompt - complex_prompt = next( - (p for p in prompts.prompts if p.name == "complex_prompt"), None - ) - assert complex_prompt is not None - - complex_result = await session.get_prompt( - "complex_prompt", {"user_query": "What is AI?", "context": "technical"} - ) - assert isinstance(complex_result, GetPromptResult) - assert len(complex_result.messages) >= 1 + # Run the common test suite + await call_all_mcp_features(session, collector) @pytest.mark.anyio @@ -854,25 +865,6 @@ async def test_fastmcp_all_features_streamable_http( # Create notification collector collector = NotificationCollector() - # Create a sampling callback that simulates an LLM - async def sampling_callback( - context: RequestContext[ClientSession, None], - params: CreateMessageRequestParams, - ) -> CreateMessageResult: - # Simulate LLM response - if params.messages and isinstance(params.messages[0].content, TextContent): - input_text = params.messages[0].content.text - else: - input_text = "No input" - response_text = f"This is a simulated LLM response to: {input_text}" - - return CreateMessageResult( - role="assistant", - content=TextContent(type="text", text=response_text), - model="test-llm-model-http", - stopReason="endTurn", - ) - # Connect to the server using StreamableHTTP async with streamablehttp_client(everything_http_server_url + "/mcp") as ( read_stream, @@ -892,163 +884,5 @@ async def message_handler(message): sampling_callback=sampling_callback, message_handler=message_handler, ) as session: - # Test initialization - result = await session.initialize() - assert isinstance(result, InitializeResult) - assert result.serverInfo.name == "EverythingServer" - - # Check server features are reported - assert result.capabilities.prompts is not None - assert result.capabilities.resources is not None - assert result.capabilities.tools is not None - # Note: logging capability may be None if no tools use context logging - - # Test tools - # 1. Simple echo tool - tool_result = await session.call_tool("echo", {"message": "hello"}) - assert len(tool_result.content) == 1 - assert isinstance(tool_result.content[0], TextContent) - assert tool_result.content[0].text == "Echo: hello" - - # 2. Tool with context (logging and progress) - # Test progress callback functionality over HTTP - progress_updates_http = [] - - def progress_callback_http( - progress: float, total: float | None, message: str | None - ) -> None: - """Collect progress updates for HTTP testing.""" - progress_updates_http.append((progress, total, message)) - print(f"HTTP Progress: {progress}/{total} - {message}") - - await session.call_tool( - "tool_with_progress", - { - "message": "http_test", - "steps": 2, - }, - progress_callback=progress_callback_http, - ) - - # Verify progress callback was called over HTTP - assert len(progress_updates_http) == 2 - for i, (progress, total, message) in enumerate(progress_updates_http): - expected_progress = (i + 1) / 2 - assert abs(progress - expected_progress) < 0.01 - assert total == 1.0 - assert message is not None - assert f"step {i + 1} of 2" in message - - # 3. Test sampling tool - sampling_result = await session.call_tool( - "sampling_tool", {"prompt": "Explain quantum computing"} - ) - assert len(sampling_result.content) == 1 - assert isinstance(sampling_result.content[0], TextContent) - assert "Sampling result:" in sampling_result.content[0].text - assert "This is a simulated LLM response" in sampling_result.content[0].text - - # Give time for log messages - await anyio.sleep(0.1) - - # Verify we received log messages - assert len(collector.log_messages) > 0 - assert any( - "Requesting sampling for prompt" in msg.data - for msg in collector.log_messages - ) - - # Test resources - # 1. Static resource - resources = await session.list_resources() - # Try using string comparison since AnyUrl might not match directly - static_resource = next( - ( - r - for r in resources.resources - if str(r.uri) == "resource://static/info" - ), - None, - ) - assert static_resource is not None - assert static_resource.name == "Static Info" - - static_content = await session.read_resource( - AnyUrl("resource://static/info") - ) - assert isinstance(static_content, ReadResourceResult) - assert len(static_content.contents) == 1 - assert isinstance(static_content.contents[0], TextResourceContents) - assert static_content.contents[0].text == "This is static resource content" - - # 2. Dynamic resource - dynamic_content = await session.read_resource( - AnyUrl("resource://dynamic/http") - ) - assert isinstance(dynamic_content, ReadResourceResult) - assert len(dynamic_content.contents) == 1 - assert isinstance(dynamic_content.contents[0], TextResourceContents) - assert ( - "Dynamic resource content for category: http" - in dynamic_content.contents[0].text - ) - - # 3. Template resource - template_content = await session.read_resource( - AnyUrl("resource://template/456/data") - ) - assert isinstance(template_content, ReadResourceResult) - assert len(template_content.contents) == 1 - assert isinstance(template_content.contents[0], TextResourceContents) - assert ( - "Template resource data for ID: 456" - in template_content.contents[0].text - ) - - # Test prompts - # 1. Simple prompt - prompts = await session.list_prompts() - simple_prompt = next( - (p for p in prompts.prompts if p.name == "simple_prompt"), None - ) - assert simple_prompt is not None - - prompt_result = await session.get_prompt("simple_prompt", {"topic": "HTTP"}) - assert isinstance(prompt_result, GetPromptResult) - assert len(prompt_result.messages) >= 1 - # The actual message structure depends on the prompt implementation - - # 2. Complex prompt - complex_prompt = next( - (p for p in prompts.prompts if p.name == "complex_prompt"), None - ) - assert complex_prompt is not None - - complex_result = await session.get_prompt( - "complex_prompt", {"user_query": "What is HTTP?", "context": "web"} - ) - assert isinstance(complex_result, GetPromptResult) - assert len(complex_result.messages) >= 1 - - # Test that all features work in sequence (integration test) - # This tests that the different transport doesn't affect feature interaction - for i in range(3): - # Call tool - tool_result = await session.call_tool( - "echo", {"message": f"iteration_{i}"} - ) - assert isinstance(tool_result.content[0], TextContent) - assert f"iteration_{i}" in tool_result.content[0].text - - # Read resource - resource_result = await session.read_resource( - AnyUrl(f"resource://dynamic/{i}") - ) - assert isinstance(resource_result.contents[0], TextResourceContents) - assert f"category: {i}" in resource_result.contents[0].text - - # Get prompt - prompt_result = await session.get_prompt( - "simple_prompt", {"topic": f"topic_{i}"} - ) - assert len(prompt_result.messages) >= 1 + # Run the common test suite with HTTP-specific test suffix + await call_all_mcp_features(session, collector) From a0c4b1888c9ba508cee588de8b89d2e081fac20a Mon Sep 17 00:00:00 2001 From: ihrpr Date: Thu, 15 May 2025 16:42:13 +0100 Subject: [PATCH 7/9] ruff --- .../__main__.py | 2 +- src/mcp/shared/session.py | 24 +++++++++---------- tests/client/test_list_methods_cursor.py | 2 +- tests/server/fastmcp/test_integration.py | 18 +++++--------- 4 files changed, 20 insertions(+), 26 deletions(-) diff --git a/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/__main__.py b/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/__main__.py index c72c76f40..1664737e3 100644 --- a/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/__main__.py +++ b/examples/servers/simple-streamablehttp-stateless/mcp_simple_streamablehttp_stateless/__main__.py @@ -3,5 +3,5 @@ if __name__ == "__main__": # Click will handle CLI arguments import sys - + sys.exit(main()) # type: ignore[call-arg] diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 7fc08d679..3292f0778 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -384,19 +384,19 @@ async def _receive_loop(self) -> None: cancelled_id = notification.root.params.requestId if cancelled_id in self._in_flight: await self._in_flight[cancelled_id].cancel() - # Handle progress notifications - elif isinstance(notification.root, ProgressNotification): - progress_token = notification.root.params.progressToken - # If there is a progress callback for this token, - # call it with the progress information - if progress_token in self._progress_callbacks: - callback = self._progress_callbacks[progress_token] - callback( - notification.root.params.progress, - notification.root.params.total, - notification.root.params.message, - ) else: + # Handle progress notifications callback + if isinstance(notification.root, ProgressNotification): + progress_token = notification.root.params.progressToken + # If there is a progress callback for this token, + # call it with the progress information + if progress_token in self._progress_callbacks: + callback = self._progress_callbacks[progress_token] + callback( + notification.root.params.progress, + notification.root.params.total, + notification.root.params.message, + ) await self._received_notification(notification) await self._handle_incoming(notification) except Exception as e: diff --git a/tests/client/test_list_methods_cursor.py b/tests/client/test_list_methods_cursor.py index f07473f4c..b0d6e36b8 100644 --- a/tests/client/test_list_methods_cursor.py +++ b/tests/client/test_list_methods_cursor.py @@ -11,7 +11,7 @@ async def test_list_tools_cursor_parameter(): """Test that the cursor parameter is accepted for list_tools. - + Note: FastMCP doesn't currently implement pagination, so this test only verifies that the cursor parameter is accepted by the client. """ diff --git a/tests/server/fastmcp/test_integration.py b/tests/server/fastmcp/test_integration.py index e7e314e72..8a150c89d 100644 --- a/tests/server/fastmcp/test_integration.py +++ b/tests/server/fastmcp/test_integration.py @@ -10,7 +10,6 @@ import time from collections.abc import Generator -import anyio import pytest import uvicorn from pydantic import AnyUrl @@ -20,6 +19,7 @@ from mcp.client.sse import sse_client from mcp.client.streamable_http import streamablehttp_client from mcp.server.fastmcp import FastMCP +from mcp.server.fastmcp.resources import FunctionResource from mcp.shared.context import RequestContext from mcp.types import ( CreateMessageRequestParams, @@ -27,6 +27,7 @@ GetPromptResult, InitializeResult, ReadResourceResult, + SamplingMessage, TextContent, TextResourceContents, ) @@ -123,8 +124,6 @@ def echo(message: str) -> str: # Tool with sampling capability @mcp.tool(description="A tool that uses sampling to generate content") async def sampling_tool(prompt: str, ctx: Context) -> str: - from mcp.types import SamplingMessage, TextContent - await ctx.info(f"Requesting sampling for prompt: {prompt}") # Request sampling from the client @@ -161,10 +160,6 @@ async def notification_tool(message: str, ctx: Context) -> str: return f"Sent notifications and logs for: {message}" # Resource - static - from pydantic import AnyUrl - - from mcp.server.fastmcp.resources import FunctionResource - def get_static_info() -> str: return "This is static resource content" @@ -293,7 +288,7 @@ def run_streamable_http_server(server_port: int) -> None: server.run() -def run_comprehensive_streamable_http_server(server_port: int) -> None: +def run_everything_streamable_http_server(server_port: int) -> None: """Run the comprehensive StreamableHTTP server with all features.""" _, app = make_everything_fastmcp_streamable_http_app() server = uvicorn.Server( @@ -492,7 +487,6 @@ async def test_fastmcp_stateless_streamable_http( assert tool_result.content[0].text == f"Echo: test_{i}" -# Fixtures for comprehensive servers @pytest.fixture def everything_server_port() -> int: """Get a free port for testing the comprehensive server.""" @@ -557,12 +551,12 @@ def everything_server(everything_server_port: int) -> Generator[None, None, None @pytest.fixture() -def comprehensive_streamable_http_server( +def everything_streamable_http_server( everything_http_server_port: int, ) -> Generator[None, None, None]: """Start the comprehensive StreamableHTTP server in a separate process.""" proc = multiprocessing.Process( - target=run_comprehensive_streamable_http_server, + target=run_everything_streamable_http_server, args=(everything_http_server_port,), daemon=True, ) @@ -858,7 +852,7 @@ async def message_handler(message): @pytest.mark.anyio async def test_fastmcp_all_features_streamable_http( - comprehensive_streamable_http_server: None, everything_http_server_url: str + everything_streamable_http_server: None, everything_http_server_url: str ) -> None: """Test all MCP features work correctly with StreamableHTTP transport.""" From e2852ad58f899e03d3de502779523c6858294a84 Mon Sep 17 00:00:00 2001 From: ihrpr Date: Thu, 15 May 2025 17:28:53 +0100 Subject: [PATCH 8/9] make callback async --- src/mcp/shared/session.py | 4 ++-- tests/server/fastmcp/test_integration.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 3292f0778..90b4eb27c 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -46,7 +46,7 @@ class ProgressFnT(Protocol): """Protocol for progress notification callbacks.""" - def __call__( + async def __call__( self, progress: float, total: float | None, message: str | None ) -> None: ... @@ -392,7 +392,7 @@ async def _receive_loop(self) -> None: # call it with the progress information if progress_token in self._progress_callbacks: callback = self._progress_callbacks[progress_token] - callback( + await callback( notification.root.params.progress, notification.root.params.total, notification.root.params.message, diff --git a/tests/server/fastmcp/test_integration.py b/tests/server/fastmcp/test_integration.py index 8a150c89d..0224a1726 100644 --- a/tests/server/fastmcp/test_integration.py +++ b/tests/server/fastmcp/test_integration.py @@ -655,10 +655,10 @@ async def call_all_mcp_features( # Test progress callback functionality progress_updates = [] - def progress_callback( + async def progress_callback( progress: float, total: float | None, message: str | None ) -> None: - """Collect progress updates for testing.""" + """Collect progress updates for testing (async version).""" progress_updates.append((progress, total, message)) print(f"Progress: {progress}/{total} - {message}") From a1b2b10081f620923594e95b7cb7133d08ac09a1 Mon Sep 17 00:00:00 2001 From: ihrpr Date: Thu, 15 May 2025 17:31:05 +0100 Subject: [PATCH 9/9] change names --- tests/server/fastmcp/test_integration.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/server/fastmcp/test_integration.py b/tests/server/fastmcp/test_integration.py index 0224a1726..79285ecb1 100644 --- a/tests/server/fastmcp/test_integration.py +++ b/tests/server/fastmcp/test_integration.py @@ -264,7 +264,7 @@ def run_server(server_port: int) -> None: server.run() -def run_everything_server(server_port: int) -> None: +def run_everything_legacy_sse_http_server(server_port: int) -> None: """Run the comprehensive server with all features.""" _, app = make_everything_fastmcp_app() server = uvicorn.Server( @@ -288,7 +288,7 @@ def run_streamable_http_server(server_port: int) -> None: server.run() -def run_everything_streamable_http_server(server_port: int) -> None: +def run_everything_server(server_port: int) -> None: """Run the comprehensive StreamableHTTP server with all features.""" _, app = make_everything_fastmcp_streamable_http_app() server = uvicorn.Server( @@ -519,7 +519,9 @@ def everything_http_server_url(everything_http_server_port: int) -> str: def everything_server(everything_server_port: int) -> Generator[None, None, None]: """Start the comprehensive server in a separate process and clean up after.""" proc = multiprocessing.Process( - target=run_everything_server, args=(everything_server_port,), daemon=True + target=run_everything_legacy_sse_http_server, + args=(everything_server_port,), + daemon=True, ) print("Starting comprehensive server process") proc.start() @@ -556,7 +558,7 @@ def everything_streamable_http_server( ) -> Generator[None, None, None]: """Start the comprehensive StreamableHTTP server in a separate process.""" proc = multiprocessing.Process( - target=run_everything_streamable_http_server, + target=run_everything_server, args=(everything_http_server_port,), daemon=True, )