diff --git a/src/client/streamableHttp.test.ts b/src/client/streamableHttp.test.ts index 0dc582d4..40f22139 100644 --- a/src/client/streamableHttp.test.ts +++ b/src/client/streamableHttp.test.ts @@ -80,7 +80,7 @@ describe("StreamableHTTPClientTransport", () => { (global.fetch as jest.Mock).mockResolvedValueOnce({ ok: true, status: 200, - headers: new Headers({ "mcp-session-id": "test-session-id" }), + headers: new Headers({ "content-type": "text/event-stream", "mcp-session-id": "test-session-id" }), }); await transport.send(message); @@ -164,7 +164,7 @@ describe("StreamableHTTPClientTransport", () => { // We expect the 405 error to be caught and handled gracefully // This should not throw an error that breaks the transport await transport.start(); - await expect(transport.openSseStream()).rejects.toThrow('Failed to open SSE stream: Method Not Allowed'); + await expect(transport.openSseStream()).rejects.toThrow("Failed to open SSE stream: Method Not Allowed"); // Check that GET was attempted expect(global.fetch).toHaveBeenCalledWith( @@ -192,7 +192,7 @@ describe("StreamableHTTPClientTransport", () => { const stream = new ReadableStream({ start(controller) { // Send a server notification via SSE - const event = 'event: message\ndata: {"jsonrpc": "2.0", "method": "serverNotification", "params": {}}\n\n'; + const event = "event: message\ndata: {\"jsonrpc\": \"2.0\", \"method\": \"serverNotification\", \"params\": {}}\n\n"; controller.enqueue(encoder.encode(event)); } }); @@ -237,7 +237,7 @@ describe("StreamableHTTPClientTransport", () => { (global.fetch as jest.Mock) .mockResolvedValueOnce({ - ok: true, + ok: true, status: 200, headers: new Headers({ "content-type": "text/event-stream" }), body: makeStream("request1") @@ -263,13 +263,13 @@ describe("StreamableHTTPClientTransport", () => { // Both streams should have delivered their messages expect(messageSpy).toHaveBeenCalledTimes(2); - + // Verify received messages without assuming specific order expect(messageSpy.mock.calls.some(call => { const msg = call[0]; return msg.id === "request1" && msg.result?.id === "request1"; })).toBe(true); - + expect(messageSpy.mock.calls.some(call => { const msg = call[0]; return msg.id === "request2" && msg.result?.id === "request2"; @@ -281,7 +281,7 @@ describe("StreamableHTTPClientTransport", () => { const encoder = new TextEncoder(); const stream = new ReadableStream({ start(controller) { - const event = 'id: event-123\nevent: message\ndata: {"jsonrpc": "2.0", "method": "serverNotification", "params": {}}\n\n'; + const event = "id: event-123\nevent: message\ndata: {\"jsonrpc\": \"2.0\", \"method\": \"serverNotification\", \"params\": {}}\n\n"; controller.enqueue(encoder.encode(event)); controller.close(); } @@ -313,4 +313,67 @@ describe("StreamableHTTPClientTransport", () => { const lastCall = calls[calls.length - 1]; expect(lastCall[1].headers.get("last-event-id")).toBe("event-123"); }); -}); \ No newline at end of file + + it("should throw error when invalid content-type is received", async () => { + const message: JSONRPCMessage = { + jsonrpc: "2.0", + method: "test", + params: {}, + id: "test-id" + }; + + const stream = new ReadableStream({ + start(controller) { + controller.enqueue("invalid text response"); + controller.close(); + } + }); + + const errorSpy = jest.fn(); + transport.onerror = errorSpy; + + (global.fetch as jest.Mock).mockResolvedValueOnce({ + ok: true, + status: 200, + headers: new Headers({ "content-type": "text/plain" }), + body: stream + }); + + await transport.start(); + await expect(transport.send(message)).rejects.toThrow("Unexpected content type: text/plain"); + expect(errorSpy).toHaveBeenCalled(); + }); + + + it("should always send specified custom headers", async () => { + const requestInit = { + headers: { + "X-Custom-Header": "CustomValue" + } + }; + transport = new StreamableHTTPClientTransport(new URL("http://localhost:1234/mcp"), { + requestInit: requestInit + }); + + let actualReqInit: RequestInit = {}; + + ((global.fetch as jest.Mock)).mockImplementation( + async (_url, reqInit) => { + actualReqInit = reqInit; + return new Response(null, { status: 200, headers: { "content-type": "text/event-stream" } }); + } + ); + + await transport.start(); + + await transport.openSseStream(); + expect((actualReqInit.headers as Headers).get("x-custom-header")).toBe("CustomValue"); + + requestInit.headers["X-Custom-Header"] = "SecondCustomValue"; + + await transport.send({ jsonrpc: "2.0", method: "test", params: {} } as JSONRPCMessage); + expect((actualReqInit.headers as Headers).get("x-custom-header")).toBe("SecondCustomValue"); + + expect(global.fetch).toHaveBeenCalledTimes(2); + }); +}); diff --git a/src/client/streamableHttp.ts b/src/client/streamableHttp.ts index 0c667e35..5ea537c7 100644 --- a/src/client/streamableHttp.ts +++ b/src/client/streamableHttp.ts @@ -1,7 +1,8 @@ import { Transport } from "../shared/transport.js"; import { JSONRPCMessage, JSONRPCMessageSchema } from "../types.js"; import { auth, AuthResult, OAuthClientProvider, UnauthorizedError } from "./auth.js"; -import { EventSourceParserStream } from 'eventsource-parser/stream'; +import { EventSourceParserStream } from "eventsource-parser/stream"; + export class StreamableHTTPError extends Error { constructor( public readonly code: number | undefined, @@ -17,16 +18,16 @@ export class StreamableHTTPError extends Error { export type StreamableHTTPClientTransportOptions = { /** * An OAuth client provider to use for authentication. - * + * * When an `authProvider` is specified and the connection is started: * 1. The connection is attempted with any existing access token from the `authProvider`. * 2. If the access token has expired, the `authProvider` is used to refresh the token. * 3. If token refresh fails or no access token exists, and auth is required, `OAuthClientProvider.redirectToAuthorization` is called, and an `UnauthorizedError` will be thrown from `connect`/`start`. - * + * * After the user has finished authorizing via their user agent, and is redirected back to the MCP client application, call `StreamableHTTPClientTransport.finishAuth` with the authorization code before retrying the connection. - * + * * If an `authProvider` is not provided, and auth is required, an `UnauthorizedError` will be thrown. - * + * * `UnauthorizedError` might also be thrown when sending any message over the transport, indicating that the session has expired, and needs to be re-authed and reconnected. */ authProvider?: OAuthClientProvider; @@ -83,7 +84,7 @@ export class StreamableHTTPClientTransport implements Transport { return await this._startOrAuthStandaloneSSE(); } - private async _commonHeaders(): Promise { + private async _commonHeaders(): Promise { const headers: HeadersInit = {}; if (this._authProvider) { const tokens = await this._authProvider.tokens(); @@ -96,24 +97,25 @@ export class StreamableHTTPClientTransport implements Transport { headers["mcp-session-id"] = this._sessionId; } - return headers; + return new Headers( + { ...headers, ...this._requestInit?.headers } + ); } private async _startOrAuthStandaloneSSE(): Promise { try { // Try to open an initial SSE stream with GET to listen for server messages // This is optional according to the spec - server may not support it - const commonHeaders = await this._commonHeaders(); - const headers = new Headers(commonHeaders); - headers.set('Accept', 'text/event-stream'); + const headers = await this._commonHeaders(); + headers.set("Accept", "text/event-stream"); // Include Last-Event-ID header for resumable streams if (this._lastEventId) { - headers.set('last-event-id', this._lastEventId); + headers.set("last-event-id", this._lastEventId); } const response = await fetch(this._url, { - method: 'GET', + method: "GET", headers, signal: this._abortController?.signal, }); @@ -124,12 +126,10 @@ export class StreamableHTTPClientTransport implements Transport { return await this._authThenStart(); } - const error = new StreamableHTTPError( + throw new StreamableHTTPError( response.status, `Failed to open SSE stream: ${response.statusText}`, ); - this.onerror?.(error); - throw error; } // Successful connection, handle the SSE stream as a standalone listener @@ -144,42 +144,32 @@ export class StreamableHTTPClientTransport implements Transport { if (!stream) { return; } - // Create a pipeline: binary stream -> text decoder -> SSE parser - const eventStream = stream - .pipeThrough(new TextDecoderStream()) - .pipeThrough(new EventSourceParserStream()); - const reader = eventStream.getReader(); const processStream = async () => { - try { - while (true) { - const { done, value: event } = await reader.read(); - if (done) { - break; - } - - // Update last event ID if provided - if (event.id) { - this._lastEventId = event.id; - } - - // Handle message events (default event type is undefined per docs) - // or explicit 'message' event type - if (!event.event || event.event === 'message') { - try { - const message = JSONRPCMessageSchema.parse(JSON.parse(event.data)); - this.onmessage?.(message); - } catch (error) { - this.onerror?.(error as Error); - } + // Create a pipeline: binary stream -> text decoder -> SSE parser + const eventStream = stream + .pipeThrough(new TextDecoderStream()) + .pipeThrough(new EventSourceParserStream()); + + for await (const event of eventStream) { + // Update last event ID if provided + if (event.id) { + this._lastEventId = event.id; + } + // Handle message events (default event type is undefined per docs) + // or explicit 'message' event type + if (!event.event || event.event === "message") { + try { + const message = JSONRPCMessageSchema.parse(JSON.parse(event.data)); + this.onmessage?.(message); + } catch (error) { + this.onerror?.(error as Error); } } - } catch (error) { - this.onerror?.(error as Error); } }; - processStream(); + processStream().catch(err => this.onerror?.(err)); } async start() { @@ -215,8 +205,7 @@ export class StreamableHTTPClientTransport implements Transport { async send(message: JSONRPCMessage | JSONRPCMessage[]): Promise { try { - const commonHeaders = await this._commonHeaders(); - const headers = new Headers({ ...commonHeaders, ...this._requestInit?.headers }); + const headers = await this._commonHeaders(); headers.set("content-type", "application/json"); headers.set("accept", "application/json, text/event-stream"); @@ -261,20 +250,13 @@ export class StreamableHTTPClientTransport implements Transport { // Get original message(s) for detecting request IDs const messages = Array.isArray(message) ? message : [message]; - // Extract IDs from request messages for tracking responses - const requestIds = messages.filter(msg => 'method' in msg && 'id' in msg) - .map(msg => 'id' in msg ? msg.id : undefined) - .filter(id => id !== undefined); - - // If we have request IDs and an SSE response, create a unique stream ID - const hasRequests = requestIds.length > 0; + const hasRequests = messages.filter(msg => "method" in msg && "id" in msg && msg.id !== undefined).length > 0; // Check the response type const contentType = response.headers.get("content-type"); if (hasRequests) { if (contentType?.includes("text/event-stream")) { - // For streaming responses, create a unique stream ID based on request IDs this._handleSseStream(response.body); } else if (contentType?.includes("application/json")) { // For non-streaming servers, we might get direct JSON responses @@ -286,6 +268,11 @@ export class StreamableHTTPClientTransport implements Transport { for (const msg of responseMessages) { this.onmessage?.(msg); } + } else { + throw new StreamableHTTPError( + -1, + `Unexpected content type: ${contentType}`, + ); } } } catch (error) { @@ -296,7 +283,7 @@ export class StreamableHTTPClientTransport implements Transport { /** * Opens SSE stream to receive messages from the server. - * + * * This allows the server to push messages to the client without requiring the client * to first send a request via HTTP POST. Some servers may not support this feature. * If authentication is required but fails, this method will throw an UnauthorizedError. @@ -309,4 +296,4 @@ export class StreamableHTTPClientTransport implements Transport { } await this._startOrAuthStandaloneSSE(); } -} \ No newline at end of file +} diff --git a/src/examples/README.md b/src/examples/README.md index 6e53fdec..cc6af51c 100644 --- a/src/examples/README.md +++ b/src/examples/README.md @@ -6,6 +6,62 @@ This directory contains example implementations of MCP clients and servers using Multi node with stete management example will be added soon after we add support. +### Server with JSON response mode (`server/jsonResponseStreamableHttp.ts`) + +A simple MCP server that uses the Streamable HTTP transport with JSON response mode enabled, implemented with Express. The server provides a simple `greet` tool that returns a greeting for a name. + +#### Running the server + +```bash +npx tsx src/examples/server/jsonResponseStreamableHttp.ts +``` + +The server will start on port 3000. You can test the initialization and tool calling: + +```bash +# Initialize the server and get the session ID from headers +SESSION_ID=$(curl -X POST \ + -H "Content-Type: application/json" \ + -H "Accept: application/json" \ + -H "Accept: text/event-stream" \ + -d '{ + "jsonrpc": "2.0", + "method": "initialize", + "params": { + "capabilities": {}, + "protocolVersion": "2025-03-26", + "clientInfo": { + "name": "test", + "version": "1.0.0" + } + }, + "id": "1" + }' \ + -i http://localhost:3000/mcp 2>&1 | grep -i "mcp-session-id" | cut -d' ' -f2 | tr -d '\r') +echo "Session ID: $SESSION_ID" + +# Call the greet tool using the saved session ID +curl -X POST \ + -H "Content-Type: application/json" \ + -H "Accept: application/json" \ + -H "Accept: text/event-stream" \ + -H "mcp-session-id: $SESSION_ID" \ + -d '{ + "jsonrpc": "2.0", + "method": "tools/call", + "params": { + "name": "greet", + "arguments": { + "name": "World" + } + }, + "id": "2" + }' \ + http://localhost:3000/mcp +``` + +Note that in this example, we're using plain JSON response mode by setting `Accept: application/json` header. + ### Server (`server/simpleStreamableHttp.ts`) A simple MCP server that uses the Streamable HTTP transport, implemented with Express. The server provides: @@ -24,10 +80,25 @@ The server will start on port 3000. You can test the initialization and tool lis ```bash # First initialize the server and save the session ID to a variable -SESSION_ID=$(curl -X POST -H "Content-Type: application/json" -H "Accept: application/json, text/event-stream" \ - -d '{"jsonrpc":"2.0","method":"initialize","params":{"capabilities":{}},"id":"1"}' \ +SESSION_ID=$(curl -X POST \ + -H "Content-Type: application/json" \ + -H "Accept: application/json" \ + -H "Accept: text/event-stream" \ + -d '{ + "jsonrpc": "2.0", + "method": "initialize", + "params": { + "capabilities": {}, + "protocolVersion": "2025-03-26", + "clientInfo": { + "name": "test", + "version": "1.0.0" + } + }, + "id": "1" + }' \ -i http://localhost:3000/mcp 2>&1 | grep -i "mcp-session-id" | cut -d' ' -f2 | tr -d '\r') -echo "Session ID: $SESSION_ID" +echo "Session ID: $SESSION_ID # Then list tools using the saved session ID curl -X POST -H "Content-Type: application/json" -H "Accept: application/json, text/event-stream" \ diff --git a/src/examples/client/simpleStreamableHttp.ts b/src/examples/client/simpleStreamableHttp.ts index 9bf43ce8..b17add14 100644 --- a/src/examples/client/simpleStreamableHttp.ts +++ b/src/examples/client/simpleStreamableHttp.ts @@ -10,24 +10,29 @@ import { GetPromptRequest, GetPromptResultSchema, ListResourcesRequest, - ListResourcesResultSchema + ListResourcesResultSchema, + LoggingMessageNotificationSchema } from '../../types.js'; async function main(): Promise { // Create a new client with streamable HTTP transport - const client = new Client({ - name: 'example-client', - version: '1.0.0' + const client = new Client({ + name: 'example-client', + version: '1.0.0' }); + const transport = new StreamableHTTPClientTransport( new URL('http://localhost:3000/mcp') ); // Connect the client using the transport and initialize the server await client.connect(transport); + client.setNotificationHandler(LoggingMessageNotificationSchema, (notification) => { + console.log(`Notification received: ${notification.params.level} - ${notification.params.data}`); + }); + console.log('Connected to MCP server'); - // List available tools const toolsRequest: ListToolsRequest = { method: 'tools/list', @@ -47,33 +52,62 @@ async function main(): Promise { const greetResult = await client.request(greetRequest, CallToolResultSchema); console.log('Greeting result:', greetResult.content[0].text); - // List available prompts - const promptsRequest: ListPromptsRequest = { - method: 'prompts/list', - params: {} - }; - const promptsResult = await client.request(promptsRequest, ListPromptsResultSchema); - console.log('Available prompts:', promptsResult.prompts); - - // Get a prompt - const promptRequest: GetPromptRequest = { - method: 'prompts/get', + // Call the new 'multi-greet' tool + console.log('\nCalling multi-greet tool (with notifications)...'); + const multiGreetRequest: CallToolRequest = { + method: 'tools/call', params: { - name: 'greeting-template', + name: 'multi-greet', arguments: { name: 'MCP User' } } }; - const promptResult = await client.request(promptRequest, GetPromptResultSchema); - console.log('Prompt template:', promptResult.messages[0].content.text); + const multiGreetResult = await client.request(multiGreetRequest, CallToolResultSchema); + console.log('Multi-greet results:'); + multiGreetResult.content.forEach(item => { + if (item.type === 'text') { + console.log(`- ${item.text}`); + } + }); + + // List available prompts + try { + const promptsRequest: ListPromptsRequest = { + method: 'prompts/list', + params: {} + }; + const promptsResult = await client.request(promptsRequest, ListPromptsResultSchema); + console.log('Available prompts:', promptsResult.prompts); + } catch (error) { + console.log(`Prompts not supported by this server (${error})`); + } + + // Get a prompt + try { + const promptRequest: GetPromptRequest = { + method: 'prompts/get', + params: { + name: 'greeting-template', + arguments: { name: 'MCP User' } + } + }; + const promptResult = await client.request(promptRequest, GetPromptResultSchema); + console.log('Prompt template:', promptResult.messages[0].content.text); + } catch (error) { + console.log(`Prompt retrieval not supported by this server (${error})`); + } // List available resources - const resourcesRequest: ListResourcesRequest = { - method: 'resources/list', - params: {} - }; - const resourcesResult = await client.request(resourcesRequest, ListResourcesResultSchema); - console.log('Available resources:', resourcesResult.resources); - + try { + const resourcesRequest: ListResourcesRequest = { + method: 'resources/list', + params: {} + }; + const resourcesResult = await client.request(resourcesRequest, ListResourcesResultSchema); + console.log('Available resources:', resourcesResult.resources); + } catch (error) { + console.log(`Resources not supported by this server (${error})`); + } + // Close the connection await client.close(); } diff --git a/src/examples/server/jsonResponseStreamableHttp.ts b/src/examples/server/jsonResponseStreamableHttp.ts new file mode 100644 index 00000000..1d322112 --- /dev/null +++ b/src/examples/server/jsonResponseStreamableHttp.ts @@ -0,0 +1,182 @@ +import express, { Request, Response } from 'express'; +import { randomUUID } from 'node:crypto'; +import { McpServer } from '../../server/mcp.js'; +import { StreamableHTTPServerTransport } from '../../server/streamableHttp.js'; +import { z } from 'zod'; +import { CallToolResult } from '../../types.js'; + +// Create an MCP server with implementation details +const server = new McpServer({ + name: 'json-response-streamable-http-server', + version: '1.0.0', +}, { + capabilities: { + logging: {}, + } +}); + +// Register a simple tool that returns a greeting +server.tool( + 'greet', + 'A simple greeting tool', + { + name: z.string().describe('Name to greet'), + }, + async ({ name }): Promise => { + return { + content: [ + { + type: 'text', + text: `Hello, ${name}!`, + }, + ], + }; + } +); + +// Register a tool that sends multiple greetings with notifications +server.tool( + 'multi-greet', + 'A tool that sends different greetings with delays between them', + { + name: z.string().describe('Name to greet'), + }, + async ({ name }, { sendNotification }): Promise => { + const sleep = (ms: number) => new Promise(resolve => setTimeout(resolve, ms)); + + await sendNotification({ + method: "notifications/message", + params: { level: "debug", data: `Starting multi-greet for ${name}` } + }); + + await sleep(1000); // Wait 1 second before first greeting + + await sendNotification({ + method: "notifications/message", + params: { level: "info", data: `Sending first greeting to ${name}` } + }); + + await sleep(1000); // Wait another second before second greeting + + await sendNotification({ + method: "notifications/message", + params: { level: "info", data: `Sending second greeting to ${name}` } + }); + + return { + content: [ + { + type: 'text', + text: `Good morning, ${name}!`, + } + ], + }; + } +); + +const app = express(); +app.use(express.json()); + +// Map to store transports by session ID +const transports: { [sessionId: string]: StreamableHTTPServerTransport } = {}; + +app.post('/mcp', async (req: Request, res: Response) => { + console.log('Received MCP request:', req.body); + try { + // Check for existing session ID + const sessionId = req.headers['mcp-session-id'] as string | undefined; + let transport: StreamableHTTPServerTransport; + + if (sessionId && transports[sessionId]) { + // Reuse existing transport + transport = transports[sessionId]; + } else if (!sessionId && isInitializeRequest(req.body)) { + // New initialization request - use JSON response mode + transport = new StreamableHTTPServerTransport({ + sessionIdGenerator: () => randomUUID(), + enableJsonResponse: true, // Enable JSON response mode + }); + + // Connect the transport to the MCP server BEFORE handling the request + await server.connect(transport); + + // After handling the request, if we get a session ID back, store the transport + await transport.handleRequest(req, res, req.body); + + // Store the transport by session ID for future requests + if (transport.sessionId) { + transports[transport.sessionId] = transport; + } + return; // Already handled + } else { + // Invalid request - no session ID or not initialization request + res.status(400).json({ + jsonrpc: '2.0', + error: { + code: -32000, + message: 'Bad Request: No valid session ID provided', + }, + id: null, + }); + return; + } + + // Handle the request with existing transport - no need to reconnect + await transport.handleRequest(req, res, req.body); + } catch (error) { + console.error('Error handling MCP request:', error); + if (!res.headersSent) { + res.status(500).json({ + jsonrpc: '2.0', + error: { + code: -32603, + message: 'Internal server error', + }, + id: null, + }); + } + } +}); + +// Helper function to detect initialize requests +function isInitializeRequest(body: unknown): boolean { + if (Array.isArray(body)) { + return body.some(msg => typeof msg === 'object' && msg !== null && 'method' in msg && msg.method === 'initialize'); + } + return typeof body === 'object' && body !== null && 'method' in body && body.method === 'initialize'; +} + +// Start the server +const PORT = 3000; +app.listen(PORT, () => { + console.log(`MCP Streamable HTTP Server listening on port ${PORT}`); + console.log(`Initialize session with the command below id you are using curl for testing: + ----------------------------- + SESSION_ID=$(curl -X POST \ + -H "Content-Type: application/json" \ + -H "Accept: application/json" \ + -H "Accept: text/event-stream" \ + -d '{ + "jsonrpc": "2.0", + "method": "initialize", + "params": { + "capabilities": {}, + "protocolVersion": "2025-03-26", + "clientInfo": { + "name": "test", + "version": "1.0.0" + } + }, + "id": "1" + }' \ + -i http://localhost:3000/mcp 2>&1 | grep -i "mcp-session-id" | cut -d' ' -f2 | tr -d '\\r') + echo "Session ID: $SESSION_ID" + -----------------------------`); +}); + +// Handle server shutdown +process.on('SIGINT', async () => { + console.log('Shutting down server...'); + await server.close(); + process.exit(0); +}); \ No newline at end of file diff --git a/src/examples/server/simpleStreamableHttp.ts b/src/examples/server/simpleStreamableHttp.ts index e6ebe4b9..5b228cbd 100644 --- a/src/examples/server/simpleStreamableHttp.ts +++ b/src/examples/server/simpleStreamableHttp.ts @@ -9,7 +9,7 @@ import { CallToolResult, GetPromptResult, ReadResourceResult } from '../../types const server = new McpServer({ name: 'simple-streamable-http-server', version: '1.0.0', -}); +}, { capabilities: { logging: {} } }); // Register a simple tool that returns a greeting server.tool( @@ -30,6 +30,46 @@ server.tool( } ); +// Register a tool that sends multiple greetings with notifications +server.tool( + 'multi-greet', + 'A tool that sends different greetings with delays between them', + { + name: z.string().describe('Name to greet'), + }, + async ({ name }, { sendNotification }): Promise => { + const sleep = (ms: number) => new Promise(resolve => setTimeout(resolve, ms)); + + await sendNotification({ + method: "notifications/message", + params: { level: "debug", data: `Starting multi-greet for ${name}` } + }); + + await sleep(1000); // Wait 1 second before first greeting + + await sendNotification({ + method: "notifications/message", + params: { level: "info", data: `Sending first greeting to ${name}` } + }); + + await sleep(1000); // Wait another second before second greeting + + await sendNotification({ + method: "notifications/message", + params: { level: "info", data: `Sending second greeting to ${name}` } + }); + + return { + content: [ + { + type: 'text', + text: `Good morning, ${name}!`, + } + ], + }; + } +); + // Register a simple prompt server.prompt( 'greeting-template', @@ -81,7 +121,7 @@ app.post('/mcp', async (req: Request, res: Response) => { // Check for existing session ID const sessionId = req.headers['mcp-session-id'] as string | undefined; let transport: StreamableHTTPServerTransport; - + if (sessionId && transports[sessionId]) { // Reuse existing transport transport = transports[sessionId]; @@ -90,14 +130,14 @@ app.post('/mcp', async (req: Request, res: Response) => { transport = new StreamableHTTPServerTransport({ sessionIdGenerator: () => randomUUID(), }); - + // Connect the transport to the MCP server BEFORE handling the request // so responses can flow back through the same transport await server.connect(transport); - + // After handling the request, if we get a session ID back, store the transport await transport.handleRequest(req, res, req.body); - + // Store the transport by session ID for future requests if (transport.sessionId) { transports[transport.sessionId] = transport; @@ -146,7 +186,28 @@ function isInitializeRequest(body: unknown): boolean { const PORT = 3000; app.listen(PORT, () => { console.log(`MCP Streamable HTTP Server listening on port ${PORT}`); - console.log(`Test with: curl -X POST -H "Content-Type: application/json" -H "Accept: application/json, text/event-stream" -d '{"jsonrpc":"2.0","method":"initialize","params":{"capabilities":{}},"id":"1"}' http://localhost:${PORT}/mcp`); + console.log(`Initialize session with the command below id you are using curl for testing: + ----------------------------- + SESSION_ID=$(curl -X POST \ + -H "Content-Type: application/json" \ + -H "Accept: application/json" \ + -H "Accept: text/event-stream" \ + -d '{ + "jsonrpc": "2.0", + "method": "initialize", + "params": { + "capabilities": {}, + "protocolVersion": "2025-03-26", + "clientInfo": { + "name": "test", + "version": "1.0.0" + } + }, + "id": "1" + }' \ + -i http://localhost:3000/mcp 2>&1 | grep -i "mcp-session-id" | cut -d' ' -f2 | tr -d '\\r') + echo "Session ID: $SESSION_ID" + -----------------------------`); }); // Handle server shutdown diff --git a/src/server/streamableHttp.test.ts b/src/server/streamableHttp.test.ts index bdbfb6ba..ad80ea62 100644 --- a/src/server/streamableHttp.test.ts +++ b/src/server/streamableHttp.test.ts @@ -1008,8 +1008,8 @@ describe("StreamableHTTPServerTransport", () => { await transport.handleRequest(req, response); // Verify that the request is tracked in the SSE map - expect(transport["_sseResponseMapping"].size).toBe(2); - expect(transport["_sseResponseMapping"].has("cleanup-test")).toBe(true); + expect(transport["_responseMapping"].size).toBe(2); + expect(transport["_responseMapping"].has("cleanup-test")).toBe(true); // Send a response await transport.send({ @@ -1019,8 +1019,8 @@ describe("StreamableHTTPServerTransport", () => { }); // Verify that the mapping was cleaned up - expect(transport["_sseResponseMapping"].size).toBe(1); - expect(transport["_sseResponseMapping"].has("cleanup-test")).toBe(false); + expect(transport["_responseMapping"].size).toBe(1); + expect(transport["_responseMapping"].has("cleanup-test")).toBe(false); }); it("should clean up connection tracking when client disconnects", async () => { @@ -1052,17 +1052,17 @@ describe("StreamableHTTPServerTransport", () => { await transport.handleRequest(req, response); // Both requests should be mapped to the same response - expect(transport["_sseResponseMapping"].size).toBe(3); - expect(transport["_sseResponseMapping"].get("req1")).toBe(response); - expect(transport["_sseResponseMapping"].get("req2")).toBe(response); + expect(transport["_responseMapping"].size).toBe(3); + expect(transport["_responseMapping"].get("req1")).toBe(response); + expect(transport["_responseMapping"].get("req2")).toBe(response); // Simulate client disconnect by triggering the stored callback if (closeCallback) closeCallback(); // All entries using this response should be removed - expect(transport["_sseResponseMapping"].size).toBe(1); - expect(transport["_sseResponseMapping"].has("req1")).toBe(false); - expect(transport["_sseResponseMapping"].has("req2")).toBe(false); + expect(transport["_responseMapping"].size).toBe(1); + expect(transport["_responseMapping"].has("req1")).toBe(false); + expect(transport["_responseMapping"].has("req2")).toBe(false); }); }); @@ -1214,6 +1214,140 @@ describe("StreamableHTTPServerTransport", () => { }); }); + describe("JSON Response Mode", () => { + let jsonResponseTransport: StreamableHTTPServerTransport; + let mockResponse: jest.Mocked; + + beforeEach(async () => { + jsonResponseTransport = new StreamableHTTPServerTransport({ + sessionIdGenerator: () => randomUUID(), + enableJsonResponse: true, + }); + + // Initialize the transport + const initMessage: JSONRPCMessage = { + jsonrpc: "2.0", + method: "initialize", + params: { + clientInfo: { name: "test-client", version: "1.0" }, + protocolVersion: "2025-03-26" + }, + id: "init-1", + }; + + const initReq = createMockRequest({ + method: "POST", + headers: { + "content-type": "application/json", + "accept": "application/json, text/event-stream", + }, + body: JSON.stringify(initMessage), + }); + + mockResponse = createMockResponse(); + await jsonResponseTransport.handleRequest(initReq, mockResponse); + mockResponse = createMockResponse(); // Reset for tests + }); + + it("should return JSON response for a single request", async () => { + const requestMessage: JSONRPCMessage = { + jsonrpc: "2.0", + method: "tools/list", + params: {}, + id: "test-req-id", + }; + + const req = createMockRequest({ + method: "POST", + headers: { + "content-type": "application/json", + "accept": "application/json, text/event-stream", + "mcp-session-id": jsonResponseTransport.sessionId, + }, + body: JSON.stringify(requestMessage), + }); + + // Mock immediate response + jsonResponseTransport.onmessage = (message) => { + if ('method' in message && 'id' in message) { + const responseMessage: JSONRPCMessage = { + jsonrpc: "2.0", + result: { value: `test-result` }, + id: message.id, + }; + void jsonResponseTransport.send(responseMessage); + } + }; + + await jsonResponseTransport.handleRequest(req, mockResponse); + // Should respond with application/json header + expect(mockResponse.writeHead).toHaveBeenCalledWith( + 200, + expect.objectContaining({ + "Content-Type": "application/json", + }) + ); + + // Should return the response as JSON + const expectedResponse = { + jsonrpc: "2.0", + result: { value: "test-result" }, + id: "test-req-id", + }; + + expect(mockResponse.end).toHaveBeenCalledWith(JSON.stringify(expectedResponse)); + }); + + it("should return JSON response for batch requests", async () => { + const batchMessages: JSONRPCMessage[] = [ + { jsonrpc: "2.0", method: "tools/list", params: {}, id: "req1" }, + { jsonrpc: "2.0", method: "tools/call", params: {}, id: "req2" }, + ]; + + const req = createMockRequest({ + method: "POST", + headers: { + "content-type": "application/json", + "accept": "application/json, text/event-stream", + "mcp-session-id": jsonResponseTransport.sessionId, + }, + body: JSON.stringify(batchMessages), + }); + + // Mock responses without enforcing specific order + jsonResponseTransport.onmessage = (message) => { + if ('method' in message && 'id' in message) { + const responseMessage: JSONRPCMessage = { + jsonrpc: "2.0", + result: { value: `result-for-${message.id}` }, + id: message.id, + }; + void jsonResponseTransport.send(responseMessage); + } + }; + + await jsonResponseTransport.handleRequest(req, mockResponse); + + // Should respond with application/json header + expect(mockResponse.writeHead).toHaveBeenCalledWith( + 200, + expect.objectContaining({ + "Content-Type": "application/json", + }) + ); + + // Verify response was sent but don't assume specific order + expect(mockResponse.end).toHaveBeenCalled(); + const responseJson = JSON.parse(mockResponse.end.mock.calls[0][0] as string); + expect(Array.isArray(responseJson)).toBe(true); + expect(responseJson).toHaveLength(2); + + // Check each response exists separately without assuming order + expect(responseJson).toContainEqual(expect.objectContaining({ id: "req1", result: { value: "result-for-req1" } })); + expect(responseJson).toContainEqual(expect.objectContaining({ id: "req2", result: { value: "result-for-req2" } })); + }); + }); + describe("Handling Pre-Parsed Body", () => { beforeEach(async () => { // Initialize the transport diff --git a/src/server/streamableHttp.ts b/src/server/streamableHttp.ts index b0fcce6d..e8844529 100644 --- a/src/server/streamableHttp.ts +++ b/src/server/streamableHttp.ts @@ -18,8 +18,12 @@ export interface StreamableHTTPServerTransportOptions { */ sessionIdGenerator: () => string | undefined; - - + /** + * If true, the server will return JSON responses instead of starting an SSE stream. + * This can be useful for simple request/response scenarios without streaming. + * Default is false (SSE streams are preferred). + */ + enableJsonResponse?: boolean; } /** @@ -60,8 +64,11 @@ export class StreamableHTTPServerTransport implements Transport { // when sessionId is not set (undefined), it means the transport is in stateless mode private sessionIdGenerator: () => string | undefined; private _started: boolean = false; - private _sseResponseMapping: Map = new Map(); + private _responseMapping: Map = new Map(); + private _requestResponseMap: Map = new Map(); private _initialized: boolean = false; + private _enableJsonResponse: boolean = false; + sessionId?: string | undefined; onclose?: () => void; @@ -70,6 +77,7 @@ export class StreamableHTTPServerTransport implements Transport { constructor(options: StreamableHTTPServerTransportOptions) { this.sessionIdGenerator = options.sessionIdGenerator; + this._enableJsonResponse = options.enableJsonResponse ?? false; } /** @@ -221,33 +229,37 @@ export class StreamableHTTPServerTransport implements Transport { this.onmessage?.(message); } } else if (hasRequests) { - const headers: Record = { - "Content-Type": "text/event-stream", - "Cache-Control": "no-cache", - Connection: "keep-alive", - }; - - // After initialization, always include the session ID if we have one - if (this.sessionId !== undefined) { - headers["mcp-session-id"] = this.sessionId; - } - - res.writeHead(200, headers); + // The default behavior is to use SSE streaming + // but in some cases server will return JSON responses + if (!this._enableJsonResponse) { + const headers: Record = { + "Content-Type": "text/event-stream", + "Cache-Control": "no-cache", + Connection: "keep-alive", + }; + + // After initialization, always include the session ID if we have one + if (this.sessionId !== undefined) { + headers["mcp-session-id"] = this.sessionId; + } + res.writeHead(200, headers); + } // Store the response for this request to send messages back through this connection // We need to track by request ID to maintain the connection for (const message of messages) { if ('method' in message && 'id' in message) { - this._sseResponseMapping.set(message.id, res); + this._responseMapping.set(message.id, res); } } // Set up close handler for client disconnects res.on("close", () => { // Remove all entries that reference this response - for (const [id, storedRes] of this._sseResponseMapping.entries()) { + for (const [id, storedRes] of this._responseMapping.entries()) { if (storedRes === res) { - this._sseResponseMapping.delete(id); + this._responseMapping.delete(id); + this._requestResponseMap.delete(id); } } }); @@ -350,10 +362,14 @@ export class StreamableHTTPServerTransport implements Transport { async close(): Promise { // Close all SSE connections - this._sseResponseMapping.forEach((response) => { + this._responseMapping.forEach((response) => { response.end(); }); - this._sseResponseMapping.clear(); + this._responseMapping.clear(); + + // Clear any pending responses + this._requestResponseMap.clear(); + this.onclose?.(); } @@ -367,24 +383,57 @@ export class StreamableHTTPServerTransport implements Transport { throw new Error("No request ID provided for the message"); } - const sseResponse = this._sseResponseMapping.get(requestId); - if (!sseResponse) { - throw new Error(`No SSE connection established for request ID: ${String(requestId)}`); + // Get the response for this request + const response = this._responseMapping.get(requestId); + if (!response) { + throw new Error(`No connection established for request ID: ${String(requestId)}`); } - // Send the message as an SSE event - sseResponse.write( - `event: message\ndata: ${JSON.stringify(message)}\n\n`, - ); - // After all JSON-RPC responses have been sent, the server SHOULD close the SSE stream. + + if (!this._enableJsonResponse) { + response.write(`event: message\ndata: ${JSON.stringify(message)}\n\n`); + } if ('result' in message || 'error' in message) { - this._sseResponseMapping.delete(requestId); - // Only close the connection if it's not needed by other requests - const canCloseConnection = ![...this._sseResponseMapping.entries()].some(([id, res]) => res === sseResponse && id !== requestId); - if (canCloseConnection) { - sseResponse?.end(); + this._requestResponseMap.set(requestId, message); + + // Get all request IDs that share the same request response object + const relatedIds = Array.from(this._responseMapping.entries()) + .filter(([_, res]) => res === response) + .map(([id]) => id); + + // Check if we have responses for all requests using this connection + const allResponsesReady = relatedIds.every(id => this._requestResponseMap.has(id)); + + if (allResponsesReady) { + if (this._enableJsonResponse) { + // All responses ready, send as JSON + const headers: Record = { + 'Content-Type': 'application/json', + }; + if (this.sessionId !== undefined) { + headers['mcp-session-id'] = this.sessionId; + } + + const responses = relatedIds + .map(id => this._requestResponseMap.get(id)!); + + response.writeHead(200, headers); + if (responses.length === 1) { + response.end(JSON.stringify(responses[0])); + } else { + response.end(JSON.stringify(responses)); + } + } else { + // End the SSE stream + response.end(); + } + // Clean up + for (const id of relatedIds) { + this._requestResponseMap.delete(id); + this._responseMapping.delete(id); + } } } } +} -} \ No newline at end of file