From 8ef320a6b48bc26fdb9f1a530234cf56c155cef7 Mon Sep 17 00:00:00 2001 From: mipengcheng3 Date: Tue, 8 Apr 2025 17:53:59 +0800 Subject: [PATCH 1/2] Add requestTimeout to MCP server --- .../WebFluxSseIntegrationTests.java | 149 ++++++++++++++++++ .../server/WebMvcSseIntegrationTests.java | 145 +++++++++++++++++ .../server/McpAsyncServer.java | 13 +- .../server/McpServer.java | 39 ++++- .../spec/McpServerSession.java | 12 +- ...rverTransportProviderIntegrationTests.java | 145 +++++++++++++++++ 6 files changed, 491 insertions(+), 12 deletions(-) diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java index dbfad821..f3c155a5 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java @@ -7,6 +7,7 @@ import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; @@ -45,6 +46,7 @@ import org.springframework.web.reactive.function.server.RouterFunctions; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.awaitility.Awaitility.await; import static org.mockito.Mockito.mock; @@ -195,6 +197,153 @@ void testCreateMessageSuccess(String clientType) throws InterruptedException { mcpServer.close(); } + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testCreateMessageWithRequestTimeoutSuccess(String clientType) throws InterruptedException { + + // Client + var clientBuilder = clientBuilders.get(clientType); + + Function samplingHandler = request -> { + assertThat(request.messages()).hasSize(1); + assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); + try { + TimeUnit.SECONDS.sleep(2); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", + CreateMessageResult.StopReason.STOP_SEQUENCE); + }; + + var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().sampling().build()) + .sampling(samplingHandler) + .build(); + + // Server + + CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), + null); + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + var craeteMessageRequest = McpSchema.CreateMessageRequest.builder() + .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, + new McpSchema.TextContent("Test message")))) + .modelPreferences(ModelPreferences.builder() + .hints(List.of()) + .costPriority(1.0) + .speedPriority(1.0) + .intelligencePriority(1.0) + .build()) + .build(); + + StepVerifier.create(exchange.createMessage(craeteMessageRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.role()).isEqualTo(Role.USER); + assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); + assertThat(result.model()).isEqualTo("MockModelName"); + assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); + }).verifyComplete(); + + return Mono.just(callResponse); + }); + + var mcpServer = McpServer.async(mcpServerTransportProvider) + .requestTimeout(Duration.ofSeconds(4)) + .serverInfo("test-server", "1.0.0") + .tools(tool) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + + mcpClient.close(); + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testCreateMessageWithRequestTimeoutFail(String clientType) throws InterruptedException { + + // Client + var clientBuilder = clientBuilders.get(clientType); + + Function samplingHandler = request -> { + assertThat(request.messages()).hasSize(1); + assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); + try { + TimeUnit.SECONDS.sleep(3); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", + CreateMessageResult.StopReason.STOP_SEQUENCE); + }; + + var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().sampling().build()) + .sampling(samplingHandler) + .build(); + + // Server + + CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), + null); + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + var craeteMessageRequest = McpSchema.CreateMessageRequest.builder() + .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, + new McpSchema.TextContent("Test message")))) + .modelPreferences(ModelPreferences.builder() + .hints(List.of()) + .costPriority(1.0) + .speedPriority(1.0) + .intelligencePriority(1.0) + .build()) + .build(); + + StepVerifier.create(exchange.createMessage(craeteMessageRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.role()).isEqualTo(Role.USER); + assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); + assertThat(result.model()).isEqualTo("MockModelName"); + assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); + }).verifyComplete(); + + return Mono.just(callResponse); + }); + + var mcpServer = McpServer.async(mcpServerTransportProvider) + .requestTimeout(Duration.ofSeconds(1)) + .serverInfo("test-server", "1.0.0") + .tools(tool) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThatExceptionOfType(McpError.class) + .isThrownBy(() -> { + mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + }) + .withMessageContaining("Timeout"); + + mcpClient.close(); + mcpServer.close(); + } + // --------------------------------------- // Roots Tests // --------------------------------------- diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java index f9190fd7..7aa76abd 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java @@ -6,6 +6,7 @@ import java.time.Duration; import java.util.List; import java.util.Map; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; @@ -41,6 +42,7 @@ import org.springframework.web.servlet.function.ServerResponse; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.awaitility.Awaitility.await; import static org.mockito.Mockito.mock; @@ -207,6 +209,149 @@ void testCreateMessageSuccess() throws InterruptedException { mcpServer.close(); } + @Test + void testCreateMessageWithRequestTimeoutSuccess() throws InterruptedException { + + // Client + + Function samplingHandler = request -> { + assertThat(request.messages()).hasSize(1); + assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); + try { + TimeUnit.SECONDS.sleep(2); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", + CreateMessageResult.StopReason.STOP_SEQUENCE); + }; + + var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().sampling().build()) + .sampling(samplingHandler) + .build(); + + // Server + + CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), + null); + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + var craeteMessageRequest = McpSchema.CreateMessageRequest.builder() + .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, + new McpSchema.TextContent("Test message")))) + .modelPreferences(ModelPreferences.builder() + .hints(List.of()) + .costPriority(1.0) + .speedPriority(1.0) + .intelligencePriority(1.0) + .build()) + .build(); + + StepVerifier.create(exchange.createMessage(craeteMessageRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.role()).isEqualTo(Role.USER); + assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); + assertThat(result.model()).isEqualTo("MockModelName"); + assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); + }).verifyComplete(); + + return Mono.just(callResponse); + }); + + var mcpServer = McpServer.async(mcpServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .requestTimeout(Duration.ofSeconds(4)) + .tools(tool) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + + mcpClient.close(); + mcpServer.close(); + } + + @Test + void testCreateMessageWithRequestTimeoutFail() throws InterruptedException { + + // Client + + Function samplingHandler = request -> { + assertThat(request.messages()).hasSize(1); + assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); + try { + TimeUnit.SECONDS.sleep(2); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", + CreateMessageResult.StopReason.STOP_SEQUENCE); + }; + + var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().sampling().build()) + .sampling(samplingHandler) + .build(); + + // Server + + CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), + null); + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + var craeteMessageRequest = McpSchema.CreateMessageRequest.builder() + .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, + new McpSchema.TextContent("Test message")))) + .modelPreferences(ModelPreferences.builder() + .hints(List.of()) + .costPriority(1.0) + .speedPriority(1.0) + .intelligencePriority(1.0) + .build()) + .build(); + + StepVerifier.create(exchange.createMessage(craeteMessageRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.role()).isEqualTo(Role.USER); + assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); + assertThat(result.model()).isEqualTo("MockModelName"); + assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); + }).verifyComplete(); + + return Mono.just(callResponse); + }); + + var mcpServer = McpServer.async(mcpServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .requestTimeout(Duration.ofSeconds(1)) + .tools(tool) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThatExceptionOfType(McpError.class) + .isThrownBy(() -> { + mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + }) + .withMessageContaining("Timeout"); + + mcpClient.close(); + mcpServer.close(); + } + // --------------------------------------- // Roots Tests // --------------------------------------- diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java index df938668..3a520381 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java @@ -4,6 +4,7 @@ package io.modelcontextprotocol.server; +import java.time.Duration; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -89,8 +90,8 @@ public class McpAsyncServer { * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization */ McpAsyncServer(McpServerTransportProvider mcpTransportProvider, ObjectMapper objectMapper, - McpServerFeatures.Async features) { - this.delegate = new AsyncServerImpl(mcpTransportProvider, objectMapper, features); + McpServerFeatures.Async features, Duration requestTimeout) { + this.delegate = new AsyncServerImpl(mcpTransportProvider, objectMapper, requestTimeout, features); } /** @@ -262,7 +263,7 @@ private static class AsyncServerImpl extends McpAsyncServer { private List protocolVersions = List.of(McpSchema.LATEST_PROTOCOL_VERSION); AsyncServerImpl(McpServerTransportProvider mcpTransportProvider, ObjectMapper objectMapper, - McpServerFeatures.Async features) { + Duration requestTimeout, McpServerFeatures.Async features) { this.mcpTransportProvider = mcpTransportProvider; this.objectMapper = objectMapper; this.serverInfo = features.serverInfo(); @@ -321,9 +322,9 @@ private static class AsyncServerImpl extends McpAsyncServer { notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_ROOTS_LIST_CHANGED, asyncRootsListChangedNotificationHandler(rootsChangeConsumers)); - mcpTransportProvider - .setSessionFactory(transport -> new McpServerSession(UUID.randomUUID().toString(), transport, - this::asyncInitializeRequestHandler, Mono::empty, requestHandlers, notificationHandlers)); + mcpTransportProvider.setSessionFactory( + transport -> new McpServerSession(UUID.randomUUID().toString(), requestTimeout, transport, + this::asyncInitializeRequestHandler, Mono::empty, requestHandlers, notificationHandlers)); } // --------------------------------------- diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java index d5427335..60434a84 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java @@ -4,6 +4,7 @@ package io.modelcontextprotocol.server; +import java.time.Duration; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; @@ -193,11 +194,28 @@ class AsyncSpecification { private final List, Mono>> rootsChangeHandlers = new ArrayList<>(); + private Duration requestTimeout = Duration.ofSeconds(10); // Default timeout + private AsyncSpecification(McpServerTransportProvider transportProvider) { Assert.notNull(transportProvider, "Transport provider must not be null"); this.transportProvider = transportProvider; } + /** + * Sets the duration to wait for server responses before timing out requests. This + * timeout applies to all requests made through the client, including tool calls, + * resource access, and prompt operations. + * @param requestTimeout The duration to wait before timing out requests. Must not + * be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if requestTimeout is null + */ + public AsyncSpecification requestTimeout(Duration requestTimeout) { + Assert.notNull(requestTimeout, "Request timeout must not be null"); + this.requestTimeout = requestTimeout; + return this; + } + /** * Sets the server implementation information that will be shared with clients * during connection initialization. This helps with version compatibility, @@ -565,7 +583,7 @@ public McpAsyncServer build() { var features = new McpServerFeatures.Async(this.serverInfo, this.serverCapabilities, this.tools, this.resources, this.resourceTemplates, this.prompts, this.rootsChangeHandlers, this.instructions); var mapper = this.objectMapper != null ? this.objectMapper : new ObjectMapper(); - return new McpAsyncServer(this.transportProvider, mapper, features); + return new McpAsyncServer(this.transportProvider, mapper, features, this.requestTimeout); } } @@ -619,11 +637,28 @@ class SyncSpecification { private final List>> rootsChangeHandlers = new ArrayList<>(); + private Duration requestTimeout = Duration.ofSeconds(10); // Default timeout + private SyncSpecification(McpServerTransportProvider transportProvider) { Assert.notNull(transportProvider, "Transport provider must not be null"); this.transportProvider = transportProvider; } + /** + * Sets the duration to wait for server responses before timing out requests. This + * timeout applies to all requests made through the client, including tool calls, + * resource access, and prompt operations. + * @param requestTimeout The duration to wait before timing out requests. Must not + * be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if requestTimeout is null + */ + public SyncSpecification requestTimeout(Duration requestTimeout) { + Assert.notNull(requestTimeout, "Request timeout must not be null"); + this.requestTimeout = requestTimeout; + return this; + } + /** * Sets the server implementation information that will be shared with clients * during connection initialization. This helps with version compatibility, @@ -992,7 +1027,7 @@ public McpSyncServer build() { this.instructions); McpServerFeatures.Async asyncFeatures = McpServerFeatures.Async.fromSync(syncFeatures); var mapper = this.objectMapper != null ? this.objectMapper : new ObjectMapper(); - var asyncServer = new McpAsyncServer(this.transportProvider, mapper, asyncFeatures); + var asyncServer = new McpAsyncServer(this.transportProvider, mapper, asyncFeatures, this.requestTimeout); return new McpSyncServer(asyncServer); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java index bcdf2248..0bc7261a 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java @@ -27,6 +27,9 @@ public class McpServerSession implements McpSession { private final String id; + /** Duration to wait for request responses before timing out */ + private final Duration requestTimeout; + private final AtomicLong requestCounter = new AtomicLong(0); private final InitRequestHandler initRequestHandler; @@ -65,10 +68,11 @@ public class McpServerSession implements McpSession { * @param requestHandlers map of request handlers to use * @param notificationHandlers map of notification handlers to use */ - public McpServerSession(String id, McpServerTransport transport, InitRequestHandler initHandler, - InitNotificationHandler initNotificationHandler, Map> requestHandlers, - Map notificationHandlers) { + public McpServerSession(String id, Duration requestTimeout, McpServerTransport transport, + InitRequestHandler initHandler, InitNotificationHandler initNotificationHandler, + Map> requestHandlers, Map notificationHandlers) { this.id = id; + this.requestTimeout = requestTimeout; this.transport = transport; this.initRequestHandler = initHandler; this.initNotificationHandler = initNotificationHandler; @@ -116,7 +120,7 @@ public Mono sendRequest(String method, Object requestParams, TypeReferenc this.pendingResponses.remove(requestId); sink.error(error); }); - }).timeout(Duration.ofSeconds(10)).handle((jsonRpcResponse, sink) -> { + }).timeout(requestTimeout).handle((jsonRpcResponse, sink) -> { if (jsonRpcResponse.error() != null) { sink.error(new McpError(jsonRpcResponse.error())); } diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java index b04940c7..9306cbb2 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java @@ -6,6 +6,7 @@ import java.time.Duration; import java.util.List; import java.util.Map; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; @@ -39,6 +40,7 @@ import org.springframework.web.client.RestClient; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.awaitility.Awaitility.await; import static org.mockito.Mockito.mock; @@ -192,6 +194,149 @@ void testCreateMessageSuccess() throws InterruptedException { mcpServer.close(); } + @Test + void testCreateMessageWithRequestTimeoutSuccess() throws InterruptedException { + + // Client + + Function samplingHandler = request -> { + assertThat(request.messages()).hasSize(1); + assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); + try { + TimeUnit.SECONDS.sleep(2); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", + CreateMessageResult.StopReason.STOP_SEQUENCE); + }; + + var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().sampling().build()) + .sampling(samplingHandler) + .build(); + + // Server + + CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), + null); + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + var craeteMessageRequest = McpSchema.CreateMessageRequest.builder() + .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, + new McpSchema.TextContent("Test message")))) + .modelPreferences(ModelPreferences.builder() + .hints(List.of()) + .costPriority(1.0) + .speedPriority(1.0) + .intelligencePriority(1.0) + .build()) + .build(); + + StepVerifier.create(exchange.createMessage(craeteMessageRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.role()).isEqualTo(Role.USER); + assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); + assertThat(result.model()).isEqualTo("MockModelName"); + assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); + }).verifyComplete(); + + return Mono.just(callResponse); + }); + + var mcpServer = McpServer.async(mcpServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .requestTimeout(Duration.ofSeconds(3)) + .tools(tool) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + + mcpClient.close(); + mcpServer.close(); + } + + @Test + void testCreateMessageWithRequestTimeoutFail() throws InterruptedException { + + // Client + + Function samplingHandler = request -> { + assertThat(request.messages()).hasSize(1); + assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); + try { + TimeUnit.SECONDS.sleep(2); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", + CreateMessageResult.StopReason.STOP_SEQUENCE); + }; + + var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().sampling().build()) + .sampling(samplingHandler) + .build(); + + // Server + + CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), + null); + + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { + + var craeteMessageRequest = McpSchema.CreateMessageRequest.builder() + .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, + new McpSchema.TextContent("Test message")))) + .modelPreferences(ModelPreferences.builder() + .hints(List.of()) + .costPriority(1.0) + .speedPriority(1.0) + .intelligencePriority(1.0) + .build()) + .build(); + + StepVerifier.create(exchange.createMessage(craeteMessageRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.role()).isEqualTo(Role.USER); + assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); + assertThat(result.model()).isEqualTo("MockModelName"); + assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); + }).verifyComplete(); + + return Mono.just(callResponse); + }); + + var mcpServer = McpServer.async(mcpServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .requestTimeout(Duration.ofSeconds(1)) + .tools(tool) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThatExceptionOfType(McpError.class) + .isThrownBy(() -> { + mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + }) + .withMessageContaining("Timeout"); + + mcpClient.close(); + mcpServer.close(); + } + // --------------------------------------- // Roots Tests // --------------------------------------- From 0dfa700faa1d2664d3708cde86c0e7c88864e399 Mon Sep 17 00:00:00 2001 From: mipengcheng3 Date: Tue, 8 Apr 2025 17:55:02 +0800 Subject: [PATCH 2/2] Add requestTimeout to MCP server --- .../WebFluxSseIntegrationTests.java | 120 +++++++++--------- .../server/WebMvcSseIntegrationTests.java | 120 +++++++++--------- ...rverTransportProviderIntegrationTests.java | 120 +++++++++--------- 3 files changed, 180 insertions(+), 180 deletions(-) diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java index f3c155a5..67ac3588 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java @@ -207,19 +207,20 @@ void testCreateMessageWithRequestTimeoutSuccess(String clientType) throws Interr Function samplingHandler = request -> { assertThat(request.messages()).hasSize(1); assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); - try { - TimeUnit.SECONDS.sleep(2); - } catch (InterruptedException e) { - throw new RuntimeException(e); - } - return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", + try { + TimeUnit.SECONDS.sleep(2); + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", CreateMessageResult.StopReason.STOP_SEQUENCE); }; var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().sampling().build()) - .sampling(samplingHandler) - .build(); + .capabilities(ClientCapabilities.builder().sampling().build()) + .sampling(samplingHandler) + .build(); // Server @@ -229,34 +230,34 @@ void testCreateMessageWithRequestTimeoutSuccess(String clientType) throws Interr McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { - var craeteMessageRequest = McpSchema.CreateMessageRequest.builder() - .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, - new McpSchema.TextContent("Test message")))) - .modelPreferences(ModelPreferences.builder() + var craeteMessageRequest = McpSchema.CreateMessageRequest.builder() + .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, + new McpSchema.TextContent("Test message")))) + .modelPreferences(ModelPreferences.builder() .hints(List.of()) .costPriority(1.0) .speedPriority(1.0) .intelligencePriority(1.0) .build()) - .build(); - - StepVerifier.create(exchange.createMessage(craeteMessageRequest)).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.role()).isEqualTo(Role.USER); - assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); - assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); - assertThat(result.model()).isEqualTo("MockModelName"); - assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); - }).verifyComplete(); - - return Mono.just(callResponse); - }); + .build(); + + StepVerifier.create(exchange.createMessage(craeteMessageRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.role()).isEqualTo(Role.USER); + assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); + assertThat(result.model()).isEqualTo("MockModelName"); + assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); + }).verifyComplete(); + + return Mono.just(callResponse); + }); var mcpServer = McpServer.async(mcpServerTransportProvider) - .requestTimeout(Duration.ofSeconds(4)) - .serverInfo("test-server", "1.0.0") - .tools(tool) - .build(); + .requestTimeout(Duration.ofSeconds(4)) + .serverInfo("test-server", "1.0.0") + .tools(tool) + .build(); InitializeResult initResult = mcpClient.initialize(); assertThat(initResult).isNotNull(); @@ -282,7 +283,8 @@ void testCreateMessageWithRequestTimeoutFail(String clientType) throws Interrupt assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); try { TimeUnit.SECONDS.sleep(3); - } catch (InterruptedException e) { + } + catch (InterruptedException e) { throw new RuntimeException(e); } return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", @@ -290,9 +292,9 @@ void testCreateMessageWithRequestTimeoutFail(String clientType) throws Interrupt }; var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().sampling().build()) - .sampling(samplingHandler) - .build(); + .capabilities(ClientCapabilities.builder().sampling().build()) + .sampling(samplingHandler) + .build(); // Server @@ -302,43 +304,41 @@ void testCreateMessageWithRequestTimeoutFail(String clientType) throws Interrupt McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { - var craeteMessageRequest = McpSchema.CreateMessageRequest.builder() - .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, - new McpSchema.TextContent("Test message")))) - .modelPreferences(ModelPreferences.builder() + var craeteMessageRequest = McpSchema.CreateMessageRequest.builder() + .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, + new McpSchema.TextContent("Test message")))) + .modelPreferences(ModelPreferences.builder() .hints(List.of()) .costPriority(1.0) .speedPriority(1.0) .intelligencePriority(1.0) .build()) - .build(); - - StepVerifier.create(exchange.createMessage(craeteMessageRequest)).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.role()).isEqualTo(Role.USER); - assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); - assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); - assertThat(result.model()).isEqualTo("MockModelName"); - assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); - }).verifyComplete(); - - return Mono.just(callResponse); - }); + .build(); + + StepVerifier.create(exchange.createMessage(craeteMessageRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.role()).isEqualTo(Role.USER); + assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); + assertThat(result.model()).isEqualTo("MockModelName"); + assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); + }).verifyComplete(); + + return Mono.just(callResponse); + }); var mcpServer = McpServer.async(mcpServerTransportProvider) - .requestTimeout(Duration.ofSeconds(1)) - .serverInfo("test-server", "1.0.0") - .tools(tool) - .build(); + .requestTimeout(Duration.ofSeconds(1)) + .serverInfo("test-server", "1.0.0") + .tools(tool) + .build(); InitializeResult initResult = mcpClient.initialize(); assertThat(initResult).isNotNull(); - assertThatExceptionOfType(McpError.class) - .isThrownBy(() -> { - mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - }) - .withMessageContaining("Timeout"); + assertThatExceptionOfType(McpError.class).isThrownBy(() -> { + mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + }).withMessageContaining("Timeout"); mcpClient.close(); mcpServer.close(); diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java index 7aa76abd..ba9d4db7 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java @@ -217,19 +217,20 @@ void testCreateMessageWithRequestTimeoutSuccess() throws InterruptedException { Function samplingHandler = request -> { assertThat(request.messages()).hasSize(1); assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); - try { - TimeUnit.SECONDS.sleep(2); - } catch (InterruptedException e) { - throw new RuntimeException(e); - } - return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", + try { + TimeUnit.SECONDS.sleep(2); + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", CreateMessageResult.StopReason.STOP_SEQUENCE); }; var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().sampling().build()) - .sampling(samplingHandler) - .build(); + .capabilities(ClientCapabilities.builder().sampling().build()) + .sampling(samplingHandler) + .build(); // Server @@ -239,34 +240,34 @@ void testCreateMessageWithRequestTimeoutSuccess() throws InterruptedException { McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { - var craeteMessageRequest = McpSchema.CreateMessageRequest.builder() - .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, - new McpSchema.TextContent("Test message")))) - .modelPreferences(ModelPreferences.builder() + var craeteMessageRequest = McpSchema.CreateMessageRequest.builder() + .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, + new McpSchema.TextContent("Test message")))) + .modelPreferences(ModelPreferences.builder() .hints(List.of()) .costPriority(1.0) .speedPriority(1.0) .intelligencePriority(1.0) .build()) - .build(); - - StepVerifier.create(exchange.createMessage(craeteMessageRequest)).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.role()).isEqualTo(Role.USER); - assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); - assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); - assertThat(result.model()).isEqualTo("MockModelName"); - assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); - }).verifyComplete(); - - return Mono.just(callResponse); - }); + .build(); + + StepVerifier.create(exchange.createMessage(craeteMessageRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.role()).isEqualTo(Role.USER); + assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); + assertThat(result.model()).isEqualTo("MockModelName"); + assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); + }).verifyComplete(); + + return Mono.just(callResponse); + }); var mcpServer = McpServer.async(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .requestTimeout(Duration.ofSeconds(4)) - .tools(tool) - .build(); + .serverInfo("test-server", "1.0.0") + .requestTimeout(Duration.ofSeconds(4)) + .tools(tool) + .build(); InitializeResult initResult = mcpClient.initialize(); assertThat(initResult).isNotNull(); @@ -290,7 +291,8 @@ void testCreateMessageWithRequestTimeoutFail() throws InterruptedException { assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); try { TimeUnit.SECONDS.sleep(2); - } catch (InterruptedException e) { + } + catch (InterruptedException e) { throw new RuntimeException(e); } return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", @@ -298,9 +300,9 @@ void testCreateMessageWithRequestTimeoutFail() throws InterruptedException { }; var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().sampling().build()) - .sampling(samplingHandler) - .build(); + .capabilities(ClientCapabilities.builder().sampling().build()) + .sampling(samplingHandler) + .build(); // Server @@ -310,43 +312,41 @@ void testCreateMessageWithRequestTimeoutFail() throws InterruptedException { McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { - var craeteMessageRequest = McpSchema.CreateMessageRequest.builder() - .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, - new McpSchema.TextContent("Test message")))) - .modelPreferences(ModelPreferences.builder() + var craeteMessageRequest = McpSchema.CreateMessageRequest.builder() + .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, + new McpSchema.TextContent("Test message")))) + .modelPreferences(ModelPreferences.builder() .hints(List.of()) .costPriority(1.0) .speedPriority(1.0) .intelligencePriority(1.0) .build()) - .build(); - - StepVerifier.create(exchange.createMessage(craeteMessageRequest)).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.role()).isEqualTo(Role.USER); - assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); - assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); - assertThat(result.model()).isEqualTo("MockModelName"); - assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); - }).verifyComplete(); - - return Mono.just(callResponse); - }); + .build(); + + StepVerifier.create(exchange.createMessage(craeteMessageRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.role()).isEqualTo(Role.USER); + assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); + assertThat(result.model()).isEqualTo("MockModelName"); + assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); + }).verifyComplete(); + + return Mono.just(callResponse); + }); var mcpServer = McpServer.async(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .requestTimeout(Duration.ofSeconds(1)) - .tools(tool) - .build(); + .serverInfo("test-server", "1.0.0") + .requestTimeout(Duration.ofSeconds(1)) + .tools(tool) + .build(); InitializeResult initResult = mcpClient.initialize(); assertThat(initResult).isNotNull(); - assertThatExceptionOfType(McpError.class) - .isThrownBy(() -> { - mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - }) - .withMessageContaining("Timeout"); + assertThatExceptionOfType(McpError.class).isThrownBy(() -> { + mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + }).withMessageContaining("Timeout"); mcpClient.close(); mcpServer.close(); diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java index 9306cbb2..972b6fd5 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java @@ -202,19 +202,20 @@ void testCreateMessageWithRequestTimeoutSuccess() throws InterruptedException { Function samplingHandler = request -> { assertThat(request.messages()).hasSize(1); assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); - try { - TimeUnit.SECONDS.sleep(2); - } catch (InterruptedException e) { - throw new RuntimeException(e); - } - return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", + try { + TimeUnit.SECONDS.sleep(2); + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", CreateMessageResult.StopReason.STOP_SEQUENCE); }; var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().sampling().build()) - .sampling(samplingHandler) - .build(); + .capabilities(ClientCapabilities.builder().sampling().build()) + .sampling(samplingHandler) + .build(); // Server @@ -224,34 +225,34 @@ void testCreateMessageWithRequestTimeoutSuccess() throws InterruptedException { McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { - var craeteMessageRequest = McpSchema.CreateMessageRequest.builder() - .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, - new McpSchema.TextContent("Test message")))) - .modelPreferences(ModelPreferences.builder() + var craeteMessageRequest = McpSchema.CreateMessageRequest.builder() + .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, + new McpSchema.TextContent("Test message")))) + .modelPreferences(ModelPreferences.builder() .hints(List.of()) .costPriority(1.0) .speedPriority(1.0) .intelligencePriority(1.0) .build()) - .build(); - - StepVerifier.create(exchange.createMessage(craeteMessageRequest)).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.role()).isEqualTo(Role.USER); - assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); - assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); - assertThat(result.model()).isEqualTo("MockModelName"); - assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); - }).verifyComplete(); - - return Mono.just(callResponse); - }); + .build(); + + StepVerifier.create(exchange.createMessage(craeteMessageRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.role()).isEqualTo(Role.USER); + assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); + assertThat(result.model()).isEqualTo("MockModelName"); + assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); + }).verifyComplete(); + + return Mono.just(callResponse); + }); var mcpServer = McpServer.async(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .requestTimeout(Duration.ofSeconds(3)) - .tools(tool) - .build(); + .serverInfo("test-server", "1.0.0") + .requestTimeout(Duration.ofSeconds(3)) + .tools(tool) + .build(); InitializeResult initResult = mcpClient.initialize(); assertThat(initResult).isNotNull(); @@ -275,7 +276,8 @@ void testCreateMessageWithRequestTimeoutFail() throws InterruptedException { assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); try { TimeUnit.SECONDS.sleep(2); - } catch (InterruptedException e) { + } + catch (InterruptedException e) { throw new RuntimeException(e); } return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", @@ -283,9 +285,9 @@ void testCreateMessageWithRequestTimeoutFail() throws InterruptedException { }; var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().sampling().build()) - .sampling(samplingHandler) - .build(); + .capabilities(ClientCapabilities.builder().sampling().build()) + .sampling(samplingHandler) + .build(); // Server @@ -295,43 +297,41 @@ void testCreateMessageWithRequestTimeoutFail() throws InterruptedException { McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { - var craeteMessageRequest = McpSchema.CreateMessageRequest.builder() - .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, - new McpSchema.TextContent("Test message")))) - .modelPreferences(ModelPreferences.builder() + var craeteMessageRequest = McpSchema.CreateMessageRequest.builder() + .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, + new McpSchema.TextContent("Test message")))) + .modelPreferences(ModelPreferences.builder() .hints(List.of()) .costPriority(1.0) .speedPriority(1.0) .intelligencePriority(1.0) .build()) - .build(); - - StepVerifier.create(exchange.createMessage(craeteMessageRequest)).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.role()).isEqualTo(Role.USER); - assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); - assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); - assertThat(result.model()).isEqualTo("MockModelName"); - assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); - }).verifyComplete(); - - return Mono.just(callResponse); - }); + .build(); + + StepVerifier.create(exchange.createMessage(craeteMessageRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.role()).isEqualTo(Role.USER); + assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); + assertThat(result.model()).isEqualTo("MockModelName"); + assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); + }).verifyComplete(); + + return Mono.just(callResponse); + }); var mcpServer = McpServer.async(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .requestTimeout(Duration.ofSeconds(1)) - .tools(tool) - .build(); + .serverInfo("test-server", "1.0.0") + .requestTimeout(Duration.ofSeconds(1)) + .tools(tool) + .build(); InitializeResult initResult = mcpClient.initialize(); assertThat(initResult).isNotNull(); - assertThatExceptionOfType(McpError.class) - .isThrownBy(() -> { - mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - }) - .withMessageContaining("Timeout"); + assertThatExceptionOfType(McpError.class).isThrownBy(() -> { + mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + }).withMessageContaining("Timeout"); mcpClient.close(); mcpServer.close();