Skip to content

Commit bab74d5

Browse files
authored
Fix StdioServerTransport.DisposeAsync hang (#235)
- It's not just Unix, it's all platforms that can hang - Windows doesn't hang on Ctrl-C because that triggers the end of the Stream, but that doesn't help in IHostApplicationLifetime.StopAsync() scenarios https://github.com/dotnet/runtime/blob/4e9627e311806330be72b7f7d3660be699878ebd/src/libraries/System.Console/src/System/ConsolePal.Unix.ConsoleStream.cs#L13 https://github.com/dotnet/runtime/blob/4e9627e311806330be72b7f7d3660be699878ebd/src/libraries/System.Console/src/System/ConsolePal.Windows.cs#L1149
1 parent 3e21f35 commit bab74d5

File tree

5 files changed

+114
-8
lines changed

5 files changed

+114
-8
lines changed

samples/TestServerWithHosting/TestServerWithHosting.csproj

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
<PropertyGroup>
44
<OutputType>Exe</OutputType>
5-
<TargetFramework>net9.0</TargetFramework>
5+
<TargetFrameworks>net9.0;net8.0</TargetFrameworks>
66
<ImplicitUsings>enable</ImplicitUsings>
77
<Nullable>enable</Nullable>
88
<!--

src/ModelContextProtocol/Protocol/Transport/StdioServerTransport.cs

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ public StdioServerTransport(McpServerOptions serverOptions, ILoggerFactory? logg
6060
/// </para>
6161
/// </remarks>
6262
public StdioServerTransport(string serverName, ILoggerFactory? loggerFactory = null)
63-
: base(Console.OpenStandardInput(),
63+
: base(new CancellableStdinStream(Console.OpenStandardInput()),
6464
new BufferedStream(Console.OpenStandardOutput()),
6565
serverName ?? throw new ArgumentNullException(nameof(serverName)),
6666
loggerFactory)
@@ -73,4 +73,37 @@ private static string GetServerName(McpServerOptions serverOptions)
7373

7474
return serverOptions.ServerInfo?.Name ?? McpServer.DefaultImplementation.Name;
7575
}
76+
77+
// Neither WindowsConsoleStream nor UnixConsoleStream respect CancellationTokens or cancel any I/O on Dispose.
78+
// WindowsConsoleStream will return an EOS on Ctrl-C, but that is not the only reason the shutdownToken may fire.
79+
private sealed class CancellableStdinStream(Stream stdinStream) : Stream
80+
{
81+
public override bool CanRead => true;
82+
public override bool CanSeek => false;
83+
public override bool CanWrite => false;
84+
85+
public override Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
86+
=> stdinStream.ReadAsync(buffer, offset, count, cancellationToken).WaitAsync(cancellationToken);
87+
88+
#if NET
89+
public override ValueTask<int> ReadAsync(Memory<byte> buffer, CancellationToken cancellationToken = default)
90+
{
91+
ValueTask<int> vt = stdinStream.ReadAsync(buffer, cancellationToken);
92+
return vt.IsCompletedSuccessfully ? vt : new(vt.AsTask().WaitAsync(cancellationToken));
93+
}
94+
#endif
95+
96+
// The McpServer shouldn't call flush on the stdin Stream, but it doesn't need to throw just in case.
97+
public override void Flush() { }
98+
99+
public override long Length => throw new NotSupportedException();
100+
101+
public override long Position { get => throw new NotSupportedException(); set => throw new NotSupportedException(); }
102+
103+
public override int Read(byte[] buffer, int offset, int count) => throw new NotSupportedException();
104+
public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException();
105+
106+
public override void SetLength(long value) => throw new NotSupportedException();
107+
public override void Write(byte[] buffer, int offset, int count) => throw new NotSupportedException();
108+
}
76109
}

tests/ModelContextProtocol.Tests/ModelContextProtocol.Tests.csproj

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
<ProjectReference Include="..\..\src\ModelContextProtocol\ModelContextProtocol.csproj" />
4848
<ProjectReference Include="..\ModelContextProtocol.TestServer\ModelContextProtocol.TestServer.csproj" />
4949
<ProjectReference Include="..\ModelContextProtocol.TestSseServer\ModelContextProtocol.TestSseServer.csproj" />
50+
<ProjectReference Include="..\..\samples\TestServerWithHosting\TestServerWithHosting.csproj" />
5051
</ItemGroup>
5152

5253
<ItemGroup>
@@ -56,6 +57,12 @@
5657
<Content Condition="$([MSBuild]::IsOSPlatform('Linux'))" Include="$([System.IO.Path]::GetFullPath('$(ArtifactsBinDir)ModelContextProtocol.TestServer\$(Configuration)'))\$(TargetFramework)\TestServer.dll">
5758
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
5859
</Content>
60+
<Content Condition="$([MSBuild]::IsOSPlatform('Windows'))" Include="$([System.IO.Path]::GetFullPath('$(ArtifactsBinDir)TestServerWithHosting\$(Configuration)'))\$(TargetFramework)\TestServerWithHosting.exe">
61+
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
62+
</Content>
63+
<Content Condition="$([MSBuild]::IsOSPlatform('Linux'))" Include="$([System.IO.Path]::GetFullPath('$(ArtifactsBinDir)TestServerWithHosting\$(Configuration)'))\$(TargetFramework)\TestServerWithHosting.dll">
64+
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
65+
</Content>
5966
</ItemGroup>
6067

6168
</Project>
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
using ModelContextProtocol.Client;
2+
using ModelContextProtocol.Protocol.Transport;
3+
using ModelContextProtocol.Tests.Utils;
4+
using System.Diagnostics;
5+
using System.Runtime.InteropServices;
6+
7+
namespace ModelContextProtocol.Tests;
8+
9+
public class StdioServerIntegrationTests(ITestOutputHelper testOutputHelper) : LoggedTest(testOutputHelper)
10+
{
11+
public static bool CanSendSigInt { get; } = RuntimeInformation.IsOSPlatform(OSPlatform.Linux) || RuntimeInformation.IsOSPlatform(OSPlatform.OSX);
12+
private const int SIGINT = 2;
13+
14+
[Fact(Skip = "Platform not supported by this test.", SkipUnless = nameof(CanSendSigInt))]
15+
public async Task SigInt_DisposesTestServerWithHosting_Gracefully()
16+
{
17+
using var process = new Process
18+
{
19+
StartInfo = new ProcessStartInfo
20+
{
21+
FileName = "dotnet",
22+
Arguments = "TestServerWithHosting.dll",
23+
RedirectStandardInput = true,
24+
RedirectStandardOutput = true,
25+
UseShellExecute = false,
26+
CreateNoWindow = true,
27+
}
28+
};
29+
30+
process.Start();
31+
32+
await using var streamServerTransport = new StreamServerTransport(
33+
process.StandardOutput.BaseStream,
34+
process.StandardInput.BaseStream,
35+
serverName: "TestServerWithHosting");
36+
37+
await using var client = await McpClientFactory.CreateAsync(
38+
new TestClientTransport(streamServerTransport),
39+
loggerFactory: LoggerFactory,
40+
cancellationToken: TestContext.Current.CancellationToken);
41+
42+
// I considered writing a similar test for Windows using Ctrl-C, then saw that dotnet watch doesn't even send a Ctrl-C
43+
// signal because it's such a pain without support for CREATE_NEW_PROCESS_GROUP in System.Diagnostics.Process.
44+
// https://github.com/dotnet/sdk/blob/43b1c12e3362098a23ca1018503eb56516840b6a/src/BuiltInTools/dotnet-watch/Internal/ProcessRunner.cs#L277-L303
45+
// https://github.com/dotnet/runtime/issues/109432, https://github.com/dotnet/runtime/issues/44944
46+
Assert.Equal(0, kill(process.Id, SIGINT));
47+
48+
using var shutdownCts = CancellationTokenSource.CreateLinkedTokenSource(TestContext.Current.CancellationToken);
49+
shutdownCts.CancelAfter(TimeSpan.FromSeconds(10));
50+
await process.WaitForExitAsync(shutdownCts.Token);
51+
52+
Assert.True(process.HasExited);
53+
Assert.Equal(0, process.ExitCode);
54+
}
55+
56+
[DllImport("libc", SetLastError = true)]
57+
private static extern int kill(int pid, int sig);
58+
59+
private sealed class TestClientTransport(ITransport sessionTransport) : IClientTransport
60+
{
61+
public string Name => nameof(TestClientTransport);
62+
63+
public Task<ITransport> ConnectAsync(CancellationToken cancellationToken = default)
64+
=> Task.FromResult(sessionTransport);
65+
}
66+
}

tests/ModelContextProtocol.Tests/Utils/KestrelInMemoryConnection.cs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@ public sealed class KestrelInMemoryConnection : ConnectionContext
88
{
99
private readonly Pipe _clientToServerPipe = new();
1010
private readonly Pipe _serverToClientPipe = new();
11-
private readonly CancellationTokenSource _connectionClosedCts = new CancellationTokenSource();
12-
private readonly IFeatureCollection _features = new FeatureCollection();
11+
private readonly CancellationTokenSource _connectionClosedCts = new();
12+
private readonly FeatureCollection _features = new();
1313

1414
public KestrelInMemoryConnection()
1515
{
@@ -37,17 +37,17 @@ public KestrelInMemoryConnection()
3737

3838
public override IDictionary<object, object?> Items { get; set; } = new Dictionary<object, object?>();
3939

40-
public override ValueTask DisposeAsync()
40+
public override async ValueTask DisposeAsync()
4141
{
4242
// This is called by Kestrel. The client should dispose the DuplexStream which
4343
// completes the other half of these pipes.
44-
_serverToClientPipe.Writer.Complete();
45-
_serverToClientPipe.Reader.Complete();
44+
await _serverToClientPipe.Writer.CompleteAsync();
45+
await _serverToClientPipe.Reader.CompleteAsync();
4646

4747
// Don't bother disposing the _connectionClosedCts, since this is just for testing,
4848
// and it's annoying to synchronize with DuplexStream.
4949

50-
return base.DisposeAsync();
50+
await base.DisposeAsync();
5151
}
5252

5353
private class DuplexPipe : IDuplexPipe

0 commit comments

Comments
 (0)