14
14
import java .net .http .HttpResponse ;
15
15
import java .nio .charset .StandardCharsets ;
16
16
import java .time .Duration ;
17
+ import java .util .ArrayList ;
17
18
import java .util .List ;
18
19
import java .util .concurrent .atomic .AtomicBoolean ;
19
20
import java .util .concurrent .atomic .AtomicReference ;
@@ -43,6 +44,24 @@ public class StreamableHttpClientTransport implements McpClientTransport {
43
44
44
45
private static final Logger LOGGER = LoggerFactory .getLogger (StreamableHttpClientTransport .class );
45
46
47
+ private static final String DEFAULT_MCP_ENDPOINT = "/mcp" ;
48
+
49
+ private static final String MCP_SESSION_ID = "Mcp-Session-Id" ;
50
+
51
+ private static final String LAST_EVENT_ID = "Last-Event-ID" ;
52
+
53
+ private static final String ACCEPT = "Accept" ;
54
+
55
+ private static final String CONTENT_TYPE = "Content-Type" ;
56
+
57
+ private static final String APPLICATION_JSON = "application/json" ;
58
+
59
+ private static final String TEXT_EVENT_STREAM = "text/event-stream" ;
60
+
61
+ private static final String APPLICATION_JSON_SEQ = "application/json-seq" ;
62
+
63
+ private static final String DEFAULT_ACCEPT_VALUES = "%s, %s" .formatted (APPLICATION_JSON , TEXT_EVENT_STREAM );
64
+
46
65
private final HttpClientSseClientTransport sseClientTransport ;
47
66
48
67
private final HttpClient httpClient ;
@@ -57,6 +76,8 @@ public class StreamableHttpClientTransport implements McpClientTransport {
57
76
58
77
private final AtomicReference <String > lastEventId = new AtomicReference <>();
59
78
79
+ private final AtomicReference <String > mcpSessionId = new AtomicReference <>();
80
+
60
81
private final AtomicBoolean fallbackToSse = new AtomicBoolean (false );
61
82
62
83
StreamableHttpClientTransport (final HttpClient httpClient , final HttpRequest .Builder requestBuilder ,
@@ -96,14 +117,13 @@ public static class Builder {
96
117
.version (HttpClient .Version .HTTP_1_1 )
97
118
.connectTimeout (Duration .ofSeconds (10 ));
98
119
99
- private final HttpRequest .Builder requestBuilder = HttpRequest .newBuilder ()
100
- .header ("Accept" , "application/json, text/event-stream" );
120
+ private final HttpRequest .Builder requestBuilder = HttpRequest .newBuilder ();
101
121
102
122
private ObjectMapper objectMapper = new ObjectMapper ();
103
123
104
124
private String baseUri ;
105
125
106
- private String endpoint = "/mcp" ;
126
+ private String endpoint = DEFAULT_MCP_ENDPOINT ;
107
127
108
128
private Consumer <HttpClient .Builder > clientCustomizer ;
109
129
@@ -152,7 +172,7 @@ public StreamableHttpClientTransport build() {
152
172
builder .customizeRequest (requestCustomizer );
153
173
}
154
174
155
- if (!endpoint .equals ("/mcp" )) {
175
+ if (!endpoint .equals (DEFAULT_MCP_ENDPOINT )) {
156
176
builder .sseEndpoint (endpoint );
157
177
}
158
178
@@ -173,13 +193,24 @@ public Mono<Void> connect(final Function<Mono<McpSchema.JSONRPCMessage>, Mono<Mc
173
193
}
174
194
175
195
return Mono .defer (() -> Mono .fromFuture (() -> {
176
- final HttpRequest .Builder builder = requestBuilder .copy ().GET ().uri (uri );
196
+ final HttpRequest .Builder request = requestBuilder .copy ().GET (). header ( ACCEPT , TEXT_EVENT_STREAM ).uri (uri );
177
197
final String lastId = lastEventId .get ();
178
198
if (lastId != null ) {
179
- builder .header ("Last-Event-ID" , lastId );
199
+ request .header (LAST_EVENT_ID , lastId );
180
200
}
181
- return httpClient .sendAsync (builder .build (), HttpResponse .BodyHandlers .ofInputStream ());
201
+ if (mcpSessionId .get () != null ) {
202
+ request .header (MCP_SESSION_ID , mcpSessionId .get ());
203
+ }
204
+
205
+ return httpClient .sendAsync (request .build (), HttpResponse .BodyHandlers .ofInputStream ());
182
206
}).flatMap (response -> {
207
+ // must like server terminate session and the client need to start a
208
+ // new session by sending a new `InitializeRequest` without a session
209
+ // ID attached.
210
+ if (mcpSessionId .get () != null && response .statusCode () == 404 ) {
211
+ mcpSessionId .set (null );
212
+ }
213
+
183
214
if (response .statusCode () == 405 || response .statusCode () == 404 ) {
184
215
LOGGER .warn ("Operation not allowed, falling back to SSE" );
185
216
fallbackToSse .set (true );
@@ -192,6 +223,7 @@ public Mono<Void> connect(final Function<Mono<McpSchema.JSONRPCMessage>, Mono<Mc
192
223
.doOnTerminate (() -> state .set (TransportState .CLOSED ))
193
224
.onErrorResume (e -> {
194
225
LOGGER .error ("Streamable transport connection error" , e );
226
+ state .set (TransportState .DISCONNECTED );
195
227
return Mono .error (e );
196
228
}));
197
229
}
@@ -204,67 +236,52 @@ public Mono<Void> sendMessage(final McpSchema.JSONRPCMessage message) {
204
236
public Mono <Void > sendMessage (final McpSchema .JSONRPCMessage message ,
205
237
final Function <Mono <McpSchema .JSONRPCMessage >, Mono <McpSchema .JSONRPCMessage >> handler ) {
206
238
if (fallbackToSse .get ()) {
207
- return sseClientTransport . sendMessage (message );
239
+ return fallbackToSse (message );
208
240
}
209
241
210
242
if (state .get () == TransportState .CLOSED ) {
211
243
return Mono .empty ();
212
244
}
213
245
214
- return sentPost (message , handler ).onErrorResume (e -> {
215
- LOGGER .error ("Streamable transport sendMessage error" , e );
216
- return Mono .error (e );
217
- });
218
- }
219
-
220
- /**
221
- * Sends a list of messages to the server.
222
- * @param messages the list of messages to send
223
- * @return a Mono that completes when all messages have been sent
224
- */
225
- public Mono <Void > sendMessages (final List <McpSchema .JSONRPCMessage > messages ,
226
- final Function <Mono <McpSchema .JSONRPCMessage >, Mono <McpSchema .JSONRPCMessage >> handler ) {
227
- if (fallbackToSse .get ()) {
228
- return Flux .fromIterable (messages ).flatMap (this ::sendMessage ).then ();
229
- }
230
-
231
- if (state .get () == TransportState .CLOSED ) {
232
- return Mono .empty ();
233
- }
234
-
235
- return sentPost (messages , handler ).onErrorResume (e -> {
236
- LOGGER .error ("Streamable transport sendMessages error" , e );
237
- return Mono .error (e );
238
- });
239
- }
240
-
241
- private Mono <Void > sentPost (final Object msg ,
242
- final Function <Mono <McpSchema .JSONRPCMessage >, Mono <McpSchema .JSONRPCMessage >> handler ) {
243
- return serializeJson (msg ).flatMap (json -> {
244
- final HttpRequest request = requestBuilder .copy ()
246
+ return serializeJson (message ).flatMap (json -> {
247
+ final HttpRequest .Builder request = requestBuilder .copy ()
245
248
.POST (HttpRequest .BodyPublishers .ofString (json ))
246
- .uri (uri )
247
- .build ();
248
- return Mono .fromFuture (httpClient .sendAsync (request , HttpResponse .BodyHandlers .ofInputStream ()))
249
+ .header (ACCEPT , DEFAULT_ACCEPT_VALUES )
250
+ .header (CONTENT_TYPE , APPLICATION_JSON )
251
+ .uri (uri );
252
+ if (mcpSessionId .get () != null ) {
253
+ request .header (MCP_SESSION_ID , mcpSessionId .get ());
254
+ }
255
+
256
+ return Mono .fromFuture (httpClient .sendAsync (request .build (), HttpResponse .BodyHandlers .ofInputStream ()))
249
257
.flatMap (response -> {
250
258
259
+ // server may assign a session ID at initialization time, if yes we
260
+ // have to use it for any subsequent requests
261
+ if (message instanceof McpSchema .JSONRPCRequest
262
+ && ((McpSchema .JSONRPCRequest ) message ).method ().equals (McpSchema .METHOD_INITIALIZE )) {
263
+ response .headers ()
264
+ .firstValue (MCP_SESSION_ID )
265
+ .map (String ::trim )
266
+ .ifPresent (this .mcpSessionId ::set );
267
+ }
268
+
251
269
// If the response is 202 Accepted, there's no body to process
252
270
if (response .statusCode () == 202 ) {
253
271
return Mono .empty ();
254
272
}
255
273
274
+ // must like server terminate session and the client need to start a
275
+ // new session by sending a new `InitializeRequest` without a session
276
+ // ID attached.
277
+ if (mcpSessionId .get () != null && response .statusCode () == 404 ) {
278
+ mcpSessionId .set (null );
279
+ }
280
+
256
281
if (response .statusCode () == 405 || response .statusCode () == 404 ) {
257
282
LOGGER .warn ("Operation not allowed, falling back to SSE" );
258
283
fallbackToSse .set (true );
259
- if (msg instanceof McpSchema .JSONRPCMessage message ) {
260
- return sseClientTransport .sendMessage (message );
261
- }
262
-
263
- if (msg instanceof List <?> list ) {
264
- @ SuppressWarnings ("unchecked" )
265
- final List <McpSchema .JSONRPCMessage > messages = (List <McpSchema .JSONRPCMessage >) list ;
266
- return Flux .fromIterable (messages ).flatMap (this ::sendMessage ).then ();
267
- }
284
+ return fallbackToSse (message );
268
285
}
269
286
270
287
if (response .statusCode () >= 400 ) {
@@ -274,18 +291,28 @@ private Mono<Void> sentPost(final Object msg,
274
291
275
292
return handleStreamingResponse (response , handler );
276
293
});
294
+ }).onErrorResume (e -> {
295
+ LOGGER .error ("Streamable transport sendMessages error" , e );
296
+ return Mono .error (e );
277
297
});
278
298
279
299
}
280
300
281
- private Mono <String > serializeJson (final Object input ) {
301
+ private Mono <Void > fallbackToSse (final McpSchema .JSONRPCMessage msg ) {
302
+ if (msg instanceof McpSchema .JSONRPCBatchRequest batch ) {
303
+ return Flux .fromIterable (batch .items ()).flatMap (sseClientTransport ::sendMessage ).then ();
304
+ }
305
+
306
+ if (msg instanceof McpSchema .JSONRPCBatchResponse batch ) {
307
+ return Flux .fromIterable (batch .items ()).flatMap (sseClientTransport ::sendMessage ).then ();
308
+ }
309
+
310
+ return sseClientTransport .sendMessage (msg );
311
+ }
312
+
313
+ private Mono <String > serializeJson (final McpSchema .JSONRPCMessage msg ) {
282
314
try {
283
- if (input instanceof McpSchema .JSONRPCMessage || input instanceof List ) {
284
- return Mono .just (objectMapper .writeValueAsString (input ));
285
- }
286
- else {
287
- return Mono .error (new IllegalArgumentException ("Unsupported message type for serialization" ));
288
- }
315
+ return Mono .just (objectMapper .writeValueAsString (msg ));
289
316
}
290
317
catch (IOException e ) {
291
318
LOGGER .error ("Error serializing JSON-RPC message" , e );
@@ -295,27 +322,31 @@ private Mono<String> serializeJson(final Object input) {
295
322
296
323
private Mono <Void > handleStreamingResponse (final HttpResponse <InputStream > response ,
297
324
final Function <Mono <McpSchema .JSONRPCMessage >, Mono <McpSchema .JSONRPCMessage >> handler ) {
298
- final String contentType = response .headers ().firstValue ("Content-Type" ).orElse ("" );
299
- if (contentType .contains ("application/json-seq" )) {
325
+ final String contentType = response .headers ().firstValue (CONTENT_TYPE ).orElse ("" );
326
+ if (contentType .contains (APPLICATION_JSON_SEQ )) {
300
327
return handleJsonStream (response , handler );
301
328
}
302
- else if (contentType .contains ("text/event-stream" )) {
329
+ else if (contentType .contains (TEXT_EVENT_STREAM )) {
303
330
return handleSseStream (response , handler );
304
331
}
305
- else if (contentType .contains ("application/json" )) {
332
+ else if (contentType .contains (APPLICATION_JSON )) {
306
333
return handleSingleJson (response , handler );
307
334
}
308
- else {
309
- return Mono .error (new UnsupportedOperationException ("Unsupported Content-Type: " + contentType ));
310
- }
335
+ return Mono .error (new UnsupportedOperationException ("Unsupported Content-Type: " + contentType ));
311
336
}
312
337
313
338
private Mono <Void > handleSingleJson (final HttpResponse <InputStream > response ,
314
339
final Function <Mono <McpSchema .JSONRPCMessage >, Mono <McpSchema .JSONRPCMessage >> handler ) {
315
340
return Mono .fromCallable (() -> {
316
- final McpSchema .JSONRPCMessage msg = McpSchema .deserializeJsonRpcMessage (objectMapper ,
317
- new String (response .body ().readAllBytes (), StandardCharsets .UTF_8 ));
318
- return handler .apply (Mono .just (msg ));
341
+ try {
342
+ final McpSchema .JSONRPCMessage msg = McpSchema .deserializeJsonRpcMessage (objectMapper ,
343
+ new String (response .body ().readAllBytes (), StandardCharsets .UTF_8 ));
344
+ return handler .apply (Mono .just (msg ));
345
+ }
346
+ catch (IOException e ) {
347
+ LOGGER .error ("Error processing JSON response" , e );
348
+ return Mono .error (e );
349
+ }
319
350
}).flatMap (Function .identity ()).then ();
320
351
}
321
352
@@ -328,7 +359,7 @@ private Mono<Void> handleJsonStream(final HttpResponse<InputStream> response,
328
359
}
329
360
catch (IOException e ) {
330
361
LOGGER .error ("Error processing JSON line" , e );
331
- return Mono .empty ( );
362
+ return Mono .error ( e );
332
363
}
333
364
}).then ();
334
365
}
@@ -347,7 +378,7 @@ private Mono<Void> handleSseStream(final HttpResponse<InputStream> response,
347
378
if (line .startsWith ("event: " ))
348
379
event = line .substring (7 ).trim ();
349
380
else if (line .startsWith ("data: " ))
350
- data += line .substring (6 ). trim () + "\n " ;
381
+ data += line .substring (6 ) + "\n " ;
351
382
else if (line .startsWith ("id: " ))
352
383
id = line .substring (4 ).trim ();
353
384
}
@@ -356,34 +387,39 @@ else if (line.startsWith("id: "))
356
387
data = data .substring (0 , data .length () - 1 );
357
388
}
358
389
359
- return new FlowSseClient .SseEvent (event , data , id );
390
+ return new FlowSseClient .SseEvent (id , event , data );
360
391
})
361
392
.filter (sseEvent -> "message" .equals (sseEvent .type ()))
362
- .doOnNext (sseEvent -> {
363
- lastEventId . set ( sseEvent .id () );
393
+ .concatMap (sseEvent -> {
394
+ String rawData = sseEvent .data (). trim ( );
364
395
try {
365
- String rawData = sseEvent .data ().trim ();
366
396
JsonNode node = objectMapper .readTree (rawData );
367
-
397
+ List < McpSchema . JSONRPCMessage > messages = new ArrayList <>();
368
398
if (node .isArray ()) {
369
399
for (JsonNode item : node ) {
370
- String rawMessage = objectMapper .writeValueAsString (item );
371
- McpSchema .JSONRPCMessage msg = McpSchema .deserializeJsonRpcMessage (objectMapper ,
372
- rawMessage );
373
- handler .apply (Mono .just (msg )).subscribe ();
400
+ messages .add (McpSchema .deserializeJsonRpcMessage (objectMapper , item .toString ()));
374
401
}
375
402
}
376
403
else if (node .isObject ()) {
377
- String rawMessage = objectMapper .writeValueAsString (node );
378
- McpSchema .JSONRPCMessage msg = McpSchema .deserializeJsonRpcMessage (objectMapper , rawMessage );
379
- handler .apply (Mono .just (msg )).subscribe ();
404
+ messages .add (McpSchema .deserializeJsonRpcMessage (objectMapper , node .toString ()));
380
405
}
381
406
else {
382
- LOGGER .warn ("Unexpected JSON in SSE data: {}" , rawData );
407
+ String warning = "Unexpected JSON in SSE data: " + rawData ;
408
+ LOGGER .warn (warning );
409
+ return Mono .error (new IllegalArgumentException (warning ));
383
410
}
411
+
412
+ return Flux .fromIterable (messages )
413
+ .concatMap (msg -> handler .apply (Mono .just (msg )))
414
+ .then (Mono .fromRunnable (() -> {
415
+ if (!sseEvent .id ().isEmpty ()) {
416
+ lastEventId .set (sseEvent .id ());
417
+ }
418
+ }));
384
419
}
385
420
catch (IOException e ) {
386
- LOGGER .error ("Error processing SSE event: {}" , sseEvent .data (), e );
421
+ LOGGER .error ("Error parsing SSE JSON: {}" , rawData , e );
422
+ return Mono .error (e );
387
423
}
388
424
})
389
425
.then ();
0 commit comments