Skip to content

Allow notification handlers post-creation temporary registration #223

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/ModelContextProtocol/Client/McpClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ public McpClient(IClientTransport clientTransport, McpClientOptions? options, Mc
{
if (capabilities.NotificationHandlers is { } notificationHandlers)
{
NotificationHandlers.AddRange(notificationHandlers);
NotificationHandlers.RegisterRange(notificationHandlers);
}

if (capabilities.Sampling is { } samplingCapability)
Expand Down
2 changes: 1 addition & 1 deletion src/ModelContextProtocol/Client/McpClientExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -600,7 +600,7 @@ internal static CreateMessageResult ToCreateMessageResult(this ChatResponse chat
var progressToken = requestParams.Meta?.ProgressToken;

List<ChatResponseUpdate> updates = [];
await foreach (var update in chatClient.GetStreamingResponseAsync(messages, options, cancellationToken))
await foreach (var update in chatClient.GetStreamingResponseAsync(messages, options, cancellationToken).ConfigureAwait(false))
{
updates.Add(update);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,37 @@ public static IMcpServerBuilder WithStdioServerTransport(this IMcpServerBuilder
Throw.IfNull(builder);

builder.Services.AddSingleton<ITransport, StdioServerTransport>();
builder.Services.AddHostedService<StdioMcpServerHostedService>();
builder.Services.AddHostedService<SingleSessionMcpServerHostedService>();

builder.Services.AddSingleton(services =>
{
ITransport serverTransport = services.GetRequiredService<ITransport>();
IOptions<McpServerOptions> options = services.GetRequiredService<IOptions<McpServerOptions>>();
ILoggerFactory? loggerFactory = services.GetService<ILoggerFactory>();

return McpServerFactory.Create(serverTransport, options.Value, loggerFactory, services);
});

return builder;
}

/// <summary>
/// Adds a server transport that uses the specified input and output streams for communication.
/// </summary>
/// <param name="builder">The builder instance.</param>
/// <param name="inputStream">The input <see cref="Stream"/> to use as standard input.</param>
/// <param name="outputStream">The output <see cref="Stream"/> to use as standard output.</param>
public static IMcpServerBuilder WithStreamServerTransport(
this IMcpServerBuilder builder,
Stream inputStream,
Stream outputStream)
{
Throw.IfNull(builder);
Throw.IfNull(inputStream);
Throw.IfNull(outputStream);

builder.Services.AddSingleton<ITransport>(new StreamServerTransport(inputStream, outputStream));
builder.Services.AddHostedService<SingleSessionMcpServerHostedService>();

builder.Services.AddSingleton(services =>
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
namespace ModelContextProtocol.Hosting;

/// <summary>
/// Hosted service for a single-session (i.e stdio) MCP server.
/// Hosted service for a single-session (e.g. stdio) MCP server.
/// </summary>
internal sealed class StdioMcpServerHostedService(IMcpServer session) : BackgroundService
internal sealed class SingleSessionMcpServerHostedService(IMcpServer session) : BackgroundService
{
/// <inheritdoc />
protected override Task ExecuteAsync(CancellationToken stoppingToken) => session.RunAsync(stoppingToken);
Expand Down
6 changes: 6 additions & 0 deletions src/ModelContextProtocol/IMcpEndpoint.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,10 @@ public interface IMcpEndpoint : IAsyncDisposable
/// <param name="message">The message.</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default);

/// <summary>Registers a handler to be invoked when a notification for the specified method is received.</summary>
/// <param name="method">The notification method.</param>
/// <param name="handler">The handler to be invoked.</param>
/// <returns>An <see cref="IDisposable"/> that will remove the registered handler when disposed.</returns>
IAsyncDisposable RegisterNotificationHandler(string method, Func<JsonRpcNotification, CancellationToken, Task> handler);
}
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ public override async ValueTask DisposeAsync()
{
try
{
await CloseAsync();
await CloseAsync().ConfigureAwait(false);
}
catch (Exception)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ public async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken ca
}

// Emit redundant "event: message" lines for better compatibility with other SDKs.
await _outgoingSseChannel.Writer.WriteAsync(new SseItem<IJsonRpcMessage?>(message, SseParser.EventTypeDefault), cancellationToken);
await _outgoingSseChannel.Writer.WriteAsync(new SseItem<IJsonRpcMessage?>(message, SseParser.EventTypeDefault), cancellationToken).ConfigureAwait(false);
}

/// <summary>
Expand All @@ -94,7 +94,7 @@ public async Task OnMessageReceivedAsync(IJsonRpcMessage message, CancellationTo
throw new InvalidOperationException($"Transport is not connected. Make sure to call {nameof(RunAsync)} first.");
}

await _incomingChannel.Writer.WriteAsync(message, cancellationToken);
await _incomingChannel.Writer.WriteAsync(message, cancellationToken).ConfigureAwait(false);
}

private static Channel<T> CreateBoundedChannel<T>(int capacity = 1) =>
Expand Down
2 changes: 1 addition & 1 deletion src/ModelContextProtocol/Protocol/Types/Capabilities.cs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ public class ClientCapabilities
/// The client will not re-enumerate the sequence.
/// </remarks>
[JsonIgnore]
public IEnumerable<KeyValuePair<string, Func<JsonRpcNotification, Task>>>? NotificationHandlers { get; set; }
public IEnumerable<KeyValuePair<string, Func<JsonRpcNotification, CancellationToken, Task>>>? NotificationHandlers { get; set; }
}

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,5 +45,5 @@ public class ServerCapabilities
/// The server will not re-enumerate the sequence.
/// </remarks>
[JsonIgnore]
public IEnumerable<KeyValuePair<string, Func<JsonRpcNotification, Task>>>? NotificationHandlers { get; set; }
public IEnumerable<KeyValuePair<string, Func<JsonRpcNotification, CancellationToken, Task>>>? NotificationHandlers { get; set; }
}
93 changes: 45 additions & 48 deletions src/ModelContextProtocol/Server/McpServer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using ModelContextProtocol.Shared;
using ModelContextProtocol.Utils;
using ModelContextProtocol.Utils.Json;
using System.Diagnostics;

namespace ModelContextProtocol.Server;

Expand Down Expand Up @@ -44,49 +45,43 @@ public McpServer(ITransport transport, McpServerOptions options, ILoggerFactory?
Services = serviceProvider;
_endpointName = $"Server ({options.ServerInfo?.Name ?? DefaultImplementation.Name} {options.ServerInfo?.Version ?? DefaultImplementation.Version})";

_toolsChangedDelegate = delegate
{
_ = SendMessageAsync(new JsonRpcNotification()
{
Method = NotificationMethods.ToolListChangedNotification,
});
};
_promptsChangedDelegate = delegate
{
_ = SendMessageAsync(new JsonRpcNotification()
{
Method = NotificationMethods.PromptListChangedNotification,
});
};
// Configure all request handlers based on the supplied options.
SetInitializeHandler(options);
SetToolsHandler(options);
SetPromptsHandler(options);
SetResourcesHandler(options);
SetSetLoggingLevelHandler(options);
SetCompletionHandler(options);
SetPingHandler();

NotificationHandlers.Add(NotificationMethods.InitializedNotification, _ =>
// Register any notification handlers that were provided.
if (options.Capabilities?.NotificationHandlers is { } notificationHandlers)
{
if (ServerOptions.Capabilities?.Tools?.ToolCollection is { } tools)
{
tools.Changed += _toolsChangedDelegate;
}
NotificationHandlers.RegisterRange(notificationHandlers);
}

if (ServerOptions.Capabilities?.Prompts?.PromptCollection is { } prompts)
// Now that everything has been configured, subscribe to any necessary notifications.
if (ServerOptions.Capabilities?.Tools?.ToolCollection is { } tools)
{
_toolsChangedDelegate = delegate
{
prompts.Changed += _promptsChangedDelegate;
}
_ = SendMessageAsync(new JsonRpcNotification() { Method = NotificationMethods.ToolListChangedNotification });
};

return Task.CompletedTask;
});
tools.Changed += _toolsChangedDelegate;
}

if (options.Capabilities?.NotificationHandlers is { } notificationHandlers)
if (ServerOptions.Capabilities?.Prompts?.PromptCollection is { } prompts)
{
NotificationHandlers.AddRange(notificationHandlers);
}
_promptsChangedDelegate = delegate
{
_ = SendMessageAsync(new JsonRpcNotification() { Method = NotificationMethods.PromptListChangedNotification });
};

SetToolsHandler(options);
SetInitializeHandler(options);
SetCompletionHandler(options);
SetPingHandler();
SetPromptsHandler(options);
SetResourcesHandler(options);
SetSetLoggingLevelHandler(options);
prompts.Changed += _promptsChangedDelegate;
}

