Skip to content

Commit 3725859

Browse files
committed
Move notification handler registrations to capabilities
Currently request handlers are set on the capability objects, but notification handlers are set after construction via an AddNotificationHandler method on the IMcpEndpoint interface. This moves handler specification to be at construction as well. This makes it more consistent with request handlers, simplifies the IMcpEndpoint interface to just be about message sending, and avoids a concurrency bug that could occur if someone tried to add a handler while the endpoint was processing notifications.
1 parent 8fcdf95 commit 3725859

19 files changed

+279
-188
lines changed

src/ModelContextProtocol/Client/McpClient.cs

Lines changed: 61 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,18 @@
55
using ModelContextProtocol.Protocol.Types;
66
using ModelContextProtocol.Shared;
77
using ModelContextProtocol.Utils.Json;
8+
using System.Diagnostics;
9+
using System.Reflection;
810
using System.Text.Json;
911

1012
namespace ModelContextProtocol.Client;
1113

1214
/// <inheritdoc/>
1315
internal sealed class McpClient : McpEndpoint, IMcpClient
1416
{
17+
/// <summary>Cached naming information used for client name/version when none is specified.</summary>
18+
private static readonly AssemblyName s_asmName = (Assembly.GetEntryAssembly() ?? Assembly.GetExecutingAssembly()).GetName();
19+
1520
private readonly IClientTransport _clientTransport;
1621
private readonly McpClientOptions _options;
1722

@@ -29,43 +34,61 @@ internal sealed class McpClient : McpEndpoint, IMcpClient
2934
/// <param name="options">Options for the client, defining protocol version and capabilities.</param>
3035
/// <param name="serverConfig">The server configuration.</param>
3136
/// <param name="loggerFactory">The logger factory.</param>
32-
public McpClient(IClientTransport clientTransport, McpClientOptions options, McpServerConfig serverConfig, ILoggerFactory? loggerFactory)
37+
public McpClient(IClientTransport clientTransport, McpClientOptions? options, McpServerConfig serverConfig, ILoggerFactory? loggerFactory)
3338
: base(loggerFactory)
3439
{
3540
_clientTransport = clientTransport;
41+
42+
if (options?.ClientInfo is null)
43+
{
44+
options = options?.Clone() ?? new();
45+
options.ClientInfo = new()
46+
{
47+
Name = s_asmName.Name ?? nameof(McpClient),
48+
Version = s_asmName.Version?.ToString() ?? "1.0.0",
49+
};
50+
}
3651
_options = options;
3752

3853
EndpointName = $"Client ({serverConfig.Id}: {serverConfig.Name})";
3954

40-
if (options.Capabilities?.Sampling is { } samplingCapability)
55+
if (options.Capabilities is { } capabilities)
4156
{
42-
if (samplingCapability.SamplingHandler is not { } samplingHandler)
57+
if (capabilities.NotificationHandlers is { } notificationHandlers)
4358
{
44-
throw new InvalidOperationException($"Sampling capability was set but it did not provide a handler.");
59+
NotificationHandlers.AddRange(notificationHandlers);
4560
}
4661

47-
SetRequestHandler(
48-
RequestMethods.SamplingCreateMessage,
49-
(request, cancellationToken) => samplingHandler(
50-
request,
51-
request?.Meta?.ProgressToken is { } token ? new TokenProgress(this, token) : NullProgress.Instance,
52-
cancellationToken),
53-
McpJsonUtilities.JsonContext.Default.CreateMessageRequestParams,
54-
McpJsonUtilities.JsonContext.Default.CreateMessageResult);
55-
}
56-
57-
if (options.Capabilities?.Roots is { } rootsCapability)
58-
{
59-
if (rootsCapability.RootsHandler is not { } rootsHandler)
62+
if (capabilities.Sampling is { } samplingCapability)
6063
{
61-
throw new InvalidOperationException($"Roots capability was set but it did not provide a handler.");
64+
if (samplingCapability.SamplingHandler is not { } samplingHandler)
65+
{
66+
throw new InvalidOperationException($"Sampling capability was set but it did not provide a handler.");
67+
}
68+
69+
RequestHandlers.Set(
70+
RequestMethods.SamplingCreateMessage,
71+
(request, cancellationToken) => samplingHandler(
72+
request,
73+
request?.Meta?.ProgressToken is { } token ? new TokenProgress(this, token) : NullProgress.Instance,
74+
cancellationToken),
75+
McpJsonUtilities.JsonContext.Default.CreateMessageRequestParams,
76+
McpJsonUtilities.JsonContext.Default.CreateMessageResult);
6277
}
6378

64-
SetRequestHandler(
65-
RequestMethods.RootsList,
66-
rootsHandler,
67-
McpJsonUtilities.JsonContext.Default.ListRootsRequestParams,
68-
McpJsonUtilities.JsonContext.Default.ListRootsResult);
79+
if (capabilities.Roots is { } rootsCapability)
80+
{
81+
if (rootsCapability.RootsHandler is not { } rootsHandler)
82+
{
83+
throw new InvalidOperationException($"Roots capability was set but it did not provide a handler.");
84+
}
85+
86+
RequestHandlers.Set(
87+
RequestMethods.RootsList,
88+
rootsHandler,
89+
McpJsonUtilities.JsonContext.Default.ListRootsRequestParams,
90+
McpJsonUtilities.JsonContext.Default.ListRootsResult);
91+
}
6992
}
7093
}
7194

@@ -96,20 +119,21 @@ public async Task ConnectAsync(CancellationToken cancellationToken = default)
96119
using var initializationCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
97120
initializationCts.CancelAfter(_options.InitializationTimeout);
98121

99-
try
100-
{
101-
// Send initialize request
102-
var initializeResponse = await this.SendRequestAsync(
103-
RequestMethods.Initialize,
104-
new InitializeRequestParams
105-
{
106-
ProtocolVersion = _options.ProtocolVersion,
107-
Capabilities = _options.Capabilities ?? new ClientCapabilities(),
108-
ClientInfo = _options.ClientInfo
109-
},
110-
McpJsonUtilities.JsonContext.Default.InitializeRequestParams,
111-
McpJsonUtilities.JsonContext.Default.InitializeResult,
112-
cancellationToken: initializationCts.Token).ConfigureAwait(false);
122+
try
123+
{
124+
// Send initialize request
125+
Debug.Assert(_options.ClientInfo is not null, "ClientInfo should be set by the constructor");
126+
var initializeResponse = await this.SendRequestAsync(
127+
RequestMethods.Initialize,
128+
new InitializeRequestParams
129+
{
130+
ProtocolVersion = _options.ProtocolVersion,
131+
Capabilities = _options.Capabilities ?? new ClientCapabilities(),
132+
ClientInfo = _options.ClientInfo!
133+
},
134+
McpJsonUtilities.JsonContext.Default.InitializeRequestParams,
135+
McpJsonUtilities.JsonContext.Default.InitializeResult,
136+
cancellationToken: initializationCts.Token).ConfigureAwait(false);
113137

114138
// Store server information
115139
_logger.ServerCapabilitiesReceived(EndpointName,

src/ModelContextProtocol/Client/McpClientFactory.cs

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -12,23 +12,6 @@ namespace ModelContextProtocol.Client;
1212
/// <summary>Provides factory methods for creating MCP clients.</summary>
1313
public static class McpClientFactory
1414
{
15-
/// <summary>Default client options to use when none are supplied.</summary>
16-
private static readonly McpClientOptions s_defaultClientOptions = CreateDefaultClientOptions();
17-
18-
/// <summary>Creates default client options to use when no options are supplied.</summary>
19-
private static McpClientOptions CreateDefaultClientOptions()
20-
{
21-
var asmName = (Assembly.GetEntryAssembly() ?? Assembly.GetExecutingAssembly()).GetName();
22-
return new()
23-
{
24-
ClientInfo = new()
25-
{
26-
Name = asmName.Name ?? "McpClient",
27-
Version = asmName.Version?.ToString() ?? "1.0.0",
28-
},
29-
};
30-
}
31-
3215
/// <summary>Creates an <see cref="IMcpClient"/>, connecting it to the specified server.</summary>
3316
/// <param name="serverConfig">Configuration for the target server to which the client should connect.</param>
3417
/// <param name="clientOptions">
@@ -52,7 +35,6 @@ public static async Task<IMcpClient> CreateAsync(
5235
{
5336
Throw.IfNull(serverConfig);
5437

55-
clientOptions ??= s_defaultClientOptions;
5638
createTransportFunc ??= CreateTransport;
5739

5840
string endpointName = $"Client ({serverConfig.Id}: {serverConfig.Name})";

src/ModelContextProtocol/Client/McpClientOptions.cs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ public class McpClientOptions
1212
/// <summary>
1313
/// Information about this client implementation.
1414
/// </summary>
15-
public required Implementation ClientInfo { get; set; }
15+
public Implementation? ClientInfo { get; set; }
1616

1717
/// <summary>
1818
/// Client capabilities to advertise to the server.
@@ -28,4 +28,14 @@ public class McpClientOptions
2828
/// Timeout for initialization sequence.
2929
/// </summary>
3030
public TimeSpan InitializationTimeout { get; set; } = TimeSpan.FromSeconds(60);
31+
32+
/// <summary>Creates a shallow clone of the options.</summary>
33+
internal McpClientOptions Clone() =>
34+
new()
35+
{
36+
ClientInfo = ClientInfo,
37+
Capabilities = Capabilities,
38+
ProtocolVersion = ProtocolVersion,
39+
InitializationTimeout = InitializationTimeout
40+
};
3141
}

src/ModelContextProtocol/IMcpEndpoint.cs

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,4 @@ public interface IMcpEndpoint : IAsyncDisposable
1515
/// <param name="message">The message.</param>
1616
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
1717
Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default);
18-
19-
/// <summary>
20-
/// Adds a handler for server notifications of a specific method.
21-
/// </summary>
22-
/// <param name="method">The notification method to handle.</param>
23-
/// <param name="handler">The async handler function to process notifications.</param>
24-
/// <remarks>
25-
/// <para>
26-
/// Each method may have multiple handlers. Adding a handler for a method that already has one
27-
/// will not replace the existing handler.
28-
/// </para>
29-
/// <para>
30-
/// <see cref="NotificationMethods"> provides constants for common notification methods.</see>
31-
/// </para>
32-
/// </remarks>
33-
void AddNotificationHandler(string method, Func<JsonRpcNotification, Task> handler);
3418
}

src/ModelContextProtocol/Protocol/Transport/McpTransportException.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ public McpTransportException(string message)
3030
/// </summary>
3131
/// <param name="message">The message that describes the error.</param>
3232
/// <param name="innerException">The exception that is the cause of the current exception.</param>
33-
public McpTransportException(string message, Exception innerException)
33+
public McpTransportException(string message, Exception? innerException)
3434
: base(message, innerException)
3535
{
3636
}

src/ModelContextProtocol/Protocol/Transport/StdioClientSessionTransport.cs

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,22 @@ public StdioClientSessionTransport(StdioClientTransportOptions options, Process
2121
/// <inheritdoc/>
2222
public override async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default)
2323
{
24-
if (_process.HasExited)
24+
Exception? processException = null;
25+
bool hasExited = false;
26+
try
27+
{
28+
hasExited = _process.HasExited;
29+
}
30+
catch (Exception e)
31+
{
32+
processException = e;
33+
hasExited = true;
34+
}
35+
36+
if (hasExited)
2537
{
2638
Logger.TransportNotConnected(EndpointName);
27-
throw new McpTransportException("Transport is not connected");
39+
throw new McpTransportException("Transport is not connected", processException);
2840
}
2941

3042
await base.SendMessageAsync(message, cancellationToken).ConfigureAwait(false);
@@ -33,7 +45,7 @@ public override async Task SendMessageAsync(IJsonRpcMessage message, Cancellatio
3345
/// <inheritdoc/>
3446
protected override ValueTask CleanupAsync(CancellationToken cancellationToken)
3547
{
36-
StdioClientTransport.DisposeProcess(_process, processStarted: true, Logger, _options.ShutdownTimeout, EndpointName);
48+
StdioClientTransport.DisposeProcess(_process, processRunning: true, Logger, _options.ShutdownTimeout, EndpointName);
3749

3850
return base.CleanupAsync(cancellationToken);
3951
}

src/ModelContextProtocol/Protocol/Transport/StdioClientTransport.cs

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
using Microsoft.Extensions.Logging.Abstractions;
33
using ModelContextProtocol.Logging;
44
using ModelContextProtocol.Utils;
5+
using System.ComponentModel;
56
using System.Diagnostics;
67
using System.Text;
78

@@ -129,13 +130,25 @@ public async Task<ITransport> ConnectAsync(CancellationToken cancellationToken =
129130
}
130131

131132
internal static void DisposeProcess(
132-
Process? process, bool processStarted, ILogger logger, TimeSpan shutdownTimeout, string endpointName)
133+
Process? process, bool processRunning, ILogger logger, TimeSpan shutdownTimeout, string endpointName)
133134
{
134135
if (process is not null)
135136
{
137+
if (processRunning)
138+
{
139+
try
140+
{
141+
processRunning = !process.HasExited;
142+
}
143+
catch
144+
{
145+
processRunning = false;
146+
}
147+
}
148+
136149
try
137150
{
138-
if (processStarted && !process.HasExited)
151+
if (processRunning)
139152
{
140153
// Wait for the process to exit.
141154
// Kill the while process tree because the process may spawn child processes

src/ModelContextProtocol/Protocol/Types/Capabilities.cs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
using ModelContextProtocol.Server;
1+
using ModelContextProtocol.Protocol.Messages;
2+
using ModelContextProtocol.Server;
23
using System.Text.Json.Serialization;
34

45
namespace ModelContextProtocol.Protocol.Types;
@@ -26,6 +27,14 @@ public class ClientCapabilities
2627
/// </summary>
2728
[JsonPropertyName("sampling")]
2829
public SamplingCapability? Sampling { get; set; }
30+
31+
/// <summary>Gets or sets notification handlers to register with the client.</summary>
32+
/// <remarks>
33+
/// When constructed, the client will enumerate these handlers, which may contain multiple handlers per key.
34+
/// The client will not re-enumerate the sequence.
35+
/// </remarks>
36+
[JsonIgnore]
37+
public IEnumerable<KeyValuePair<string, Func<JsonRpcNotification, Task>>>? NotificationHandlers { get; set; }
2938
}
3039

3140
/// <summary>

src/ModelContextProtocol/Protocol/Types/ServerCapabilities.cs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
using System.Text.Json.Serialization;
1+
using ModelContextProtocol.Protocol.Messages;
2+
using System.Text.Json.Serialization;
23

34
namespace ModelContextProtocol.Protocol.Types;
45

@@ -37,4 +38,12 @@ public class ServerCapabilities
3738
/// </summary>
3839
[JsonPropertyName("tools")]
3940
public ToolsCapability? Tools { get; set; }
41+
42+
/// <summary>Gets or sets notification handlers to register with the server.</summary>
43+
/// <remarks>
44+
/// When constructed, the server will enumerate these handlers, which may contain multiple handlers per key.
45+
/// The server will not re-enumerate the sequence.
46+
/// </remarks>
47+
[JsonIgnore]
48+
public IEnumerable<KeyValuePair<string, Func<JsonRpcNotification, Task>>>? NotificationHandlers { get; set; }
4049
}

0 commit comments

Comments
 (0)