Skip to content

Commit 848c9df

Browse files
authored
Allow notification handlers post-creation temporary registration (#223)
- We still allow notification handlers to be configured pre-creation, eliminating any race conditions that may result from missed notifications. But now we also allow for post-creation notification handler registration. - Handlers may now also be removed. This permits temporary registrations that may be created in a scoped manner. - Registration handling is thread-safe. - Handlers are now cancelable. - Exceptions from notification handlers trigger all normal exception handling in the dispatch pipeline. This also: - Adds a WithStreamServerTransport as a counterpart to WithStdioServerTransport. The latter was used in several tests in a hacky way, and that's now simplified via the former's existence.
1 parent f135355 commit 848c9df

20 files changed

+704
-172
lines changed

src/ModelContextProtocol/Client/McpClient.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ public McpClient(IClientTransport clientTransport, McpClientOptions? options, Mc
4949
{
5050
if (capabilities.NotificationHandlers is { } notificationHandlers)
5151
{
52-
NotificationHandlers.AddRange(notificationHandlers);
52+
NotificationHandlers.RegisterRange(notificationHandlers);
5353
}
5454

5555
if (capabilities.Sampling is { } samplingCapability)

src/ModelContextProtocol/Client/McpClientExtensions.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -600,7 +600,7 @@ internal static CreateMessageResult ToCreateMessageResult(this ChatResponse chat
600600
var progressToken = requestParams.Meta?.ProgressToken;
601601

602602
List<ChatResponseUpdate> updates = [];
603-
await foreach (var update in chatClient.GetStreamingResponseAsync(messages, options, cancellationToken))
603+
await foreach (var update in chatClient.GetStreamingResponseAsync(messages, options, cancellationToken).ConfigureAwait(false))
604604
{
605605
updates.Add(update);
606606

src/ModelContextProtocol/Configuration/McpServerBuilderExtensions.cs

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,37 @@ public static IMcpServerBuilder WithStdioServerTransport(this IMcpServerBuilder
349349
Throw.IfNull(builder);
350350

351351
builder.Services.AddSingleton<ITransport, StdioServerTransport>();
352-
builder.Services.AddHostedService<StdioMcpServerHostedService>();
352+
builder.Services.AddHostedService<SingleSessionMcpServerHostedService>();
353+
354+
builder.Services.AddSingleton(services =>
355+
{
356+
ITransport serverTransport = services.GetRequiredService<ITransport>();
357+
IOptions<McpServerOptions> options = services.GetRequiredService<IOptions<McpServerOptions>>();
358+
ILoggerFactory? loggerFactory = services.GetService<ILoggerFactory>();
359+
360+
return McpServerFactory.Create(serverTransport, options.Value, loggerFactory, services);
361+
});
362+
363+
return builder;
364+
}
365+
366+
/// <summary>
367+
/// Adds a server transport that uses the specified input and output streams for communication.
368+
/// </summary>
369+
/// <param name="builder">The builder instance.</param>
370+
/// <param name="inputStream">The input <see cref="Stream"/> to use as standard input.</param>
371+
/// <param name="outputStream">The output <see cref="Stream"/> to use as standard output.</param>
372+
public static IMcpServerBuilder WithStreamServerTransport(
373+
this IMcpServerBuilder builder,
374+
Stream inputStream,
375+
Stream outputStream)
376+
{
377+
Throw.IfNull(builder);
378+
Throw.IfNull(inputStream);
379+
Throw.IfNull(outputStream);
380+
381+
builder.Services.AddSingleton<ITransport>(new StreamServerTransport(inputStream, outputStream));
382+
builder.Services.AddHostedService<SingleSessionMcpServerHostedService>();
353383

354384
builder.Services.AddSingleton(services =>
355385
{

src/ModelContextProtocol/Hosting/StdioMcpServerHostedService.cs renamed to src/ModelContextProtocol/Hosting/SingleSessionMcpServerHostedService.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
namespace ModelContextProtocol.Hosting;
55

66
/// <summary>
7-
/// Hosted service for a single-session (i.e stdio) MCP server.
7+
/// Hosted service for a single-session (e.g. stdio) MCP server.
88
/// </summary>
9-
internal sealed class StdioMcpServerHostedService(IMcpServer session) : BackgroundService
9+
internal sealed class SingleSessionMcpServerHostedService(IMcpServer session) : BackgroundService
1010
{
1111
/// <inheritdoc />
1212
protected override Task ExecuteAsync(CancellationToken stoppingToken) => session.RunAsync(stoppingToken);

src/ModelContextProtocol/IMcpEndpoint.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,10 @@ public interface IMcpEndpoint : IAsyncDisposable
1515
/// <param name="message">The message.</param>
1616
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
1717
Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default);
18+
19+
/// <summary>Registers a handler to be invoked when a notification for the specified method is received.</summary>
20+
/// <param name="method">The notification method.</param>
21+
/// <param name="handler">The handler to be invoked.</param>
22+
/// <returns>An <see cref="IDisposable"/> that will remove the registered handler when disposed.</returns>
23+
IAsyncDisposable RegisterNotificationHandler(string method, Func<JsonRpcNotification, CancellationToken, Task> handler);
1824
}

src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ public override async ValueTask DisposeAsync()
173173
{
174174
try
175175
{
176-
await CloseAsync();
176+
await CloseAsync().ConfigureAwait(false);
177177
}
178178
catch (Exception)
179179
{

src/ModelContextProtocol/Protocol/Transport/SseResponseStreamTransport.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ public async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken ca
7777
}
7878

7979
// Emit redundant "event: message" lines for better compatibility with other SDKs.
80-
await _outgoingSseChannel.Writer.WriteAsync(new SseItem<IJsonRpcMessage?>(message, SseParser.EventTypeDefault), cancellationToken);
80+
await _outgoingSseChannel.Writer.WriteAsync(new SseItem<IJsonRpcMessage?>(message, SseParser.EventTypeDefault), cancellationToken).ConfigureAwait(false);
8181
}
8282

8383
/// <summary>
@@ -94,7 +94,7 @@ public async Task OnMessageReceivedAsync(IJsonRpcMessage message, CancellationTo
9494
throw new InvalidOperationException($"Transport is not connected. Make sure to call {nameof(RunAsync)} first.");
9595
}
9696

97-
await _incomingChannel.Writer.WriteAsync(message, cancellationToken);
97+
await _incomingChannel.Writer.WriteAsync(message, cancellationToken).ConfigureAwait(false);
9898
}
9999

100100
private static Channel<T> CreateBoundedChannel<T>(int capacity = 1) =>

src/ModelContextProtocol/Protocol/Types/Capabilities.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ public class ClientCapabilities
3434
/// The client will not re-enumerate the sequence.
3535
/// </remarks>
3636
[JsonIgnore]
37-
public IEnumerable<KeyValuePair<string, Func<JsonRpcNotification, Task>>>? NotificationHandlers { get; set; }
37+
public IEnumerable<KeyValuePair<string, Func<JsonRpcNotification, CancellationToken, Task>>>? NotificationHandlers { get; set; }
3838
}
3939

4040
/// <summary>

src/ModelContextProtocol/Protocol/Types/ServerCapabilities.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,5 +45,5 @@ public class ServerCapabilities
4545
/// The server will not re-enumerate the sequence.
4646
/// </remarks>
4747
[JsonIgnore]
48-
public IEnumerable<KeyValuePair<string, Func<JsonRpcNotification, Task>>>? NotificationHandlers { get; set; }
48+
public IEnumerable<KeyValuePair<string, Func<JsonRpcNotification, CancellationToken, Task>>>? NotificationHandlers { get; set; }
4949
}

src/ModelContextProtocol/Server/McpServer.cs

Lines changed: 45 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
using ModelContextProtocol.Shared;
66
using ModelContextProtocol.Utils;
77
using ModelContextProtocol.Utils.Json;
8+
using System.Diagnostics;
89

910
namespace ModelContextProtocol.Server;
1011

@@ -44,49 +45,43 @@ public McpServer(ITransport transport, McpServerOptions options, ILoggerFactory?
4445
Services = serviceProvider;
4546
_endpointName = $"Server ({options.ServerInfo?.Name ?? DefaultImplementation.Name} {options.ServerInfo?.Version ?? DefaultImplementation.Version})";
4647

47-
_toolsChangedDelegate = delegate
48-
{
49-
_ = SendMessageAsync(new JsonRpcNotification()
50-
{
51-
Method = NotificationMethods.ToolListChangedNotification,
52-
});
53-
};
54-
_promptsChangedDelegate = delegate
55-
{
56-
_ = SendMessageAsync(new JsonRpcNotification()
57-
{
58-
Method = NotificationMethods.PromptListChangedNotification,
59-
});
60-
};
48+
// Configure all request handlers based on the supplied options.
49+
SetInitializeHandler(options);
50+
SetToolsHandler(options);
51+
SetPromptsHandler(options);
52+
SetResourcesHandler(options);
53+
SetSetLoggingLevelHandler(options);
54+
SetCompletionHandler(options);
55+
SetPingHandler();
6156

62-
NotificationHandlers.Add(NotificationMethods.InitializedNotification, _ =>
57+
// Register any notification handlers that were provided.
58+
if (options.Capabilities?.NotificationHandlers is { } notificationHandlers)
6359
{
64-
if (ServerOptions.Capabilities?.Tools?.ToolCollection is { } tools)
65-
{
66-
tools.Changed += _toolsChangedDelegate;
67-
}
60+
NotificationHandlers.RegisterRange(notificationHandlers);
61+
}
6862

69-
if (ServerOptions.Capabilities?.Prompts?.PromptCollection is { } prompts)
63+
// Now that everything has been configured, subscribe to any necessary notifications.
64+
if (ServerOptions.Capabilities?.Tools?.ToolCollection is { } tools)
65+
{
66+
_toolsChangedDelegate = delegate
7067
{
71-
prompts.Changed += _promptsChangedDelegate;
72-
}
68+
_ = SendMessageAsync(new JsonRpcNotification() { Method = NotificationMethods.ToolListChangedNotification });
69+
};
7370

74-
return Task.CompletedTask;
75-
});
71+
tools.Changed += _toolsChangedDelegate;
72+
}
7673

77-
if (options.Capabilities?.NotificationHandlers is { } notificationHandlers)
74+
if (ServerOptions.Capabilities?.Prompts?.PromptCollection is { } prompts)
7875
{
79-
NotificationHandlers.AddRange(notificationHandlers);
80-
}
76+
_promptsChangedDelegate = delegate
77+
{
78+
_ = SendMessageAsync(new JsonRpcNotification() { Method = NotificationMethods.PromptListChangedNotification });
79+
};
8180

82-
SetToolsHandler(options);
83-
SetInitializeHandler(options);
84-
SetCompletionHandler(options);
85-
SetPingHandler();
86-
SetPromptsHandler(options);
87-
SetResourcesHandler(options);
88-
SetSetLoggingLevelHandler(options);
81+
prompts.Changed += _promptsChangedDelegate;
82+
}
8983

84+
// And start the session.
9085
StartSession(transport);
9186
}
9287

@@ -129,12 +124,14 @@ public async Task RunAsync(CancellationToken cancellationToken = default)
129124

130125
public override async ValueTask DisposeUnsynchronizedAsync()
131126
{
132-
if (ServerOptions.Capabilities?.Tools?.ToolCollection is { } tools)
127+
if (_toolsChangedDelegate is not null &&
128+
ServerOptions.Capabilities?.Tools?.ToolCollection is { } tools)
133129
{
134130
tools.Changed -= _toolsChangedDelegate;
135131
}
136132

137-
if (ServerOptions.Capabilities?.Prompts?.PromptCollection is { } prompts)
133+
if (_promptsChangedDelegate is not null &&
134+
ServerOptions.Capabilities?.Prompts?.PromptCollection is { } prompts)
138135
{
139136
prompts.Changed -= _promptsChangedDelegate;
140137
}
@@ -179,8 +176,8 @@ private void SetCompletionHandler(McpServerOptions options)
179176
// This capability is not optional, so return an empty result if there is no handler.
180177
RequestHandlers.Set(RequestMethods.CompletionComplete,
181178
options.GetCompletionHandler is { } handler ?
182-
(request, ct) => handler(new(this, request), ct) :
183-
(request, ct) => Task.FromResult(new CompleteResult() { Completion = new() { Values = [], Total = 0, HasMore = false } }),
179+
(request, cancellationToken) => handler(new(this, request), cancellationToken) :
180+
(request, cancellationToken) => Task.FromResult(new CompleteResult() { Completion = new() { Values = [], Total = 0, HasMore = false } }),
184181
McpJsonUtilities.JsonContext.Default.CompleteRequestParams,
185182
McpJsonUtilities.JsonContext.Default.CompleteResult);
186183
}
@@ -205,20 +202,20 @@ private void SetResourcesHandler(McpServerOptions options)
205202

206203
RequestHandlers.Set(
207204
RequestMethods.ResourcesList,
208-
(request, ct) => listResourcesHandler(new(this, request), ct),
205+
(request, cancellationToken) => listResourcesHandler(new(this, request), cancellationToken),
209206
McpJsonUtilities.JsonContext.Default.ListResourcesRequestParams,
210207
McpJsonUtilities.JsonContext.Default.ListResourcesResult);
211208

212209
RequestHandlers.Set(
213210
RequestMethods.ResourcesRead,
214-
(request, ct) => readResourceHandler(new(this, request), ct),
211+
(request, cancellationToken) => readResourceHandler(new(this, request), cancellationToken),
215212
McpJsonUtilities.JsonContext.Default.ReadResourceRequestParams,
216213
McpJsonUtilities.JsonContext.Default.ReadResourceResult);
217214

218215
listResourceTemplatesHandler ??= (static (_, _) => Task.FromResult(new ListResourceTemplatesResult()));
219216
RequestHandlers.Set(
220217
RequestMethods.ResourcesTemplatesList,
221-
(request, ct) => listResourceTemplatesHandler(new(this, request), ct),
218+
(request, cancellationToken) => listResourceTemplatesHandler(new(this, request), cancellationToken),
222219
McpJsonUtilities.JsonContext.Default.ListResourceTemplatesRequestParams,
223220
McpJsonUtilities.JsonContext.Default.ListResourceTemplatesResult);
224221

@@ -236,13 +233,13 @@ private void SetResourcesHandler(McpServerOptions options)
236233

237234
RequestHandlers.Set(
238235
RequestMethods.ResourcesSubscribe,
239-
(request, ct) => subscribeHandler(new(this, request), ct),
236+
(request, cancellationToken) => subscribeHandler(new(this, request), cancellationToken),
240237
McpJsonUtilities.JsonContext.Default.SubscribeRequestParams,
241238
McpJsonUtilities.JsonContext.Default.EmptyResult);
242239

243240
RequestHandlers.Set(
244241
RequestMethods.ResourcesUnsubscribe,
245-
(request, ct) => unsubscribeHandler(new(this, request), ct),
242+
(request, cancellationToken) => unsubscribeHandler(new(this, request), cancellationToken),
246243
McpJsonUtilities.JsonContext.Default.UnsubscribeRequestParams,
247244
McpJsonUtilities.JsonContext.Default.EmptyResult);
248245
}
@@ -329,13 +326,13 @@ await originalListPromptsHandler(request, cancellationToken).ConfigureAwait(fals
329326

330327
RequestHandlers.Set(
331328
RequestMethods.PromptsList,
332-
(request, ct) => listPromptsHandler(new(this, request), ct),
329+
(request, cancellationToken) => listPromptsHandler(new(this, request), cancellationToken),
333330
McpJsonUtilities.JsonContext.Default.ListPromptsRequestParams,
334331
McpJsonUtilities.JsonContext.Default.ListPromptsResult);
335332

336333
RequestHandlers.Set(
337334
RequestMethods.PromptsGet,
338-
(request, ct) => getPromptHandler(new(this, request), ct),
335+
(request, cancellationToken) => getPromptHandler(new(this, request), cancellationToken),
339336
McpJsonUtilities.JsonContext.Default.GetPromptRequestParams,
340337
McpJsonUtilities.JsonContext.Default.GetPromptResult);
341338
}
@@ -422,13 +419,13 @@ await originalListToolsHandler(request, cancellationToken).ConfigureAwait(false)
422419

