Skip to content

Commit 3582abf

Browse files
committed
feat(client): small enhancements + adds Batch to McpSchema to simplify StreamableHttpClientTransport
1 parent 3491098 commit 3582abf

File tree

2 files changed

+147
-85
lines changed

2 files changed

+147
-85
lines changed

mcp/src/main/java/io/modelcontextprotocol/client/transport/StreamableHttpClientTransport.java

Lines changed: 120 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import java.net.http.HttpResponse;
1515
import java.nio.charset.StandardCharsets;
1616
import java.time.Duration;
17+
import java.util.ArrayList;
1718
import java.util.List;
1819
import java.util.concurrent.atomic.AtomicBoolean;
1920
import java.util.concurrent.atomic.AtomicReference;
@@ -43,6 +44,24 @@ public class StreamableHttpClientTransport implements McpClientTransport {
4344

4445
private static final Logger LOGGER = LoggerFactory.getLogger(StreamableHttpClientTransport.class);
4546

47+
private static final String DEFAULT_MCP_ENDPOINT = "/mcp";
48+
49+
private static final String MCP_SESSION_ID = "Mcp-Session-Id";
50+
51+
private static final String LAST_EVENT_ID = "Last-Event-ID";
52+
53+
private static final String ACCEPT = "Accept";
54+
55+
private static final String CONTENT_TYPE = "Content-Type";
56+
57+
private static final String APPLICATION_JSON = "application/json";
58+
59+
private static final String TEXT_EVENT_STREAM = "text/event-stream";
60+
61+
private static final String APPLICATION_JSON_SEQ = "application/json-seq";
62+
63+
private static final String DEFAULT_ACCEPT_VALUES = "%s, %s".formatted(APPLICATION_JSON, TEXT_EVENT_STREAM);
64+
4665
private final HttpClientSseClientTransport sseClientTransport;
4766

4867
private final HttpClient httpClient;
@@ -57,6 +76,8 @@ public class StreamableHttpClientTransport implements McpClientTransport {
5776

5877
private final AtomicReference<String> lastEventId = new AtomicReference<>();
5978

79+
private final AtomicReference<String> mcpSessionId = new AtomicReference<>();
80+
6081
private final AtomicBoolean fallbackToSse = new AtomicBoolean(false);
6182

6283
StreamableHttpClientTransport(final HttpClient httpClient, final HttpRequest.Builder requestBuilder,
@@ -96,14 +117,13 @@ public static class Builder {
96117
.version(HttpClient.Version.HTTP_1_1)
97118
.connectTimeout(Duration.ofSeconds(10));
98119

99-
private final HttpRequest.Builder requestBuilder = HttpRequest.newBuilder()
100-
.header("Accept", "application/json, text/event-stream");
120+
private final HttpRequest.Builder requestBuilder = HttpRequest.newBuilder();
101121

102122
private ObjectMapper objectMapper = new ObjectMapper();
103123

104124
private String baseUri;
105125

106-
private String endpoint = "/mcp";
126+
private String endpoint = DEFAULT_MCP_ENDPOINT;
107127

108128
private Consumer<HttpClient.Builder> clientCustomizer;
109129

@@ -152,7 +172,7 @@ public StreamableHttpClientTransport build() {
152172
builder.customizeRequest(requestCustomizer);
153173
}
154174

155-
if (!endpoint.equals("/mcp")) {
175+
if (!endpoint.equals(DEFAULT_MCP_ENDPOINT)) {
156176
builder.sseEndpoint(endpoint);
157177
}
158178

@@ -173,13 +193,24 @@ public Mono<Void> connect(final Function<Mono<McpSchema.JSONRPCMessage>, Mono<Mc
173193
}
174194

175195
return Mono.defer(() -> Mono.fromFuture(() -> {
176-
final HttpRequest.Builder builder = requestBuilder.copy().GET().uri(uri);
196+
final HttpRequest.Builder request = requestBuilder.copy().GET().header(ACCEPT, TEXT_EVENT_STREAM).uri(uri);
177197
final String lastId = lastEventId.get();
178198
if (lastId != null) {
179-
builder.header("Last-Event-ID", lastId);
199+
request.header(LAST_EVENT_ID, lastId);
180200
}
181-
return httpClient.sendAsync(builder.build(), HttpResponse.BodyHandlers.ofInputStream());
201+
if (mcpSessionId.get() != null) {
202+
request.header(MCP_SESSION_ID, mcpSessionId.get());
203+
}
204+
205+
return httpClient.sendAsync(request.build(), HttpResponse.BodyHandlers.ofInputStream());
182206
}).flatMap(response -> {
207+
// must like server terminate session and the client need to start a
208+
// new session by sending a new `InitializeRequest` without a session
209+
// ID attached.
210+
if (mcpSessionId.get() != null && response.statusCode() == 404) {
211+
mcpSessionId.set(null);
212+
}
213+
183214
if (response.statusCode() == 405 || response.statusCode() == 404) {
184215
LOGGER.warn("Operation not allowed, falling back to SSE");
185216
fallbackToSse.set(true);
@@ -192,6 +223,7 @@ public Mono<Void> connect(final Function<Mono<McpSchema.JSONRPCMessage>, Mono<Mc
192223
.doOnTerminate(() -> state.set(TransportState.CLOSED))
193224
.onErrorResume(e -> {
194225
LOGGER.error("Streamable transport connection error", e);
226+
state.set(TransportState.DISCONNECTED);
195227
return Mono.error(e);
196228
}));
197229
}
@@ -204,67 +236,52 @@ public Mono<Void> sendMessage(final McpSchema.JSONRPCMessage message) {
204236
public Mono<Void> sendMessage(final McpSchema.JSONRPCMessage message,
205237
final Function<Mono<McpSchema.JSONRPCMessage>, Mono<McpSchema.JSONRPCMessage>> handler) {
206238
if (fallbackToSse.get()) {
207-
return sseClientTransport.sendMessage(message);
239+
return fallbackToSse(message);
208240
}
209241

210242
if (state.get() == TransportState.CLOSED) {
211243
return Mono.empty();
212244
}
213245

214-
return sentPost(message, handler).onErrorResume(e -> {
215-
LOGGER.error("Streamable transport sendMessage error", e);
216-
return Mono.error(e);
217-
});
218-
}
219-
220-
/**
221-
* Sends a list of messages to the server.
222-
* @param messages the list of messages to send
223-
* @return a Mono that completes when all messages have been sent
224-
*/
225-
public Mono<Void> sendMessages(final List<McpSchema.JSONRPCMessage> messages,
226-
final Function<Mono<McpSchema.JSONRPCMessage>, Mono<McpSchema.JSONRPCMessage>> handler) {
227-
if (fallbackToSse.get()) {
228-
return Flux.fromIterable(messages).flatMap(this::sendMessage).then();
229-
}
230-
231-
if (state.get() == TransportState.CLOSED) {
232-
return Mono.empty();
233-
}
234-
235-
return sentPost(messages, handler).onErrorResume(e -> {
236-
LOGGER.error("Streamable transport sendMessages error", e);
237-
return Mono.error(e);
238-
});
239-
}
240-
241-
private Mono<Void> sentPost(final Object msg,
242-
final Function<Mono<McpSchema.JSONRPCMessage>, Mono<McpSchema.JSONRPCMessage>> handler) {
243-
return serializeJson(msg).flatMap(json -> {
244-
final HttpRequest request = requestBuilder.copy()
246+
return serializeJson(message).flatMap(json -> {
247+
final HttpRequest.Builder request = requestBuilder.copy()
245248
.POST(HttpRequest.BodyPublishers.ofString(json))
246-
.uri(uri)
247-
.build();
248-
return Mono.fromFuture(httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofInputStream()))
249+
.header(ACCEPT, DEFAULT_ACCEPT_VALUES)
250+
.header(CONTENT_TYPE, APPLICATION_JSON)
251+
.uri(uri);
252+
if (mcpSessionId.get() != null) {
253+
request.header(MCP_SESSION_ID, mcpSessionId.get());
254+
}
255+
256+
return Mono.fromFuture(httpClient.sendAsync(request.build(), HttpResponse.BodyHandlers.ofInputStream()))
249257
.flatMap(response -> {
250258

259+
// server may assign a session ID at initialization time, if yes we
260+
// have to use it for any subsequent requests
261+
if (message instanceof McpSchema.JSONRPCRequest
262+
&& ((McpSchema.JSONRPCRequest) message).method().equals(McpSchema.METHOD_INITIALIZE)) {
263+
response.headers()
264+
.firstValue(MCP_SESSION_ID)
265+
.map(String::trim)
266+
.ifPresent(this.mcpSessionId::set);
267+
}
268+
251269
// If the response is 202 Accepted, there's no body to process
252270
if (response.statusCode() == 202) {
253271
return Mono.empty();
254272
}
255273

274+
// must like server terminate session and the client need to start a
275+
// new session by sending a new `InitializeRequest` without a session
276+
// ID attached.
277+
if (mcpSessionId.get() != null && response.statusCode() == 404) {
278+
mcpSessionId.set(null);
279+
}
280+
256281
if (response.statusCode() == 405 || response.statusCode() == 404) {
257282
LOGGER.warn("Operation not allowed, falling back to SSE");
258283
fallbackToSse.set(true);
259-
if (msg instanceof McpSchema.JSONRPCMessage message) {
260-
return sseClientTransport.sendMessage(message);
261-
}
262-
263-
if (msg instanceof List<?> list) {
264-
@SuppressWarnings("unchecked")
265-
final List<McpSchema.JSONRPCMessage> messages = (List<McpSchema.JSONRPCMessage>) list;
266-
return Flux.fromIterable(messages).flatMap(this::sendMessage).then();
267-
}
284+
return fallbackToSse(message);
268285
}
269286

270287
if (response.statusCode() >= 400) {
@@ -274,18 +291,28 @@ private Mono<Void> sentPost(final Object msg,
274291

275292
return handleStreamingResponse(response, handler);
276293
});
294+
}).onErrorResume(e -> {
295+
LOGGER.error("Streamable transport sendMessages error", e);
296+
return Mono.error(e);
277297
});
278298

279299
}
280300

281-
private Mono<String> serializeJson(final Object input) {
301+
private Mono<Void> fallbackToSse(final McpSchema.JSONRPCMessage msg) {
302+
if (msg instanceof McpSchema.JSONRPCBatchRequest batch) {
303+
return Flux.fromIterable(batch.items()).flatMap(sseClientTransport::sendMessage).then();
304+
}
305+
306+
if (msg instanceof McpSchema.JSONRPCBatchResponse batch) {
307+
return Flux.fromIterable(batch.items()).flatMap(sseClientTransport::sendMessage).then();
308+
}
309+
310+
return sseClientTransport.sendMessage(msg);
311+
}
312+
313+
private Mono<String> serializeJson(final McpSchema.JSONRPCMessage msg) {
282314
try {
283-
if (input instanceof McpSchema.JSONRPCMessage || input instanceof List) {
284-
return Mono.just(objectMapper.writeValueAsString(input));
285-
}
286-
else {
287-
return Mono.error(new IllegalArgumentException("Unsupported message type for serialization"));
288-
}
315+
return Mono.just(objectMapper.writeValueAsString(msg));
289316
}
290317
catch (IOException e) {
291318
LOGGER.error("Error serializing JSON-RPC message", e);
@@ -295,27 +322,31 @@ private Mono<String> serializeJson(final Object input) {
295322

296323
private Mono<Void> handleStreamingResponse(final HttpResponse<InputStream> response,
297324
final Function<Mono<McpSchema.JSONRPCMessage>, Mono<McpSchema.JSONRPCMessage>> handler) {
298-
final String contentType = response.headers().firstValue("Content-Type").orElse("");
299-
if (contentType.contains("application/json-seq")) {
325+
final String contentType = response.headers().firstValue(CONTENT_TYPE).orElse("");
326+
if (contentType.contains(APPLICATION_JSON_SEQ)) {
300327
return handleJsonStream(response, handler);
301328
}
302-
else if (contentType.contains("text/event-stream")) {
329+
else if (contentType.contains(TEXT_EVENT_STREAM)) {
303330
return handleSseStream(response, handler);
304331
}
305-
else if (contentType.contains("application/json")) {
332+
else if (contentType.contains(APPLICATION_JSON)) {
306333
return handleSingleJson(response, handler);
307334
}
308-
else {
309-
return Mono.error(new UnsupportedOperationException("Unsupported Content-Type: " + contentType));
310-
}
335+
return Mono.error(new UnsupportedOperationException("Unsupported Content-Type: " + contentType));
311336
}
312337

313338
private Mono<Void> handleSingleJson(final HttpResponse<InputStream> response,
314339
final Function<Mono<McpSchema.JSONRPCMessage>, Mono<McpSchema.JSONRPCMessage>> handler) {
315340
return Mono.fromCallable(() -> {
316-
final McpSchema.JSONRPCMessage msg = McpSchema.deserializeJsonRpcMessage(objectMapper,
317-
new String(response.body().readAllBytes(), StandardCharsets.UTF_8));
318-
return handler.apply(Mono.just(msg));
341+
try {
342+
final McpSchema.JSONRPCMessage msg = McpSchema.deserializeJsonRpcMessage(objectMapper,
343+
new String(response.body().readAllBytes(), StandardCharsets.UTF_8));
344+
return handler.apply(Mono.just(msg));
345+
}
346+
catch (IOException e) {
347+
LOGGER.error("Error processing JSON response", e);
348+
return Mono.error(e);
349+
}
319350
}).flatMap(Function.identity()).then();
320351
}
321352

@@ -328,7 +359,7 @@ private Mono<Void> handleJsonStream(final HttpResponse<InputStream> response,
328359
}
329360
catch (IOException e) {
330361
LOGGER.error("Error processing JSON line", e);
331-
return Mono.empty();
362+
return Mono.error(e);
332363
}
333364
}).then();
334365
}
@@ -347,7 +378,7 @@ private Mono<Void> handleSseStream(final HttpResponse<InputStream> response,
347378
if (line.startsWith("event: "))
348379
event = line.substring(7).trim();
349380
else if (line.startsWith("data: "))
350-
data += line.substring(6).trim() + "\n";
381+
data += line.substring(6) + "\n";
351382
else if (line.startsWith("id: "))
352383
id = line.substring(4).trim();
353384
}
@@ -356,34 +387,39 @@ else if (line.startsWith("id: "))
356387
data = data.substring(0, data.length() - 1);
357388
}
358389

359-
return new FlowSseClient.SseEvent(event, data, id);
390+
return new FlowSseClient.SseEvent(id, event, data);
360391
})
361392
.filter(sseEvent -> "message".equals(sseEvent.type()))
362-
.doOnNext(sseEvent -> {
363-
lastEventId.set(sseEvent.id());
393+
.concatMap(sseEvent -> {
394+
String rawData = sseEvent.data().trim();
364395
try {
365-
String rawData = sseEvent.data().trim();
366396
JsonNode node = objectMapper.readTree(rawData);
367-
397+
List<McpSchema.JSONRPCMessage> messages = new ArrayList<>();
368398
if (node.isArray()) {
369399
for (JsonNode item : node) {
370-
String rawMessage = objectMapper.writeValueAsString(item);
371-
McpSchema.JSONRPCMessage msg = McpSchema.deserializeJsonRpcMessage(objectMapper,
372-
rawMessage);
373-
handler.apply(Mono.just(msg)).subscribe();
400+
messages.add(McpSchema.deserializeJsonRpcMessage(objectMapper, item.toString()));
374401
}
375402
}
376403
else if (node.isObject()) {
377-
String rawMessage = objectMapper.writeValueAsString(node);
378-
McpSchema.JSONRPCMessage msg = McpSchema.deserializeJsonRpcMessage(objectMapper, rawMessage);
379-
handler.apply(Mono.just(msg)).subscribe();
404+
messages.add(McpSchema.deserializeJsonRpcMessage(objectMapper, node.toString()));
380405
}
381406
else {
382-
LOGGER.warn("Unexpected JSON in SSE data: {}", rawData);
407+
String warning = "Unexpected JSON in SSE data: " + rawData;
408+
LOGGER.warn(warning);
409+
return Mono.error(new IllegalArgumentException(warning));
383410
}
411+
412+
return Flux.fromIterable(messages)
413+
.concatMap(msg -> handler.apply(Mono.just(msg)))
414+
.then(Mono.fromRunnable(() -> {
415+
if (!sseEvent.id().isEmpty()) {
416+
lastEventId.set(sseEvent.id());
417+
}
418+
}));
384419
}
385420
catch (IOException e) {
386-
LOGGER.error("Error processing SSE event: {}", sseEvent.data(), e);
421+
LOGGER.error("Error parsing SSE JSON: {}", rawData, e);
422+
return Mono.error(e);
387423
}
388424
})
389425
.then();

mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import java.util.List;
1111
import java.util.Map;
1212

13+
import com.fasterxml.jackson.annotation.JsonIgnore;
1314
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
1415
import com.fasterxml.jackson.annotation.JsonInclude;
1516
import com.fasterxml.jackson.annotation.JsonProperty;
@@ -173,12 +174,37 @@ else if (map.containsKey("result") || map.containsKey("error")) {
173174
// ---------------------------
174175
// JSON-RPC Message Types
175176
// ---------------------------
176-
public sealed interface JSONRPCMessage permits JSONRPCRequest, JSONRPCNotification, JSONRPCResponse {
177+
public sealed interface JSONRPCMessage
178+
permits JSONRPCBatchRequest, JSONRPCBatchResponse, JSONRPCRequest, JSONRPCNotification, JSONRPCResponse {
177179

178180
String jsonrpc();
179181

180182
}
181183

184+
@JsonInclude(JsonInclude.Include.NON_ABSENT)
185+
@JsonIgnoreProperties(ignoreUnknown = true)
186+
public record JSONRPCBatchRequest( // @formatter:off
187+
@JsonProperty("items") List<JSONRPCMessage> items) implements JSONRPCMessage {
188+
189+
@Override
190+
@JsonIgnore
191+
public String jsonrpc() {
192+
return JSONRPC_VERSION;
193+
}
194+
} // @formatter:on
195+
196+
@JsonInclude(JsonInclude.Include.NON_ABSENT)
197+
@JsonIgnoreProperties(ignoreUnknown = true)
198+
public record JSONRPCBatchResponse( // @formatter:off
199+
@JsonProperty("items") List<JSONRPCMessage> items) implements JSONRPCMessage {
200+
201+
@Override
202+
@JsonIgnore
203+
public String jsonrpc() {
204+
return JSONRPC_VERSION;
205+
}
206+
} // @formatter:on
207+
182208
@JsonInclude(JsonInclude.Include.NON_ABSENT)
183209
@JsonIgnoreProperties(ignoreUnknown = true)
184210
public record JSONRPCRequest( // @formatter:off

0 commit comments

Comments
 (0)