diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java index e3a997ba..39f33c95 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java @@ -160,7 +160,7 @@ public class McpAsyncClient { * @param features the MCP Client supported features. */ McpAsyncClient(McpClientTransport transport, Duration requestTimeout, Duration initializationTimeout, - McpClientFeatures.Async features) { + McpClientFeatures.Async features, boolean connectOnInit) { Assert.notNull(transport, "Transport must not be null"); Assert.notNull(requestTimeout, "Request timeout must not be null"); @@ -235,7 +235,9 @@ public class McpAsyncClient { asyncLoggingNotificationHandler(loggingConsumersFinal)); this.mcpSession = new McpClientSession(requestTimeout, transport, requestHandlers, notificationHandlers); - + if (connectOnInit) { + this.mcpSession.openSSE(); + } } /** @@ -302,6 +304,18 @@ public Mono closeGracefully() { return this.mcpSession.closeGracefully(); } + // --------------------------- + // open an SSE stream + // --------------------------- + /** + * The client may issue an HTTP GET to the MCP endpoint. This can be used to open an + * SSE stream, allowing the server to communicate to the client, without the client + * first sending data via HTTP POST. + */ + public void openSSE() { + this.mcpSession.openSSE(); + } + // -------------------------- // Initialization // -------------------------- diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java index a1dc1168..dac2ee8a 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java @@ -157,11 +157,13 @@ class SyncSpec { private Duration requestTimeout = Duration.ofSeconds(20); // Default timeout + private boolean connectOnInit = true; // Default true, for backward compatibility + private Duration initializationTimeout = Duration.ofSeconds(20); private ClientCapabilities capabilities; - private Implementation clientInfo = new Implementation("Java SDK MCP Client", "1.0.0"); + private Implementation clientInfo = new Implementation("Java SDK MCP Sync Client", "0.10.0"); private final Map roots = new HashMap<>(); @@ -195,6 +197,17 @@ public SyncSpec requestTimeout(Duration requestTimeout) { return this; } + /** + * Sets whether to connect to the server during the initialization phase (open an + * SSE stream). + * @param connectOnInit true to open an SSE stream during the initialization + * @return This builder instance for method chaining + */ + public SyncSpec withConnectOnInit(final boolean connectOnInit) { + this.connectOnInit = connectOnInit; + return this; + } + /** * @param initializationTimeout The duration to wait for the initialization * lifecycle step to complete. @@ -368,8 +381,8 @@ public McpSyncClient build() { McpClientFeatures.Async asyncFeatures = McpClientFeatures.Async.fromSync(syncFeatures); - return new McpSyncClient( - new McpAsyncClient(transport, this.requestTimeout, this.initializationTimeout, asyncFeatures)); + return new McpSyncClient(new McpAsyncClient(transport, this.requestTimeout, this.initializationTimeout, + asyncFeatures, this.connectOnInit)); } } @@ -396,11 +409,13 @@ class AsyncSpec { private Duration requestTimeout = Duration.ofSeconds(20); // Default timeout + private boolean connectOnInit = true; // Default true, for backward compatibility + private Duration initializationTimeout = Duration.ofSeconds(20); private ClientCapabilities capabilities; - private Implementation clientInfo = new Implementation("Spring AI MCP Client", "0.3.1"); + private Implementation clientInfo = new Implementation("Java SDK MCP Async Client", "0.10.0"); private final Map roots = new HashMap<>(); @@ -434,6 +449,17 @@ public AsyncSpec requestTimeout(Duration requestTimeout) { return this; } + /** + * Sets whether to connect to the server during the initialization phase (open an + * SSE stream). + * @param connectOnInit true to open an SSE stream during the initialization + * @return This builder instance for method chaining + */ + public AsyncSpec withConnectOnInit(final boolean connectOnInit) { + this.connectOnInit = connectOnInit; + return this; + } + /** * @param initializationTimeout The duration to wait for the initialization * lifecycle step to complete. @@ -606,7 +632,8 @@ public McpAsyncClient build() { return new McpAsyncClient(this.transport, this.requestTimeout, this.initializationTimeout, new McpClientFeatures.Async(this.clientInfo, this.capabilities, this.roots, this.toolsChangeConsumers, this.resourcesChangeConsumers, this.promptsChangeConsumers, - this.loggingConsumers, this.samplingHandler)); + this.loggingConsumers, this.samplingHandler), + this.connectOnInit); } } diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java index a8fb979e..e9676e8a 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java @@ -137,6 +137,18 @@ public boolean closeGracefully() { return true; } + // --------------------------- + // open an SSE stream + // --------------------------- + /** + * The client may issue an HTTP GET to the MCP endpoint. This can be used to open an + * SSE stream, allowing the server to communicate to the client, without the client + * first sending data via HTTP POST. + */ + public void openSSE() { + this.delegate.openSSE(); + } + /** * The initialization phase MUST be the first interaction between client and server. * During this phase, the client and server: @@ -156,9 +168,7 @@ public boolean closeGracefully() { * The server MUST respond with its own capabilities and information: * {@link McpSchema.ServerCapabilities}.
* After successful initialization, the client MUST send an initialized notification - * to indicate it is ready to begin normal operations. - * - *
+ * to indicate it is ready to begin normal operations.
* * Initialization @@ -280,9 +290,8 @@ public McpSchema.ReadResourceResult readResource(McpSchema.ReadResourceRequest r /** * Resource templates allow servers to expose parameterized resources using URI - * templates. Arguments may be auto-completed through the completion API. - * - * Request a list of resource templates the server has. + * templates. Arguments may be auto-completed through the completion API. Request a + * list of resource templates the server has. * @param cursor the cursor * @return the list of resource templates result. */ @@ -301,9 +310,7 @@ public McpSchema.ListResourceTemplatesResult listResourceTemplates() { /** * Subscriptions. The protocol supports optional subscriptions to resource changes. * Clients can subscribe to specific resources and receive notifications when they - * change. - * - * Send a resources/subscribe request. + * change. Send a resources/subscribe request. * @param subscribeRequest the subscribe request contains the uri of the resource to * subscribe to. */ diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/StreamableHttpClientTransport.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/StreamableHttpClientTransport.java new file mode 100644 index 00000000..6e0c5d7b --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/StreamableHttpClientTransport.java @@ -0,0 +1,421 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + */ + +package io.modelcontextprotocol.client.transport; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; +import java.util.function.Function; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; + +import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.util.Assert; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.util.retry.Retry; + +/** + * A transport implementation for the Model Context Protocol (MCP) using JSON streaming. + * + * @author Aliaksei Darafeyeu + */ +public class StreamableHttpClientTransport implements McpClientTransport { + + private static final Logger LOGGER = LoggerFactory.getLogger(StreamableHttpClientTransport.class); + + private static final String DEFAULT_MCP_ENDPOINT = "/mcp"; + + private static final String MCP_SESSION_ID = "Mcp-Session-Id"; + + private static final String LAST_EVENT_ID = "Last-Event-ID"; + + private static final String ACCEPT = "Accept"; + + private static final String CONTENT_TYPE = "Content-Type"; + + private static final String APPLICATION_JSON = "application/json"; + + private static final String TEXT_EVENT_STREAM = "text/event-stream"; + + private static final String APPLICATION_JSON_SEQ = "application/json-seq"; + + private static final String DEFAULT_ACCEPT_VALUES = "%s, %s".formatted(APPLICATION_JSON, TEXT_EVENT_STREAM); + + private final HttpClientSseClientTransport sseClientTransport; + + private final HttpClient httpClient; + + private final HttpRequest.Builder requestBuilder; + + private final ObjectMapper objectMapper; + + private final URI uri; + + private final AtomicReference lastEventId = new AtomicReference<>(); + + private final AtomicReference mcpSessionId = new AtomicReference<>(); + + private final AtomicBoolean fallbackToSse = new AtomicBoolean(false); + + StreamableHttpClientTransport(final HttpClient httpClient, final HttpRequest.Builder requestBuilder, + final ObjectMapper objectMapper, final String baseUri, final String endpoint, + final HttpClientSseClientTransport sseClientTransport) { + this.httpClient = httpClient; + this.requestBuilder = requestBuilder; + this.objectMapper = objectMapper; + this.uri = URI.create(baseUri + endpoint); + this.sseClientTransport = sseClientTransport; + } + + /** + * Creates a new StreamableHttpClientTransport instance with the specified URI. + * @param uri the URI to connect to + * @return a new Builder instance + */ + public static Builder builder(final String uri) { + return new Builder().withBaseUri(uri); + } + + /** + * A builder for creating instances of WebSocketClientTransport. + */ + public static class Builder { + + private final HttpClient.Builder clientBuilder = HttpClient.newBuilder() + .version(HttpClient.Version.HTTP_1_1) + .connectTimeout(Duration.ofSeconds(10)); + + private final HttpRequest.Builder requestBuilder = HttpRequest.newBuilder(); + + private ObjectMapper objectMapper = new ObjectMapper(); + + private String baseUri; + + private String endpoint = DEFAULT_MCP_ENDPOINT; + + private Consumer clientCustomizer; + + private Consumer requestCustomizer; + + public Builder withCustomizeClient(final Consumer clientCustomizer) { + Assert.notNull(clientCustomizer, "clientCustomizer must not be null"); + clientCustomizer.accept(clientBuilder); + this.clientCustomizer = clientCustomizer; + return this; + } + + public Builder withCustomizeRequest(final Consumer requestCustomizer) { + Assert.notNull(requestCustomizer, "requestCustomizer must not be null"); + requestCustomizer.accept(requestBuilder); + this.requestCustomizer = requestCustomizer; + return this; + } + + public Builder withObjectMapper(final ObjectMapper objectMapper) { + Assert.notNull(objectMapper, "objectMapper must not be null"); + this.objectMapper = objectMapper; + return this; + } + + public Builder withBaseUri(final String baseUri) { + Assert.hasText(baseUri, "baseUri must not be empty"); + this.baseUri = baseUri; + return this; + } + + public Builder withEndpoint(final String endpoint) { + Assert.hasText(endpoint, "endpoint must not be empty"); + this.endpoint = endpoint; + return this; + } + + public StreamableHttpClientTransport build() { + final HttpClientSseClientTransport.Builder builder = HttpClientSseClientTransport.builder(baseUri) + .objectMapper(objectMapper); + if (clientCustomizer != null) { + builder.customizeClient(clientCustomizer); + } + + if (requestCustomizer != null) { + builder.customizeRequest(requestCustomizer); + } + + if (!endpoint.equals(DEFAULT_MCP_ENDPOINT)) { + builder.sseEndpoint(endpoint); + } + + return new StreamableHttpClientTransport(clientBuilder.build(), requestBuilder, objectMapper, baseUri, + endpoint, builder.build()); + } + + } + + @Override + public Mono connect(final Function, Mono> handler) { + if (fallbackToSse.get()) { + return sseClientTransport.connect(handler); + } + + return Mono.defer(() -> Mono.fromFuture(() -> { + final HttpRequest.Builder request = requestBuilder.copy().GET().header(ACCEPT, TEXT_EVENT_STREAM).uri(uri); + final String lastId = lastEventId.get(); + if (lastId != null) { + request.header(LAST_EVENT_ID, lastId); + } + if (mcpSessionId.get() != null) { + request.header(MCP_SESSION_ID, mcpSessionId.get()); + } + + return httpClient.sendAsync(request.build(), HttpResponse.BodyHandlers.ofInputStream()); + }).flatMap(response -> { + // must like server terminate session and the client need to start a + // new session by sending a new `InitializeRequest` without a session + // ID attached. + if (mcpSessionId.get() != null && response.statusCode() == 404) { + mcpSessionId.set(null); + } + + if (response.statusCode() == 405 || response.statusCode() == 404) { + LOGGER.warn("Operation not allowed, falling back to SSE"); + fallbackToSse.set(true); + return sseClientTransport.connect(handler); + } + return handleStreamingResponse(response, handler); + }) + .retryWhen(Retry.backoff(3, Duration.ofSeconds(3)).filter(err -> err instanceof IllegalStateException)) + .onErrorResume(e -> { + LOGGER.error("Streamable transport connection error", e); + return Mono.error(e); + })).doOnTerminate(this::closeGracefully); + } + + @Override + public Mono sendMessage(final McpSchema.JSONRPCMessage message) { + return sendMessage(message, msg -> msg); + } + + public Mono sendMessage(final McpSchema.JSONRPCMessage message, + final Function, Mono> handler) { + if (fallbackToSse.get()) { + return fallbackToSse(message); + } + + return serializeJson(message).flatMap(json -> { + final HttpRequest.Builder request = requestBuilder.copy() + .POST(HttpRequest.BodyPublishers.ofString(json)) + .header(ACCEPT, DEFAULT_ACCEPT_VALUES) + .header(CONTENT_TYPE, APPLICATION_JSON) + .uri(uri); + if (mcpSessionId.get() != null) { + request.header(MCP_SESSION_ID, mcpSessionId.get()); + } + + return Mono.fromFuture(httpClient.sendAsync(request.build(), HttpResponse.BodyHandlers.ofInputStream())) + .flatMap(response -> { + + // server may assign a session ID at initialization time, if yes we + // have to use it for any subsequent requests + if (message instanceof McpSchema.JSONRPCRequest + && ((McpSchema.JSONRPCRequest) message).method().equals(McpSchema.METHOD_INITIALIZE)) { + response.headers() + .firstValue(MCP_SESSION_ID) + .map(String::trim) + .ifPresent(this.mcpSessionId::set); + } + + // If the response is 202 Accepted, there's no body to process + if (response.statusCode() == 202) { + return Mono.empty(); + } + + // must like server terminate session and the client need to start a + // new session by sending a new `InitializeRequest` without a session + // ID attached. + if (mcpSessionId.get() != null && response.statusCode() == 404) { + mcpSessionId.set(null); + } + + if (response.statusCode() == 405 || response.statusCode() == 404) { + LOGGER.warn("Operation not allowed, falling back to SSE"); + fallbackToSse.set(true); + return fallbackToSse(message); + } + + if (response.statusCode() >= 400) { + return Mono + .error(new IllegalArgumentException("Unexpected status code: " + response.statusCode())); + } + + return handleStreamingResponse(response, handler); + }); + }).onErrorResume(e -> { + LOGGER.error("Streamable transport sendMessages error", e); + return Mono.error(e); + }); + + } + + private Mono fallbackToSse(final McpSchema.JSONRPCMessage msg) { + if (msg instanceof McpSchema.JSONRPCBatchRequest batch) { + return Flux.fromIterable(batch.items()).flatMap(sseClientTransport::sendMessage).then(); + } + + if (msg instanceof McpSchema.JSONRPCBatchResponse batch) { + return Flux.fromIterable(batch.items()).flatMap(sseClientTransport::sendMessage).then(); + } + + return sseClientTransport.sendMessage(msg); + } + + private Mono serializeJson(final McpSchema.JSONRPCMessage msg) { + try { + return Mono.just(objectMapper.writeValueAsString(msg)); + } + catch (IOException e) { + LOGGER.error("Error serializing JSON-RPC message", e); + return Mono.error(e); + } + } + + private Mono handleStreamingResponse(final HttpResponse response, + final Function, Mono> handler) { + final String contentType = response.headers().firstValue(CONTENT_TYPE).orElse(""); + if (contentType.contains(APPLICATION_JSON_SEQ)) { + return handleJsonStream(response, handler); + } + else if (contentType.contains(TEXT_EVENT_STREAM)) { + return handleSseStream(response, handler); + } + else if (contentType.contains(APPLICATION_JSON)) { + return handleSingleJson(response, handler); + } + return Mono.error(new UnsupportedOperationException("Unsupported Content-Type: " + contentType)); + } + + private Mono handleSingleJson(final HttpResponse response, + final Function, Mono> handler) { + return Mono.fromCallable(() -> { + try { + final McpSchema.JSONRPCMessage msg = McpSchema.deserializeJsonRpcMessage(objectMapper, + new String(response.body().readAllBytes(), StandardCharsets.UTF_8)); + return handler.apply(Mono.just(msg)); + } + catch (IOException e) { + LOGGER.error("Error processing JSON response", e); + return Mono.error(e); + } + }).flatMap(Function.identity()).then(); + } + + private Mono handleJsonStream(final HttpResponse response, + final Function, Mono> handler) { + return Flux.fromStream(new BufferedReader(new InputStreamReader(response.body())).lines()).flatMap(jsonLine -> { + try { + final McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, jsonLine); + return handler.apply(Mono.just(message)); + } + catch (IOException e) { + LOGGER.error("Error processing JSON line", e); + return Mono.error(e); + } + }).then(); + } + + private Mono handleSseStream(final HttpResponse response, + final Function, Mono> handler) { + return Flux.fromStream(new BufferedReader(new InputStreamReader(response.body())).lines()) + .map(String::trim) + .bufferUntil(String::isEmpty) + .map(eventLines -> { + String event = ""; + String data = ""; + String id = ""; + + for (String line : eventLines) { + if (line.startsWith("event: ")) + event = line.substring(7).trim(); + else if (line.startsWith("data: ")) + data += line.substring(6) + "\n"; + else if (line.startsWith("id: ")) + id = line.substring(4).trim(); + } + + if (data.endsWith("\n")) { + data = data.substring(0, data.length() - 1); + } + + return new FlowSseClient.SseEvent(id, event, data); + }) + .filter(sseEvent -> "message".equals(sseEvent.type())) + .concatMap(sseEvent -> { + String rawData = sseEvent.data().trim(); + try { + JsonNode node = objectMapper.readTree(rawData); + List messages = new ArrayList<>(); + if (node.isArray()) { + for (JsonNode item : node) { + messages.add(McpSchema.deserializeJsonRpcMessage(objectMapper, item.toString())); + } + } + else if (node.isObject()) { + messages.add(McpSchema.deserializeJsonRpcMessage(objectMapper, node.toString())); + } + else { + String warning = "Unexpected JSON in SSE data: " + rawData; + LOGGER.warn(warning); + return Mono.error(new IllegalArgumentException(warning)); + } + + return Flux.fromIterable(messages) + .concatMap(msg -> handler.apply(Mono.just(msg))) + .then(Mono.fromRunnable(() -> { + if (!sseEvent.id().isEmpty()) { + lastEventId.set(sseEvent.id()); + } + })); + } + catch (IOException e) { + LOGGER.error("Error parsing SSE JSON: {}", rawData, e); + return Mono.error(e); + } + }) + .then(); + } + + @Override + public Mono closeGracefully() { + mcpSessionId.set(null); + lastEventId.set(null); + if (fallbackToSse.get()) { + return sseClientTransport.closeGracefully(); + } + return Mono.empty(); + } + + @Override + public T unmarshalFrom(final Object data, final TypeReference typeRef) { + return objectMapper.convertValue(data, typeRef); + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/StreamableHttpServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/StreamableHttpServerTransportProvider.java new file mode 100644 index 00000000..20934f85 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/StreamableHttpServerTransportProvider.java @@ -0,0 +1,276 @@ +package io.modelcontextprotocol.server.transport; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpServerSession; +import io.modelcontextprotocol.spec.McpServerTransport; +import io.modelcontextprotocol.spec.McpServerTransportProvider; +import io.modelcontextprotocol.spec.McpSession; +import io.modelcontextprotocol.spec.StatelessMcpSession; +import jakarta.servlet.AsyncContext; +import jakarta.servlet.ServletException; +import jakarta.servlet.http.HttpServlet; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; + +/** + * @author Aliaksei_Darafeyeu + */ +public class StreamableHttpServerTransportProvider extends HttpServlet implements McpServerTransportProvider { + /** + * Logger for this class + */ + private static final Logger logger = LoggerFactory.getLogger(StreamableHttpServerTransportProvider.class); + + private static final String MCP_SESSION_ID = "Mcp-Session-Id"; + private static final String APPLICATION_JSON = "application/json"; + private static final String TEXT_EVENT_STREAM = "text/event-stream"; + + private McpServerSession.Factory sessionFactory; + + private final ObjectMapper objectMapper; + + private final McpServerTransportProvider legacyTransportProvider; + + private final Set allowedOrigins; + + /** + * Map of active client sessions, keyed by session ID + */ + private final Map sessions = new ConcurrentHashMap<>(); + + public StreamableHttpServerTransportProvider(final ObjectMapper objectMapper, final McpServerTransportProvider legacyTransportProvider, final Set allowedOrigins) { + this.objectMapper = objectMapper; + this.legacyTransportProvider = legacyTransportProvider; + this.allowedOrigins = allowedOrigins; + } + + @Override + public void setSessionFactory(McpServerSession.Factory sessionFactory) { + this.sessionFactory = sessionFactory; + } + + @Override + public Mono notifyClients(String method, Object params) { + if (sessions.isEmpty()) { + logger.debug("No active sessions to broadcast message to"); + return Mono.empty(); + } + + logger.debug("Attempting to broadcast message to {} active sessions", sessions.size()); + return Flux.fromIterable(sessions.values()) + .flatMap(session -> session.sendNotification(method, params) + .doOnError(e -> logger.error("Failed to send message to session {}: {}", session.getId(), e.getMessage())) + .onErrorComplete()) + .then(); + } + + @Override + public Mono closeGracefully() { + logger.debug("Initiating graceful shutdown with {} active sessions", sessions.size()); + return Flux.fromIterable(sessions.values()).flatMap(McpSession::closeGracefully).then(); + } + + @Override + public void destroy() { + closeGracefully().block(); + super.destroy(); + } + + @Override + protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException { + // 1. Origin header check + String origin = req.getHeader("Origin"); + if (origin != null && !allowedOrigins.contains(origin)) { + resp.sendError(HttpServletResponse.SC_FORBIDDEN, "Origin not allowed"); + return; + } + + // 2. Accept header routing + final String accept = Optional.ofNullable(req.getHeader("Accept")).orElse(""); + final List acceptTypes = Arrays.stream(accept.split(",")) + .map(String::trim) + .toList(); + + // todo!!!! + if (!acceptTypes.contains(APPLICATION_JSON) && !acceptTypes.contains(TEXT_EVENT_STREAM)) { + if (legacyTransportProvider instanceof HttpServletSseServerTransportProvider legacy) { + legacy.doPost(req, resp); + } else { + resp.sendError(HttpServletResponse.SC_NOT_ACCEPTABLE, "Legacy transport not available"); + } + return; + } + + // 3. Enable async + final AsyncContext asyncContext = req.startAsync(); + asyncContext.setTimeout(0); + + // resp + resp.setStatus(HttpServletResponse.SC_OK); + resp.setCharacterEncoding("UTF-8"); + + final McpServerTransport transport = new StreamableHttpServerTransport(resp.getOutputStream(), objectMapper); + final McpSession session = getOrCreateSession(req.getHeader(MCP_SESSION_ID), transport); + if (!"stateless".equals(session.getId())) { + resp.setHeader(MCP_SESSION_ID, session.getId()); + } + final Flux messages = parseRequestBodyAsStream(req); + + if (accept.contains(TEXT_EVENT_STREAM)) { + // TODO: Handle streaming JSON-RPC over HTTP + resp.setContentType(TEXT_EVENT_STREAM); + resp.setHeader("Connection", "keep-alive"); + + messages.flatMap(session::handle) + .doOnError(e -> sendError(resp, 500, "Streaming failed: " + e.getMessage())) + .then(transport.closeGracefully()) + .subscribe(); + } else if (accept.contains(APPLICATION_JSON)) { + // TODO: Handle traditional JSON-RPC response + resp.setContentType(APPLICATION_JSON); + + messages.flatMap(session::handle) + .collectList() + .flatMap(responses -> { + try { + String json = new ObjectMapper().writeValueAsString( + responses.size() == 1 ? responses.get(0) : responses + ); + resp.getWriter().write(json); + return transport.closeGracefully(); + } catch (IOException e) { + return Mono.error(e); + } + }) + .doOnError(e -> sendError(resp, 500, "JSON response failed: " + e.getMessage())) + .subscribe(); + + } else { + resp.sendError(HttpServletResponse.SC_NOT_ACCEPTABLE, "Unsupported Accept header"); + } + } + + @Override + protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException { + if (legacyTransportProvider instanceof HttpServletSseServerTransportProvider legacy) { + legacy.doGet(req, resp); + } else { + resp.sendError(HttpServletResponse.SC_NOT_ACCEPTABLE, "Legacy transport not available"); + } + } + + protected void doDelete(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException { + final String sessionId = req.getHeader("mcp-session-id"); + if (sessionId == null || !sessions.containsKey(sessionId)) { + resp.sendError(HttpServletResponse.SC_NOT_FOUND, "Session not found"); + return; + } + + final McpSession session = sessions.remove(sessionId); + session.closeGracefully().subscribe(); + resp.setStatus(HttpServletResponse.SC_NO_CONTENT); + } + + // todo:!!! + private Flux parseRequestBodyAsStream(final HttpServletRequest req) { + return Mono.fromCallable(() -> { + try (final InputStream inputStream = req.getInputStream()) { + final JsonNode node = objectMapper.readTree(inputStream); + if (node.isArray()) { + final List messages = new ArrayList<>(); + for (final JsonNode item : node) { + messages.add(objectMapper.treeToValue(item, McpSchema.JSONRPCMessage.class)); + } + return messages; + } else if (node.isObject()) { + return List.of(objectMapper.treeToValue(node, McpSchema.JSONRPCMessage.class)); + } else { + throw new IllegalArgumentException("Invalid JSON-RPC request: not object or array"); + } + } + }).flatMapMany(Flux::fromIterable); + } + + private McpSession getOrCreateSession(final String sessionId, final McpServerTransport transport) { + if (sessionId != null && sessionFactory != null) { + // Reuse or track sessions if you support that; for now, we just create new ones + return sessions.get(sessionId); + } else if (sessionFactory != null) { + final String newSessionId = UUID.randomUUID().toString(); + return sessions.put(newSessionId, sessionFactory.create(transport)); + } else { + return new StatelessMcpSession(transport); + } + } + + private void sendError(final HttpServletResponse resp, final int code, final String msg) { + try { + resp.sendError(code, msg); + } catch (IOException ignored) { + logger.debug("Exception during send error"); + } + } + + public static class StreamableHttpServerTransport implements McpServerTransport { + private final ObjectMapper objectMapper; + private final OutputStream outputStream; + + public StreamableHttpServerTransport(final OutputStream outputStream, final ObjectMapper objectMapper) { + this.objectMapper = objectMapper; + this.outputStream = outputStream; + } + + @Override + public Mono sendMessage(final McpSchema.JSONRPCMessage message) { + return Mono.fromRunnable(() -> { + try { + String json = objectMapper.writeValueAsString(message); + outputStream.write(json.getBytes(StandardCharsets.UTF_8)); + outputStream.write('\n'); + outputStream.flush(); + } catch (IOException e) { + throw new RuntimeException("Failed to send message", e); + } + }); + } + + @Override + public T unmarshalFrom(Object data, TypeReference typeRef) { + return objectMapper.convertValue(data, typeRef); + } + + @Override + public Mono closeGracefully() { + return Mono.fromRunnable(() -> { + try { + outputStream.flush(); + outputStream.close(); + } catch (IOException e) { + // ignore or log + } + }); + } + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java index f577b493..4c10fba6 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java @@ -61,7 +61,7 @@ public class McpClientSession implements McpSession { /** Atomic counter for generating unique request IDs */ private final AtomicLong requestCounter = new AtomicLong(0); - private final Disposable connection; + private Disposable connection; /** * Functional interface for handling incoming JSON-RPC requests. Implementations @@ -116,6 +116,17 @@ public McpClientSession(Duration requestTimeout, McpClientTransport transport, this.transport = transport; this.requestHandlers.putAll(requestHandlers); this.notificationHandlers.putAll(notificationHandlers); + } + + /** + * The client may issue an HTTP GET to the MCP endpoint. This can be used to open an + * SSE stream, allowing the server to communicate to the client, without the client + * first sending data via HTTP POST. + */ + public void openSSE() { + if (this.connection != null && !this.connection.isDisposed()) { + return; // already connected and still active + } // TODO: consider mono.transformDeferredContextual where the Context contains // the @@ -288,7 +299,9 @@ public Mono closeGracefully() { */ @Override public void close() { - this.connection.dispose(); + if (this.connection != null) { + this.connection.dispose(); + } transport.close(); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java index 8df8a158..93bfb748 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java @@ -10,6 +10,7 @@ import java.util.List; import java.util.Map; +import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; @@ -173,12 +174,37 @@ else if (map.containsKey("result") || map.containsKey("error")) { // --------------------------- // JSON-RPC Message Types // --------------------------- - public sealed interface JSONRPCMessage permits JSONRPCRequest, JSONRPCNotification, JSONRPCResponse { + public sealed interface JSONRPCMessage + permits JSONRPCBatchRequest, JSONRPCBatchResponse, JSONRPCRequest, JSONRPCNotification, JSONRPCResponse { String jsonrpc(); } + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record JSONRPCBatchRequest( // @formatter:off + @JsonProperty("items") List items) implements JSONRPCMessage { + + @Override + @JsonIgnore + public String jsonrpc() { + return JSONRPC_VERSION; + } + } // @formatter:on + + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record JSONRPCBatchResponse( // @formatter:off + @JsonProperty("items") List items) implements JSONRPCMessage { + + @Override + @JsonIgnore + public String jsonrpc() { + return JSONRPC_VERSION; + } + } // @formatter:on + @JsonInclude(JsonInclude.Include.NON_ABSENT) @JsonIgnoreProperties(ignoreUnknown = true) public record JSONRPCRequest( // @formatter:off diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSession.java index 473a860c..28fe44f9 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSession.java @@ -25,6 +25,24 @@ */ public interface McpSession { + /** + * Retrieve the session id. + * @return session id + */ + String getId(); + + /** + * Called by the {@link McpServerTransportProvider} once the session is determined. + * The purpose of this method is to dispatch the message to an appropriate handler as + * specified by the MCP server implementation + * ({@link io.modelcontextprotocol.server.McpAsyncServer} or + * {@link io.modelcontextprotocol.server.McpSyncServer}) via + * {@link McpServerSession.Factory} that the server creates. + * @param message the incoming JSON-RPC message + * @return a Mono that completes when the message is processed + */ + Mono handle(McpSchema.JSONRPCMessage message); + /** * Sends a request to the model counterparty and expects a response of type T. * diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/StatelessMcpSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/StatelessMcpSession.java new file mode 100644 index 00000000..2b911d11 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/StatelessMcpSession.java @@ -0,0 +1,82 @@ +package io.modelcontextprotocol.spec; + +import com.fasterxml.jackson.core.type.TypeReference; +import reactor.core.publisher.Mono; + +import java.util.UUID; + +/** + * @author Aliaksei_Darafeyeu + */ +public class StatelessMcpSession implements McpSession { + + private final McpTransport transport; + + public StatelessMcpSession(final McpTransport transport) { + this.transport = transport; + } + + @Override + public String getId() { + return "stateless"; + } + + @Override + public Mono handle(McpSchema.JSONRPCMessage message) { + if (message instanceof McpSchema.JSONRPCRequest request) { + // Stateless sessions do not support incoming requests + McpSchema.JSONRPCResponse errorResponse = new McpSchema.JSONRPCResponse( + McpSchema.JSONRPC_VERSION, + request.id(), + null, + new McpSchema.JSONRPCResponse.JSONRPCError( + McpSchema.ErrorCodes.METHOD_NOT_FOUND, + "Stateless session does not handle requests", + null + ) + ); + return transport.sendMessage(errorResponse); + } + else if (message instanceof McpSchema.JSONRPCNotification notification) { + // Stateless session ignores incoming notifications + return Mono.empty(); + } + else if (message instanceof McpSchema.JSONRPCResponse response) { + // No request/response correlation in stateless mode + return Mono.empty(); + } + else { + return Mono.empty(); + } + } + + + @Override + public Mono sendRequest(String method, Object requestParams, TypeReference typeRef) { + // Stateless = no request/response correlation + String requestId = UUID.randomUUID().toString(); + McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest( + McpSchema.JSONRPC_VERSION, method, requestId, requestParams + ); + + return Mono.defer(() -> Mono.from(this.transport.sendMessage(request)).then(Mono.error(new IllegalStateException("Stateless session cannot receive responses"))) + ); + } + + @Override + public Mono sendNotification(String method, Object params) { + McpSchema.JSONRPCNotification notification = + new McpSchema.JSONRPCNotification(McpSchema.JSONRPC_VERSION, method, params); + return Mono.from(this.transport.sendMessage(notification)); + } + + @Override + public Mono closeGracefully() { + return this.transport.closeGracefully(); + } + + @Override + public void close() { + this.closeGracefully().subscribe(); + } +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/StreamableHttpClientTransportAsyncTest.java b/mcp/src/test/java/io/modelcontextprotocol/client/StreamableHttpClientTransportAsyncTest.java new file mode 100644 index 00000000..4447d0b5 --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/client/StreamableHttpClientTransportAsyncTest.java @@ -0,0 +1,44 @@ +package io.modelcontextprotocol.client; + +import org.junit.jupiter.api.Timeout; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.containers.wait.strategy.Wait; + +import io.modelcontextprotocol.client.transport.StreamableHttpClientTransport; +import io.modelcontextprotocol.spec.McpClientTransport; + +/** + * Tests for the {@link McpAsyncClient} with {@link StreamableHttpClientTransport}. + * + * @author Aliaksei Darafeyeu + */ +@Timeout(15) +public class StreamableHttpClientTransportAsyncTest extends AbstractMcpAsyncClientTests { + + String host = "http://localhost:3003"; + + // Uses the https://github.com/tzolov/mcp-everything-server-docker-image + @SuppressWarnings("resource") + GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v1") + .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) + .withExposedPorts(3001) + .waitingFor(Wait.forHttp("/").forStatusCode(404)); + + @Override + protected McpClientTransport createMcpTransport() { + return StreamableHttpClientTransport.builder(host).build(); + } + + @Override + protected void onStart() { + container.start(); + int port = container.getMappedPort(3001); + host = "http://" + container.getHost() + ":" + port; + } + + @Override + protected void onClose() { + container.stop(); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/StreamableHttpClientTransportSyncTest.java b/mcp/src/test/java/io/modelcontextprotocol/client/StreamableHttpClientTransportSyncTest.java new file mode 100644 index 00000000..4fc20395 --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/client/StreamableHttpClientTransportSyncTest.java @@ -0,0 +1,44 @@ +package io.modelcontextprotocol.client; + +import org.junit.jupiter.api.Timeout; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.containers.wait.strategy.Wait; + +import io.modelcontextprotocol.client.transport.StreamableHttpClientTransport; +import io.modelcontextprotocol.spec.McpClientTransport; + +/** + * Tests for the {@link McpSyncClient} with {@link StreamableHttpClientTransport}. + * + * @author Aliaksei Darafeyeu + */ +@Timeout(15) +public class StreamableHttpClientTransportSyncTest extends AbstractMcpSyncClientTests { + + String host = "http://localhost:3003"; + + // Uses the https://github.com/tzolov/mcp-everything-server-docker-image + @SuppressWarnings("resource") + GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v1") + .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) + .withExposedPorts(3001) + .waitingFor(Wait.forHttp("/").forStatusCode(404)); + + @Override + protected McpClientTransport createMcpTransport() { + return StreamableHttpClientTransport.builder(host).build(); + } + + @Override + protected void onStart() { + container.start(); + int port = container.getMappedPort(3001); + host = "http://" + container.getHost() + ":" + port; + } + + @Override + protected void onClose() { + container.stop(); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/spec/McpClientSessionTests.java b/mcp/src/test/java/io/modelcontextprotocol/spec/McpClientSessionTests.java index f72be43e..223e17eb 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/spec/McpClientSessionTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/spec/McpClientSessionTests.java @@ -48,6 +48,7 @@ void setUp() { transport = new MockMcpClientTransport(); session = new McpClientSession(TIMEOUT, transport, Map.of(), Map.of(TEST_NOTIFICATION, params -> Mono.fromRunnable(() -> logger.info("Status update: " + params)))); + session.openSSE(); } @AfterEach @@ -141,6 +142,7 @@ void testRequestHandling() { params -> Mono.just(params)); transport = new MockMcpClientTransport(); session = new McpClientSession(TIMEOUT, transport, requestHandlers, Map.of()); + session.openSSE(); // Simulate incoming request McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, ECHO_METHOD, @@ -162,7 +164,7 @@ void testNotificationHandling() { transport = new MockMcpClientTransport(); session = new McpClientSession(TIMEOUT, transport, Map.of(), Map.of(TEST_NOTIFICATION, params -> Mono.fromRunnable(() -> receivedParams.tryEmitValue(params)))); - + session.openSSE(); // Simulate incoming notification from the server Map notificationParams = Map.of("status", "ready");