// And start the session.
StartSession(transport);
}

Expand Down Expand Up @@ -129,12 +124,14 @@ public async Task RunAsync(CancellationToken cancellationToken = default)

public override async ValueTask DisposeUnsynchronizedAsync()
{
if (ServerOptions.Capabilities?.Tools?.ToolCollection is { } tools)
if (_toolsChangedDelegate is not null &&
ServerOptions.Capabilities?.Tools?.ToolCollection is { } tools)
{
tools.Changed -= _toolsChangedDelegate;
}

if (ServerOptions.Capabilities?.Prompts?.PromptCollection is { } prompts)
if (_promptsChangedDelegate is not null &&
ServerOptions.Capabilities?.Prompts?.PromptCollection is { } prompts)
{
prompts.Changed -= _promptsChangedDelegate;
}
Expand Down Expand Up @@ -179,8 +176,8 @@ private void SetCompletionHandler(McpServerOptions options)
// This capability is not optional, so return an empty result if there is no handler.
RequestHandlers.Set(RequestMethods.CompletionComplete,
options.GetCompletionHandler is { } handler ?
(request, ct) => handler(new(this, request), ct) :
(request, ct) => Task.FromResult(new CompleteResult() { Completion = new() { Values = [], Total = 0, HasMore = false } }),
(request, cancellationToken) => handler(new(this, request), cancellationToken) :
(request, cancellationToken) => Task.FromResult(new CompleteResult() { Completion = new() { Values = [], Total = 0, HasMore = false } }),
McpJsonUtilities.JsonContext.Default.CompleteRequestParams,
McpJsonUtilities.JsonContext.Default.CompleteResult);
}
Expand All @@ -205,20 +202,20 @@ private void SetResourcesHandler(McpServerOptions options)

