From 0c1b040fed6bddfed9ed0369e24c3b9f715f31ed Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Sun, 6 Apr 2025 07:10:07 -0400 Subject: [PATCH] Allow notification handlers post-creation temporary registration - We still allow notification handlers to be configured pre-creation, eliminating any race conditions that may result from missed notifications. But now we also allow for post-creation notification handler registration. - Handlers may now also be removed. This permits temporary registrations that may be created in a scoped manner. - Registration handling is thread-safe. - Handlers are now cancelable. - Exceptions from notification handlers trigger all normal exception handling in the dispatch pipeline. This also: - Adds a WithStreamServerTransport as a counterpart to WithStdioServerTransport. The latter was used in several tests in a hacky way, and that's now simplified via the former's existence. --- src/ModelContextProtocol/Client/McpClient.cs | 2 +- .../Client/McpClientExtensions.cs | 2 +- .../McpServerBuilderExtensions.cs | 32 +- ...=> SingleSessionMcpServerHostedService.cs} | 4 +- src/ModelContextProtocol/IMcpEndpoint.cs | 6 + .../Transport/SseClientSessionTransport.cs | 2 +- .../Transport/SseResponseStreamTransport.cs | 4 +- .../Protocol/Types/Capabilities.cs | 2 +- .../Protocol/Types/ServerCapabilities.cs | 2 +- src/ModelContextProtocol/Server/McpServer.cs | 93 +++--- .../Shared/McpEndpoint.cs | 5 +- src/ModelContextProtocol/Shared/McpSession.cs | 33 +- .../Shared/NotificationHandlers.cs | 287 +++++++++++++++++- .../Client/McpClientExtensionsTests.cs | 4 +- .../ClientIntegrationTests.cs | 6 +- .../McpServerBuilderExtensionsPromptsTests.cs | 48 ++- .../McpServerBuilderExtensionsToolsTests.cs | 88 +++--- .../Protocol/NotificationHandlerTests.cs | 248 +++++++++++++++ .../Server/McpServerTests.cs | 4 +- .../SseIntegrationTests.cs | 4 +- 20 files changed, 704 insertions(+), 172 deletions(-) rename src/ModelContextProtocol/Hosting/{StdioMcpServerHostedService.cs => SingleSessionMcpServerHostedService.cs} (63%) create mode 100644 tests/ModelContextProtocol.Tests/Protocol/NotificationHandlerTests.cs diff --git a/src/ModelContextProtocol/Client/McpClient.cs b/src/ModelContextProtocol/Client/McpClient.cs index 64a89680..9ab22b54 100644 --- a/src/ModelContextProtocol/Client/McpClient.cs +++ b/src/ModelContextProtocol/Client/McpClient.cs @@ -49,7 +49,7 @@ public McpClient(IClientTransport clientTransport, McpClientOptions? options, Mc { if (capabilities.NotificationHandlers is { } notificationHandlers) { - NotificationHandlers.AddRange(notificationHandlers); + NotificationHandlers.RegisterRange(notificationHandlers); } if (capabilities.Sampling is { } samplingCapability) diff --git a/src/ModelContextProtocol/Client/McpClientExtensions.cs b/src/ModelContextProtocol/Client/McpClientExtensions.cs index 37ffb135..db3801d1 100644 --- a/src/ModelContextProtocol/Client/McpClientExtensions.cs +++ b/src/ModelContextProtocol/Client/McpClientExtensions.cs @@ -600,7 +600,7 @@ internal static CreateMessageResult ToCreateMessageResult(this ChatResponse chat var progressToken = requestParams.Meta?.ProgressToken; List updates = []; - await foreach (var update in chatClient.GetStreamingResponseAsync(messages, options, cancellationToken)) + await foreach (var update in chatClient.GetStreamingResponseAsync(messages, options, cancellationToken).ConfigureAwait(false)) { updates.Add(update); diff --git a/src/ModelContextProtocol/Configuration/McpServerBuilderExtensions.cs b/src/ModelContextProtocol/Configuration/McpServerBuilderExtensions.cs index a564fd39..b0ba472c 100644 --- a/src/ModelContextProtocol/Configuration/McpServerBuilderExtensions.cs +++ b/src/ModelContextProtocol/Configuration/McpServerBuilderExtensions.cs @@ -349,7 +349,37 @@ public static IMcpServerBuilder WithStdioServerTransport(this IMcpServerBuilder Throw.IfNull(builder); builder.Services.AddSingleton(); - builder.Services.AddHostedService(); + builder.Services.AddHostedService(); + + builder.Services.AddSingleton(services => + { + ITransport serverTransport = services.GetRequiredService(); + IOptions options = services.GetRequiredService>(); + ILoggerFactory? loggerFactory = services.GetService(); + + return McpServerFactory.Create(serverTransport, options.Value, loggerFactory, services); + }); + + return builder; + } + + /// + /// Adds a server transport that uses the specified input and output streams for communication. + /// + /// The builder instance. + /// The input to use as standard input. + /// The output to use as standard output. + public static IMcpServerBuilder WithStreamServerTransport( + this IMcpServerBuilder builder, + Stream inputStream, + Stream outputStream) + { + Throw.IfNull(builder); + Throw.IfNull(inputStream); + Throw.IfNull(outputStream); + + builder.Services.AddSingleton(new StreamServerTransport(inputStream, outputStream)); + builder.Services.AddHostedService(); builder.Services.AddSingleton(services => { diff --git a/src/ModelContextProtocol/Hosting/StdioMcpServerHostedService.cs b/src/ModelContextProtocol/Hosting/SingleSessionMcpServerHostedService.cs similarity index 63% rename from src/ModelContextProtocol/Hosting/StdioMcpServerHostedService.cs rename to src/ModelContextProtocol/Hosting/SingleSessionMcpServerHostedService.cs index ae13e19d..42937759 100644 --- a/src/ModelContextProtocol/Hosting/StdioMcpServerHostedService.cs +++ b/src/ModelContextProtocol/Hosting/SingleSessionMcpServerHostedService.cs @@ -4,9 +4,9 @@ namespace ModelContextProtocol.Hosting; /// -/// Hosted service for a single-session (i.e stdio) MCP server. +/// Hosted service for a single-session (e.g. stdio) MCP server. /// -internal sealed class StdioMcpServerHostedService(IMcpServer session) : BackgroundService +internal sealed class SingleSessionMcpServerHostedService(IMcpServer session) : BackgroundService { /// protected override Task ExecuteAsync(CancellationToken stoppingToken) => session.RunAsync(stoppingToken); diff --git a/src/ModelContextProtocol/IMcpEndpoint.cs b/src/ModelContextProtocol/IMcpEndpoint.cs index 95d6dbc5..6ef704b5 100644 --- a/src/ModelContextProtocol/IMcpEndpoint.cs +++ b/src/ModelContextProtocol/IMcpEndpoint.cs @@ -15,4 +15,10 @@ public interface IMcpEndpoint : IAsyncDisposable /// The message. /// The to monitor for cancellation requests. The default is . Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default); + + /// Registers a handler to be invoked when a notification for the specified method is received. + /// The notification method. + /// The handler to be invoked. + /// An that will remove the registered handler when disposed. + IAsyncDisposable RegisterNotificationHandler(string method, Func handler); } diff --git a/src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs b/src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs index 49b2fe40..854b0402 100644 --- a/src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs @@ -173,7 +173,7 @@ public override async ValueTask DisposeAsync() { try { - await CloseAsync(); + await CloseAsync().ConfigureAwait(false); } catch (Exception) { diff --git a/src/ModelContextProtocol/Protocol/Transport/SseResponseStreamTransport.cs b/src/ModelContextProtocol/Protocol/Transport/SseResponseStreamTransport.cs index de1cb711..b635bb30 100644 --- a/src/ModelContextProtocol/Protocol/Transport/SseResponseStreamTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/SseResponseStreamTransport.cs @@ -77,7 +77,7 @@ public async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken ca } // Emit redundant "event: message" lines for better compatibility with other SDKs. - await _outgoingSseChannel.Writer.WriteAsync(new SseItem(message, SseParser.EventTypeDefault), cancellationToken); + await _outgoingSseChannel.Writer.WriteAsync(new SseItem(message, SseParser.EventTypeDefault), cancellationToken).ConfigureAwait(false); } /// @@ -94,7 +94,7 @@ public async Task OnMessageReceivedAsync(IJsonRpcMessage message, CancellationTo throw new InvalidOperationException($"Transport is not connected. Make sure to call {nameof(RunAsync)} first."); } - await _incomingChannel.Writer.WriteAsync(message, cancellationToken); + await _incomingChannel.Writer.WriteAsync(message, cancellationToken).ConfigureAwait(false); } private static Channel CreateBoundedChannel(int capacity = 1) => diff --git a/src/ModelContextProtocol/Protocol/Types/Capabilities.cs b/src/ModelContextProtocol/Protocol/Types/Capabilities.cs index a29a530a..40a0be37 100644 --- a/src/ModelContextProtocol/Protocol/Types/Capabilities.cs +++ b/src/ModelContextProtocol/Protocol/Types/Capabilities.cs @@ -34,7 +34,7 @@ public class ClientCapabilities /// The client will not re-enumerate the sequence. /// [JsonIgnore] - public IEnumerable>>? NotificationHandlers { get; set; } + public IEnumerable>>? NotificationHandlers { get; set; } } /// diff --git a/src/ModelContextProtocol/Protocol/Types/ServerCapabilities.cs b/src/ModelContextProtocol/Protocol/Types/ServerCapabilities.cs index 9f8e3ac8..3b328e58 100644 --- a/src/ModelContextProtocol/Protocol/Types/ServerCapabilities.cs +++ b/src/ModelContextProtocol/Protocol/Types/ServerCapabilities.cs @@ -45,5 +45,5 @@ public class ServerCapabilities /// The server will not re-enumerate the sequence. /// [JsonIgnore] - public IEnumerable>>? NotificationHandlers { get; set; } + public IEnumerable>>? NotificationHandlers { get; set; } } diff --git a/src/ModelContextProtocol/Server/McpServer.cs b/src/ModelContextProtocol/Server/McpServer.cs index 9b6898d8..22e1584c 100644 --- a/src/ModelContextProtocol/Server/McpServer.cs +++ b/src/ModelContextProtocol/Server/McpServer.cs @@ -5,6 +5,7 @@ using ModelContextProtocol.Shared; using ModelContextProtocol.Utils; using ModelContextProtocol.Utils.Json; +using System.Diagnostics; namespace ModelContextProtocol.Server; @@ -44,49 +45,43 @@ public McpServer(ITransport transport, McpServerOptions options, ILoggerFactory? Services = serviceProvider; _endpointName = $"Server ({options.ServerInfo?.Name ?? DefaultImplementation.Name} {options.ServerInfo?.Version ?? DefaultImplementation.Version})"; - _toolsChangedDelegate = delegate - { - _ = SendMessageAsync(new JsonRpcNotification() - { - Method = NotificationMethods.ToolListChangedNotification, - }); - }; - _promptsChangedDelegate = delegate - { - _ = SendMessageAsync(new JsonRpcNotification() - { - Method = NotificationMethods.PromptListChangedNotification, - }); - }; + // Configure all request handlers based on the supplied options. + SetInitializeHandler(options); + SetToolsHandler(options); + SetPromptsHandler(options); + SetResourcesHandler(options); + SetSetLoggingLevelHandler(options); + SetCompletionHandler(options); + SetPingHandler(); - NotificationHandlers.Add(NotificationMethods.InitializedNotification, _ => + // Register any notification handlers that were provided. + if (options.Capabilities?.NotificationHandlers is { } notificationHandlers) { - if (ServerOptions.Capabilities?.Tools?.ToolCollection is { } tools) - { - tools.Changed += _toolsChangedDelegate; - } + NotificationHandlers.RegisterRange(notificationHandlers); + } - if (ServerOptions.Capabilities?.Prompts?.PromptCollection is { } prompts) + // Now that everything has been configured, subscribe to any necessary notifications. + if (ServerOptions.Capabilities?.Tools?.ToolCollection is { } tools) + { + _toolsChangedDelegate = delegate { - prompts.Changed += _promptsChangedDelegate; - } + _ = SendMessageAsync(new JsonRpcNotification() { Method = NotificationMethods.ToolListChangedNotification }); + }; - return Task.CompletedTask; - }); + tools.Changed += _toolsChangedDelegate; + } - if (options.Capabilities?.NotificationHandlers is { } notificationHandlers) + if (ServerOptions.Capabilities?.Prompts?.PromptCollection is { } prompts) { - NotificationHandlers.AddRange(notificationHandlers); - } + _promptsChangedDelegate = delegate + { + _ = SendMessageAsync(new JsonRpcNotification() { Method = NotificationMethods.PromptListChangedNotification }); + }; - SetToolsHandler(options); - SetInitializeHandler(options); - SetCompletionHandler(options); - SetPingHandler(); - SetPromptsHandler(options); - SetResourcesHandler(options); - SetSetLoggingLevelHandler(options); + prompts.Changed += _promptsChangedDelegate; + } + // And start the session. StartSession(transport); } @@ -129,12 +124,14 @@ public async Task RunAsync(CancellationToken cancellationToken = default) public override async ValueTask DisposeUnsynchronizedAsync() { - if (ServerOptions.Capabilities?.Tools?.ToolCollection is { } tools) + if (_toolsChangedDelegate is not null && + ServerOptions.Capabilities?.Tools?.ToolCollection is { } tools) { tools.Changed -= _toolsChangedDelegate; } - if (ServerOptions.Capabilities?.Prompts?.PromptCollection is { } prompts) + if (_promptsChangedDelegate is not null && + ServerOptions.Capabilities?.Prompts?.PromptCollection is { } prompts) { prompts.Changed -= _promptsChangedDelegate; } @@ -179,8 +176,8 @@ private void SetCompletionHandler(McpServerOptions options) // This capability is not optional, so return an empty result if there is no handler. RequestHandlers.Set(RequestMethods.CompletionComplete, options.GetCompletionHandler is { } handler ? - (request, ct) => handler(new(this, request), ct) : - (request, ct) => Task.FromResult(new CompleteResult() { Completion = new() { Values = [], Total = 0, HasMore = false } }), + (request, cancellationToken) => handler(new(this, request), cancellationToken) : + (request, cancellationToken) => Task.FromResult(new CompleteResult() { Completion = new() { Values = [], Total = 0, HasMore = false } }), McpJsonUtilities.JsonContext.Default.CompleteRequestParams, McpJsonUtilities.JsonContext.Default.CompleteResult); } @@ -205,20 +202,20 @@ private void SetResourcesHandler(McpServerOptions options) RequestHandlers.Set( RequestMethods.ResourcesList, - (request, ct) => listResourcesHandler(new(this, request), ct), + (request, cancellationToken) => listResourcesHandler(new(this, request), cancellationToken), McpJsonUtilities.JsonContext.Default.ListResourcesRequestParams, McpJsonUtilities.JsonContext.Default.ListResourcesResult); RequestHandlers.Set( RequestMethods.ResourcesRead, - (request, ct) => readResourceHandler(new(this, request), ct), + (request, cancellationToken) => readResourceHandler(new(this, request), cancellationToken), McpJsonUtilities.JsonContext.Default.ReadResourceRequestParams, McpJsonUtilities.JsonContext.Default.ReadResourceResult); listResourceTemplatesHandler ??= (static (_, _) => Task.FromResult(new ListResourceTemplatesResult())); RequestHandlers.Set( RequestMethods.ResourcesTemplatesList, - (request, ct) => listResourceTemplatesHandler(new(this, request), ct), + (request, cancellationToken) => listResourceTemplatesHandler(new(this, request), cancellationToken), McpJsonUtilities.JsonContext.Default.ListResourceTemplatesRequestParams, McpJsonUtilities.JsonContext.Default.ListResourceTemplatesResult); @@ -236,13 +233,13 @@ private void SetResourcesHandler(McpServerOptions options) RequestHandlers.Set( RequestMethods.ResourcesSubscribe, - (request, ct) => subscribeHandler(new(this, request), ct), + (request, cancellationToken) => subscribeHandler(new(this, request), cancellationToken), McpJsonUtilities.JsonContext.Default.SubscribeRequestParams, McpJsonUtilities.JsonContext.Default.EmptyResult); RequestHandlers.Set( RequestMethods.ResourcesUnsubscribe, - (request, ct) => unsubscribeHandler(new(this, request), ct), + (request, cancellationToken) => unsubscribeHandler(new(this, request), cancellationToken), McpJsonUtilities.JsonContext.Default.UnsubscribeRequestParams, McpJsonUtilities.JsonContext.Default.EmptyResult); } @@ -329,13 +326,13 @@ await originalListPromptsHandler(request, cancellationToken).ConfigureAwait(fals RequestHandlers.Set( RequestMethods.PromptsList, - (request, ct) => listPromptsHandler(new(this, request), ct), + (request, cancellationToken) => listPromptsHandler(new(this, request), cancellationToken), McpJsonUtilities.JsonContext.Default.ListPromptsRequestParams, McpJsonUtilities.JsonContext.Default.ListPromptsResult); RequestHandlers.Set( RequestMethods.PromptsGet, - (request, ct) => getPromptHandler(new(this, request), ct), + (request, cancellationToken) => getPromptHandler(new(this, request), cancellationToken), McpJsonUtilities.JsonContext.Default.GetPromptRequestParams, McpJsonUtilities.JsonContext.Default.GetPromptResult); } @@ -422,13 +419,13 @@ await originalListToolsHandler(request, cancellationToken).ConfigureAwait(false) RequestHandlers.Set( RequestMethods.ToolsList, - (request, ct) => listToolsHandler(new(this, request), ct), + (request, cancellationToken) => listToolsHandler(new(this, request), cancellationToken), McpJsonUtilities.JsonContext.Default.ListToolsRequestParams, McpJsonUtilities.JsonContext.Default.ListToolsResult); RequestHandlers.Set( RequestMethods.ToolsCall, - (request, ct) => callToolHandler(new(this, request), ct), + (request, cancellationToken) => callToolHandler(new(this, request), cancellationToken), McpJsonUtilities.JsonContext.Default.CallToolRequestParams, McpJsonUtilities.JsonContext.Default.CallToolResponse); } @@ -447,7 +444,7 @@ private void SetSetLoggingLevelHandler(McpServerOptions options) RequestHandlers.Set( RequestMethods.LoggingSetLevel, - (request, ct) => setLoggingLevelHandler(new(this, request), ct), + (request, cancellationToken) => setLoggingLevelHandler(new(this, request), cancellationToken), McpJsonUtilities.JsonContext.Default.SetLevelRequestParams, McpJsonUtilities.JsonContext.Default.EmptyResult); } diff --git a/src/ModelContextProtocol/Shared/McpEndpoint.cs b/src/ModelContextProtocol/Shared/McpEndpoint.cs index c26af4b1..f24e9134 100644 --- a/src/ModelContextProtocol/Shared/McpEndpoint.cs +++ b/src/ModelContextProtocol/Shared/McpEndpoint.cs @@ -41,7 +41,7 @@ protected McpEndpoint(ILoggerFactory? loggerFactory = null) protected RequestHandlers RequestHandlers { get; } = []; - protected NotificationHandlers NotificationHandlers { get; } = []; + protected NotificationHandlers NotificationHandlers { get; } = new(); public Task SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken = default) => GetSessionOrThrow().SendRequestAsync(request, cancellationToken); @@ -49,6 +49,9 @@ public Task SendRequestAsync(JsonRpcRequest request, Cancellati public Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default) => GetSessionOrThrow().SendMessageAsync(message, cancellationToken); + public IAsyncDisposable RegisterNotificationHandler(string method, Func handler) => + GetSessionOrThrow().RegisterNotificationHandler(method, handler); + /// /// Gets the name of the endpoint for logging and debug purposes. /// diff --git a/src/ModelContextProtocol/Shared/McpSession.cs b/src/ModelContextProtocol/Shared/McpSession.cs index 3d06f7fc..0e94c056 100644 --- a/src/ModelContextProtocol/Shared/McpSession.cs +++ b/src/ModelContextProtocol/Shared/McpSession.cs @@ -213,7 +213,7 @@ private async Task HandleMessageAsync(IJsonRpcMessage message, CancellationToken break; case JsonRpcNotification notification: - await HandleNotification(notification).ConfigureAwait(false); + await HandleNotification(notification, cancellationToken).ConfigureAwait(false); break; case IJsonRpcMessageWithId messageWithId: @@ -236,7 +236,7 @@ private async Task HandleMessageAsync(IJsonRpcMessage message, CancellationToken } } - private async Task HandleNotification(JsonRpcNotification notification) + private async Task HandleNotification(JsonRpcNotification notification, CancellationToken cancellationToken) { // Special-case cancellation to cancel a pending operation. (We'll still subsequently invoke a user-specified handler if one exists.) if (notification.Method == NotificationMethods.CancelledNotification) @@ -257,21 +257,7 @@ private async Task HandleNotification(JsonRpcNotification notification) } // Handle user-defined notifications. - if (_notificationHandlers.TryGetValue(notification.Method, out var handlers)) - { - foreach (var notificationHandler in handlers) - { - try - { - await notificationHandler(notification).ConfigureAwait(false); - } - catch (Exception ex) - { - // Log handler error but continue processing - _logger.NotificationHandlerError(EndpointName, notification.Method, ex); - } - } - } + await _notificationHandlers.InvokeHandlers(notification.Method, notification, cancellationToken).ConfigureAwait(false); } private void HandleMessageWithId(IJsonRpcMessage message, IJsonRpcMessageWithId messageWithId) @@ -310,6 +296,14 @@ await _transport.SendMessageAsync(new JsonRpcResponse }, cancellationToken).ConfigureAwait(false); } + public IAsyncDisposable RegisterNotificationHandler(string method, Func handler) + { + Throw.IfNullOrWhiteSpace(method); + Throw.IfNull(handler); + + return _notificationHandlers.Register(method, handler); + } + /// /// Sends a JSON-RPC request to the server. /// It is strongly recommended use the capability-specific methods instead of this one. @@ -525,6 +519,11 @@ private static void AddRpcRequestTags(ref TagList tags, Activity? activity, Json private static void AddExceptionTags(ref TagList tags, Exception e) { + if (e is AggregateException ae && ae.InnerException is not null and not AggregateException) + { + e = ae.InnerException; + } + tags.Add("error.type", e.GetType().FullName); tags.Add("rpc.jsonrpc.error_code", (e as McpException)?.ErrorCode is int errorCode ? errorCode : diff --git a/src/ModelContextProtocol/Shared/NotificationHandlers.cs b/src/ModelContextProtocol/Shared/NotificationHandlers.cs index c3b5dfe7..2962a272 100644 --- a/src/ModelContextProtocol/Shared/NotificationHandlers.cs +++ b/src/ModelContextProtocol/Shared/NotificationHandlers.cs @@ -1,28 +1,293 @@ using ModelContextProtocol.Protocol.Messages; +using System.Diagnostics; namespace ModelContextProtocol.Shared; -internal sealed class NotificationHandlers : Dictionary>> +/// Provides thread-safe storage for notification handlers. +internal sealed class NotificationHandlers { + /// A dictionary of linked lists of registrations, indexed by the notification method. + private readonly Dictionary _handlers = []; + + /// Gets the object to be used for all synchronization. + private object SyncObj => _handlers; + + /// Registers all of the specified handlers. + /// The handlers to register. + /// + /// Registrations completed with this method are non-removable. + /// + public void RegisterRange(IEnumerable>> handlers) + { + foreach (var entry in handlers) + { + _ = Register(entry.Key, entry.Value, temporary: false); + } + } + /// Adds a notification handler as part of configuring the endpoint. - /// This method is not thread-safe and should only be used serially as part of configuring the instance. - public void Add(string method, Func handler) + /// The notification method for which the handler is being registered. + /// The handler being registered. + /// + /// if the registration can be removed later; if it cannot. + /// If , the registration will be permanent: calling + /// on the returned instance will not unregister the handler. + /// + public IAsyncDisposable Register( + string method, Func handler, bool temporary = true) + { + // Create the new registration instance. + Registration reg = new(this, method, handler, temporary); + + // Store the registration into the dictionary. If there's not currently a registration for the method, + // then this registration instance just becomes the single value. If there is currently a registration, + // then this new registration becomes the new head of the linked list, and the old head becomes the next + // item in the list. + lock (SyncObj) + { + if (_handlers.TryGetValue(method, out var existingHandlerHead)) + { + reg.Next = existingHandlerHead; + existingHandlerHead.Prev = reg; + } + + _handlers[method] = reg; + } + + // Return the new registration. It must be disposed of when no longer used, or it will end up being + // leaked into the list. This is the same as with CancellationToken.Register. + return reg; + } + + public async Task InvokeHandlers(string method, JsonRpcNotification notification, CancellationToken cancellationToken) { - if (!TryGetValue(method, out var handlers)) + // If there are no handlers registered for this method, we're done. + Registration? reg; + lock (SyncObj) { - this[method] = handlers = []; + if (!_handlers.TryGetValue(method, out reg)) + { + return; + } } - handlers.Add(handler); + // Invoke each handler in the list. We guarantee that we'll try to invoke + // any handlers that were in the list when the list was fetched from the dictionary, + // which is why DisposeAsync doesn't modify the Prev/Next of the registration being + // disposed; if those were nulled out, we'd be unable to walk around it in the list + // if we happened to be on that item when it was disposed. + List? exceptions = null; + while (reg is not null) + { + try + { + await reg.InvokeAsync(notification, cancellationToken).ConfigureAwait(false); + } + catch (Exception e) + { + (exceptions ??= []).Add(e); + } + + lock (SyncObj) + { + reg = reg.Next; + } + } + + if (exceptions is not null) + { + throw new AggregateException(exceptions); + } } - /// Adds notification handlers as part of configuring the endpoint. - /// This method is not thread-safe and should only be used serially as part of configuring the instance. - public void AddRange(IEnumerable>> handlers) + /// Provides storage for a handler registration. + private sealed class Registration( + NotificationHandlers handlers, string method, Func handler, bool unregisterable) : IAsyncDisposable { - foreach (var handler in handlers) + /// Used to prevent deadlocks during disposal. + /// + /// The task returned from does not complete until all invocations of the handler + /// have completed and no more will be performed, so that the consumer can then trust that any resources accessed + /// by that handler are no longer in use and may be cleaned up. If were to be invoked + /// and its task awaited from within the invocation of the handler, however, that would result in deadlock, since + /// the task wouldn't complete until the invocation completed, and the invocation wouldn't complete until the task + /// completed. To circument that, we track via an in-flight invocations. If + /// detects it's being invoked from within an invocation, it will avoid waiting. For + /// simplicity, we don't require that it's the same handler. + /// + private static readonly AsyncLocal s_invokingAncestor = new(); + + /// The parent to which this registration belongs. + private readonly NotificationHandlers _handlers = handlers; + + /// The method with which this registration is associated. + private readonly string _method = method; + + /// The handler this registration represents. + private readonly Func _handler = handler; + + /// true if this instance is temporary; false if it's permanent + private readonly bool _temporary = unregisterable; + + /// Provides a task that can await to know when all in-flight invocations have completed. + /// + /// This will only be initialized if sees in-flight invocations, in which case it'll initialize + /// this and then await its task. The task will be completed when the last + /// in-flight notification completes. + /// + private TaskCompletionSource? _disposeTcs; + + /// The number of remaining references to this registration. + /// + /// The ref count starts life at 1 to represent the whole registration; that ref count will be subtracted when + /// the instance is disposed. Every invocation then temporarily increases the ref count before invocation and + /// decrements it after. When is called, it decrements the ref count. In the common + /// case, that'll bring the count down to 0, in which case the instance will never be subsequently invoked. + /// If, however, after that decrement the count is still positive, then there are in-flight invocations; the last + /// one of those to complete will end up decrementing the ref count to 0. + /// + private int _refCount = 1; + + /// Tracks whether has ever been invoked. + /// + /// It's rare but possible is called multiple times. Only the first + /// should decrement the initial ref count, but they all must wait until all invocations have quiesced. + /// + private bool _disposedCalled = false; + + /// The next registration in the linked list. + public Registration? Next; + /// The previous registration in the linked list. + public Registration? Prev; + + /// Removes the registration. + public async ValueTask DisposeAsync() + { + if (!_temporary) + { + return; + } + + lock (_handlers.SyncObj) + { + // If DisposeAsync was previously called, we don't want to do all of the work again + // to remove the registration from the list, and we must not do the work again to + // decrement the ref count and possibly initialize the _disposeTcs. + if (!_disposedCalled) + { + _disposedCalled = true; + + // If this handler is the head of the list for this method, we need to update + // the dictionary, either to point to a different head, or if this is the only + // item in the list, to remove the entry from the dictionary entirely. + if (_handlers._handlers.TryGetValue(_method, out var handlers) && handlers == this) + { + if (Next is not null) + { + _handlers._handlers[_method] = Next; + } + else + { + _handlers._handlers.Remove(_method); + } + } + + // Remove the registration from the linked list by routing the nodes around it + // to point past this one. Importantly, we do not modify this node's Next or Prev. + // We want to ensure that an enumeration through all of the registrations can still + // progress through this one. + if (Prev is not null) + { + Prev.Next = Next; + } + if (Next is not null) + { + Next.Prev = Prev; + } + + // Decrement the ref count. In the common case, there's no in-flight invocation for + // this handler. However, in the uncommon case that there is, we need to wait for + // that invocation to complete. To do that, initialize the _disposeTcs. It's created + // with RunContinuationsAsynchronously so that completing it doesn't run the continuation + // under any held locks. + if (--_refCount != 0) + { + Debug.Assert(_disposeTcs is null); + _disposeTcs = new(TaskCreationOptions.RunContinuationsAsynchronously); + } + } + } + + // Ensure that DisposeAsync doesn't complete until all in-flight invocations have completed, + // unless our call chain includes one of those in-flight invocations, in which case waiting + // would deadlock. + if (_disposeTcs is not null && s_invokingAncestor.Value == 0) + { + await _disposeTcs.Task.ConfigureAwait(false); + } + } + + /// Invoke the handler associated with the registration. + public Task InvokeAsync(JsonRpcNotification notification, CancellationToken cancellationToken) { - Add(handler.Key, handler.Value); + // For permanent registrations, skip all the tracking overhead and just invoke the handler. + if (!_temporary) + { + return _handler(notification, cancellationToken); + } + + // For temporary registrations, track the invocation and coordinate with disposal. + return InvokeTemporaryAsync(notification, cancellationToken); + } + + /// Invoke the handler associated with the temporary registration. + private async Task InvokeTemporaryAsync(JsonRpcNotification notification, CancellationToken cancellationToken) + { + // Check whether we need to handle this registration. If DisposeAsync has been called, + // then even if there are in-flight invocations for it, we avoid adding more. + // If DisposeAsync has not been called, then we need to increment the ref count to + // signal that there's another in-flight invocation. + lock (_handlers.SyncObj) + { + Debug.Assert(_refCount != 0 || _disposedCalled, $"Expected {nameof(_disposedCalled)} == true when {nameof(_refCount)} == 0"); + if (_disposedCalled) + { + return; + } + + Debug.Assert(_refCount > 0); + _refCount++; + } + + // Ensure that if DisposeAsync is called from within the handler, it won't deadlock by waiting + // for the in-flight invocation to complete. + s_invokingAncestor.Value++; + + try + { + // Invoke the handler. + await _handler(notification, cancellationToken).ConfigureAwait(false); + } + finally + { + // Undo the in-flight tracking. + s_invokingAncestor.Value--; + + // Now decrement the ref count we previously incremented. If that brings the ref count to 0, + // DisposeAsync must have been called while this was in-flight, which also means it's now + // waiting on _disposeTcs; unblock it. + lock (_handlers.SyncObj) + { + _refCount--; + if (_refCount == 0) + { + Debug.Assert(_disposedCalled); + Debug.Assert(_disposeTcs is not null); + bool completed = _disposeTcs!.TrySetResult(true); + Debug.Assert(completed); + } + } + } } } } diff --git a/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs b/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs index 1a451a2d..07ea3d18 100644 --- a/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs +++ b/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs @@ -26,9 +26,7 @@ public McpClientExtensionsTests(ITestOutputHelper outputHelper) { ServiceCollection sc = new(); sc.AddSingleton(LoggerFactory); - sc.AddMcpServer().WithStdioServerTransport(); - // Call WithStdioServerTransport to get the IMcpServer registration, then overwrite default transport with a pipe transport. - sc.AddSingleton(new StreamServerTransport(_clientToServerPipe.Reader.AsStream(), _serverToClientPipe.Writer.AsStream())); + sc.AddMcpServer().WithStreamServerTransport(_clientToServerPipe.Reader.AsStream(), _serverToClientPipe.Writer.AsStream()); for (int f = 0; f < 10; f++) { string name = $"Method{f}"; diff --git a/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs b/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs index 6f0c80f1..be148397 100644 --- a/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs +++ b/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs @@ -259,7 +259,7 @@ public async Task SubscribeResource_Stdio() { NotificationHandlers = [ - new(NotificationMethods.ResourceUpdatedNotification, notification => + new(NotificationMethods.ResourceUpdatedNotification, (notification, cancellationToken) => { var notificationParams = JsonSerializer.Deserialize(notification.Params); tcs.TrySetResult(true); @@ -289,7 +289,7 @@ public async Task UnsubscribeResource_Stdio() { NotificationHandlers = [ - new(NotificationMethods.ResourceUpdatedNotification, (notification) => + new(NotificationMethods.ResourceUpdatedNotification, (notification, cancellationToken) => { var notificationParams = JsonSerializer.Deserialize(notification.Params); receivedNotification.TrySetResult(true); @@ -560,7 +560,7 @@ public async Task SetLoggingLevel_ReceivesLoggingMessages(string clientId) { NotificationHandlers = [ - new(NotificationMethods.LoggingMessageNotification, (notification) => + new(NotificationMethods.LoggingMessageNotification, (notification, cancellationToken) => { var loggingMessageNotificationParameters = JsonSerializer.Deserialize(notification.Params); if (loggingMessageNotificationParameters is not null) diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs index ed1be347..64dc9fb5 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs @@ -31,7 +31,7 @@ public McpServerBuilderExtensionsPromptsTests(ITestOutputHelper testOutputHelper sc.AddSingleton(LoggerFactory); _builder = sc .AddMcpServer() - .WithStdioServerTransport() + .WithStreamServerTransport(_clientToServerPipe.Reader.AsStream(), _serverToClientPipe.Writer.AsStream()) .WithListPromptsHandler(async (request, cancellationToken) => { var cursor = request.Params?.Cursor; @@ -93,8 +93,6 @@ public McpServerBuilderExtensionsPromptsTests(ITestOutputHelper testOutputHelper .WithPrompts(); - // Call WithStdioServerTransport to get the IMcpServer registration, then overwrite default transport with a pipe transport. - sc.AddSingleton(new StreamServerTransport(_clientToServerPipe.Reader.AsStream(), _serverToClientPipe.Writer.AsStream(), loggerFactory: LoggerFactory)); sc.AddSingleton(new ObjectWithId()); _serviceProvider = sc.BuildServiceProvider(); @@ -176,23 +174,12 @@ public async Task Can_List_And_Call_Registered_Prompts() [Fact] public async Task Can_Be_Notified_Of_Prompt_Changes() { - Channel listChanged = Channel.CreateUnbounded(); - - IMcpClient client = await CreateMcpClientForServer(new() - { - Capabilities = new() - { - NotificationHandlers = [new("notifications/prompts/list_changed", notification => - { - listChanged.Writer.TryWrite(notification); - return Task.CompletedTask; - })], - }, - }); + IMcpClient client = await CreateMcpClientForServer(); var prompts = await client.ListPromptsAsync(TestContext.Current.CancellationToken); Assert.Equal(6, prompts.Count); + Channel listChanged = Channel.CreateUnbounded(); var notificationRead = listChanged.Reader.ReadAsync(TestContext.Current.CancellationToken); Assert.False(notificationRead.IsCompleted); @@ -201,17 +188,24 @@ public async Task Can_Be_Notified_Of_Prompt_Changes() Assert.NotNull(serverPrompts); var newPrompt = McpServerPrompt.Create([McpServerPrompt(Name = "NewPrompt")] () => "42"); - serverPrompts.Add(newPrompt); - await notificationRead; - - prompts = await client.ListPromptsAsync(TestContext.Current.CancellationToken); - Assert.Equal(7, prompts.Count); - Assert.Contains(prompts, t => t.Name == "NewPrompt"); - - notificationRead = listChanged.Reader.ReadAsync(TestContext.Current.CancellationToken); - Assert.False(notificationRead.IsCompleted); - serverPrompts.Remove(newPrompt); - await notificationRead; + await using (client.RegisterNotificationHandler("notifications/prompts/list_changed", (notification, cancellationToken) => + { + listChanged.Writer.TryWrite(notification); + return Task.CompletedTask; + })) + { + serverPrompts.Add(newPrompt); + await notificationRead; + + prompts = await client.ListPromptsAsync(TestContext.Current.CancellationToken); + Assert.Equal(7, prompts.Count); + Assert.Contains(prompts, t => t.Name == "NewPrompt"); + + notificationRead = listChanged.Reader.ReadAsync(TestContext.Current.CancellationToken); + Assert.False(notificationRead.IsCompleted); + serverPrompts.Remove(newPrompt); + await notificationRead; + } prompts = await client.ListPromptsAsync(TestContext.Current.CancellationToken); Assert.Equal(6, prompts.Count); diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs index 827c7056..075ab8f8 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs @@ -35,7 +35,7 @@ public McpServerBuilderExtensionsToolsTests(ITestOutputHelper testOutputHelper) sc.AddSingleton(LoggerFactory); _builder = sc .AddMcpServer() - .WithStdioServerTransport() + .WithStreamServerTransport(_clientToServerPipe.Reader.AsStream(), _serverToClientPipe.Writer.AsStream()) .WithListToolsHandler(async (request, cancellationToken) => { var cursor = request.Params?.Cursor; @@ -117,8 +117,6 @@ public McpServerBuilderExtensionsToolsTests(ITestOutputHelper testOutputHelper) }) .WithTools(); - // Call WithStdioServerTransport to get the IMcpServer registration, then overwrite default transport with a pipe transport. - sc.AddSingleton(new StreamServerTransport(_clientToServerPipe.Reader.AsStream(), _serverToClientPipe.Writer.AsStream(), loggerFactory: LoggerFactory)); sc.AddSingleton(new ObjectWithId()); _serviceProvider = sc.BuildServiceProvider(); @@ -243,23 +241,12 @@ public async Task Can_Create_Multiple_Servers_From_Options_And_List_Registered_T [Fact] public async Task Can_Be_Notified_Of_Tool_Changes() { - Channel listChanged = Channel.CreateUnbounded(); - - IMcpClient client = await CreateMcpClientForServer(new() - { - Capabilities = new() - { - NotificationHandlers = [new(NotificationMethods.ToolListChangedNotification, notification => - { - listChanged.Writer.TryWrite(notification); - return Task.CompletedTask; - })], - }, - }); + IMcpClient client = await CreateMcpClientForServer(); var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); Assert.Equal(16, tools.Count); + Channel listChanged = Channel.CreateUnbounded(); var notificationRead = listChanged.Reader.ReadAsync(TestContext.Current.CancellationToken); Assert.False(notificationRead.IsCompleted); @@ -268,17 +255,24 @@ public async Task Can_Be_Notified_Of_Tool_Changes() Assert.NotNull(serverTools); var newTool = McpServerTool.Create([McpServerTool(Name = "NewTool")] () => "42"); - serverTools.Add(newTool); - await notificationRead; + await using (client.RegisterNotificationHandler(NotificationMethods.ToolListChangedNotification, (notification, cancellationToken) => + { + listChanged.Writer.TryWrite(notification); + return Task.CompletedTask; + })) + { + serverTools.Add(newTool); + await notificationRead; - tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); - Assert.Equal(17, tools.Count); - Assert.Contains(tools, t => t.Name == "NewTool"); + tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); + Assert.Equal(17, tools.Count); + Assert.Contains(tools, t => t.Name == "NewTool"); - notificationRead = listChanged.Reader.ReadAsync(TestContext.Current.CancellationToken); - Assert.False(notificationRead.IsCompleted); - serverTools.Remove(newTool); - await notificationRead; + notificationRead = listChanged.Reader.ReadAsync(TestContext.Current.CancellationToken); + Assert.False(notificationRead.IsCompleted); + serverTools.Remove(newTool); + await notificationRead; + } tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); Assert.Equal(16, tools.Count); @@ -626,20 +620,7 @@ public void Create_ExtractsToolAnnotations_SomeSet() [Fact] public async Task HandlesIProgressParameter() { - ConcurrentQueue notifications = new(); - - IMcpClient client = await CreateMcpClientForServer(new() - { - Capabilities = new() - { - NotificationHandlers = [new(NotificationMethods.ProgressNotification, notification => - { - ProgressNotification pn = JsonSerializer.Deserialize(notification.Params)!; - notifications.Enqueue(pn); - return Task.CompletedTask; - })], - }, - }); + IMcpClient client = await CreateMcpClientForServer(); var tools = await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); Assert.NotNull(tools); @@ -647,17 +628,26 @@ public async Task HandlesIProgressParameter() McpClientTool progressTool = tools.First(t => t.Name == nameof(EchoTool.SendsProgressNotifications)); - var result = await client.SendRequestAsync( - RequestMethods.ToolsCall, - new CallToolRequestParams - { - Name = progressTool.ProtocolTool.Name, - Meta = new() { ProgressToken = new("abc123") }, - }, - cancellationToken: TestContext.Current.CancellationToken); + ConcurrentQueue notifications = new(); + await using (client.RegisterNotificationHandler(NotificationMethods.ProgressNotification, (notification, cancellationToken) => + { + ProgressNotification pn = JsonSerializer.Deserialize(notification.Params)!; + notifications.Enqueue(pn); + return Task.CompletedTask; + })) + { + var result = await client.SendRequestAsync( + RequestMethods.ToolsCall, + new CallToolRequestParams + { + Name = progressTool.ProtocolTool.Name, + Meta = new() { ProgressToken = new("abc123") }, + }, + cancellationToken: TestContext.Current.CancellationToken); - Assert.Contains("done", JsonSerializer.Serialize(result)); - SpinWait.SpinUntil(() => notifications.Count == 10, TimeSpan.FromSeconds(10)); + Assert.Contains("done", JsonSerializer.Serialize(result)); + SpinWait.SpinUntil(() => notifications.Count == 10, TimeSpan.FromSeconds(10)); + } ProgressNotification[] array = notifications.OrderBy(n => n.Progress.Progress).ToArray(); Assert.Equal(10, array.Length); diff --git a/tests/ModelContextProtocol.Tests/Protocol/NotificationHandlerTests.cs b/tests/ModelContextProtocol.Tests/Protocol/NotificationHandlerTests.cs new file mode 100644 index 00000000..ed1f4b20 --- /dev/null +++ b/tests/ModelContextProtocol.Tests/Protocol/NotificationHandlerTests.cs @@ -0,0 +1,248 @@ +using Microsoft.Extensions.AI; +using Microsoft.Extensions.DependencyInjection; +using ModelContextProtocol.Client; +using ModelContextProtocol.Protocol.Transport; +using ModelContextProtocol.Server; +using ModelContextProtocol.Tests.Utils; +using System.IO.Pipelines; + +namespace ModelContextProtocol.Tests; + +public class NotificationHandlerTests : LoggedTest, IAsyncDisposable +{ + private readonly Pipe _clientToServerPipe = new(); + private readonly Pipe _serverToClientPipe = new(); + private readonly ServiceProvider _serviceProvider; + private readonly IMcpServerBuilder _builder; + private readonly CancellationTokenSource _cts; + private readonly Task _serverTask; + private readonly IMcpServer _server; + + public NotificationHandlerTests(ITestOutputHelper testOutputHelper) + : base(testOutputHelper) + { + ServiceCollection sc = new(); + sc.AddSingleton(LoggerFactory); + _builder = sc + .AddMcpServer() + .WithStreamServerTransport(_clientToServerPipe.Reader.AsStream(), _serverToClientPipe.Writer.AsStream()); + _serviceProvider = sc.BuildServiceProvider(); + + _cts = CancellationTokenSource.CreateLinkedTokenSource(TestContext.Current.CancellationToken); + _server = _serviceProvider.GetRequiredService(); + _serverTask = _server.RunAsync(_cts.Token); + } + + public async ValueTask DisposeAsync() + { + await _cts.CancelAsync(); + + _clientToServerPipe.Writer.Complete(); + _serverToClientPipe.Writer.Complete(); + + await _serverTask; + + await _serviceProvider.DisposeAsync(); + _cts.Dispose(); + Dispose(); + } + + private async Task CreateMcpClientForServer(McpClientOptions? options = null) + { + return await McpClientFactory.CreateAsync( + new McpServerConfig() + { + Id = "TestServer", + Name = "TestServer", + TransportType = "ignored", + }, + options, + createTransportFunc: (_, _) => new StreamClientTransport( + serverInput: _clientToServerPipe.Writer.AsStream(), + _serverToClientPipe.Reader.AsStream(), + LoggerFactory), + loggerFactory: LoggerFactory, + cancellationToken: TestContext.Current.CancellationToken); + } + + [Fact] + public async Task RegistrationsAreRemovedWhenDisposed() + { + const string NotificationName = "somethingsomething"; + IMcpClient client = await CreateMcpClientForServer(); + + const int Iterations = 10; + + int counter = 0; + for (int i = 0; i < Iterations; i++) + { + var tcs = new TaskCompletionSource(); + await using (client.RegisterNotificationHandler(NotificationName, (notification, cancellationToken) => + { + Interlocked.Increment(ref counter); + tcs.SetResult(true); + return Task.CompletedTask; + })) + { + await _server.SendNotificationAsync(NotificationName, TestContext.Current.CancellationToken); + await tcs.Task; + } + } + + Assert.Equal(Iterations, counter); + } + + [Fact] + public async Task MultipleRegistrationsResultInMultipleCallbacks() + { + const string NotificationName = "somethingsomething"; + IMcpClient client = await CreateMcpClientForServer(); + + const int RegistrationCount = 10; + + int remaining = RegistrationCount; + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + IAsyncDisposable[] registrations = new IAsyncDisposable[RegistrationCount]; + for (int i = 0; i < registrations.Length; i++) + { + registrations[i] = client.RegisterNotificationHandler(NotificationName, (notification, cancellationToken) => + { + int result = Interlocked.Decrement(ref remaining); + Assert.InRange(result, 0, RegistrationCount); + if (result == 0) + { + tcs.TrySetResult(true); + } + + return Task.CompletedTask; + }); + } + + try + { + await _server.SendNotificationAsync(NotificationName, TestContext.Current.CancellationToken); + await tcs.Task; + } + finally + { + for (int i = registrations.Length - 1; i >= 0; i--) + { + await registrations[i].DisposeAsync(); + } + } + } + + [Fact] + public async Task MultipleHandlersRunEvenIfOneThrows() + { + const string NotificationName = "somethingsomething"; + IMcpClient client = await CreateMcpClientForServer(); + + const int RegistrationCount = 10; + + int remaining = RegistrationCount; + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + IAsyncDisposable[] registrations = new IAsyncDisposable[RegistrationCount]; + for (int i = 0; i < registrations.Length; i++) + { + registrations[i] = client.RegisterNotificationHandler(NotificationName, (notification, cancellationToken) => + { + int result = Interlocked.Decrement(ref remaining); + Assert.InRange(result, 0, RegistrationCount); + if (result == 0) + { + tcs.TrySetResult(true); + } + + throw new InvalidOperationException("Test exception"); + }); + } + + try + { + await _server.SendNotificationAsync(NotificationName, TestContext.Current.CancellationToken); + await tcs.Task; + } + finally + { + for (int i = registrations.Length - 1; i >= 0; i--) + { + await registrations[i].DisposeAsync(); + } + } + } + + [Theory] + [InlineData(1)] + [InlineData(3)] + public async Task DisposeAsyncDoesNotCompleteWhileNotificationHandlerRuns(int numberOfDisposals) + { + const string NotificationName = "somethingsomething"; + IMcpClient client = await CreateMcpClientForServer(); + + var handlerRunning = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var releaseHandler = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + IAsyncDisposable registration = client.RegisterNotificationHandler(NotificationName, async (notification, cancellationToken) => + { + handlerRunning.SetResult(true); + await releaseHandler.Task; + }); + + await _server.SendNotificationAsync(NotificationName, TestContext.Current.CancellationToken); + await handlerRunning.Task; + + var disposals = new ValueTask[numberOfDisposals]; + for (int i = 0; i < numberOfDisposals; i++) + { + disposals[i] = registration.DisposeAsync(); + } + + await Task.Delay(1, TestContext.Current.CancellationToken); + + foreach (ValueTask disposal in disposals) + { + Assert.False(disposal.IsCompleted); + } + + releaseHandler.SetResult(true); + + foreach (ValueTask disposal in disposals) + { + await disposal; + } + } + + [Theory] + [InlineData(1)] + [InlineData(3)] + public async Task DisposeAsyncCompletesImmediatelyWhenInvokedFromHandler(int numberOfDisposals) + { + const string NotificationName = "somethingsomething"; + IMcpClient client = await CreateMcpClientForServer(); + + var handlerRunning = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var releaseHandler = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + IAsyncDisposable? registration = null; + registration = client.RegisterNotificationHandler(NotificationName, async (notification, cancellationToken) => + { + for (int i = 0; i < numberOfDisposals; i++) + { + Assert.NotNull(registration); + ValueTask disposal = registration!.DisposeAsync(); + Assert.True(disposal.IsCompletedSuccessfully); + await disposal; + } + + handlerRunning.SetResult(true); + }); + + await _server.SendNotificationAsync(NotificationName, TestContext.Current.CancellationToken); + await handlerRunning.Task; + + ValueTask disposal = registration.DisposeAsync(); + Assert.True(disposal.IsCompletedSuccessfully); + await disposal; + } +} diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs index fb462748..b1ef6c2c 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs @@ -637,6 +637,8 @@ public Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancella throw new NotImplementedException(); public Task RunAsync(CancellationToken cancellationToken = default) => throw new NotImplementedException(); + public IAsyncDisposable RegisterNotificationHandler(string method, Func handler) => + throw new NotImplementedException(); } [Fact] @@ -648,7 +650,7 @@ public async Task NotifyProgress_Should_Be_Handled() var notificationReceived = new TaskCompletionSource(); options.Capabilities = new() { - NotificationHandlers = [new(NotificationMethods.ProgressNotification, notification => + NotificationHandlers = [new(NotificationMethods.ProgressNotification, (notification, cancellationToken) => { notificationReceived.TrySetResult(notification); return Task.CompletedTask; diff --git a/tests/ModelContextProtocol.Tests/SseIntegrationTests.cs b/tests/ModelContextProtocol.Tests/SseIntegrationTests.cs index 91e99898..94cc786e 100644 --- a/tests/ModelContextProtocol.Tests/SseIntegrationTests.cs +++ b/tests/ModelContextProtocol.Tests/SseIntegrationTests.cs @@ -213,9 +213,9 @@ public async Task ConnectAndReceiveNotification_InMemoryServer() { Capabilities = new() { - NotificationHandlers = [new("test/notification", args => + NotificationHandlers = [new("test/notification", (notification, cancellationToken) => { - var msg = args.Params?["message"]?.GetValue(); + var msg = notification.Params?["message"]?.GetValue(); receivedNotification.SetResult(msg); return Task.CompletedTask;