diff --git a/src/mcp/client/stdio/__init__.py b/src/mcp/client/stdio/__init__.py index 6d815b43a..fce605633 100644 --- a/src/mcp/client/stdio/__init__.py +++ b/src/mcp/client/stdio/__init__.py @@ -108,20 +108,28 @@ async def stdio_client(server: StdioServerParameters, errlog: TextIO = sys.stder read_stream_writer, read_stream = anyio.create_memory_object_stream(0) write_stream, write_stream_reader = anyio.create_memory_object_stream(0) - command = _get_executable_command(server.command) - - # Open process with stderr piped for capture - process = await _create_platform_compatible_process( - command=command, - args=server.args, - env=( - {**get_default_environment(), **server.env} - if server.env is not None - else get_default_environment() - ), - errlog=errlog, - cwd=server.cwd, - ) + try: + command = _get_executable_command(server.command) + + # Open process with stderr piped for capture + process = await _create_platform_compatible_process( + command=command, + args=server.args, + env=( + {**get_default_environment(), **server.env} + if server.env is not None + else get_default_environment() + ), + errlog=errlog, + cwd=server.cwd, + ) + except OSError: + # Clean up streams if process creation fails + await read_stream.aclose() + await write_stream.aclose() + await read_stream_writer.aclose() + await write_stream_reader.aclose() + raise async def stdout_reader(): assert process.stdout, "Opened process is missing stdout" @@ -177,12 +185,18 @@ async def stdin_writer(): yield read_stream, write_stream finally: # Clean up process to prevent any dangling orphaned processes - if sys.platform == "win32": - await terminate_windows_process(process) - else: - process.terminate() + try: + if sys.platform == "win32": + await terminate_windows_process(process) + else: + process.terminate() + except ProcessLookupError: + # Process already exited, which is fine + pass await read_stream.aclose() await write_stream.aclose() + await read_stream_writer.aclose() + await write_stream_reader.aclose() def _get_executable_command(command: str) -> str: diff --git a/tests/client/test_stdio.py b/tests/client/test_stdio.py index 33d90e769..1c6ffe000 100644 --- a/tests/client/test_stdio.py +++ b/tests/client/test_stdio.py @@ -2,11 +2,17 @@ import pytest -from mcp.client.stdio import StdioServerParameters, stdio_client +from mcp.client.session import ClientSession +from mcp.client.stdio import ( + StdioServerParameters, + stdio_client, +) +from mcp.shared.exceptions import McpError from mcp.shared.message import SessionMessage -from mcp.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse +from mcp.types import CONNECTION_CLOSED, JSONRPCMessage, JSONRPCRequest, JSONRPCResponse tee: str = shutil.which("tee") # type: ignore +python: str = shutil.which("python") # type: ignore @pytest.mark.anyio @@ -50,3 +56,43 @@ async def test_stdio_client(): assert read_messages[1] == JSONRPCMessage( root=JSONRPCResponse(jsonrpc="2.0", id=2, result={}) ) + + +@pytest.mark.anyio +async def test_stdio_client_bad_path(): + """Check that the connection doesn't hang if process errors.""" + server_params = StdioServerParameters( + command="python", args=["-c", "non-existent-file.py"] + ) + async with stdio_client(server_params) as (read_stream, write_stream): + async with ClientSession(read_stream, write_stream) as session: + # The session should raise an error when the connection closes + with pytest.raises(McpError) as exc_info: + await session.initialize() + + # Check that we got a connection closed error + assert exc_info.value.error.code == CONNECTION_CLOSED + assert "Connection closed" in exc_info.value.error.message + + +@pytest.mark.anyio +async def test_stdio_client_nonexistent_command(): + """Test that stdio_client raises an error for non-existent commands.""" + # Create a server with a non-existent command + server_params = StdioServerParameters( + command="/path/to/nonexistent/command", + args=["--help"], + ) + + # Should raise an error when trying to start the process + with pytest.raises(Exception) as exc_info: + async with stdio_client(server_params) as (_, _): + pass + + # The error should indicate the command was not found + error_message = str(exc_info.value) + assert ( + "nonexistent" in error_message + or "not found" in error_message.lower() + or "cannot find the file" in error_message.lower() # Windows error message + )