423420
RequestHandlers.Set(
424421
RequestMethods.ToolsList,
425-
(request, ct) => listToolsHandler(new(this, request), ct),
422+
(request, cancellationToken) => listToolsHandler(new(this, request), cancellationToken),
426423
McpJsonUtilities.JsonContext.Default.ListToolsRequestParams,
427424
McpJsonUtilities.JsonContext.Default.ListToolsResult);
428425

429426
RequestHandlers.Set(
430427
RequestMethods.ToolsCall,
431-
(request, ct) => callToolHandler(new(this, request), ct),
428+
(request, cancellationToken) => callToolHandler(new(this, request), cancellationToken),
432429
McpJsonUtilities.JsonContext.Default.CallToolRequestParams,
433430
McpJsonUtilities.JsonContext.Default.CallToolResponse);
434431
}
@@ -447,7 +444,7 @@ private void SetSetLoggingLevelHandler(McpServerOptions options)
447444

448445
RequestHandlers.Set(
449446
RequestMethods.LoggingSetLevel,
450-
(request, ct) => setLoggingLevelHandler(new(this, request), ct),
447+
(request, cancellationToken) => setLoggingLevelHandler(new(this, request), cancellationToken),
451448
McpJsonUtilities.JsonContext.Default.SetLevelRequestParams,
452449
McpJsonUtilities.JsonContext.Default.EmptyResult);
453450
}

src/ModelContextProtocol/Shared/McpEndpoint.cs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,17 @@ protected McpEndpoint(ILoggerFactory? loggerFactory = null)
4141

4242
protected RequestHandlers RequestHandlers { get; } = [];
4343

44-
protected NotificationHandlers NotificationHandlers { get; } = [];
44+
protected NotificationHandlers NotificationHandlers { get; } = new();
4545

4646
public Task<JsonRpcResponse> SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken = default)
4747
=> GetSessionOrThrow().SendRequestAsync(request, cancellationToken);
4848

4949
public Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default)
5050
=> GetSessionOrThrow().SendMessageAsync(message, cancellationToken);
5151

52+
public IAsyncDisposable RegisterNotificationHandler(string method, Func<JsonRpcNotification, CancellationToken, Task> handler) =>
53+
GetSessionOrThrow().RegisterNotificationHandler(method, handler);
54+
5255
/// <summary>
5356
/// Gets the name of the endpoint for logging and debug purposes.
5457
/// </summary>

0 commit comments

Comments
 (0)