diff --git a/src/server/streamableHttp.test.ts b/src/server/streamableHttp.test.ts index aff9e511..bdbfb6ba 100644 --- a/src/server/streamableHttp.test.ts +++ b/src/server/streamableHttp.test.ts @@ -842,8 +842,8 @@ describe("StreamableHTTPServerTransport", () => { }, body: JSON.stringify(initMessage), }); - - await transport.handleRequest(initReq, mockResponse); + const initResponse = createMockResponse(); + await transport.handleRequest(initReq, initResponse); mockResponse.writeHead.mockClear(); }); @@ -934,6 +934,136 @@ describe("StreamableHTTPServerTransport", () => { // Now stream should be closed expect(mockResponse.end).toHaveBeenCalled(); }); + + it("should keep stream open when multiple requests share the same connection", async () => { + // Create a fresh response for this test + const sharedResponse = createMockResponse(); + + // Send two requests in a batch that will share the same connection + const batchRequests: JSONRPCMessage[] = [ + { jsonrpc: "2.0", method: "method1", params: {}, id: "req1" }, + { jsonrpc: "2.0", method: "method2", params: {}, id: "req2" } + ]; + + const req = createMockRequest({ + method: "POST", + headers: { + "content-type": "application/json", + "accept": "application/json, text/event-stream", + "mcp-session-id": transport.sessionId + }, + body: JSON.stringify(batchRequests) + }); + + await transport.handleRequest(req, sharedResponse); + + // Respond to first request + const response1: JSONRPCMessage = { + jsonrpc: "2.0", + result: { value: "result1" }, + id: "req1" + }; + + await transport.send(response1); + + // Connection should remain open because req2 is still pending + expect(sharedResponse.write).toHaveBeenCalledWith( + expect.stringContaining(`event: message\ndata: ${JSON.stringify(response1)}\n\n`) + ); + expect(sharedResponse.end).not.toHaveBeenCalled(); + + // Respond to second request + const response2: JSONRPCMessage = { + jsonrpc: "2.0", + result: { value: "result2" }, + id: "req2" + }; + + await transport.send(response2); + + // Now connection should close as all requests are complete + expect(sharedResponse.write).toHaveBeenCalledWith( + expect.stringContaining(`event: message\ndata: ${JSON.stringify(response2)}\n\n`) + ); + expect(sharedResponse.end).toHaveBeenCalled(); + }); + + it("should clean up connection tracking when a response is sent", async () => { + const req = createMockRequest({ + method: "POST", + headers: { + "content-type": "application/json", + "accept": "application/json, text/event-stream", + "mcp-session-id": transport.sessionId + }, + body: JSON.stringify({ + jsonrpc: "2.0", + method: "test", + params: {}, + id: "cleanup-test" + }) + }); + + const response = createMockResponse(); + await transport.handleRequest(req, response); + + // Verify that the request is tracked in the SSE map + expect(transport["_sseResponseMapping"].size).toBe(2); + expect(transport["_sseResponseMapping"].has("cleanup-test")).toBe(true); + + // Send a response + await transport.send({ + jsonrpc: "2.0", + result: {}, + id: "cleanup-test" + }); + + // Verify that the mapping was cleaned up + expect(transport["_sseResponseMapping"].size).toBe(1); + expect(transport["_sseResponseMapping"].has("cleanup-test")).toBe(false); + }); + + it("should clean up connection tracking when client disconnects", async () => { + // Setup two requests that share a connection + const req = createMockRequest({ + method: "POST", + headers: { + "content-type": "application/json", + "accept": "application/json, text/event-stream", + "mcp-session-id": transport.sessionId + }, + body: JSON.stringify([ + { jsonrpc: "2.0", method: "longRunning1", params: {}, id: "req1" }, + { jsonrpc: "2.0", method: "longRunning2", params: {}, id: "req2" } + ]) + }); + + const response = createMockResponse(); + + // We need to manually store the callback to trigger it later + let closeCallback: (() => void) | undefined; + response.on.mockImplementation((event, callback: () => void) => { + if (typeof event === "string" && event === "close") { + closeCallback = callback; + } + return response; + }); + + await transport.handleRequest(req, response); + + // Both requests should be mapped to the same response + expect(transport["_sseResponseMapping"].size).toBe(3); + expect(transport["_sseResponseMapping"].get("req1")).toBe(response); + expect(transport["_sseResponseMapping"].get("req2")).toBe(response); + + // Simulate client disconnect by triggering the stored callback + if (closeCallback) closeCallback(); + + // All entries using this response should be removed + expect(transport["_sseResponseMapping"].size).toBe(1); + expect(transport["_sseResponseMapping"].has("req1")).toBe(false); + expect(transport["_sseResponseMapping"].has("req2")).toBe(false); + }); }); describe("Message Targeting", () => { diff --git a/src/server/streamableHttp.ts b/src/server/streamableHttp.ts index 34b4fd95..b0fcce6d 100644 --- a/src/server/streamableHttp.ts +++ b/src/server/streamableHttp.ts @@ -197,19 +197,7 @@ export class StreamableHTTPServerTransport implements Transport { } this.sessionId = this.sessionIdGenerator(); this._initialized = true; - const headers: Record = {}; - if (this.sessionId !== undefined) { - headers["mcp-session-id"] = this.sessionId; - } - - // Process initialization messages before responding - for (const message of messages) { - this.onmessage?.(message); - } - - res.writeHead(200, headers).end(); - return; } // If an Mcp-Session-Id is returned by the server during initialization, // clients using the Streamable HTTP transport MUST include it @@ -254,6 +242,16 @@ export class StreamableHTTPServerTransport implements Transport { } } + // Set up close handler for client disconnects + res.on("close", () => { + // Remove all entries that reference this response + for (const [id, storedRes] of this._sseResponseMapping.entries()) { + if (storedRes === res) { + this._sseResponseMapping.delete(id); + } + } + }); + // handle each message for (const message of messages) { this.onmessage?.(message); @@ -360,36 +358,31 @@ export class StreamableHTTPServerTransport implements Transport { } async send(message: JSONRPCMessage, options?: { relatedRequestId?: RequestId }): Promise { - const relatedRequestId = options?.relatedRequestId; - // SSE connections are established per POST request, for now we don't support it through the GET - // this will be changed when we implement the GET SSE connection - if (relatedRequestId === undefined) { - throw new Error("relatedRequestId is required for Streamable HTTP transport"); + let requestId = options?.relatedRequestId; + if ('result' in message || 'error' in message) { + // If the message is a response, use the request ID from the message + requestId = message.id; + } + if (requestId === undefined) { + throw new Error("No request ID provided for the message"); } - const sseResponse = this._sseResponseMapping.get(relatedRequestId); + const sseResponse = this._sseResponseMapping.get(requestId); if (!sseResponse) { - throw new Error(`No SSE connection established for request ID: ${String(relatedRequestId)}`); + throw new Error(`No SSE connection established for request ID: ${String(requestId)}`); } // Send the message as an SSE event sseResponse.write( `event: message\ndata: ${JSON.stringify(message)}\n\n`, ); - - // If this is a response message with the same ID as the request, we can check - // if we need to close the stream after sending the response + // After all JSON-RPC responses have been sent, the server SHOULD close the SSE stream. if ('result' in message || 'error' in message) { - if (message.id === relatedRequestId) { - // This is a response to the original request, we can close the stream - // after sending all related responses - this._sseResponseMapping.delete(relatedRequestId); - - // Only close the connection if it's not needed by other requests - const canCloseConnection = ![...this._sseResponseMapping.entries()].some(([id, res]) => res === sseResponse && id !== relatedRequestId); - if (canCloseConnection) { - sseResponse.end(); - } + this._sseResponseMapping.delete(requestId); + // Only close the connection if it's not needed by other requests + const canCloseConnection = ![...this._sseResponseMapping.entries()].some(([id, res]) => res === sseResponse && id !== requestId); + if (canCloseConnection) { + sseResponse?.end(); } } }