Skip to content

Commit ce27db5

Browse files
committed
feat(client): small enhancements
1 parent 1c0a336 commit ce27db5

File tree

1 file changed

+18
-9
lines changed

1 file changed

+18
-9
lines changed

mcp/src/main/java/io/modelcontextprotocol/client/transport/StreamableHttpClientTransport.java

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ public Mono<Void> connect(final Function<Mono<McpSchema.JSONRPCMessage>, Mono<Mc
185185
fallbackToSse.set(true);
186186
return sseClientTransport.connect(handler);
187187
}
188-
return handleStreamingResponse(handler, response);
188+
return handleStreamingResponse(response, handler);
189189
})
190190
.retryWhen(Retry.backoff(3, Duration.ofSeconds(3)).filter(err -> err instanceof IllegalStateException))
191191
.doOnSuccess(v -> state.set(TransportState.CONNECTED))
@@ -198,6 +198,11 @@ public Mono<Void> connect(final Function<Mono<McpSchema.JSONRPCMessage>, Mono<Mc
198198

199199
@Override
200200
public Mono<Void> sendMessage(final McpSchema.JSONRPCMessage message) {
201+
return sendMessage(message, msg -> msg);
202+
}
203+
204+
public Mono<Void> sendMessage(final McpSchema.JSONRPCMessage message,
205+
final Function<Mono<McpSchema.JSONRPCMessage>, Mono<McpSchema.JSONRPCMessage>> handler) {
201206
if (fallbackToSse.get()) {
202207
return sseClientTransport.sendMessage(message);
203208
}
@@ -206,7 +211,7 @@ public Mono<Void> sendMessage(final McpSchema.JSONRPCMessage message) {
206211
return Mono.empty();
207212
}
208213

209-
return sentPost(message).onErrorResume(e -> {
214+
return sentPost(message, handler).onErrorResume(e -> {
210215
LOGGER.error("Streamable transport sendMessage error", e);
211216
return Mono.error(e);
212217
});
@@ -217,7 +222,8 @@ public Mono<Void> sendMessage(final McpSchema.JSONRPCMessage message) {
217222
* @param messages the list of messages to send
218223
* @return a Mono that completes when all messages have been sent
219224
*/
220-
public Mono<Void> sendMessages(final List<McpSchema.JSONRPCMessage> messages) {
225+
public Mono<Void> sendMessages(final List<McpSchema.JSONRPCMessage> messages,
226+
final Function<Mono<McpSchema.JSONRPCMessage>, Mono<McpSchema.JSONRPCMessage>> handler) {
221227
if (fallbackToSse.get()) {
222228
return Flux.fromIterable(messages).flatMap(this::sendMessage).then();
223229
}
@@ -226,13 +232,14 @@ public Mono<Void> sendMessages(final List<McpSchema.JSONRPCMessage> messages) {
226232
return Mono.empty();
227233
}
228234

229-
return sentPost(messages).onErrorResume(e -> {
235+
return sentPost(messages, handler).onErrorResume(e -> {
230236
LOGGER.error("Streamable transport sendMessages error", e);
231237
return Mono.error(e);
232238
});
233239
}
234240

235-
private Mono<Void> sentPost(final Object msg) {
241+
private Mono<Void> sentPost(final Object msg,
242+
final Function<Mono<McpSchema.JSONRPCMessage>, Mono<McpSchema.JSONRPCMessage>> handler) {
236243
return serializeJson(msg).flatMap(json -> {
237244
final HttpRequest request = requestBuilder.copy()
238245
.POST(HttpRequest.BodyPublishers.ofString(json))
@@ -265,7 +272,7 @@ private Mono<Void> sentPost(final Object msg) {
265272
.error(new IllegalArgumentException("Unexpected status code: " + response.statusCode()));
266273
}
267274

268-
return handleStreamingResponse(it -> it, response);
275+
return handleStreamingResponse(response, handler);
269276
});
270277
});
271278

@@ -286,9 +293,8 @@ private Mono<String> serializeJson(final Object input) {
286293
}
287294
}
288295

289-
private Mono<Void> handleStreamingResponse(
290-
final Function<Mono<McpSchema.JSONRPCMessage>, Mono<McpSchema.JSONRPCMessage>> handler,
291-
final HttpResponse<InputStream> response) {
296+
private Mono<Void> handleStreamingResponse(final HttpResponse<InputStream> response,
297+
final Function<Mono<McpSchema.JSONRPCMessage>, Mono<McpSchema.JSONRPCMessage>> handler) {
292298
final String contentType = response.headers().firstValue("Content-Type").orElse("");
293299
if (contentType.contains("application/json-seq")) {
294300
return handleJsonStream(response, handler);
@@ -386,6 +392,9 @@ else if (node.isObject()) {
386392
@Override
387393
public Mono<Void> closeGracefully() {
388394
state.set(TransportState.CLOSED);
395+
if (fallbackToSse.get()) {
396+
return sseClientTransport.closeGracefully();
397+
}
389398
return Mono.empty();
390399
}
391400

0 commit comments

Comments
 (0)