RequestHandlers.Set(
RequestMethods.ResourcesList,
(request, ct) => listResourcesHandler(new(this, request), ct),
(request, cancellationToken) => listResourcesHandler(new(this, request), cancellationToken),
McpJsonUtilities.JsonContext.Default.ListResourcesRequestParams,
McpJsonUtilities.JsonContext.Default.ListResourcesResult);

RequestHandlers.Set(
RequestMethods.ResourcesRead,
(request, ct) => readResourceHandler(new(this, request), ct),
(request, cancellationToken) => readResourceHandler(new(this, request), cancellationToken),
McpJsonUtilities.JsonContext.Default.ReadResourceRequestParams,
McpJsonUtilities.JsonContext.Default.ReadResourceResult);

listResourceTemplatesHandler ??= (static (_, _) => Task.FromResult(new ListResourceTemplatesResult()));
RequestHandlers.Set(
RequestMethods.ResourcesTemplatesList,
(request, ct) => listResourceTemplatesHandler(new(this, request), ct),
(request, cancellationToken) => listResourceTemplatesHandler(new(this, request), cancellationToken),
McpJsonUtilities.JsonContext.Default.ListResourceTemplatesRequestParams,
McpJsonUtilities.JsonContext.Default.ListResourceTemplatesResult);

Expand All @@ -236,13 +233,13 @@ private void SetResourcesHandler(McpServerOptions options)

