From 8ee7d7c871547d57687c085d960b6bafcb3896b4 Mon Sep 17 00:00:00 2001 From: Alberto Pontini Date: Wed, 21 May 2025 16:15:16 +0200 Subject: [PATCH 1/4] Add an interface for McpServerSession to allow for more extensibility Having an interface rather than a concrete class can help with extensibility of the McpServerSession class. This PR implements that and changes usages around the code to only mention the interface rather than the implementation class --- .../WebFluxSseServerTransportProvider.java | 15 +- .../WebMvcSseServerTransportProvider.java | 17 +- .../server/McpAsyncServer.java | 3 +- ...HttpServletSseServerTransportProvider.java | 15 +- .../StdioServerTransportProvider.java | 24 +- .../spec/McpServerSession.java | 430 ++++-------------- .../spec/McpServerSessionImpl.java | 276 +++++++++++ .../MockMcpServerTransportProvider.java | 2 - .../StdioServerTransportProviderTests.java | 56 +-- 9 files changed, 421 insertions(+), 417 deletions(-) create mode 100644 mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSessionImpl.java diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java index 62264d9a..03dc65aa 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java @@ -1,9 +1,5 @@ package io.modelcontextprotocol.server.transport; -import java.io.IOException; -import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; - import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; import io.modelcontextprotocol.spec.McpError; @@ -12,13 +8,10 @@ import io.modelcontextprotocol.spec.McpServerTransport; import io.modelcontextprotocol.spec.McpServerTransportProvider; import io.modelcontextprotocol.util.Assert; +import java.io.IOException; +import java.util.concurrent.ConcurrentHashMap; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import reactor.core.Exceptions; -import reactor.core.publisher.Flux; -import reactor.core.publisher.FluxSink; -import reactor.core.publisher.Mono; - import org.springframework.http.HttpStatus; import org.springframework.http.MediaType; import org.springframework.http.codec.ServerSentEvent; @@ -26,6 +19,10 @@ import org.springframework.web.reactive.function.server.RouterFunctions; import org.springframework.web.reactive.function.server.ServerRequest; import org.springframework.web.reactive.function.server.ServerResponse; +import reactor.core.Exceptions; +import reactor.core.publisher.Flux; +import reactor.core.publisher.FluxSink; +import reactor.core.publisher.Mono; /** * Server-side implementation of the MCP (Model Context Protocol) HTTP transport using diff --git a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java index fc86cfaa..17cc9f84 100644 --- a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java @@ -4,31 +4,28 @@ package io.modelcontextprotocol.server.transport; -import java.io.IOException; -import java.time.Duration; -import java.util.Map; -import java.util.UUID; -import java.util.concurrent.ConcurrentHashMap; - import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpServerSession; import io.modelcontextprotocol.spec.McpServerTransport; import io.modelcontextprotocol.spec.McpServerTransportProvider; -import io.modelcontextprotocol.spec.McpServerSession; import io.modelcontextprotocol.util.Assert; +import java.io.IOException; +import java.time.Duration; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; - import org.springframework.http.HttpStatus; import org.springframework.web.servlet.function.RouterFunction; import org.springframework.web.servlet.function.RouterFunctions; import org.springframework.web.servlet.function.ServerRequest; import org.springframework.web.servlet.function.ServerResponse; import org.springframework.web.servlet.function.ServerResponse.SseBuilder; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; /** * Server-side implementation of the Model Context Protocol (MCP) transport layer using diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java index 1efa13de..f143bbcd 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java @@ -4,6 +4,7 @@ package io.modelcontextprotocol.server; +import io.modelcontextprotocol.spec.McpServerSessionImpl; import java.time.Duration; import java.util.ArrayList; import java.util.HashMap; @@ -184,7 +185,7 @@ public class McpAsyncServer { asyncRootsListChangedNotificationHandler(rootsChangeConsumers)); mcpTransportProvider.setSessionFactory( - transport -> new McpServerSession(UUID.randomUUID().toString(), requestTimeout, transport, + transport -> new McpServerSessionImpl(UUID.randomUUID().toString(), requestTimeout, transport, this::asyncInitializeRequestHandler, Mono::empty, requestHandlers, notificationHandlers)); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java index afdbff47..1c527e42 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java @@ -3,14 +3,6 @@ */ package io.modelcontextprotocol.server.transport; -import java.io.BufferedReader; -import java.io.IOException; -import java.io.PrintWriter; -import java.util.Map; -import java.util.UUID; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.atomic.AtomicBoolean; - import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; import io.modelcontextprotocol.spec.McpError; @@ -25,6 +17,13 @@ import jakarta.servlet.http.HttpServlet; import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletResponse; +import java.io.BufferedReader; +import java.io.IOException; +import java.io.PrintWriter; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicBoolean; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import reactor.core.publisher.Flux; diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java index 819da977..dc4c7d4d 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java @@ -4,18 +4,6 @@ package io.modelcontextprotocol.server.transport; -import java.io.BufferedReader; -import java.io.IOException; -import java.io.InputStream; -import java.io.InputStreamReader; -import java.io.OutputStream; -import java.io.Reader; -import java.nio.charset.StandardCharsets; -import java.util.Map; -import java.util.concurrent.Executors; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.function.Function; - import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; import io.modelcontextprotocol.spec.McpError; @@ -24,7 +12,17 @@ import io.modelcontextprotocol.spec.McpServerSession; import io.modelcontextprotocol.spec.McpServerTransport; import io.modelcontextprotocol.spec.McpServerTransportProvider; +import io.modelcontextprotocol.spec.McpSession; import io.modelcontextprotocol.util.Assert; +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.io.OutputStream; +import java.nio.charset.StandardCharsets; +import java.util.concurrent.Executors; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Function; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import reactor.core.publisher.Flux; @@ -50,7 +48,7 @@ public class StdioServerTransportProvider implements McpServerTransportProvider private final OutputStream outputStream; - private McpServerSession session; + private McpSession session; private final AtomicBoolean isClosing = new AtomicBoolean(false); diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java index 86906d85..4ab73be5 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java @@ -1,353 +1,89 @@ package io.modelcontextprotocol.spec; -import java.time.Duration; -import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.atomic.AtomicLong; -import java.util.concurrent.atomic.AtomicReference; - -import com.fasterxml.jackson.core.type.TypeReference; import io.modelcontextprotocol.server.McpAsyncServerExchange; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; import reactor.core.publisher.Mono; -import reactor.core.publisher.MonoSink; -import reactor.core.publisher.Sinks; - -/** - * Represents a Model Control Protocol (MCP) session on the server side. It manages - * bidirectional JSON-RPC communication with the client. - */ -public class McpServerSession implements McpSession { - - private static final Logger logger = LoggerFactory.getLogger(McpServerSession.class); - - private final ConcurrentHashMap> pendingResponses = new ConcurrentHashMap<>(); - - private final String id; - - /** Duration to wait for request responses before timing out */ - private final Duration requestTimeout; - - private final AtomicLong requestCounter = new AtomicLong(0); - - private final InitRequestHandler initRequestHandler; - - private final InitNotificationHandler initNotificationHandler; - - private final Map> requestHandlers; - - private final Map notificationHandlers; - - private final McpServerTransport transport; - - private final Sinks.One exchangeSink = Sinks.one(); - - private final AtomicReference clientCapabilities = new AtomicReference<>(); - - private final AtomicReference clientInfo = new AtomicReference<>(); - - private static final int STATE_UNINITIALIZED = 0; - - private static final int STATE_INITIALIZING = 1; - - private static final int STATE_INITIALIZED = 2; - - private final AtomicInteger state = new AtomicInteger(STATE_UNINITIALIZED); - - /** - * Creates a new server session with the given parameters and the transport to use. - * @param id session id - * @param transport the transport to use - * @param initHandler called when a - * {@link io.modelcontextprotocol.spec.McpSchema.InitializeRequest} is received by the - * server - * @param initNotificationHandler called when a - * {@link io.modelcontextprotocol.spec.McpSchema#METHOD_NOTIFICATION_INITIALIZED} is - * received. - * @param requestHandlers map of request handlers to use - * @param notificationHandlers map of notification handlers to use - */ - public McpServerSession(String id, Duration requestTimeout, McpServerTransport transport, - InitRequestHandler initHandler, InitNotificationHandler initNotificationHandler, - Map> requestHandlers, Map notificationHandlers) { - this.id = id; - this.requestTimeout = requestTimeout; - this.transport = transport; - this.initRequestHandler = initHandler; - this.initNotificationHandler = initNotificationHandler; - this.requestHandlers = requestHandlers; - this.notificationHandlers = notificationHandlers; - } - - /** - * Retrieve the session id. - * @return session id - */ - public String getId() { - return this.id; - } - - /** - * Called upon successful initialization sequence between the client and the server - * with the client capabilities and information. - * - * Initialization - * Spec - * @param clientCapabilities the capabilities the connected client provides - * @param clientInfo the information about the connected client - */ - public void init(McpSchema.ClientCapabilities clientCapabilities, McpSchema.Implementation clientInfo) { - this.clientCapabilities.lazySet(clientCapabilities); - this.clientInfo.lazySet(clientInfo); - } - - private String generateRequestId() { - return this.id + "-" + this.requestCounter.getAndIncrement(); - } - - @Override - public Mono sendRequest(String method, Object requestParams, TypeReference typeRef) { - String requestId = this.generateRequestId(); - - return Mono.create(sink -> { - this.pendingResponses.put(requestId, sink); - McpSchema.JSONRPCRequest jsonrpcRequest = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, method, - requestId, requestParams); - this.transport.sendMessage(jsonrpcRequest).subscribe(v -> { - }, error -> { - this.pendingResponses.remove(requestId); - sink.error(error); - }); - }).timeout(requestTimeout).handle((jsonRpcResponse, sink) -> { - if (jsonRpcResponse.error() != null) { - sink.error(new McpError(jsonRpcResponse.error())); - } - else { - if (typeRef.getType().equals(Void.class)) { - sink.complete(); - } - else { - sink.next(this.transport.unmarshalFrom(jsonRpcResponse.result(), typeRef)); - } - } - }); - } - - @Override - public Mono sendNotification(String method, Object params) { - McpSchema.JSONRPCNotification jsonrpcNotification = new McpSchema.JSONRPCNotification(McpSchema.JSONRPC_VERSION, - method, params); - return this.transport.sendMessage(jsonrpcNotification); - } - - /** - * Called by the {@link McpServerTransportProvider} once the session is determined. - * The purpose of this method is to dispatch the message to an appropriate handler as - * specified by the MCP server implementation - * ({@link io.modelcontextprotocol.server.McpAsyncServer} or - * {@link io.modelcontextprotocol.server.McpSyncServer}) via - * {@link McpServerSession.Factory} that the server creates. - * @param message the incoming JSON-RPC message - * @return a Mono that completes when the message is processed - */ - public Mono handle(McpSchema.JSONRPCMessage message) { - return Mono.defer(() -> { - // TODO handle errors for communication to without initialization happening - // first - if (message instanceof McpSchema.JSONRPCResponse response) { - logger.debug("Received Response: {}", response); - var sink = pendingResponses.remove(response.id()); - if (sink == null) { - logger.warn("Unexpected response for unknown id {}", response.id()); - } - else { - sink.success(response); - } - return Mono.empty(); - } - else if (message instanceof McpSchema.JSONRPCRequest request) { - logger.debug("Received request: {}", request); - return handleIncomingRequest(request).onErrorResume(error -> { - var errorResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), null, - new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, - error.getMessage(), null)); - // TODO: Should the error go to SSE or back as POST return? - return this.transport.sendMessage(errorResponse).then(Mono.empty()); - }).flatMap(this.transport::sendMessage); - } - else if (message instanceof McpSchema.JSONRPCNotification notification) { - // TODO handle errors for communication to without initialization - // happening first - logger.debug("Received notification: {}", notification); - // TODO: in case of error, should the POST request be signalled? - return handleIncomingNotification(notification) - .doOnError(error -> logger.error("Error handling notification: {}", error.getMessage())); - } - else { - logger.warn("Received unknown message type: {}", message); - return Mono.empty(); - } - }); - } - - /** - * Handles an incoming JSON-RPC request by routing it to the appropriate handler. - * @param request The incoming JSON-RPC request - * @return A Mono containing the JSON-RPC response - */ - private Mono handleIncomingRequest(McpSchema.JSONRPCRequest request) { - return Mono.defer(() -> { - Mono resultMono; - if (McpSchema.METHOD_INITIALIZE.equals(request.method())) { - // TODO handle situation where already initialized! - McpSchema.InitializeRequest initializeRequest = transport.unmarshalFrom(request.params(), - new TypeReference() { - }); - - this.state.lazySet(STATE_INITIALIZING); - this.init(initializeRequest.capabilities(), initializeRequest.clientInfo()); - resultMono = this.initRequestHandler.handle(initializeRequest); - } - else { - // TODO handle errors for communication to this session without - // initialization happening first - var handler = this.requestHandlers.get(request.method()); - if (handler == null) { - MethodNotFoundError error = getMethodNotFoundError(request.method()); - return Mono.just(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), null, - new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.METHOD_NOT_FOUND, - error.message(), error.data()))); - } - - resultMono = this.exchangeSink.asMono().flatMap(exchange -> handler.handle(exchange, request.params())); - } - return resultMono - .map(result -> new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), result, null)) - .onErrorResume(error -> Mono.just(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), - null, new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, - error.getMessage(), null)))); // TODO: add error message - // through the data field - }); - } - - /** - * Handles an incoming JSON-RPC notification by routing it to the appropriate handler. - * @param notification The incoming JSON-RPC notification - * @return A Mono that completes when the notification is processed - */ - private Mono handleIncomingNotification(McpSchema.JSONRPCNotification notification) { - return Mono.defer(() -> { - if (McpSchema.METHOD_NOTIFICATION_INITIALIZED.equals(notification.method())) { - this.state.lazySet(STATE_INITIALIZED); - exchangeSink.tryEmitValue(new McpAsyncServerExchange(this, clientCapabilities.get(), clientInfo.get())); - return this.initNotificationHandler.handle(); - } - - var handler = notificationHandlers.get(notification.method()); - if (handler == null) { - logger.error("No handler registered for notification method: {}", notification.method()); - return Mono.empty(); - } - return this.exchangeSink.asMono().flatMap(exchange -> handler.handle(exchange, notification.params())); - }); - } - - record MethodNotFoundError(String method, String message, Object data) { - } - - private MethodNotFoundError getMethodNotFoundError(String method) { - return new MethodNotFoundError(method, "Method not found: " + method, null); - } - - @Override - public Mono closeGracefully() { - return this.transport.closeGracefully(); - } - - @Override - public void close() { - this.transport.close(); - } - - /** - * Request handler for the initialization request. - */ - public interface InitRequestHandler { - - /** - * Handles the initialization request. - * @param initializeRequest the initialization request by the client - * @return a Mono that will emit the result of the initialization - */ - Mono handle(McpSchema.InitializeRequest initializeRequest); - - } - - /** - * Notification handler for the initialization notification from the client. - */ - public interface InitNotificationHandler { - - /** - * Specifies an action to take upon successful initialization. - * @return a Mono that will complete when the initialization is acted upon. - */ - Mono handle(); - - } - - /** - * A handler for client-initiated notifications. - */ - public interface NotificationHandler { - - /** - * Handles a notification from the client. - * @param exchange the exchange associated with the client that allows calling - * back to the connected client or inspecting its capabilities. - * @param params the parameters of the notification. - * @return a Mono that completes once the notification is handled. - */ - Mono handle(McpAsyncServerExchange exchange, Object params); - - } - - /** - * A handler for client-initiated requests. - * - * @param the type of the response that is expected as a result of handling the - * request. - */ - public interface RequestHandler { - - /** - * Handles a request from the client. - * @param exchange the exchange associated with the client that allows calling - * back to the connected client or inspecting its capabilities. - * @param params the parameters of the request. - * @return a Mono that will emit the response to the request. - */ - Mono handle(McpAsyncServerExchange exchange, Object params); - - } - - /** - * Factory for creating server sessions which delegate to a provided 1:1 transport - * with a connected client. - */ - @FunctionalInterface - public interface Factory { - - /** - * Creates a new 1:1 representation of the client-server interaction. - * @param sessionTransport the transport to use for communication with the client. - * @return a new server session. - */ - McpServerSession create(McpServerTransport sessionTransport); - - } +public interface McpServerSession extends McpSession { + + String getId(); + + Mono handle(McpSchema.JSONRPCMessage message); + + /** + * Request handler for the initialization request. + */ + interface InitRequestHandler { + + /** + * Handles the initialization request. + * @param initializeRequest the initialization request by the client + * @return a Mono that will emit the result of the initialization + */ + Mono handle(McpSchema.InitializeRequest initializeRequest); + + } + + /** + * Notification handler for the initialization notification from the client. + */ + interface InitNotificationHandler { + + /** + * Specifies an action to take upon successful initialization. + * @return a Mono that will complete when the initialization is acted upon. + */ + Mono handle(); + + } + + /** + * A handler for client-initiated notifications. + */ + interface NotificationHandler { + + /** + * Handles a notification from the client. + * @param exchange the exchange associated with the client that allows calling + * back to the connected client or inspecting its capabilities. + * @param params the parameters of the notification. + * @return a Mono that completes once the notification is handled. + */ + Mono handle(McpAsyncServerExchange exchange, Object params); + + } + + /** + * A handler for client-initiated requests. + * + * @param the type of the response that is expected as a result of handling the + * request. + */ + interface RequestHandler { + + /** + * Handles a request from the client. + * @param exchange the exchange associated with the client that allows calling + * back to the connected client or inspecting its capabilities. + * @param params the parameters of the request. + * @return a Mono that will emit the response to the request. + */ + Mono handle(McpAsyncServerExchange exchange, Object params); + + } + + /** + * Factory for creating server sessions which delegate to a provided 1:1 transport + * with a connected client. + */ + @FunctionalInterface + interface Factory { + + /** + * Creates a new 1:1 representation of the client-server interaction. + * @param sessionTransport the transport to use for communication with the client. + * @return a new server session. + */ + McpServerSession create(McpServerTransport sessionTransport); + + } } diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSessionImpl.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSessionImpl.java new file mode 100644 index 00000000..ad22d863 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSessionImpl.java @@ -0,0 +1,276 @@ +package io.modelcontextprotocol.spec; + +import java.time.Duration; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; + +import com.fasterxml.jackson.core.type.TypeReference; +import io.modelcontextprotocol.server.McpAsyncServerExchange; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Mono; +import reactor.core.publisher.MonoSink; +import reactor.core.publisher.Sinks; + +/** + * Represents a Model Control Protocol (MCP) session on the server side. It manages + * bidirectional JSON-RPC communication with the client. + */ +public class McpServerSessionImpl implements McpServerSession { + + private static final Logger logger = LoggerFactory.getLogger(McpServerSessionImpl.class); + + private final ConcurrentHashMap> pendingResponses = new ConcurrentHashMap<>(); + + private final String id; + + /** Duration to wait for request responses before timing out */ + private final Duration requestTimeout; + + private final AtomicLong requestCounter = new AtomicLong(0); + + private final InitRequestHandler initRequestHandler; + + private final InitNotificationHandler initNotificationHandler; + + private final Map> requestHandlers; + + private final Map notificationHandlers; + + private final McpServerTransport transport; + + private final Sinks.One exchangeSink = Sinks.one(); + + private final AtomicReference clientCapabilities = new AtomicReference<>(); + + private final AtomicReference clientInfo = new AtomicReference<>(); + + private static final int STATE_UNINITIALIZED = 0; + + private static final int STATE_INITIALIZING = 1; + + private static final int STATE_INITIALIZED = 2; + + private final AtomicInteger state = new AtomicInteger(STATE_UNINITIALIZED); + + /** + * Creates a new server session with the given parameters and the transport to use. + * @param id session id + * @param transport the transport to use + * @param initHandler called when a + * {@link io.modelcontextprotocol.spec.McpSchema.InitializeRequest} is received by the + * server + * @param initNotificationHandler called when a + * {@link io.modelcontextprotocol.spec.McpSchema#METHOD_NOTIFICATION_INITIALIZED} is + * received. + * @param requestHandlers map of request handlers to use + * @param notificationHandlers map of notification handlers to use + */ + public McpServerSessionImpl(String id, Duration requestTimeout, McpServerTransport transport, + InitRequestHandler initHandler, InitNotificationHandler initNotificationHandler, + Map> requestHandlers, Map notificationHandlers) { + this.id = id; + this.requestTimeout = requestTimeout; + this.transport = transport; + this.initRequestHandler = initHandler; + this.initNotificationHandler = initNotificationHandler; + this.requestHandlers = requestHandlers; + this.notificationHandlers = notificationHandlers; + } + + /** + * Retrieve the session id. + * @return session id + */ + @Override + public String getId() { + return this.id; + } + + /** + * Called upon successful initialization sequence between the client and the server + * with the client capabilities and information. + * + * Initialization + * Spec + * @param clientCapabilities the capabilities the connected client provides + * @param clientInfo the information about the connected client + */ + public void init(McpSchema.ClientCapabilities clientCapabilities, McpSchema.Implementation clientInfo) { + this.clientCapabilities.lazySet(clientCapabilities); + this.clientInfo.lazySet(clientInfo); + } + + private String generateRequestId() { + return this.id + "-" + this.requestCounter.getAndIncrement(); + } + + @Override + public Mono sendRequest(String method, Object requestParams, TypeReference typeRef) { + String requestId = this.generateRequestId(); + + return Mono.create(sink -> { + this.pendingResponses.put(requestId, sink); + McpSchema.JSONRPCRequest jsonrpcRequest = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, method, + requestId, requestParams); + this.transport.sendMessage(jsonrpcRequest).subscribe(v -> { + }, error -> { + this.pendingResponses.remove(requestId); + sink.error(error); + }); + }).timeout(requestTimeout).handle((jsonRpcResponse, sink) -> { + if (jsonRpcResponse.error() != null) { + sink.error(new McpError(jsonRpcResponse.error())); + } + else { + if (typeRef.getType().equals(Void.class)) { + sink.complete(); + } + else { + sink.next(this.transport.unmarshalFrom(jsonRpcResponse.result(), typeRef)); + } + } + }); + } + + @Override + public Mono sendNotification(String method, Object params) { + McpSchema.JSONRPCNotification jsonrpcNotification = new McpSchema.JSONRPCNotification(McpSchema.JSONRPC_VERSION, + method, params); + return this.transport.sendMessage(jsonrpcNotification); + } + + /** + * Called by the {@link McpServerTransportProvider} once the session is determined. + * The purpose of this method is to dispatch the message to an appropriate handler as + * specified by the MCP server implementation + * ({@link io.modelcontextprotocol.server.McpAsyncServer} or + * {@link io.modelcontextprotocol.server.McpSyncServer}) via + * {@link McpServerSession.Factory} that the server creates. + * @param message the incoming JSON-RPC message + * @return a Mono that completes when the message is processed + */ + @Override + public Mono handle(McpSchema.JSONRPCMessage message) { + return Mono.defer(() -> { + // TODO handle errors for communication to without initialization happening + // first + if (message instanceof McpSchema.JSONRPCResponse response) { + logger.debug("Received Response: {}", response); + var sink = pendingResponses.remove(response.id()); + if (sink == null) { + logger.warn("Unexpected response for unknown id {}", response.id()); + } + else { + sink.success(response); + } + return Mono.empty(); + } + else if (message instanceof McpSchema.JSONRPCRequest request) { + logger.debug("Received request: {}", request); + return handleIncomingRequest(request).onErrorResume(error -> { + var errorResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), null, + new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, + error.getMessage(), null)); + // TODO: Should the error go to SSE or back as POST return? + return this.transport.sendMessage(errorResponse).then(Mono.empty()); + }).flatMap(this.transport::sendMessage); + } + else if (message instanceof McpSchema.JSONRPCNotification notification) { + // TODO handle errors for communication to without initialization + // happening first + logger.debug("Received notification: {}", notification); + // TODO: in case of error, should the POST request be signalled? + return handleIncomingNotification(notification) + .doOnError(error -> logger.error("Error handling notification: {}", error.getMessage())); + } + else { + logger.warn("Received unknown message type: {}", message); + return Mono.empty(); + } + }); + } + + /** + * Handles an incoming JSON-RPC request by routing it to the appropriate handler. + * @param request The incoming JSON-RPC request + * @return A Mono containing the JSON-RPC response + */ + private Mono handleIncomingRequest(McpSchema.JSONRPCRequest request) { + return Mono.defer(() -> { + Mono resultMono; + if (McpSchema.METHOD_INITIALIZE.equals(request.method())) { + // TODO handle situation where already initialized! + McpSchema.InitializeRequest initializeRequest = transport.unmarshalFrom(request.params(), + new TypeReference() { + }); + + this.state.lazySet(STATE_INITIALIZING); + this.init(initializeRequest.capabilities(), initializeRequest.clientInfo()); + resultMono = this.initRequestHandler.handle(initializeRequest); + } + else { + // TODO handle errors for communication to this session without + // initialization happening first + var handler = this.requestHandlers.get(request.method()); + if (handler == null) { + MethodNotFoundError error = getMethodNotFoundError(request.method()); + return Mono.just(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), null, + new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.METHOD_NOT_FOUND, + error.message(), error.data()))); + } + + resultMono = this.exchangeSink.asMono().flatMap(exchange -> handler.handle(exchange, request.params())); + } + return resultMono + .map(result -> new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), result, null)) + .onErrorResume(error -> Mono.just(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), + null, new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, + error.getMessage(), null)))); // TODO: add error message + // through the data field + }); + } + + /** + * Handles an incoming JSON-RPC notification by routing it to the appropriate handler. + * @param notification The incoming JSON-RPC notification + * @return A Mono that completes when the notification is processed + */ + private Mono handleIncomingNotification(McpSchema.JSONRPCNotification notification) { + return Mono.defer(() -> { + if (McpSchema.METHOD_NOTIFICATION_INITIALIZED.equals(notification.method())) { + this.state.lazySet(STATE_INITIALIZED); + exchangeSink.tryEmitValue(new McpAsyncServerExchange(this, clientCapabilities.get(), clientInfo.get())); + return this.initNotificationHandler.handle(); + } + + var handler = notificationHandlers.get(notification.method()); + if (handler == null) { + logger.error("No handler registered for notification method: {}", notification.method()); + return Mono.empty(); + } + return this.exchangeSink.asMono().flatMap(exchange -> handler.handle(exchange, notification.params())); + }); + } + + record MethodNotFoundError(String method, String message, Object data) { + } + + private MethodNotFoundError getMethodNotFoundError(String method) { + return new MethodNotFoundError(method, "Method not found: " + method, null); + } + + @Override + public Mono closeGracefully() { + return this.transport.closeGracefully(); + } + + @Override + public void close() { + this.transport.close(); + } +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/MockMcpServerTransportProvider.java b/mcp/src/test/java/io/modelcontextprotocol/MockMcpServerTransportProvider.java index 20a8c0cf..7ba35bbf 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/MockMcpServerTransportProvider.java +++ b/mcp/src/test/java/io/modelcontextprotocol/MockMcpServerTransportProvider.java @@ -15,8 +15,6 @@ */ package io.modelcontextprotocol; -import java.util.Map; - import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpServerSession; import io.modelcontextprotocol.spec.McpServerSession.Factory; diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/StdioServerTransportProviderTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/transport/StdioServerTransportProviderTests.java index 14987b5a..854fdeff 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/transport/StdioServerTransportProviderTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/transport/StdioServerTransportProviderTests.java @@ -4,6 +4,17 @@ package io.modelcontextprotocol.server.transport; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpServerSession; +import io.modelcontextprotocol.spec.McpServerTransport; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.InputStream; @@ -13,12 +24,6 @@ import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; - -import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.spec.McpError; -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpServerSession; -import io.modelcontextprotocol.spec.McpServerTransport; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Disabled; @@ -26,12 +31,6 @@ import reactor.core.publisher.Mono; import reactor.test.StepVerifier; -import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - /** * Tests for {@link StdioServerTransportProvider}. * @@ -52,9 +51,9 @@ class StdioServerTransportProviderTests { private ObjectMapper objectMapper; - private McpServerSession.Factory sessionFactory; + private McpServerSession.Factory sessionFactory; - private McpServerSession mockSession; + private McpServerSession mockSession; @BeforeEach void setUp() { @@ -66,9 +65,9 @@ void setUp() { objectMapper = new ObjectMapper(); - // Create mocks for session factory and session - mockSession = mock(McpServerSession.class); - sessionFactory = mock(McpServerSession.Factory.class); + // Create mocks for session factory and session + mockSession = mock(McpServerSession.class); + sessionFactory = mock(McpServerSession.Factory.class); // Configure mock behavior when(sessionFactory.create(any(McpServerTransport.class))).thenReturn(mockSession); @@ -110,16 +109,19 @@ void shouldHandleIncomingMessages() throws Exception { AtomicReference capturedMessage = new AtomicReference<>(); CountDownLatch messageLatch = new CountDownLatch(1); - McpServerSession.Factory realSessionFactory = transport -> { - McpServerSession session = mock(McpServerSession.class); - when(session.handle(any())).thenAnswer(invocation -> { - capturedMessage.set(invocation.getArgument(0)); - messageLatch.countDown(); - return Mono.empty(); - }); - when(session.closeGracefully()).thenReturn(Mono.empty()); - return session; - }; + McpServerSession.Factory realSessionFactory = + transport -> { + McpServerSession session = mock(McpServerSession.class); + when(session.handle(any())) + .thenAnswer( + invocation -> { + capturedMessage.set(invocation.getArgument(0)); + messageLatch.countDown(); + return Mono.empty(); + }); + when(session.closeGracefully()).thenReturn(Mono.empty()); + return session; + }; // Set session factory transportProvider.setSessionFactory(realSessionFactory); From 7dbfaa43d592897c470c0fb2fb85bc09bf955ea9 Mon Sep 17 00:00:00 2001 From: Alberto Pontini Date: Wed, 21 May 2025 16:21:05 +0200 Subject: [PATCH 2/4] Fix formatting and one interface reference --- .../StdioServerTransportProvider.java | 2 +- .../spec/McpServerSession.java | 163 +++++++++--------- .../spec/McpServerSessionImpl.java | 1 + .../StdioServerTransportProviderTests.java | 33 ++-- 4 files changed, 99 insertions(+), 100 deletions(-) diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java index dc4c7d4d..2c84d83f 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java @@ -48,7 +48,7 @@ public class StdioServerTransportProvider implements McpServerTransportProvider private final OutputStream outputStream; - private McpSession session; + private McpServerSession session; private final AtomicBoolean isClosing = new AtomicBoolean(false); diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java index 4ab73be5..cfaebe7f 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java @@ -5,85 +5,86 @@ public interface McpServerSession extends McpSession { - String getId(); - - Mono handle(McpSchema.JSONRPCMessage message); - - /** - * Request handler for the initialization request. - */ - interface InitRequestHandler { - - /** - * Handles the initialization request. - * @param initializeRequest the initialization request by the client - * @return a Mono that will emit the result of the initialization - */ - Mono handle(McpSchema.InitializeRequest initializeRequest); - - } - - /** - * Notification handler for the initialization notification from the client. - */ - interface InitNotificationHandler { - - /** - * Specifies an action to take upon successful initialization. - * @return a Mono that will complete when the initialization is acted upon. - */ - Mono handle(); - - } - - /** - * A handler for client-initiated notifications. - */ - interface NotificationHandler { - - /** - * Handles a notification from the client. - * @param exchange the exchange associated with the client that allows calling - * back to the connected client or inspecting its capabilities. - * @param params the parameters of the notification. - * @return a Mono that completes once the notification is handled. - */ - Mono handle(McpAsyncServerExchange exchange, Object params); - - } - - /** - * A handler for client-initiated requests. - * - * @param the type of the response that is expected as a result of handling the - * request. - */ - interface RequestHandler { - - /** - * Handles a request from the client. - * @param exchange the exchange associated with the client that allows calling - * back to the connected client or inspecting its capabilities. - * @param params the parameters of the request. - * @return a Mono that will emit the response to the request. - */ - Mono handle(McpAsyncServerExchange exchange, Object params); - - } - - /** - * Factory for creating server sessions which delegate to a provided 1:1 transport - * with a connected client. - */ - @FunctionalInterface - interface Factory { - - /** - * Creates a new 1:1 representation of the client-server interaction. - * @param sessionTransport the transport to use for communication with the client. - * @return a new server session. - */ - McpServerSession create(McpServerTransport sessionTransport); - - } + String getId(); + + Mono handle(McpSchema.JSONRPCMessage message); + + /** + * Request handler for the initialization request. + */ + interface InitRequestHandler { + + /** + * Handles the initialization request. + * @param initializeRequest the initialization request by the client + * @return a Mono that will emit the result of the initialization + */ + Mono handle(McpSchema.InitializeRequest initializeRequest); + + } + + /** + * Notification handler for the initialization notification from the client. + */ + interface InitNotificationHandler { + + /** + * Specifies an action to take upon successful initialization. + * @return a Mono that will complete when the initialization is acted upon. + */ + Mono handle(); + + } + + /** + * A handler for client-initiated notifications. + */ + interface NotificationHandler { + + /** + * Handles a notification from the client. + * @param exchange the exchange associated with the client that allows calling + * back to the connected client or inspecting its capabilities. + * @param params the parameters of the notification. + * @return a Mono that completes once the notification is handled. + */ + Mono handle(McpAsyncServerExchange exchange, Object params); + + } + + /** + * A handler for client-initiated requests. + * + * @param the type of the response that is expected as a result of handling the + * request. + */ + interface RequestHandler { + + /** + * Handles a request from the client. + * @param exchange the exchange associated with the client that allows calling + * back to the connected client or inspecting its capabilities. + * @param params the parameters of the request. + * @return a Mono that will emit the response to the request. + */ + Mono handle(McpAsyncServerExchange exchange, Object params); + + } + + /** + * Factory for creating server sessions which delegate to a provided 1:1 transport + * with a connected client. + */ + @FunctionalInterface + interface Factory { + + /** + * Creates a new 1:1 representation of the client-server interaction. + * @param sessionTransport the transport to use for communication with the client. + * @return a new server session. + */ + McpServerSession create(McpServerTransport sessionTransport); + + } + } diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSessionImpl.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSessionImpl.java index ad22d863..dea6b79f 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSessionImpl.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSessionImpl.java @@ -273,4 +273,5 @@ public Mono closeGracefully() { public void close() { this.transport.close(); } + } diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/StdioServerTransportProviderTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/transport/StdioServerTransportProviderTests.java index 854fdeff..5be18754 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/transport/StdioServerTransportProviderTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/transport/StdioServerTransportProviderTests.java @@ -51,9 +51,9 @@ class StdioServerTransportProviderTests { private ObjectMapper objectMapper; - private McpServerSession.Factory sessionFactory; + private McpServerSession.Factory sessionFactory; - private McpServerSession mockSession; + private McpServerSession mockSession; @BeforeEach void setUp() { @@ -65,9 +65,9 @@ void setUp() { objectMapper = new ObjectMapper(); - // Create mocks for session factory and session - mockSession = mock(McpServerSession.class); - sessionFactory = mock(McpServerSession.Factory.class); + // Create mocks for session factory and session + mockSession = mock(McpServerSession.class); + sessionFactory = mock(McpServerSession.Factory.class); // Configure mock behavior when(sessionFactory.create(any(McpServerTransport.class))).thenReturn(mockSession); @@ -109,19 +109,16 @@ void shouldHandleIncomingMessages() throws Exception { AtomicReference capturedMessage = new AtomicReference<>(); CountDownLatch messageLatch = new CountDownLatch(1); - McpServerSession.Factory realSessionFactory = - transport -> { - McpServerSession session = mock(McpServerSession.class); - when(session.handle(any())) - .thenAnswer( - invocation -> { - capturedMessage.set(invocation.getArgument(0)); - messageLatch.countDown(); - return Mono.empty(); - }); - when(session.closeGracefully()).thenReturn(Mono.empty()); - return session; - }; + McpServerSession.Factory realSessionFactory = transport -> { + McpServerSession session = mock(McpServerSession.class); + when(session.handle(any())).thenAnswer(invocation -> { + capturedMessage.set(invocation.getArgument(0)); + messageLatch.countDown(); + return Mono.empty(); + }); + when(session.closeGracefully()).thenReturn(Mono.empty()); + return session; + }; // Set session factory transportProvider.setSessionFactory(realSessionFactory); From 1e52334083ee8baa7d24cbbb6574455877fc3271 Mon Sep 17 00:00:00 2001 From: Alberto Pontini Date: Thu, 22 May 2025 07:45:15 +0200 Subject: [PATCH 3/4] Fix changes in imports order --- .../WebFluxSseServerTransportProvider.java | 14 ++++++----- .../WebMvcSseServerTransportProvider.java | 18 ++++++++------- ...HttpServletSseServerTransportProvider.java | 15 ++++++------ .../StdioServerTransportProvider.java | 20 ++++++++-------- .../StdioServerTransportProviderTests.java | 23 ++++++++++--------- 5 files changed, 48 insertions(+), 42 deletions(-) diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java index 03dc65aa..fde067f0 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java @@ -1,5 +1,8 @@ package io.modelcontextprotocol.server.transport; +import java.io.IOException; +import java.util.concurrent.ConcurrentHashMap; + import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; import io.modelcontextprotocol.spec.McpError; @@ -8,10 +11,13 @@ import io.modelcontextprotocol.spec.McpServerTransport; import io.modelcontextprotocol.spec.McpServerTransportProvider; import io.modelcontextprotocol.util.Assert; -import java.io.IOException; -import java.util.concurrent.ConcurrentHashMap; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.Exceptions; +import reactor.core.publisher.Flux; +import reactor.core.publisher.FluxSink; +import reactor.core.publisher.Mono; + import org.springframework.http.HttpStatus; import org.springframework.http.MediaType; import org.springframework.http.codec.ServerSentEvent; @@ -19,10 +25,6 @@ import org.springframework.web.reactive.function.server.RouterFunctions; import org.springframework.web.reactive.function.server.ServerRequest; import org.springframework.web.reactive.function.server.ServerResponse; -import reactor.core.Exceptions; -import reactor.core.publisher.Flux; -import reactor.core.publisher.FluxSink; -import reactor.core.publisher.Mono; /** * Server-side implementation of the MCP (Model Context Protocol) HTTP transport using diff --git a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java index 17cc9f84..da76a3a1 100644 --- a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java @@ -4,28 +4,30 @@ package io.modelcontextprotocol.server.transport; +import java.io.IOException; +import java.time.Duration; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; + import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpServerSession; import io.modelcontextprotocol.spec.McpServerTransport; import io.modelcontextprotocol.spec.McpServerTransportProvider; +import io.modelcontextprotocol.spec.McpServerSession; import io.modelcontextprotocol.util.Assert; -import java.io.IOException; -import java.time.Duration; -import java.util.UUID; -import java.util.concurrent.ConcurrentHashMap; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + import org.springframework.http.HttpStatus; import org.springframework.web.servlet.function.RouterFunction; import org.springframework.web.servlet.function.RouterFunctions; import org.springframework.web.servlet.function.ServerRequest; import org.springframework.web.servlet.function.ServerResponse; import org.springframework.web.servlet.function.ServerResponse.SseBuilder; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; /** * Server-side implementation of the Model Context Protocol (MCP) transport layer using @@ -297,7 +299,7 @@ private ServerResponse handleMessage(ServerRequest request) { return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).body("Server is shutting down"); } - if (request.param("sessionId").isEmpty()) { + if (!request.param("sessionId").isPresent()) { return ServerResponse.badRequest().body(new McpError("Session ID missing in message endpoint")); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java index 1c527e42..afdbff47 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java @@ -3,6 +3,14 @@ */ package io.modelcontextprotocol.server.transport; +import java.io.BufferedReader; +import java.io.IOException; +import java.io.PrintWriter; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicBoolean; + import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; import io.modelcontextprotocol.spec.McpError; @@ -17,13 +25,6 @@ import jakarta.servlet.http.HttpServlet; import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletResponse; -import java.io.BufferedReader; -import java.io.IOException; -import java.io.PrintWriter; -import java.util.Map; -import java.util.UUID; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.atomic.AtomicBoolean; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import reactor.core.publisher.Flux; diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java index 2c84d83f..9ef9c782 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java @@ -4,16 +4,6 @@ package io.modelcontextprotocol.server.transport; -import com.fasterxml.jackson.core.type.TypeReference; -import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.spec.McpError; -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage; -import io.modelcontextprotocol.spec.McpServerSession; -import io.modelcontextprotocol.spec.McpServerTransport; -import io.modelcontextprotocol.spec.McpServerTransportProvider; -import io.modelcontextprotocol.spec.McpSession; -import io.modelcontextprotocol.util.Assert; import java.io.BufferedReader; import java.io.IOException; import java.io.InputStream; @@ -23,6 +13,16 @@ import java.util.concurrent.Executors; import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Function; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage; +import io.modelcontextprotocol.spec.McpServerSession; +import io.modelcontextprotocol.spec.McpServerTransport; +import io.modelcontextprotocol.spec.McpServerTransportProvider; +import io.modelcontextprotocol.util.Assert; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import reactor.core.publisher.Flux; diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/StdioServerTransportProviderTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/transport/StdioServerTransportProviderTests.java index 5be18754..14987b5a 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/transport/StdioServerTransportProviderTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/transport/StdioServerTransportProviderTests.java @@ -4,17 +4,6 @@ package io.modelcontextprotocol.server.transport; -import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.spec.McpError; -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpServerSession; -import io.modelcontextprotocol.spec.McpServerTransport; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.InputStream; @@ -24,6 +13,12 @@ import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpServerSession; +import io.modelcontextprotocol.spec.McpServerTransport; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Disabled; @@ -31,6 +26,12 @@ import reactor.core.publisher.Mono; import reactor.test.StepVerifier; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + /** * Tests for {@link StdioServerTransportProvider}. * From 5027807cc156bb676acd89d564f13742ac3d8744 Mon Sep 17 00:00:00 2001 From: Alberto Pontini Date: Thu, 22 May 2025 07:47:38 +0200 Subject: [PATCH 4/4] Revert accidental if change --- .../server/transport/WebMvcSseServerTransportProvider.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java index da76a3a1..114eff60 100644 --- a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java @@ -299,7 +299,7 @@ private ServerResponse handleMessage(ServerRequest request) { return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).body("Server is shutting down"); } - if (!request.param("sessionId").isPresent()) { + if (request.param("sessionId").isEmpty()) { return ServerResponse.badRequest().body(new McpError("Session ID missing in message endpoint")); }