Skip to content

Fix StdioServerTransport.DisposeAsync hang #235

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Apr 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion samples/TestServerWithHosting/TestServerWithHosting.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

<PropertyGroup>
<OutputType>Exe</OutputType>
<TargetFramework>net9.0</TargetFramework>
<TargetFrameworks>net9.0;net8.0</TargetFrameworks>
<ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable>
<!--
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ public StdioServerTransport(McpServerOptions serverOptions, ILoggerFactory? logg
/// </para>
/// </remarks>
public StdioServerTransport(string serverName, ILoggerFactory? loggerFactory = null)
: base(Console.OpenStandardInput(),
: base(new CancellableStdinStream(Console.OpenStandardInput()),
new BufferedStream(Console.OpenStandardOutput()),
serverName ?? throw new ArgumentNullException(nameof(serverName)),
loggerFactory)
Expand All @@ -73,4 +73,37 @@ private static string GetServerName(McpServerOptions serverOptions)

return serverOptions.ServerInfo?.Name ?? McpServer.DefaultImplementation.Name;
}

// Neither WindowsConsoleStream nor UnixConsoleStream respect CancellationTokens or cancel any I/O on Dispose.
// WindowsConsoleStream will return an EOS on Ctrl-C, but that is not the only reason the shutdownToken may fire.
private sealed class CancellableStdinStream(Stream stdinStream) : Stream
{
public override bool CanRead => true;
public override bool CanSeek => false;
public override bool CanWrite => false;

public override Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
=> stdinStream.ReadAsync(buffer, offset, count, cancellationToken).WaitAsync(cancellationToken);

#if NET
public override ValueTask<int> ReadAsync(Memory<byte> buffer, CancellationToken cancellationToken = default)
{
ValueTask<int> vt = stdinStream.ReadAsync(buffer, cancellationToken);
return vt.IsCompletedSuccessfully ? vt : new(vt.AsTask().WaitAsync(cancellationToken));
}
#endif

// The McpServer shouldn't call flush on the stdin Stream, but it doesn't need to throw just in case.
public override void Flush() { }

public override long Length => throw new NotSupportedException();

public override long Position { get => throw new NotSupportedException(); set => throw new NotSupportedException(); }

public override int Read(byte[] buffer, int offset, int count) => throw new NotSupportedException();
public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException();

public override void SetLength(long value) => throw new NotSupportedException();
public override void Write(byte[] buffer, int offset, int count) => throw new NotSupportedException();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
<ProjectReference Include="..\..\src\ModelContextProtocol\ModelContextProtocol.csproj" />
<ProjectReference Include="..\ModelContextProtocol.TestServer\ModelContextProtocol.TestServer.csproj" />
<ProjectReference Include="..\ModelContextProtocol.TestSseServer\ModelContextProtocol.TestSseServer.csproj" />
<ProjectReference Include="..\..\samples\TestServerWithHosting\TestServerWithHosting.csproj" />
</ItemGroup>

<ItemGroup>
Expand All @@ -56,6 +57,12 @@
<Content Condition="$([MSBuild]::IsOSPlatform('Linux'))" Include="$([System.IO.Path]::GetFullPath('$(ArtifactsBinDir)ModelContextProtocol.TestServer\$(Configuration)'))\$(TargetFramework)\TestServer.dll">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
</Content>
<Content Condition="$([MSBuild]::IsOSPlatform('Windows'))" Include="$([System.IO.Path]::GetFullPath('$(ArtifactsBinDir)TestServerWithHosting\$(Configuration)'))\$(TargetFramework)\TestServerWithHosting.exe">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
</Content>
<Content Condition="$([MSBuild]::IsOSPlatform('Linux'))" Include="$([System.IO.Path]::GetFullPath('$(ArtifactsBinDir)TestServerWithHosting\$(Configuration)'))\$(TargetFramework)\TestServerWithHosting.dll">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
</Content>
</ItemGroup>

</Project>
66 changes: 66 additions & 0 deletions tests/ModelContextProtocol.Tests/StdioServerIntegrationTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
using ModelContextProtocol.Client;
using ModelContextProtocol.Protocol.Transport;
using ModelContextProtocol.Tests.Utils;
using System.Diagnostics;
using System.Runtime.InteropServices;

namespace ModelContextProtocol.Tests;

public class StdioServerIntegrationTests(ITestOutputHelper testOutputHelper) : LoggedTest(testOutputHelper)
{
public static bool CanSendSigInt { get; } = RuntimeInformation.IsOSPlatform(OSPlatform.Linux) || RuntimeInformation.IsOSPlatform(OSPlatform.OSX);
private const int SIGINT = 2;

[Fact(Skip = "Platform not supported by this test.", SkipUnless = nameof(CanSendSigInt))]
public async Task SigInt_DisposesTestServerWithHosting_Gracefully()
{
using var process = new Process
{
StartInfo = new ProcessStartInfo
{
FileName = "dotnet",
Arguments = "TestServerWithHosting.dll",
RedirectStandardInput = true,
RedirectStandardOutput = true,
UseShellExecute = false,
CreateNoWindow = true,
}
};

process.Start();

await using var streamServerTransport = new StreamServerTransport(
process.StandardOutput.BaseStream,
process.StandardInput.BaseStream,
serverName: "TestServerWithHosting");

await using var client = await McpClientFactory.CreateAsync(
new TestClientTransport(streamServerTransport),
loggerFactory: LoggerFactory,
cancellationToken: TestContext.Current.CancellationToken);

// I considered writing a similar test for Windows using Ctrl-C, then saw that dotnet watch doesn't even send a Ctrl-C
// signal because it's such a pain without support for CREATE_NEW_PROCESS_GROUP in System.Diagnostics.Process.
// https://github.com/dotnet/sdk/blob/43b1c12e3362098a23ca1018503eb56516840b6a/src/BuiltInTools/dotnet-watch/Internal/ProcessRunner.cs#L277-L303
// https://github.com/dotnet/runtime/issues/109432, https://github.com/dotnet/runtime/issues/44944
Assert.Equal(0, kill(process.Id, SIGINT));

using var shutdownCts = CancellationTokenSource.CreateLinkedTokenSource(TestContext.Current.CancellationToken);
shutdownCts.CancelAfter(TimeSpan.FromSeconds(10));
await process.WaitForExitAsync(shutdownCts.Token);

Assert.True(process.HasExited);
Assert.Equal(0, process.ExitCode);
}

[DllImport("libc", SetLastError = true)]
private static extern int kill(int pid, int sig);

private sealed class TestClientTransport(ITransport sessionTransport) : IClientTransport
{
public string Name => nameof(TestClientTransport);

public Task<ITransport> ConnectAsync(CancellationToken cancellationToken = default)
=> Task.FromResult(sessionTransport);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ public sealed class KestrelInMemoryConnection : ConnectionContext
{
private readonly Pipe _clientToServerPipe = new();
private readonly Pipe _serverToClientPipe = new();
private readonly CancellationTokenSource _connectionClosedCts = new CancellationTokenSource();
private readonly IFeatureCollection _features = new FeatureCollection();
private readonly CancellationTokenSource _connectionClosedCts = new();
private readonly FeatureCollection _features = new();

public KestrelInMemoryConnection()
{
Expand Down Expand Up @@ -37,17 +37,17 @@ public KestrelInMemoryConnection()

public override IDictionary<object, object?> Items { get; set; } = new Dictionary<object, object?>();

public override ValueTask DisposeAsync()
public override async ValueTask DisposeAsync()
{
// This is called by Kestrel. The client should dispose the DuplexStream which
// completes the other half of these pipes.
_serverToClientPipe.Writer.Complete();
_serverToClientPipe.Reader.Complete();
await _serverToClientPipe.Writer.CompleteAsync();
await _serverToClientPipe.Reader.CompleteAsync();

// Don't bother disposing the _connectionClosedCts, since this is just for testing,
// and it's annoying to synchronize with DuplexStream.

return base.DisposeAsync();
await base.DisposeAsync();
}

private class DuplexPipe : IDuplexPipe
Expand Down