diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java index e3a997ba..39f33c95 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java @@ -160,7 +160,7 @@ public class McpAsyncClient { * @param features the MCP Client supported features. */ McpAsyncClient(McpClientTransport transport, Duration requestTimeout, Duration initializationTimeout, - McpClientFeatures.Async features) { + McpClientFeatures.Async features, boolean connectOnInit) { Assert.notNull(transport, "Transport must not be null"); Assert.notNull(requestTimeout, "Request timeout must not be null"); @@ -235,7 +235,9 @@ public class McpAsyncClient { asyncLoggingNotificationHandler(loggingConsumersFinal)); this.mcpSession = new McpClientSession(requestTimeout, transport, requestHandlers, notificationHandlers); - + if (connectOnInit) { + this.mcpSession.openSSE(); + } } /** @@ -302,6 +304,18 @@ public Mono closeGracefully() { return this.mcpSession.closeGracefully(); } + // --------------------------- + // open an SSE stream + // --------------------------- + /** + * The client may issue an HTTP GET to the MCP endpoint. This can be used to open an + * SSE stream, allowing the server to communicate to the client, without the client + * first sending data via HTTP POST. + */ + public void openSSE() { + this.mcpSession.openSSE(); + } + // -------------------------- // Initialization // -------------------------- diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java index a1dc1168..dac2ee8a 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java @@ -157,11 +157,13 @@ class SyncSpec { private Duration requestTimeout = Duration.ofSeconds(20); // Default timeout + private boolean connectOnInit = true; // Default true, for backward compatibility + private Duration initializationTimeout = Duration.ofSeconds(20); private ClientCapabilities capabilities; - private Implementation clientInfo = new Implementation("Java SDK MCP Client", "1.0.0"); + private Implementation clientInfo = new Implementation("Java SDK MCP Sync Client", "0.10.0"); private final Map roots = new HashMap<>(); @@ -195,6 +197,17 @@ public SyncSpec requestTimeout(Duration requestTimeout) { return this; } + /** + * Sets whether to connect to the server during the initialization phase (open an + * SSE stream). + * @param connectOnInit true to open an SSE stream during the initialization + * @return This builder instance for method chaining + */ + public SyncSpec withConnectOnInit(final boolean connectOnInit) { + this.connectOnInit = connectOnInit; + return this; + } + /** * @param initializationTimeout The duration to wait for the initialization * lifecycle step to complete. @@ -368,8 +381,8 @@ public McpSyncClient build() { McpClientFeatures.Async asyncFeatures = McpClientFeatures.Async.fromSync(syncFeatures); - return new McpSyncClient( - new McpAsyncClient(transport, this.requestTimeout, this.initializationTimeout, asyncFeatures)); + return new McpSyncClient(new McpAsyncClient(transport, this.requestTimeout, this.initializationTimeout, + asyncFeatures, this.connectOnInit)); } } @@ -396,11 +409,13 @@ class AsyncSpec { private Duration requestTimeout = Duration.ofSeconds(20); // Default timeout + private boolean connectOnInit = true; // Default true, for backward compatibility + private Duration initializationTimeout = Duration.ofSeconds(20); private ClientCapabilities capabilities; - private Implementation clientInfo = new Implementation("Spring AI MCP Client", "0.3.1"); + private Implementation clientInfo = new Implementation("Java SDK MCP Async Client", "0.10.0"); private final Map roots = new HashMap<>(); @@ -434,6 +449,17 @@ public AsyncSpec requestTimeout(Duration requestTimeout) { return this; } + /** + * Sets whether to connect to the server during the initialization phase (open an + * SSE stream). + * @param connectOnInit true to open an SSE stream during the initialization + * @return This builder instance for method chaining + */ + public AsyncSpec withConnectOnInit(final boolean connectOnInit) { + this.connectOnInit = connectOnInit; + return this; + } + /** * @param initializationTimeout The duration to wait for the initialization * lifecycle step to complete. @@ -606,7 +632,8 @@ public McpAsyncClient build() { return new McpAsyncClient(this.transport, this.requestTimeout, this.initializationTimeout, new McpClientFeatures.Async(this.clientInfo, this.capabilities, this.roots, this.toolsChangeConsumers, this.resourcesChangeConsumers, this.promptsChangeConsumers, - this.loggingConsumers, this.samplingHandler)); + this.loggingConsumers, this.samplingHandler), + this.connectOnInit); } } diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java index a8fb979e..e9676e8a 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java @@ -137,6 +137,18 @@ public boolean closeGracefully() { return true; } + // --------------------------- + // open an SSE stream + // --------------------------- + /** + * The client may issue an HTTP GET to the MCP endpoint. This can be used to open an + * SSE stream, allowing the server to communicate to the client, without the client + * first sending data via HTTP POST. + */ + public void openSSE() { + this.delegate.openSSE(); + } + /** * The initialization phase MUST be the first interaction between client and server. * During this phase, the client and server: @@ -156,9 +168,7 @@ public boolean closeGracefully() { * The server MUST respond with its own capabilities and information: * {@link McpSchema.ServerCapabilities}.
* After successful initialization, the client MUST send an initialized notification - * to indicate it is ready to begin normal operations. - * - *
+ * to indicate it is ready to begin normal operations.
* * Initialization @@ -280,9 +290,8 @@ public McpSchema.ReadResourceResult readResource(McpSchema.ReadResourceRequest r /** * Resource templates allow servers to expose parameterized resources using URI - * templates. Arguments may be auto-completed through the completion API. - * - * Request a list of resource templates the server has. + * templates. Arguments may be auto-completed through the completion API. Request a + * list of resource templates the server has. * @param cursor the cursor * @return the list of resource templates result. */ @@ -301,9 +310,7 @@ public McpSchema.ListResourceTemplatesResult listResourceTemplates() { /** * Subscriptions. The protocol supports optional subscriptions to resource changes. * Clients can subscribe to specific resources and receive notifications when they - * change. - * - * Send a resources/subscribe request. + * change. Send a resources/subscribe request. * @param subscribeRequest the subscribe request contains the uri of the resource to * subscribe to. */ diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/StreamableHttpClientTransport.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/StreamableHttpClientTransport.java new file mode 100644 index 00000000..6e0c5d7b --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/StreamableHttpClientTransport.java @@ -0,0 +1,421 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + */ + +package io.modelcontextprotocol.client.transport; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; +import java.util.function.Function; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; + +import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.util.Assert; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.util.retry.Retry; + +/** + * A transport implementation for the Model Context Protocol (MCP) using JSON streaming. + * + * @author Aliaksei Darafeyeu + */ +public class StreamableHttpClientTransport implements McpClientTransport { + + private static final Logger LOGGER = LoggerFactory.getLogger(StreamableHttpClientTransport.class); + + private static final String DEFAULT_MCP_ENDPOINT = "/mcp"; + + private static final String MCP_SESSION_ID = "Mcp-Session-Id"; + + private static final String LAST_EVENT_ID = "Last-Event-ID"; + + private static final String ACCEPT = "Accept"; + + private static final String CONTENT_TYPE = "Content-Type"; + + private static final String APPLICATION_JSON = "application/json"; + + private static final String TEXT_EVENT_STREAM = "text/event-stream"; + + private static final String APPLICATION_JSON_SEQ = "application/json-seq"; + + private static final String DEFAULT_ACCEPT_VALUES = "%s, %s".formatted(APPLICATION_JSON, TEXT_EVENT_STREAM); + + private final HttpClientSseClientTransport sseClientTransport; + + private final HttpClient httpClient; + + private final HttpRequest.Builder requestBuilder; + + private final ObjectMapper objectMapper; + + private final URI uri; + + private final AtomicReference lastEventId = new AtomicReference<>(); + + private final AtomicReference mcpSessionId = new AtomicReference<>(); + + private final AtomicBoolean fallbackToSse = new AtomicBoolean(false); + + StreamableHttpClientTransport(final HttpClient httpClient, final HttpRequest.Builder requestBuilder, + final ObjectMapper objectMapper, final String baseUri, final String endpoint, + final HttpClientSseClientTransport sseClientTransport) { + this.httpClient = httpClient; + this.requestBuilder = requestBuilder; + this.objectMapper = objectMapper; + this.uri = URI.create(baseUri + endpoint); + this.sseClientTransport = sseClientTransport; + } + + /** + * Creates a new StreamableHttpClientTransport instance with the specified URI. + * @param uri the URI to connect to + * @return a new Builder instance + */ + public static Builder builder(final String uri) { + return new Builder().withBaseUri(uri); + } + + /** + * A builder for creating instances of WebSocketClientTransport. + */ + public static class Builder { + + private final HttpClient.Builder clientBuilder = HttpClient.newBuilder() + .version(HttpClient.Version.HTTP_1_1) + .connectTimeout(Duration.ofSeconds(10)); + + private final HttpRequest.Builder requestBuilder = HttpRequest.newBuilder(); + + private ObjectMapper objectMapper = new ObjectMapper(); + + private String baseUri; + + private String endpoint = DEFAULT_MCP_ENDPOINT; + + private Consumer clientCustomizer; + + private Consumer requestCustomizer; + + public Builder withCustomizeClient(final Consumer clientCustomizer) { + Assert.notNull(clientCustomizer, "clientCustomizer must not be null"); + clientCustomizer.accept(clientBuilder); + this.clientCustomizer = clientCustomizer; + return this; + } + + public Builder withCustomizeRequest(final Consumer requestCustomizer) { + Assert.notNull(requestCustomizer, "requestCustomizer must not be null"); + requestCustomizer.accept(requestBuilder); + this.requestCustomizer = requestCustomizer; + return this; + } + + public Builder withObjectMapper(final ObjectMapper objectMapper) { + Assert.notNull(objectMapper, "objectMapper must not be null"); + this.objectMapper = objectMapper; + return this; + } + + public Builder withBaseUri(final String baseUri) { + Assert.hasText(baseUri, "baseUri must not be empty"); + this.baseUri = baseUri; + return this; + } + + public Builder withEndpoint(final String endpoint) { + Assert.hasText(endpoint, "endpoint must not be empty"); + this.endpoint = endpoint; + return this; + } + + public StreamableHttpClientTransport build() { + final HttpClientSseClientTransport.Builder builder = HttpClientSseClientTransport.builder(baseUri) + .objectMapper(objectMapper); + if (clientCustomizer != null) { + builder.customizeClient(clientCustomizer); + } + + if (requestCustomizer != null) { + builder.customizeRequest(requestCustomizer); + } + + if (!endpoint.equals(DEFAULT_MCP_ENDPOINT)) { + builder.sseEndpoint(endpoint); + } + + return new StreamableHttpClientTransport(clientBuilder.build(), requestBuilder, objectMapper, baseUri, + endpoint, builder.build()); + } + + } + + @Override + public Mono connect(final Function, Mono> handler) { + if (fallbackToSse.get()) { + return sseClientTransport.connect(handler); + } + + return Mono.defer(() -> Mono.fromFuture(() -> { + final HttpRequest.Builder request = requestBuilder.copy().GET().header(ACCEPT, TEXT_EVENT_STREAM).uri(uri); + final String lastId = lastEventId.get(); + if (lastId != null) { + request.header(LAST_EVENT_ID, lastId); + } + if (mcpSessionId.get() != null) { + request.header(MCP_SESSION_ID, mcpSessionId.get()); + } + + return httpClient.sendAsync(request.build(), HttpResponse.BodyHandlers.ofInputStream()); + }).flatMap(response -> { + // must like server terminate session and the client need to start a + // new session by sending a new `InitializeRequest` without a session + // ID attached. + if (mcpSessionId.get() != null && response.statusCode() == 404) { + mcpSessionId.set(null); + } + + if (response.statusCode() == 405 || response.statusCode() == 404) { + LOGGER.warn("Operation not allowed, falling back to SSE"); + fallbackToSse.set(true); + return sseClientTransport.connect(handler); + } + return handleStreamingResponse(response, handler); + }) + .retryWhen(Retry.backoff(3, Duration.ofSeconds(3)).filter(err -> err instanceof IllegalStateException)) + .onErrorResume(e -> { + LOGGER.error("Streamable transport connection error", e); + return Mono.error(e); + })).doOnTerminate(this::closeGracefully); + } + + @Override + public Mono sendMessage(final McpSchema.JSONRPCMessage message) { + return sendMessage(message, msg -> msg); + } + + public Mono sendMessage(final McpSchema.JSONRPCMessage message, + final Function, Mono> handler) { + if (fallbackToSse.get()) { + return fallbackToSse(message); + } + + return serializeJson(message).flatMap(json -> { + final HttpRequest.Builder request = requestBuilder.copy() + .POST(HttpRequest.BodyPublishers.ofString(json)) + .header(ACCEPT, DEFAULT_ACCEPT_VALUES) + .header(CONTENT_TYPE, APPLICATION_JSON) + .uri(uri); + if (mcpSessionId.get() != null) { + request.header(MCP_SESSION_ID, mcpSessionId.get()); + } + + return Mono.fromFuture(httpClient.sendAsync(request.build(), HttpResponse.BodyHandlers.ofInputStream())) + .flatMap(response -> { + + // server may assign a session ID at initialization time, if yes we + // have to use it for any subsequent requests + if (message instanceof McpSchema.JSONRPCRequest + && ((McpSchema.JSONRPCRequest) message).method().equals(McpSchema.METHOD_INITIALIZE)) { + response.headers() + .firstValue(MCP_SESSION_ID) + .map(String::trim) + .ifPresent(this.mcpSessionId::set); + } + + // If the response is 202 Accepted, there's no body to process + if (response.statusCode() == 202) { + return Mono.empty(); + } + + // must like server terminate session and the client need to start a + // new session by sending a new `InitializeRequest` without a session + // ID attached. + if (mcpSessionId.get() != null && response.statusCode() == 404) { + mcpSessionId.set(null); + } + + if (response.statusCode() == 405 || response.statusCode() == 404) { + LOGGER.warn("Operation not allowed, falling back to SSE"); + fallbackToSse.set(true); + return fallbackToSse(message); + } + + if (response.statusCode() >= 400) { + return Mono + .error(new IllegalArgumentException("Unexpected status code: " + response.statusCode())); + } + + return handleStreamingResponse(response, handler); + }); + }).onErrorResume(e -> { + LOGGER.error("Streamable transport sendMessages error", e); + return Mono.error(e); + }); + + } + + private Mono fallbackToSse(final McpSchema.JSONRPCMessage msg) { + if (msg instanceof McpSchema.JSONRPCBatchRequest batch) { + return Flux.fromIterable(batch.items()).flatMap(sseClientTransport::sendMessage).then(); + } + + if (msg instanceof McpSchema.JSONRPCBatchResponse batch) { + return Flux.fromIterable(batch.items()).flatMap(sseClientTransport::sendMessage).then(); + } + + return sseClientTransport.sendMessage(msg); + } + + private Mono serializeJson(final McpSchema.JSONRPCMessage msg) { + try { + return Mono.just(objectMapper.writeValueAsString(msg)); + } + catch (IOException e) { + LOGGER.error("Error serializing JSON-RPC message", e); + return Mono.error(e); + } + } + + private Mono handleStreamingResponse(final HttpResponse response, + final Function, Mono> handler) { + final String contentType = response.headers().firstValue(CONTENT_TYPE).orElse(""); + if (contentType.contains(APPLICATION_JSON_SEQ)) { + return handleJsonStream(response, handler); + } + else if (contentType.contains(TEXT_EVENT_STREAM)) { + return handleSseStream(response, handler); + } + else if (contentType.contains(APPLICATION_JSON)) { + return handleSingleJson(response, handler); + } + return Mono.error(new UnsupportedOperationException("Unsupported Content-Type: " + contentType)); + } + + private Mono handleSingleJson(final HttpResponse response, + final Function, Mono> handler) { + return Mono.fromCallable(() -> { + try { + final McpSchema.JSONRPCMessage msg = McpSchema.deserializeJsonRpcMessage(objectMapper, + new String(response.body().readAllBytes(), StandardCharsets.UTF_8)); + return handler.apply(Mono.just(msg)); + } + catch (IOException e) { + LOGGER.error("Error processing JSON response", e); + return Mono.error(e); + } + }).flatMap(Function.identity()).then(); + } + + private Mono handleJsonStream(final HttpResponse response, + final Function, Mono> handler) { + return Flux.fromStream(new BufferedReader(new InputStreamReader(response.body())).lines()).flatMap(jsonLine -> { + try { + final McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, jsonLine); + return handler.apply(Mono.just(message)); + } + catch (IOException e) { + LOGGER.error("Error processing JSON line", e); + return Mono.error(e); + } + }).then(); + } + + private Mono handleSseStream(final HttpResponse response, + final Function, Mono> handler) { + return Flux.fromStream(new BufferedReader(new InputStreamReader(response.body())).lines()) + .map(String::trim) + .bufferUntil(String::isEmpty) + .map(eventLines -> { + String event = ""; + String data = ""; + String id = ""; + + for (String line : eventLines) { + if (line.startsWith("event: ")) + event = line.substring(7).trim(); + else if (line.startsWith("data: ")) + data += line.substring(6) + "\n"; + else if (line.startsWith("id: ")) + id = line.substring(4).trim(); + } + + if (data.endsWith("\n")) { + data = data.substring(0, data.length() - 1); + } + + return new FlowSseClient.SseEvent(id, event, data); + }) + .filter(sseEvent -> "message".equals(sseEvent.type())) + .concatMap(sseEvent -> { + String rawData = sseEvent.data().trim(); + try { + JsonNode node = objectMapper.readTree(rawData); + List messages = new ArrayList<>(); + if (node.isArray()) { + for (JsonNode item : node) { + messages.add(McpSchema.deserializeJsonRpcMessage(objectMapper, item.toString())); + } + } + else if (node.isObject()) { + messages.add(McpSchema.deserializeJsonRpcMessage(objectMapper, node.toString())); + } + else { + String warning = "Unexpected JSON in SSE data: " + rawData; + LOGGER.warn(warning); + return Mono.error(new IllegalArgumentException(warning)); + } + + return Flux.fromIterable(messages) + .concatMap(msg -> handler.apply(Mono.just(msg))) + .then(Mono.fromRunnable(() -> { + if (!sseEvent.id().isEmpty()) { + lastEventId.set(sseEvent.id()); + } + })); + } + catch (IOException e) { + LOGGER.error("Error parsing SSE JSON: {}", rawData, e); + return Mono.error(e); + } + }) + .then(); + } + + @Override + public Mono closeGracefully() { + mcpSessionId.set(null); + lastEventId.set(null); + if (fallbackToSse.get()) { + return sseClientTransport.closeGracefully(); + } + return Mono.empty(); + } + + @Override + public T unmarshalFrom(final Object data, final TypeReference typeRef) { + return objectMapper.convertValue(data, typeRef); + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java index f577b493..4c10fba6 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java @@ -61,7 +61,7 @@ public class McpClientSession implements McpSession { /** Atomic counter for generating unique request IDs */ private final AtomicLong requestCounter = new AtomicLong(0); - private final Disposable connection; + private Disposable connection; /** * Functional interface for handling incoming JSON-RPC requests. Implementations @@ -116,6 +116,17 @@ public McpClientSession(Duration requestTimeout, McpClientTransport transport, this.transport = transport; this.requestHandlers.putAll(requestHandlers); this.notificationHandlers.putAll(notificationHandlers); + } + + /** + * The client may issue an HTTP GET to the MCP endpoint. This can be used to open an + * SSE stream, allowing the server to communicate to the client, without the client + * first sending data via HTTP POST. + */ + public void openSSE() { + if (this.connection != null && !this.connection.isDisposed()) { + return; // already connected and still active + } // TODO: consider mono.transformDeferredContextual where the Context contains // the @@ -288,7 +299,9 @@ public Mono closeGracefully() { */ @Override public void close() { - this.connection.dispose(); + if (this.connection != null) { + this.connection.dispose(); + } transport.close(); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java index 8df8a158..93bfb748 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java @@ -10,6 +10,7 @@ import java.util.List; import java.util.Map; +import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; @@ -173,12 +174,37 @@ else if (map.containsKey("result") || map.containsKey("error")) { // --------------------------- // JSON-RPC Message Types // --------------------------- - public sealed interface JSONRPCMessage permits JSONRPCRequest, JSONRPCNotification, JSONRPCResponse { + public sealed interface JSONRPCMessage + permits JSONRPCBatchRequest, JSONRPCBatchResponse, JSONRPCRequest, JSONRPCNotification, JSONRPCResponse { String jsonrpc(); } + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record JSONRPCBatchRequest( // @formatter:off + @JsonProperty("items") List items) implements JSONRPCMessage { + + @Override + @JsonIgnore + public String jsonrpc() { + return JSONRPC_VERSION; + } + } // @formatter:on + + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record JSONRPCBatchResponse( // @formatter:off + @JsonProperty("items") List items) implements JSONRPCMessage { + + @Override + @JsonIgnore + public String jsonrpc() { + return JSONRPC_VERSION; + } + } // @formatter:on + @JsonInclude(JsonInclude.Include.NON_ABSENT) @JsonIgnoreProperties(ignoreUnknown = true) public record JSONRPCRequest( // @formatter:off diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/StreamableHttpClientTransportAsyncTest.java b/mcp/src/test/java/io/modelcontextprotocol/client/StreamableHttpClientTransportAsyncTest.java new file mode 100644 index 00000000..4447d0b5 --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/client/StreamableHttpClientTransportAsyncTest.java @@ -0,0 +1,44 @@ +package io.modelcontextprotocol.client; + +import org.junit.jupiter.api.Timeout; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.containers.wait.strategy.Wait; + +import io.modelcontextprotocol.client.transport.StreamableHttpClientTransport; +import io.modelcontextprotocol.spec.McpClientTransport; + +/** + * Tests for the {@link McpAsyncClient} with {@link StreamableHttpClientTransport}. + * + * @author Aliaksei Darafeyeu + */ +@Timeout(15) +public class StreamableHttpClientTransportAsyncTest extends AbstractMcpAsyncClientTests { + + String host = "http://localhost:3003"; + + // Uses the https://github.com/tzolov/mcp-everything-server-docker-image + @SuppressWarnings("resource") + GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v1") + .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) + .withExposedPorts(3001) + .waitingFor(Wait.forHttp("/").forStatusCode(404)); + + @Override + protected McpClientTransport createMcpTransport() { + return StreamableHttpClientTransport.builder(host).build(); + } + + @Override + protected void onStart() { + container.start(); + int port = container.getMappedPort(3001); + host = "http://" + container.getHost() + ":" + port; + } + + @Override + protected void onClose() { + container.stop(); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/StreamableHttpClientTransportSyncTest.java b/mcp/src/test/java/io/modelcontextprotocol/client/StreamableHttpClientTransportSyncTest.java new file mode 100644 index 00000000..4fc20395 --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/client/StreamableHttpClientTransportSyncTest.java @@ -0,0 +1,44 @@ +package io.modelcontextprotocol.client; + +import org.junit.jupiter.api.Timeout; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.containers.wait.strategy.Wait; + +import io.modelcontextprotocol.client.transport.StreamableHttpClientTransport; +import io.modelcontextprotocol.spec.McpClientTransport; + +/** + * Tests for the {@link McpSyncClient} with {@link StreamableHttpClientTransport}. + * + * @author Aliaksei Darafeyeu + */ +@Timeout(15) +public class StreamableHttpClientTransportSyncTest extends AbstractMcpSyncClientTests { + + String host = "http://localhost:3003"; + + // Uses the https://github.com/tzolov/mcp-everything-server-docker-image + @SuppressWarnings("resource") + GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v1") + .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) + .withExposedPorts(3001) + .waitingFor(Wait.forHttp("/").forStatusCode(404)); + + @Override + protected McpClientTransport createMcpTransport() { + return StreamableHttpClientTransport.builder(host).build(); + } + + @Override + protected void onStart() { + container.start(); + int port = container.getMappedPort(3001); + host = "http://" + container.getHost() + ":" + port; + } + + @Override + protected void onClose() { + container.stop(); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/spec/McpClientSessionTests.java b/mcp/src/test/java/io/modelcontextprotocol/spec/McpClientSessionTests.java index f72be43e..223e17eb 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/spec/McpClientSessionTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/spec/McpClientSessionTests.java @@ -48,6 +48,7 @@ void setUp() { transport = new MockMcpClientTransport(); session = new McpClientSession(TIMEOUT, transport, Map.of(), Map.of(TEST_NOTIFICATION, params -> Mono.fromRunnable(() -> logger.info("Status update: " + params)))); + session.openSSE(); } @AfterEach @@ -141,6 +142,7 @@ void testRequestHandling() { params -> Mono.just(params)); transport = new MockMcpClientTransport(); session = new McpClientSession(TIMEOUT, transport, requestHandlers, Map.of()); + session.openSSE(); // Simulate incoming request McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, ECHO_METHOD, @@ -162,7 +164,7 @@ void testNotificationHandling() { transport = new MockMcpClientTransport(); session = new McpClientSession(TIMEOUT, transport, Map.of(), Map.of(TEST_NOTIFICATION, params -> Mono.fromRunnable(() -> receivedParams.tryEmitValue(params)))); - + session.openSSE(); // Simulate incoming notification from the server Map notificationParams = Map.of("status", "ready");