RequestHandlers.Set(
RequestMethods.ResourcesSubscribe,
(request, ct) => subscribeHandler(new(this, request), ct),
(request, cancellationToken) => subscribeHandler(new(this, request), cancellationToken),
McpJsonUtilities.JsonContext.Default.SubscribeRequestParams,
McpJsonUtilities.JsonContext.Default.EmptyResult);

RequestHandlers.Set(
RequestMethods.ResourcesUnsubscribe,
(request, ct) => unsubscribeHandler(new(this, request), ct),
(request, cancellationToken) => unsubscribeHandler(new(this, request), cancellationToken),
McpJsonUtilities.JsonContext.Default.UnsubscribeRequestParams,
McpJsonUtilities.JsonContext.Default.EmptyResult);
}
Expand Down Expand Up @@ -329,13 +326,13 @@ await originalListPromptsHandler(request, cancellationToken).ConfigureAwait(fals

RequestHandlers.Set(
RequestMethods.PromptsList,
(request, ct) => listPromptsHandler(new(this, request), ct),
(request, cancellationToken) => listPromptsHandler(new(this, request), cancellationToken),
McpJsonUtilities.JsonContext.Default.ListPromptsRequestParams,
McpJsonUtilities.JsonContext.Default.ListPromptsResult);

RequestHandlers.Set(
RequestMethods.PromptsGet,
(request, ct) => getPromptHandler(new(this, request), ct),
(request, cancellationToken) => getPromptHandler(new(this, request), cancellationToken),
McpJsonUtilities.JsonContext.Default.GetPromptRequestParams,
McpJsonUtilities.JsonContext.Default.GetPromptResult);
}
Expand Down Expand Up @@ -422,13 +419,13 @@ await originalListToolsHandler(request, cancellationToken).ConfigureAwait(false)

RequestHandlers.Set(
RequestMethods.ToolsList,
(request, ct) => listToolsHandler(new(this, request), ct),
(request, cancellationToken) => listToolsHandler(new(this, request), cancellationToken),
McpJsonUtilities.JsonContext.Default.ListToolsRequestParams,
McpJsonUtilities.JsonContext.Default.ListToolsResult);

RequestHandlers.Set(
RequestMethods.ToolsCall,
(request, ct) => callToolHandler(new(this, request), ct),
(request, cancellationToken) => callToolHandler(new(this, request), cancellationToken),
McpJsonUtilities.JsonContext.Default.CallToolRequestParams,
McpJsonUtilities.JsonContext.Default.CallToolResponse);
}
Expand All @@ -447,7 +444,7 @@ private void SetSetLoggingLevelHandler(McpServerOptions options)

RequestHandlers.Set(
RequestMethods.LoggingSetLevel,
(request, ct) => setLoggingLevelHandler(new(this, request), ct),
(request, cancellationToken) => setLoggingLevelHandler(new(this, request), cancellationToken),
McpJsonUtilities.JsonContext.Default.SetLevelRequestParams,
McpJsonUtilities.JsonContext.Default.EmptyResult);
}
Expand Down
5 changes: 4 additions & 1 deletion src/ModelContextProtocol/Shared/McpEndpoint.cs
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,17 @@ protected McpEndpoint(ILoggerFactory? loggerFactory = null)

protected RequestHandlers RequestHandlers { get; } = [];

protected NotificationHandlers NotificationHandlers { get; } = [];
protected NotificationHandlers NotificationHandlers { get; } = new();

public Task<JsonRpcResponse> SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken = default)
=> GetSessionOrThrow().SendRequestAsync(request, cancellationToken);

public Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default)
=> GetSessionOrThrow().SendMessageAsync(message, cancellationToken);

public IAsyncDisposable RegisterNotificationHandler(string method, Func<JsonRpcNotification, CancellationToken, Task> handler) =>
GetSessionOrThrow().RegisterNotificationHandler(method, handler);

/// <summary>
/// Gets the name of the endpoint for logging and debug purposes.
/// </summary>
Expand Down
Loading