diff --git a/mcp/pom.xml b/mcp/pom.xml
index 6b0f4a9f..13390c67 100644
--- a/mcp/pom.xml
+++ b/mcp/pom.xml
@@ -126,6 +126,12 @@
${junit.version}
test
+
+ org.junit.jupiter
+ junit-jupiter-params
+ ${junit.version}
+ test
+
org.mockito
mockito-core
diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java
index 632d3844..99cf2a62 100644
--- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java
+++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java
@@ -24,6 +24,7 @@
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage;
import io.modelcontextprotocol.util.Assert;
+import io.modelcontextprotocol.util.Utils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Mono;
@@ -69,7 +70,7 @@ public class HttpClientSseClientTransport implements McpClientTransport {
private static final String DEFAULT_SSE_ENDPOINT = "/sse";
/** Base URI for the MCP server */
- private final String baseUri;
+ private final URI baseUri;
/** SSE endpoint path */
private final String sseEndpoint;
@@ -178,7 +179,7 @@ public HttpClientSseClientTransport(HttpClient.Builder clientBuilder, HttpReques
Assert.hasText(sseEndpoint, "sseEndpoint must not be empty");
Assert.notNull(httpClient, "httpClient must not be null");
Assert.notNull(requestBuilder, "requestBuilder must not be null");
- this.baseUri = baseUri;
+ this.baseUri = URI.create(baseUri);
this.sseEndpoint = sseEndpoint;
this.objectMapper = objectMapper;
this.httpClient = httpClient;
@@ -340,7 +341,8 @@ public Mono connect(Function, Mono> h
CompletableFuture future = new CompletableFuture<>();
connectionFuture.set(future);
- sseClient.subscribe(this.baseUri + this.sseEndpoint, new FlowSseClient.SseEventHandler() {
+ URI clientUri = Utils.resolveUri(this.baseUri, this.sseEndpoint);
+ sseClient.subscribe(clientUri.toString(), new FlowSseClient.SseEventHandler() {
@Override
public void onEvent(SseEvent event) {
if (isClosing) {
@@ -412,7 +414,8 @@ public Mono sendMessage(JSONRPCMessage message) {
try {
String jsonText = this.objectMapper.writeValueAsString(message);
- HttpRequest request = this.requestBuilder.uri(URI.create(this.baseUri + endpoint))
+ URI requestUri = Utils.resolveUri(baseUri, endpoint);
+ HttpRequest request = this.requestBuilder.uri(requestUri)
.POST(HttpRequest.BodyPublishers.ofString(jsonText))
.build();
diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/Utils.java b/mcp/src/main/java/io/modelcontextprotocol/util/Utils.java
index 0f799ca0..8e654e59 100644
--- a/mcp/src/main/java/io/modelcontextprotocol/util/Utils.java
+++ b/mcp/src/main/java/io/modelcontextprotocol/util/Utils.java
@@ -4,11 +4,12 @@
package io.modelcontextprotocol.util;
+import reactor.util.annotation.Nullable;
+
+import java.net.URI;
import java.util.Collection;
import java.util.Map;
-import reactor.util.annotation.Nullable;
-
/**
* Miscellaneous utility methods.
*
@@ -52,4 +53,55 @@ public static boolean isEmpty(@Nullable Map, ?> map) {
return (map == null || map.isEmpty());
}
+ /**
+ * Resolves the given endpoint URL against the base URL.
+ *
+ * - If the endpoint URL is relative, it will be resolved against the base URL.
+ * - If the endpoint URL is absolute, it will be validated to ensure it matches the
+ * base URL's scheme, authority, and path prefix.
+ * - If validation fails for an absolute URL, an {@link IllegalArgumentException} is
+ * thrown.
+ *
+ * @param baseUrl The base URL (must be absolute)
+ * @param endpointUrl The endpoint URL (can be relative or absolute)
+ * @return The resolved endpoint URI
+ * @throws IllegalArgumentException If the absolute endpoint URL does not match the
+ * base URL or URI is malformed
+ */
+ public static URI resolveUri(URI baseUrl, String endpointUrl) {
+ URI endpointUri = URI.create(endpointUrl);
+ if (endpointUri.isAbsolute() && !isUnderBaseUri(baseUrl, endpointUri)) {
+ throw new IllegalArgumentException("Absolute endpoint URL does not match the base URL.");
+ }
+ else {
+ return baseUrl.resolve(endpointUri);
+ }
+ }
+
+ /**
+ * Checks if the given absolute endpoint URI falls under the base URI. It validates
+ * the scheme, authority (host and port), and ensures that the base path is a prefix
+ * of the endpoint path.
+ * @param baseUri The base URI
+ * @param endpointUri The endpoint URI to check
+ * @return true if endpointUri is within baseUri's hierarchy, false otherwise
+ */
+ private static boolean isUnderBaseUri(URI baseUri, URI endpointUri) {
+ if (!baseUri.getScheme().equals(endpointUri.getScheme())
+ || !baseUri.getAuthority().equals(endpointUri.getAuthority())) {
+ return false;
+ }
+
+ URI normalizedBase = baseUri.normalize();
+ URI normalizedEndpoint = endpointUri.normalize();
+
+ String basePath = normalizedBase.getPath();
+ String endpointPath = normalizedEndpoint.getPath();
+
+ if (basePath.endsWith("/")) {
+ basePath = basePath.substring(0, basePath.length() - 1);
+ }
+ return endpointPath.startsWith(basePath);
+ }
+
}
diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java
index e5178c0e..a75f7675 100644
--- a/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java
+++ b/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java
@@ -7,12 +7,13 @@
import java.net.URI;
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
+import java.net.http.HttpResponse;
import java.time.Duration;
import java.util.Map;
+import java.util.concurrent.CompletableFuture;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
-import java.util.function.Consumer;
import java.util.function.Function;
import io.modelcontextprotocol.spec.McpSchema;
@@ -21,6 +22,8 @@
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.Timeout;
+import org.mockito.ArgumentCaptor;
+import org.mockito.Mockito;
import org.testcontainers.containers.GenericContainer;
import org.testcontainers.containers.wait.strategy.Wait;
import reactor.core.publisher.Mono;
@@ -31,6 +34,9 @@
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatCode;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
import com.fasterxml.jackson.databind.ObjectMapper;
@@ -364,4 +370,23 @@ void testChainedCustomizations() {
customizedTransport.closeGracefully().block();
}
+ @Test
+ @SuppressWarnings("unchecked")
+ void testResolvingClientEndpoint() {
+ HttpClient httpClient = Mockito.mock(HttpClient.class);
+ HttpResponse httpResponse = Mockito.mock(HttpResponse.class);
+ CompletableFuture> future = new CompletableFuture<>();
+ future.complete(httpResponse);
+ when(httpClient.sendAsync(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))).thenReturn(future);
+
+ HttpClientSseClientTransport transport = new HttpClientSseClientTransport(httpClient, HttpRequest.newBuilder(),
+ "http://example.com", "http://example.com/sse", new ObjectMapper());
+
+ transport.connect(Function.identity());
+
+ ArgumentCaptor httpRequestCaptor = ArgumentCaptor.forClass(HttpRequest.class);
+ verify(httpClient).sendAsync(httpRequestCaptor.capture(), any(HttpResponse.BodyHandler.class));
+ assertThat(httpRequestCaptor.getValue().uri()).isEqualTo(URI.create("http://example.com/sse"));
+ }
+
}
diff --git a/mcp/src/test/java/io/modelcontextprotocol/util/UtilsTests.java b/mcp/src/test/java/io/modelcontextprotocol/util/UtilsTests.java
index aced20cb..0f2e689b 100644
--- a/mcp/src/test/java/io/modelcontextprotocol/util/UtilsTests.java
+++ b/mcp/src/test/java/io/modelcontextprotocol/util/UtilsTests.java
@@ -6,12 +6,17 @@
import org.junit.jupiter.api.Test;
+import java.net.URI;
import java.util.Collection;
import java.util.List;
import java.util.Map;
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.CsvSource;
class UtilsTests {
@@ -37,4 +42,28 @@ void testMapIsEmpty() {
assertFalse(Utils.isEmpty(Map.of("key", "value")));
}
+ @ParameterizedTest
+ @CsvSource({
+ // relative endpoints
+ "http://localhost:8080/root, /api/v1, http://localhost:8080/api/v1",
+ "http://localhost:8080/root/, api, http://localhost:8080/root/api",
+ "http://localhost:8080, /api, http://localhost:8080/api",
+ // absolute endpoints matching base
+ "http://localhost:8080/root, http://localhost:8080/root/api/v1, http://localhost:8080/root/api/v1",
+ "http://localhost:8080/root, http://localhost:8080/root, http://localhost:8080/root" })
+ void testValidUriResolution(String baseUrl, String endpoint, String expectedResult) {
+ URI result = Utils.resolveUri(URI.create(baseUrl), endpoint);
+ assertThat(result.toString()).isEqualTo(expectedResult);
+ }
+
+ @ParameterizedTest
+ @CsvSource({ "http://localhost:8080/root, http://localhost:8080/other/api",
+ "http://localhost:8080/root, http://otherhost/api",
+ "http://localhost:8080/root, http://localhost:9090/root/api" })
+ void testAbsoluteUriNotMatchingBase(String baseUrl, String endpoint) {
+ assertThatThrownBy(() -> Utils.resolveUri(URI.create(baseUrl), endpoint))
+ .isInstanceOf(IllegalArgumentException.class)
+ .hasMessageContaining("does not match the base URL");
+ }
+
}
\ No newline at end of file