Skip to content

Remove McpServerConfig and have McpClientFactory accept IClientTransport instances directly. #230

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 5 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,18 +31,15 @@ To get started writing a client, the `McpClientFactory.CreateAsync` method is us
to a server. Once you have an `IMcpClient`, you can interact with it, such as to enumerate all available tools and invoke tools.

```csharp
var client = await McpClientFactory.CreateAsync(new()
var clientTransport = new StdioClientTransport(new()
{
Id = "everything",
Name = "Everything",
TransportType = TransportTypes.StdIo,
TransportOptions = new()
{
["command"] = "npx",
["arguments"] = "-y @modelcontextprotocol/server-everything",
}
Command = "npx",
Arguments = ["-y", "@modelcontextprotocol/server-everything"],
});

var client = await McpClientFactory.CreateAsync(clientTransport);

// Print the list of tools available from the server.
foreach (var tool in await client.ListToolsAsync())
{
Expand Down
12 changes: 4 additions & 8 deletions samples/ChatWithTools/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,12 @@
// Connect to an MCP server
Console.WriteLine("Connecting client to MCP 'everything' server");
var mcpClient = await McpClientFactory.CreateAsync(
new()
new StdioClientTransport(new()
{
Id = "everything",
Command = "npx",
Arguments = ["-y", "--verbose", "@modelcontextprotocol/server-everything"],
Name = "Everything",
TransportType = TransportTypes.StdIo,
TransportOptions = new()
{
["command"] = "npx", ["arguments"] = "-y @modelcontextprotocol/server-everything",
}
});
}));

// Get all available tools
Console.WriteLine("Tools available:");
Expand Down
23 changes: 10 additions & 13 deletions samples/QuickstartClient/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,15 @@

var (command, arguments) = GetCommandAndArguments(args);

await using var mcpClient = await McpClientFactory.CreateAsync(new()
var clientTransport = new StdioClientTransport(new()
{
Id = "demo-server",
Name = "Demo Server",
TransportType = TransportTypes.StdIo,
TransportOptions = new()
{
["command"] = command,
["arguments"] = arguments,
}
Command = command,
Arguments = arguments,
});

await using var mcpClient = await McpClientFactory.CreateAsync(clientTransport);

var tools = await mcpClient.ListToolsAsync();
foreach (var tool in tools)
{
Expand Down Expand Up @@ -86,13 +83,13 @@ static void PromptForInput()
///
/// This method would only be required if you're creating a generic client, such as we use for the quickstart.
/// </remarks>
static (string command, string arguments) GetCommandAndArguments(string[] args)
static (string command, string[] arguments) GetCommandAndArguments(string[] args)
{
return args switch
{
[var script] when script.EndsWith(".py") => ("python", script),
[var script] when script.EndsWith(".js") => ("node", script),
[var script] when Directory.Exists(script) || (File.Exists(script) && script.EndsWith(".csproj")) => ("dotnet", $"run --project {script} --no-build"),
_ => ("dotnet", "run --project ../../../../QuickstartWeatherServer --no-build")
[var script] when script.EndsWith(".py") => ("python", args),
[var script] when script.EndsWith(".js") => ("node", args),
[var script] when Directory.Exists(script) || (File.Exists(script) && script.EndsWith(".csproj")) => ("dotnet", ["run", "--project", script, "--no-build"]),
_ => ("dotnet", ["run", "--project", "../../../../QuickstartWeatherServer", "--no-build"])
};
}
101 changes: 101 additions & 0 deletions src/Common/Polyfills/System/PasteArguments.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

// Copied from:
// https://github.com/dotnet/runtime/blob/d2650b6ae7023a2d9d2c74c56116f1f18472ab04/src/libraries/System.Private.CoreLib/src/System/PasteArguments.cs
// and changed from using ValueStringBuilder to StringBuilder.

using System.Text;

namespace System;

internal static partial class PasteArguments
{
internal static void AppendArgument(StringBuilder stringBuilder, string argument)
{
if (stringBuilder.Length != 0)
{
stringBuilder.Append(' ');
}

// Parsing rules for non-argv[0] arguments:
// - Backslash is a normal character except followed by a quote.
// - 2N backslashes followed by a quote ==> N literal backslashes followed by unescaped quote
// - 2N+1 backslashes followed by a quote ==> N literal backslashes followed by a literal quote
// - Parsing stops at first whitespace outside of quoted region.
// - (post 2008 rule): A closing quote followed by another quote ==> literal quote, and parsing remains in quoting mode.
if (argument.Length != 0 && ContainsNoWhitespaceOrQuotes(argument))
{
// Simple case - no quoting or changes needed.
stringBuilder.Append(argument);
}
else
{
stringBuilder.Append(Quote);
int idx = 0;
while (idx < argument.Length)
{
char c = argument[idx++];
if (c == Backslash)
{
int numBackSlash = 1;
while (idx < argument.Length && argument[idx] == Backslash)
{
idx++;
numBackSlash++;
}

if (idx == argument.Length)
{
// We'll emit an end quote after this so must double the number of backslashes.
stringBuilder.Append(Backslash, numBackSlash * 2);
}
else if (argument[idx] == Quote)
{
// Backslashes will be followed by a quote. Must double the number of backslashes.
stringBuilder.Append(Backslash, numBackSlash * 2 + 1);
stringBuilder.Append(Quote);
idx++;
}
else
{
// Backslash will not be followed by a quote, so emit as normal characters.
stringBuilder.Append(Backslash, numBackSlash);
}

continue;
}

if (c == Quote)
{
// Escape the quote so it appears as a literal. This also guarantees that we won't end up generating a closing quote followed
// by another quote (which parses differently pre-2008 vs. post-2008.)
stringBuilder.Append(Backslash);
stringBuilder.Append(Quote);
continue;
}

stringBuilder.Append(c);
}

stringBuilder.Append(Quote);
}
}

private static bool ContainsNoWhitespaceOrQuotes(string s)
{
for (int i = 0; i < s.Length; i++)
{
char c = s[i];
if (char.IsWhiteSpace(c) || c == Quote)
{
return false;
}
}

return true;
}

private const char Quote = '\"';
private const char Backslash = '\\';
}
5 changes: 2 additions & 3 deletions src/ModelContextProtocol/Client/McpClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,16 @@ internal sealed class McpClient : McpEndpoint, IMcpClient
/// </summary>
/// <param name="clientTransport">The transport to use for communication with the server.</param>
/// <param name="options">Options for the client, defining protocol version and capabilities.</param>
/// <param name="serverConfig">The server configuration.</param>
/// <param name="loggerFactory">The logger factory.</param>
public McpClient(IClientTransport clientTransport, McpClientOptions? options, McpServerConfig serverConfig, ILoggerFactory? loggerFactory)
public McpClient(IClientTransport clientTransport, McpClientOptions? options, ILoggerFactory? loggerFactory)
: base(loggerFactory)
{
options ??= new();

_clientTransport = clientTransport;
_options = options;

EndpointName = $"Client ({serverConfig.Id}: {serverConfig.Name})";
EndpointName = clientTransport.Name;

if (options.Capabilities is { } capabilities)
{
Expand Down
109 changes: 10 additions & 99 deletions src/ModelContextProtocol/Client/McpClientFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,128 +12,39 @@ namespace ModelContextProtocol.Client;
public static class McpClientFactory
{
/// <summary>Creates an <see cref="IMcpClient"/>, connecting it to the specified server.</summary>
/// <param name="serverConfig">Configuration for the target server to which the client should connect.</param>
/// <param name="clientTransport">The transport instance used to communicate with the server.</param>
/// <param name="clientOptions">
/// A client configuration object which specifies client capabilities and protocol version.
/// If <see langword="null"/>, details based on the current process will be employed.
/// </param>
/// <param name="createTransportFunc">An optional factory method which returns transport implementations based on a server configuration.</param>
/// <param name="loggerFactory">A logger factory for creating loggers for clients.</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <returns>An <see cref="IMcpClient"/> that's connected to the specified server.</returns>
/// <exception cref="ArgumentNullException"><paramref name="serverConfig"/> is <see langword="null"/>.</exception>
/// <exception cref="ArgumentNullException"><paramref name="clientTransport"/> is <see langword="null"/>.</exception>
/// <exception cref="ArgumentNullException"><paramref name="clientOptions"/> is <see langword="null"/>.</exception>
/// <exception cref="ArgumentException"><paramref name="serverConfig"/> contains invalid information.</exception>
/// <exception cref="InvalidOperationException"><paramref name="createTransportFunc"/> returns an invalid transport.</exception>
public static async Task<IMcpClient> CreateAsync(
McpServerConfig serverConfig,
IClientTransport clientTransport,
McpClientOptions? clientOptions = null,
Func<McpServerConfig, ILoggerFactory?, IClientTransport>? createTransportFunc = null,
ILoggerFactory? loggerFactory = null,
CancellationToken cancellationToken = default)
{
Throw.IfNull(serverConfig);

createTransportFunc ??= CreateTransport;

string endpointName = $"Client ({serverConfig.Id}: {serverConfig.Name})";
Throw.IfNull(clientTransport);

string endpointName = clientTransport.Name;
var logger = loggerFactory?.CreateLogger(typeof(McpClientFactory)) ?? NullLogger.Instance;
logger.CreatingClient(endpointName);

var transport =
createTransportFunc(serverConfig, loggerFactory) ??
throw new InvalidOperationException($"{nameof(createTransportFunc)} returned a null transport.");

McpClient client = new(clientTransport, clientOptions, loggerFactory);
try
{
McpClient client = new(transport, clientOptions, serverConfig, loggerFactory);
try
{
await client.ConnectAsync(cancellationToken).ConfigureAwait(false);
logger.ClientCreated(endpointName);
return client;
}
catch
{
await client.DisposeAsync().ConfigureAwait(false);
throw;
}
await client.ConnectAsync(cancellationToken).ConfigureAwait(false);
logger.ClientCreated(endpointName);
return client;
}
catch
{
if (transport is IAsyncDisposable asyncDisposableTransport)
{
await asyncDisposableTransport.DisposeAsync().ConfigureAwait(false);
}
else if (transport is IDisposable disposableTransport)
{
disposableTransport.Dispose();
}
await client.DisposeAsync().ConfigureAwait(false);
throw;
}
}

private static IClientTransport CreateTransport(McpServerConfig serverConfig, ILoggerFactory? loggerFactory)
{
if (string.Equals(serverConfig.TransportType, TransportTypes.StdIo, StringComparison.OrdinalIgnoreCase))
{
string? command = serverConfig.TransportOptions?.GetValueOrDefault("command");
if (string.IsNullOrWhiteSpace(command))
{
command = serverConfig.Location;
if (string.IsNullOrWhiteSpace(command))
{
throw new ArgumentException("Command is required for stdio transport.", nameof(serverConfig));
}
}

string? arguments = serverConfig.TransportOptions?.GetValueOrDefault("arguments");

if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows) &&
serverConfig.TransportType.Equals(TransportTypes.StdIo, StringComparison.OrdinalIgnoreCase) &&
!string.IsNullOrEmpty(command) &&
!string.Equals(Path.GetFileName(command), "cmd.exe", StringComparison.OrdinalIgnoreCase))
{
// On Windows, for stdio, we need to wrap non-shell commands with cmd.exe /c {command} (usually npx or uvicorn).
// The stdio transport will not work correctly if the command is not run in a shell.
arguments = string.IsNullOrWhiteSpace(arguments) ?
$"/c {command}" :
$"/c {command} {arguments}";
command = "cmd.exe";
}

return new StdioClientTransport(new StdioClientTransportOptions
{
Command = command!,
Arguments = arguments,
WorkingDirectory = serverConfig.TransportOptions?.GetValueOrDefault("workingDirectory"),
EnvironmentVariables = serverConfig.TransportOptions?
.Where(kv => kv.Key.StartsWith("env:", StringComparison.Ordinal))
.ToDictionary(kv => kv.Key.Substring("env:".Length), kv => kv.Value),
ShutdownTimeout = TimeSpan.TryParse(serverConfig.TransportOptions?.GetValueOrDefault("shutdownTimeout"), CultureInfo.InvariantCulture, out var timespan) ? timespan : StdioClientTransportOptions.DefaultShutdownTimeout
}, serverConfig, loggerFactory);
}

if (string.Equals(serverConfig.TransportType, TransportTypes.Sse, StringComparison.OrdinalIgnoreCase) ||
string.Equals(serverConfig.TransportType, "http", StringComparison.OrdinalIgnoreCase))
{
return new SseClientTransport(new SseClientTransportOptions
{
ConnectionTimeout = TimeSpan.FromSeconds(ParseInt32OrDefault(serverConfig.TransportOptions, "connectionTimeout", 30)),
MaxReconnectAttempts = ParseInt32OrDefault(serverConfig.TransportOptions, "maxReconnectAttempts", 3),
ReconnectDelay = TimeSpan.FromSeconds(ParseInt32OrDefault(serverConfig.TransportOptions, "reconnectDelay", 5)),
AdditionalHeaders = serverConfig.TransportOptions?
.Where(kv => kv.Key.StartsWith("header.", StringComparison.Ordinal))
.ToDictionary(kv => kv.Key.Substring("header.".Length), kv => kv.Value)
}, serverConfig, loggerFactory);

static int ParseInt32OrDefault(Dictionary<string, string>? options, string key, int defaultValue) =>
options?.TryGetValue(key, out var value) is not true ? defaultValue :
int.TryParse(value, out var result) ? result :
throw new ArgumentException($"Invalid value '{value}' for option '{key}' in transport options.", nameof(serverConfig));
}

throw new ArgumentException($"Unsupported transport type '{serverConfig.TransportType}'.", nameof(serverConfig));
}
}
34 changes: 0 additions & 34 deletions src/ModelContextProtocol/Configuration/McpServerConfig.cs

This file was deleted.

1 change: 1 addition & 0 deletions src/ModelContextProtocol/ModelContextProtocol.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
<PackageId>ModelContextProtocol</PackageId>
<Description>.NET SDK for the Model Context Protocol (MCP)</Description>
<PackageReadmeFile>README.md</PackageReadmeFile>
<LangVersion>preview</LangVersion>
</PropertyGroup>

<PropertyGroup Condition="'$(TargetFramework)' != 'netstandard2.0'">
Expand Down
Loading