Skip to content

Commit faf12b6

Browse files
authored
Use Kestrel for all in-memory HTTP tests (#225)
* Add route pattern parameter to MapMcp * Add configureOptionsAsync * Use Kestrel for all in-memory HTTP tests * Use ApplicationStopping token for SSE responses * Use RegisterNotificationHandler in MapMcp tests * Work around SocketsHttpHandler bug where it doesn't call DispseAsync on the stream returned by the ConnectCallback
1 parent 848c9df commit faf12b6

File tree

17 files changed

+631
-642
lines changed

17 files changed

+631
-642
lines changed

src/ModelContextProtocol.AspNetCore/McpEndpointRouteBuilderExtensions.cs

Lines changed: 52 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,18 @@
11
using Microsoft.AspNetCore.Http;
22
using Microsoft.AspNetCore.Http.Features;
33
using Microsoft.AspNetCore.Routing;
4+
using Microsoft.AspNetCore.Routing.Patterns;
45
using Microsoft.AspNetCore.WebUtilities;
56
using Microsoft.Extensions.DependencyInjection;
7+
using Microsoft.Extensions.Hosting;
68
using Microsoft.Extensions.Logging;
79
using Microsoft.Extensions.Options;
810
using ModelContextProtocol.Protocol.Messages;
911
using ModelContextProtocol.Protocol.Transport;
1012
using ModelContextProtocol.Server;
1113
using ModelContextProtocol.Utils.Json;
1214
using System.Collections.Concurrent;
15+
using System.Diagnostics.CodeAnalysis;
1316
using System.Security.Cryptography;
1417

1518
namespace Microsoft.AspNetCore.Builder;
@@ -23,53 +26,87 @@ public static class McpEndpointRouteBuilderExtensions
2326
/// Sets up endpoints for handling MCP HTTP Streaming transport.
2427
/// </summary>
2528
/// <param name="endpoints">The web application to attach MCP HTTP endpoints.</param>
26-
/// <param name="runSession">Provides an optional asynchronous callback for handling new MCP sessions.</param>
29+
/// <param name="pattern">The route pattern prefix to map to.</param>
30+
/// <param name="configureOptionsAsync">Configure per-session options.</param>
31+
/// <param name="runSessionAsync">Provides an optional asynchronous callback for handling new MCP sessions.</param>
2732
/// <returns>Returns a builder for configuring additional endpoint conventions like authorization policies.</returns>
28-
public static IEndpointConventionBuilder MapMcp(this IEndpointRouteBuilder endpoints, Func<HttpContext, IMcpServer, CancellationToken, Task>? runSession = null)
33+
public static IEndpointConventionBuilder MapMcp(
34+
this IEndpointRouteBuilder endpoints,
35+
[StringSyntax("Route")] string pattern = "",
36+
Func<HttpContext, McpServerOptions, CancellationToken, Task>? configureOptionsAsync = null,
37+
Func<HttpContext, IMcpServer, CancellationToken, Task>? runSessionAsync = null)
38+
=> endpoints.MapMcp(RoutePatternFactory.Parse(pattern), configureOptionsAsync, runSessionAsync);
39+
40+
/// <summary>
41+
/// Sets up endpoints for handling MCP HTTP Streaming transport.
42+
/// </summary>
43+
/// <param name="endpoints">The web application to attach MCP HTTP endpoints.</param>
44+
/// <param name="pattern">The route pattern prefix to map to.</param>
45+
/// <param name="configureOptionsAsync">Configure per-session options.</param>
46+
/// <param name="runSessionAsync">Provides an optional asynchronous callback for handling new MCP sessions.</param>
47+
/// <returns>Returns a builder for configuring additional endpoint conventions like authorization policies.</returns>
48+
public static IEndpointConventionBuilder MapMcp(this IEndpointRouteBuilder endpoints,
49+
RoutePattern pattern,
50+
Func<HttpContext, McpServerOptions, CancellationToken, Task>? configureOptionsAsync = null,
51+
Func<HttpContext, IMcpServer, CancellationToken, Task>? runSessionAsync = null)
2952
{
3053
ConcurrentDictionary<string, SseResponseStreamTransport> _sessions = new(StringComparer.Ordinal);
3154

3255
var loggerFactory = endpoints.ServiceProvider.GetRequiredService<ILoggerFactory>();
33-
var mcpServerOptions = endpoints.ServiceProvider.GetRequiredService<IOptions<McpServerOptions>>();
56+
var optionsSnapshot = endpoints.ServiceProvider.GetRequiredService<IOptions<McpServerOptions>>();
57+
var optionsFactory = endpoints.ServiceProvider.GetRequiredService<IOptionsFactory<McpServerOptions>>();
58+
var hostApplicationLifetime = endpoints.ServiceProvider.GetRequiredService<IHostApplicationLifetime>();
3459

35-
var routeGroup = endpoints.MapGroup("");
60+
var routeGroup = endpoints.MapGroup(pattern);
3661

3762
routeGroup.MapGet("/sse", async context =>
3863
{
39-
var response = context.Response;
40-
var requestAborted = context.RequestAborted;
64+
// If the server is shutting down, we need to cancel all SSE connections immediately without waiting for HostOptions.ShutdownTimeout
65+
// which defaults to 30 seconds.
66+
using var sseCts = CancellationTokenSource.CreateLinkedTokenSource(context.RequestAborted, hostApplicationLifetime.ApplicationStopping);
67+
var cancellationToken = sseCts.Token;
4168

69+
var response = context.Response;
4270
response.Headers.ContentType = "text/event-stream";
4371
response.Headers.CacheControl = "no-cache,no-store";
4472

73+
// Make sure we disable all response buffering for SSE
74+
context.Response.Headers.ContentEncoding = "identity";
75+
context.Features.GetRequiredFeature<IHttpResponseBodyFeature>().DisableBuffering();
76+
4577
var sessionId = MakeNewSessionId();
4678
await using var transport = new SseResponseStreamTransport(response.Body, $"/message?sessionId={sessionId}");
4779
if (!_sessions.TryAdd(sessionId, transport))
4880
{
4981
throw new Exception($"Unreachable given good entropy! Session with ID '{sessionId}' has already been created.");
5082
}
5183

52-
try
84+
var options = optionsSnapshot.Value;
85+
if (configureOptionsAsync is not null)
5386
{
54-
// Make sure we disable all response buffering for SSE
55-
context.Response.Headers.ContentEncoding = "identity";
56-
context.Features.GetRequiredFeature<IHttpResponseBodyFeature>().DisableBuffering();
87+
options = optionsFactory.Create(Options.DefaultName);
88+
await configureOptionsAsync.Invoke(context, options, cancellationToken);
89+
}
5790

58-
var transportTask = transport.RunAsync(cancellationToken: requestAborted);
59-
await using var server = McpServerFactory.Create(transport, mcpServerOptions.Value, loggerFactory, endpoints.ServiceProvider);
91+
try
92+
{
93+
var transportTask = transport.RunAsync(cancellationToken);
6094

6195
try
6296
{
63-
runSession ??= RunSession;
64-
await runSession(context, server, requestAborted);
97+
await using var mcpServer = McpServerFactory.Create(transport, options, loggerFactory, endpoints.ServiceProvider);
98+
context.Features.Set(mcpServer);
99+
100+
runSessionAsync ??= RunSession;
101+
await runSessionAsync(context, mcpServer, cancellationToken);
65102
}
66103
finally
67104
{
68105
await transport.DisposeAsync();
69106
await transportTask;
70107
}
71108
}
72-
catch (OperationCanceledException) when (requestAborted.IsCancellationRequested)
109+
catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested)
73110
{
74111
// RequestAborted always triggers when the client disconnects before a complete response body is written,
75112
// but this is how SSE connections are typically closed.

src/ModelContextProtocol/Client/McpClient.cs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,10 @@ public async Task ConnectAsync(CancellationToken cancellationToken = default)
106106
{
107107
// Connect transport
108108
_sessionTransport = await _clientTransport.ConnectAsync(cancellationToken).ConfigureAwait(false);
109-
StartSession(_sessionTransport);
109+
InitializeSession(_sessionTransport);
110+
// We don't want the ConnectAsync token to cancel the session after we've successfully connected.
111+
// The base class handles cleaning up the session in DisposeAsync without our help.
112+
StartSession(_sessionTransport, fullSessionCancellationToken: CancellationToken.None);
110113

111114
// Perform initialization sequence
112115
using var initializationCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);

src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -151,16 +151,14 @@ private async Task CloseAsync()
151151
{
152152
try
153153
{
154-
if (!_connectionCts.IsCancellationRequested)
155-
{
156-
await _connectionCts.CancelAsync().ConfigureAwait(false);
157-
_connectionCts.Dispose();
158-
}
154+
await _connectionCts.CancelAsync().ConfigureAwait(false);
159155

160156
if (_receiveTask != null)
161157
{
162158
await _receiveTask.ConfigureAwait(false);
163159
}
160+
161+
_connectionCts.Dispose();
164162
}
165163
finally
166164
{

src/ModelContextProtocol/Server/McpServer.cs

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ internal sealed class McpServer : McpEndpoint, IMcpServer
1818
Version = DefaultAssemblyName.Version?.ToString() ?? "1.0.0",
1919
};
2020

21+
private readonly ITransport _sessionTransport;
22+
2123
private readonly EventHandler? _toolsChangedDelegate;
2224
private readonly EventHandler? _promptsChangedDelegate;
2325

@@ -41,6 +43,7 @@ public McpServer(ITransport transport, McpServerOptions options, ILoggerFactory?
4143

4244
options ??= new();
4345

46+
_sessionTransport = transport;
4447
ServerOptions = options;
4548
Services = serviceProvider;
4649
_endpointName = $"Server ({options.ServerInfo?.Name ?? DefaultImplementation.Name} {options.ServerInfo?.Version ?? DefaultImplementation.Version})";
@@ -81,8 +84,8 @@ public McpServer(ITransport transport, McpServerOptions options, ILoggerFactory?
8184
prompts.Changed += _promptsChangedDelegate;
8285
}
8386

84-
// And start the session.
85-
StartSession(transport);
87+
// And initialize the session.
88+
InitializeSession(transport);
8689
}
8790

8891
public ServerCapabilities? ServerCapabilities { get; set; }
@@ -112,9 +115,8 @@ public async Task RunAsync(CancellationToken cancellationToken = default)
112115

113116
try
114117
{
115-
using var _ = cancellationToken.Register(static s => ((McpServer)s!).CancelSession(), this);
116-
// The McpServer ctor always calls StartSession, so MessageProcessingTask is always set.
117-
await MessageProcessingTask!.ConfigureAwait(false);
118+
StartSession(_sessionTransport, fullSessionCancellationToken: cancellationToken);
119+
await MessageProcessingTask.ConfigureAwait(false);
118120
}
119121
finally
120122
{

src/ModelContextProtocol/Shared/McpEndpoint.cs

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
using ModelContextProtocol.Protocol.Transport;
66
using ModelContextProtocol.Server;
77
using ModelContextProtocol.Utils;
8+
using System.Diagnostics;
89
using System.Diagnostics.CodeAnalysis;
910
using System.Reflection;
1011

@@ -62,12 +63,16 @@ public IAsyncDisposable RegisterNotificationHandler(string method, Func<JsonRpcN
6263
/// </summary>
6364
protected Task? MessageProcessingTask { get; private set; }
6465

65-
[MemberNotNull(nameof(MessageProcessingTask))]
66-
protected void StartSession(ITransport sessionTransport)
66+
protected void InitializeSession(ITransport sessionTransport)
6767
{
68-
_sessionCts = new CancellationTokenSource();
6968
_session = new McpSession(this is IMcpServer, sessionTransport, EndpointName, RequestHandlers, NotificationHandlers, _logger);
70-
MessageProcessingTask = _session.ProcessMessagesAsync(_sessionCts.Token);
69+
}
70+
71+
[MemberNotNull(nameof(MessageProcessingTask))]
72+
protected void StartSession(ITransport sessionTransport, CancellationToken fullSessionCancellationToken)
73+
{
74+
_sessionCts = CancellationTokenSource.CreateLinkedTokenSource(fullSessionCancellationToken);
75+
MessageProcessingTask = GetSessionOrThrow().ProcessMessagesAsync(_sessionCts.Token);
7176
}
7277

7378
protected void CancelSession() => _sessionCts?.Cancel();
@@ -122,5 +127,5 @@ public virtual async ValueTask DisposeUnsynchronizedAsync()
122127
}
123128

124129
protected McpSession GetSessionOrThrow()
125-
=> _session ?? throw new InvalidOperationException($"This should be unreachable from public API! Call {nameof(StartSession)} before sending messages.");
130+
=> _session ?? throw new InvalidOperationException($"This should be unreachable from public API! Call {nameof(InitializeSession)} before sending messages.");
126131
}

tests/ModelContextProtocol.TestSseServer/Program.cs

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
using ModelContextProtocol.Protocol.Types;
1+
using Microsoft.AspNetCore.Connections;
2+
using ModelContextProtocol.Protocol.Types;
23
using ModelContextProtocol.Server;
34
using Serilog;
45
using System.Text;
@@ -372,18 +373,34 @@ static CreateMessageRequestParams CreateRequestSamplingParams(string context, st
372373
};
373374
}
374375

375-
public static async Task MainAsync(string[] args, ILoggerProvider? loggerProvider = null, CancellationToken cancellationToken = default)
376+
public static async Task MainAsync(string[] args, ILoggerProvider? loggerProvider = null, IConnectionListenerFactory? kestrelTransport = null, CancellationToken cancellationToken = default)
376377
{
377378
Console.WriteLine("Starting server...");
378379

379-
int port = args.Length > 0 && uint.TryParse(args[0], out var parsedPort) ? (int)parsedPort : 3001;
380-
381-
var builder = WebApplication.CreateSlimBuilder(args);
382-
builder.WebHost.ConfigureKestrel(options =>
380+
var builder = WebApplication.CreateEmptyBuilder(new()
383381
{
384-
options.ListenLocalhost(port);
382+
Args = args,
385383
});
386384

385+
if (kestrelTransport is null)
386+
{
387+
int port = args.Length > 0 && uint.TryParse(args[0], out var parsedPort) ? (int)parsedPort : 3001;
388+
builder.WebHost.ConfigureKestrel(options =>
389+
{
390+
options.ListenLocalhost(port);
391+
});
392+
}
393+
else
394+
{
395+
// Add passed-in transport before calling UseKestrelCore() to avoid the SocketsHttpHandler getting added.
396+
builder.Services.AddSingleton(kestrelTransport);
397+
}
398+
399+
builder.WebHost.UseKestrelCore();
400+
builder.Services.AddLogging();
401+
builder.Services.AddRoutingCore();
402+
403+
builder.Logging.AddConsole();
387404
ConfigureSerilog(builder.Logging);
388405
if (loggerProvider is not null)
389406
{
@@ -393,6 +410,8 @@ public static async Task MainAsync(string[] args, ILoggerProvider? loggerProvide
393410
builder.Services.AddMcpServer(ConfigureOptions);
394411

395412
var app = builder.Build();
413+
app.UseRouting();
414+
app.UseEndpoints(_ => { });
396415

397416
app.MapMcp();
398417

0 commit comments

Comments
 (0)