From bd097bfce5e4d651dd3a43071b8249527b414c14 Mon Sep 17 00:00:00 2001 From: Tim Child Date: Mon, 31 Mar 2025 13:19:16 -0700 Subject: [PATCH 1/8] add test that checks if stdio connection hangs with bad connection params --- tests/client/test_stdio.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/client/test_stdio.py b/tests/client/test_stdio.py index 95747ffd1..5ceaca15e 100644 --- a/tests/client/test_stdio.py +++ b/tests/client/test_stdio.py @@ -1,6 +1,7 @@ import shutil import pytest +from anyio import fail_after from mcp.client.stdio import StdioServerParameters, stdio_client from mcp.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse @@ -41,3 +42,18 @@ async def test_stdio_client(): assert read_messages[1] == JSONRPCMessage( root=JSONRPCResponse(jsonrpc="2.0", id=2, result={}) ) + + +@pytest.mark.anyio +async def test_stdio_client_bad_path(): + """Check that the connection doesn't hang if process errors.""" + server_parameters = StdioServerParameters( + command="uv", args=["run", "non-existent-file.py"] + ) + + try: + with fail_after(1): + async with stdio_client(server_parameters) as (read_stream, write_stream): + pass + except TimeoutError: + pytest.fail("The connection hung.") From 8558120eac0ed2eaf16aa71e49b4344bbb152a2e Mon Sep 17 00:00:00 2001 From: Tim Child Date: Mon, 31 Mar 2025 14:18:09 -0700 Subject: [PATCH 2/8] fix process hanging on bad stdio connection params --- src/mcp/client/stdio/__init__.py | 31 ++++++++++++++++++++++++++----- tests/client/test_stdio.py | 29 ++++++++++++++++++++++------- 2 files changed, 48 insertions(+), 12 deletions(-) diff --git a/src/mcp/client/stdio/__init__.py b/src/mcp/client/stdio/__init__.py index 83de57a2b..3f0aff658 100644 --- a/src/mcp/client/stdio/__init__.py +++ b/src/mcp/client/stdio/__init__.py @@ -6,6 +6,7 @@ import anyio import anyio.lowlevel +from anyio.abc import Process from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from anyio.streams.text import TextReceiveStream from pydantic import BaseModel, Field @@ -38,6 +39,10 @@ ) +class ProcessTerminatedEarlyError(Exception): + """Raised when a process terminates unexpectedly.""" + + def get_default_environment() -> dict[str, str]: """ Returns a default environment object including only environment variables deemed @@ -110,7 +115,7 @@ async def stdio_client(server: StdioServerParameters, errlog: TextIO = sys.stder command = _get_executable_command(server.command) # Open process with stderr piped for capture - process = await _create_platform_compatible_process( + process: Process = await _create_platform_compatible_process( command=command, args=server.args, env=( @@ -163,20 +168,36 @@ async def stdin_writer(): except anyio.ClosedResourceError: await anyio.lowlevel.checkpoint() + process_error: str | None = None + async with ( anyio.create_task_group() as tg, process, ): tg.start_soon(stdout_reader) tg.start_soon(stdin_writer) + # tg.start_soon(monitor_process, tg.cancel_scope) try: yield read_stream, write_stream finally: - # Clean up process to prevent any dangling orphaned processes - if sys.platform == "win32": - await terminate_windows_process(process) + await read_stream.aclose() + await write_stream.aclose() + await read_stream_writer.aclose() + await write_stream_reader.aclose() + + if process.returncode is not None and process.returncode != 0: + process_error = f"Process exited with code {process.returncode}." else: - process.terminate() + # Clean up process to prevent any dangling orphaned processes + if sys.platform == "win32": + await terminate_windows_process(process) + else: + process.terminate() + + if process_error: + # Raise outside the task group so that the error is not wrapped in an + # ExceptionGroup + raise ProcessTerminatedEarlyError(process_error) def _get_executable_command(command: str) -> str: diff --git a/tests/client/test_stdio.py b/tests/client/test_stdio.py index 5ceaca15e..2799f838d 100644 --- a/tests/client/test_stdio.py +++ b/tests/client/test_stdio.py @@ -3,7 +3,11 @@ import pytest from anyio import fail_after -from mcp.client.stdio import StdioServerParameters, stdio_client +from mcp.client.stdio import ( + ProcessTerminatedEarlyError, + StdioServerParameters, + stdio_client, +) from mcp.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse tee: str = shutil.which("tee") # type: ignore @@ -51,9 +55,20 @@ async def test_stdio_client_bad_path(): command="uv", args=["run", "non-existent-file.py"] ) - try: - with fail_after(1): - async with stdio_client(server_parameters) as (read_stream, write_stream): - pass - except TimeoutError: - pytest.fail("The connection hung.") + with pytest.raises(ProcessTerminatedEarlyError): + try: + with fail_after(1): + async with stdio_client(server_parameters) as ( + read_stream, + _, + ): + # Try waiting for read_stream so that we don't exit before the + # process fails. + async with read_stream: + async for message in read_stream: + if isinstance(message, Exception): + raise message + + pass + except TimeoutError: + pytest.fail("The connection hung.") From be41b81c52dda1cdaed6ea34874b09b527dbf387 Mon Sep 17 00:00:00 2001 From: Tim Child Date: Mon, 31 Mar 2025 14:19:43 -0700 Subject: [PATCH 3/8] make sure test only runs if `uv` available --- tests/client/test_stdio.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/client/test_stdio.py b/tests/client/test_stdio.py index 2799f838d..2dbc38b43 100644 --- a/tests/client/test_stdio.py +++ b/tests/client/test_stdio.py @@ -11,6 +11,7 @@ from mcp.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse tee: str = shutil.which("tee") # type: ignore +uv: str = shutil.which("uv") # type: ignore @pytest.mark.anyio @@ -49,6 +50,7 @@ async def test_stdio_client(): @pytest.mark.anyio +@pytest.mark.skipif(uv is None, reason="could not find uv command") async def test_stdio_client_bad_path(): """Check that the connection doesn't hang if process errors.""" server_parameters = StdioServerParameters( From 90f224d6e19521e08246c3362bfa4a1bc7b0f080 Mon Sep 17 00:00:00 2001 From: Tim Child Date: Mon, 7 Apr 2025 11:59:48 -0700 Subject: [PATCH 4/8] fix detection of failed process --- src/mcp/client/stdio/__init__.py | 26 ++++++++++++++++++-------- tests/client/test_stdio.py | 2 -- 2 files changed, 18 insertions(+), 10 deletions(-) diff --git a/src/mcp/client/stdio/__init__.py b/src/mcp/client/stdio/__init__.py index 3f0aff658..d51d5c125 100644 --- a/src/mcp/client/stdio/__init__.py +++ b/src/mcp/client/stdio/__init__.py @@ -127,7 +127,7 @@ async def stdio_client(server: StdioServerParameters, errlog: TextIO = sys.stder cwd=server.cwd, ) - async def stdout_reader(): + async def stdout_reader(done_event: anyio.Event): assert process.stdout, "Opened process is missing stdout" try: @@ -151,6 +151,7 @@ async def stdout_reader(): await read_stream_writer.send(message) except anyio.ClosedResourceError: await anyio.lowlevel.checkpoint() + done_event.set() async def stdin_writer(): assert process.stdin, "Opened process is missing stdin" @@ -174,21 +175,30 @@ async def stdin_writer(): anyio.create_task_group() as tg, process, ): - tg.start_soon(stdout_reader) + stdout_done_event = anyio.Event() + tg.start_soon(stdout_reader, stdout_done_event) tg.start_soon(stdin_writer) - # tg.start_soon(monitor_process, tg.cancel_scope) try: yield read_stream, write_stream + if stdout_done_event.is_set(): + # The stdout reader exited before the calling code stopped listening + # (e.g. because of process error) + # Give the process a chance to exit if it was the reason for crashing + # so we can get exit code + with anyio.move_on_after(0.1) as scope: + await process.wait() + process_error = f"Process exited with code {process.returncode}." + if scope.cancelled_caught: + process_error = ( + "Stdout reader exited (process did not exit immediately)." + ) finally: await read_stream.aclose() await write_stream.aclose() await read_stream_writer.aclose() await write_stream_reader.aclose() - - if process.returncode is not None and process.returncode != 0: - process_error = f"Process exited with code {process.returncode}." - else: - # Clean up process to prevent any dangling orphaned processes + # Clean up process to prevent any dangling orphaned processes + if process.returncode is None: if sys.platform == "win32": await terminate_windows_process(process) else: diff --git a/tests/client/test_stdio.py b/tests/client/test_stdio.py index 2dbc38b43..ae9689743 100644 --- a/tests/client/test_stdio.py +++ b/tests/client/test_stdio.py @@ -70,7 +70,5 @@ async def test_stdio_client_bad_path(): async for message in read_stream: if isinstance(message, Exception): raise message - - pass except TimeoutError: pytest.fail("The connection hung.") From acce290cb4a64a47b608e9fd6835a246d775ce2a Mon Sep 17 00:00:00 2001 From: ihrpr Date: Wed, 28 May 2025 21:37:18 +0100 Subject: [PATCH 5/8] change tests --- src/mcp/client/stdio/__init__.py | 12 +++++--- tests/client/test_stdio.py | 52 ++++++++++++++++++++------------ 2 files changed, 40 insertions(+), 24 deletions(-) diff --git a/src/mcp/client/stdio/__init__.py b/src/mcp/client/stdio/__init__.py index 6d815b43a..5af82b47d 100644 --- a/src/mcp/client/stdio/__init__.py +++ b/src/mcp/client/stdio/__init__.py @@ -177,10 +177,14 @@ async def stdin_writer(): yield read_stream, write_stream finally: # Clean up process to prevent any dangling orphaned processes - if sys.platform == "win32": - await terminate_windows_process(process) - else: - process.terminate() + try: + if sys.platform == "win32": + await terminate_windows_process(process) + else: + process.terminate() + except ProcessLookupError: + # Process already exited, which is fine + pass await read_stream.aclose() await write_stream.aclose() diff --git a/tests/client/test_stdio.py b/tests/client/test_stdio.py index d8a07bfec..6777e8d04 100644 --- a/tests/client/test_stdio.py +++ b/tests/client/test_stdio.py @@ -1,17 +1,18 @@ import shutil import pytest -from anyio import fail_after +from mcp.client.session import ClientSession from mcp.client.stdio import ( StdioServerParameters, stdio_client, ) +from mcp.shared.exceptions import McpError from mcp.shared.message import SessionMessage -from mcp.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse +from mcp.types import CONNECTION_CLOSED, JSONRPCMessage, JSONRPCRequest, JSONRPCResponse tee: str = shutil.which("tee") # type: ignore -uv: str = shutil.which("uv") # type: ignore +python: str = shutil.which("python") # type: ignore @pytest.mark.anyio @@ -58,25 +59,36 @@ async def test_stdio_client(): @pytest.mark.anyio -@pytest.mark.skipif(uv is None, reason="could not find uv command") async def test_stdio_client_bad_path(): """Check that the connection doesn't hang if process errors.""" - server_parameters = StdioServerParameters( - command="uv", args=["run", "non-existent-file.py"] + server_params = StdioServerParameters( + command="python", args=["-c", "non-existent-file.py"] ) + async with stdio_client(server_params) as (read_stream, write_stream): + async with ClientSession(read_stream, write_stream) as session: + # The session should raise an error when the connection closes + with pytest.raises(McpError) as exc_info: + await session.initialize() + # Check that we got a connection closed error + assert exc_info.value.error.code == CONNECTION_CLOSED + assert "Connection closed" in exc_info.value.error.message + + +@pytest.mark.anyio +async def test_stdio_client_nonexistent_command(): + """Test that stdio_client raises an error for non-existent commands.""" + # Create a server with a non-existent command + server_params = StdioServerParameters( + command="/path/to/nonexistent/command", + args=["--help"], + ) + + # Should raise an error when trying to start the process with pytest.raises(Exception) as exc_info: - try: - with fail_after(1): - async with stdio_client(server_parameters) as ( - read_stream, - _, - ): - # Try waiting for read_stream so that we don't exit before the - # process fails. - async with read_stream: - async for message in read_stream: - if isinstance(message, Exception): - raise message - except TimeoutError: - pytest.fail("The connection hung.") + async with stdio_client(server_params) as (_, _): + pass + + # The error should indicate the command was not found + error_message = str(exc_info.value) + assert "nonexistent" in error_message or "not found" in error_message.lower() From 1fb54485ecaafe7bb0864df31ae3ee4b54e6ef6a Mon Sep 17 00:00:00 2001 From: ihrpr Date: Wed, 28 May 2025 21:45:56 +0100 Subject: [PATCH 6/8] close unclosed streams that were masked by not cought exceptions --- src/mcp/client/stdio/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/mcp/client/stdio/__init__.py b/src/mcp/client/stdio/__init__.py index 5af82b47d..b6d5cfcb7 100644 --- a/src/mcp/client/stdio/__init__.py +++ b/src/mcp/client/stdio/__init__.py @@ -187,6 +187,8 @@ async def stdin_writer(): pass await read_stream.aclose() await write_stream.aclose() + await read_stream_writer.aclose() + await write_stream_reader.aclose() def _get_executable_command(command: str) -> str: From ba385fa7f75c212d94148905738c7fa0dc1cdbc7 Mon Sep 17 00:00:00 2001 From: ihrpr Date: Wed, 28 May 2025 21:54:43 +0100 Subject: [PATCH 7/8] catch OSError as it's casing unclosed resources --- src/mcp/client/stdio/__init__.py | 36 +++++++++++++++++++------------- 1 file changed, 22 insertions(+), 14 deletions(-) diff --git a/src/mcp/client/stdio/__init__.py b/src/mcp/client/stdio/__init__.py index b6d5cfcb7..fce605633 100644 --- a/src/mcp/client/stdio/__init__.py +++ b/src/mcp/client/stdio/__init__.py @@ -108,20 +108,28 @@ async def stdio_client(server: StdioServerParameters, errlog: TextIO = sys.stder read_stream_writer, read_stream = anyio.create_memory_object_stream(0) write_stream, write_stream_reader = anyio.create_memory_object_stream(0) - command = _get_executable_command(server.command) - - # Open process with stderr piped for capture - process = await _create_platform_compatible_process( - command=command, - args=server.args, - env=( - {**get_default_environment(), **server.env} - if server.env is not None - else get_default_environment() - ), - errlog=errlog, - cwd=server.cwd, - ) + try: + command = _get_executable_command(server.command) + + # Open process with stderr piped for capture + process = await _create_platform_compatible_process( + command=command, + args=server.args, + env=( + {**get_default_environment(), **server.env} + if server.env is not None + else get_default_environment() + ), + errlog=errlog, + cwd=server.cwd, + ) + except OSError: + # Clean up streams if process creation fails + await read_stream.aclose() + await write_stream.aclose() + await read_stream_writer.aclose() + await write_stream_reader.aclose() + raise async def stdout_reader(): assert process.stdout, "Opened process is missing stdout" From a94b5490b0272899145d731422c12873e6fbb3d0 Mon Sep 17 00:00:00 2001 From: ihrpr Date: Wed, 28 May 2025 21:58:59 +0100 Subject: [PATCH 8/8] win error --- tests/client/test_stdio.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/client/test_stdio.py b/tests/client/test_stdio.py index 6777e8d04..1c6ffe000 100644 --- a/tests/client/test_stdio.py +++ b/tests/client/test_stdio.py @@ -91,4 +91,8 @@ async def test_stdio_client_nonexistent_command(): # The error should indicate the command was not found error_message = str(exc_info.value) - assert "nonexistent" in error_message or "not found" in error_message.lower() + assert ( + "nonexistent" in error_message + or "not found" in error_message.lower() + or "cannot find the file" in error_message.lower() # Windows error message + )