From adbacc63881908031914a527b44ab820523e7f57 Mon Sep 17 00:00:00 2001 From: David Dworken Date: Thu, 29 May 2025 13:32:26 -0700 Subject: [PATCH 1/6] Add DNS rebinding protection for SSE transport --- package-lock.json | 4 +- src/server/sse.test.ts | 240 ++++++++++++++++++++++++++++++++++++++++- src/server/sse.ts | 66 +++++++++++- 3 files changed, 306 insertions(+), 4 deletions(-) diff --git a/package-lock.json b/package-lock.json index 40bad9fe..ef539382 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "@modelcontextprotocol/sdk", - "version": "1.11.4", + "version": "1.12.1", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "@modelcontextprotocol/sdk", - "version": "1.11.4", + "version": "1.12.1", "license": "MIT", "dependencies": { "ajv": "^6.12.6", diff --git a/src/server/sse.test.ts b/src/server/sse.test.ts index 2fd2c042..38ba9e59 100644 --- a/src/server/sse.test.ts +++ b/src/server/sse.test.ts @@ -1,12 +1,14 @@ import http from 'http'; import { jest } from '@jest/globals'; -import { SSEServerTransport } from './sse.js'; +import { SSEServerTransport } from './sse.js'; +import { AuthInfo } from './auth/types.js'; const createMockResponse = () => { const res = { writeHead: jest.fn(), write: jest.fn().mockReturnValue(true), on: jest.fn(), + end: jest.fn().mockReturnThis(), }; res.writeHead.mockReturnThis(); res.on.mockReturnThis(); @@ -14,6 +16,12 @@ const createMockResponse = () => { return res as unknown as http.ServerResponse; }; +const createMockRequest = (headers: Record = {}) => { + return { + headers, + } as unknown as http.IncomingMessage & { auth?: AuthInfo }; +}; + describe('SSEServerTransport', () => { describe('start method', () => { it('should correctly append sessionId to a simple relative endpoint', async () => { @@ -106,4 +114,234 @@ describe('SSEServerTransport', () => { ); }); }); + + describe('DNS rebinding protection', () => { + beforeEach(() => { + jest.clearAllMocks(); + }); + + describe('Host header validation', () => { + it('should accept requests with allowed host headers', async () => { + const mockRes = createMockResponse(); + const transport = new SSEServerTransport('/messages', mockRes, { + allowedHosts: ['localhost:3000', 'example.com'], + }); + await transport.start(); + + const mockReq = createMockRequest({ + host: 'localhost:3000', + 'content-type': 'application/json', + }); + const mockHandleRes = createMockResponse(); + + await transport.handlePostMessage(mockReq, mockHandleRes, { jsonrpc: '2.0', method: 'test' }); + + expect(mockHandleRes.writeHead).toHaveBeenCalledWith(202); + expect(mockHandleRes.end).toHaveBeenCalledWith('Accepted'); + }); + + it('should reject requests with disallowed host headers', async () => { + const mockRes = createMockResponse(); + const transport = new SSEServerTransport('/messages', mockRes, { + allowedHosts: ['localhost:3000'], + }); + await transport.start(); + + const mockReq = createMockRequest({ + host: 'evil.com', + 'content-type': 'application/json', + }); + const mockHandleRes = createMockResponse(); + + await transport.handlePostMessage(mockReq, mockHandleRes, { jsonrpc: '2.0', method: 'test' }); + + expect(mockHandleRes.writeHead).toHaveBeenCalledWith(403); + expect(mockHandleRes.end).toHaveBeenCalledWith('Invalid Host header: evil.com'); + }); + + it('should reject requests without host header when allowedHosts is configured', async () => { + const mockRes = createMockResponse(); + const transport = new SSEServerTransport('/messages', mockRes, { + allowedHosts: ['localhost:3000'], + }); + await transport.start(); + + const mockReq = createMockRequest({ + 'content-type': 'application/json', + }); + const mockHandleRes = createMockResponse(); + + await transport.handlePostMessage(mockReq, mockHandleRes, { jsonrpc: '2.0', method: 'test' }); + + expect(mockHandleRes.writeHead).toHaveBeenCalledWith(403); + expect(mockHandleRes.end).toHaveBeenCalledWith('Invalid Host header: undefined'); + }); + }); + + describe('Origin header validation', () => { + it('should accept requests with allowed origin headers', async () => { + const mockRes = createMockResponse(); + const transport = new SSEServerTransport('/messages', mockRes, { + allowedOrigins: ['http://localhost:3000', 'https://example.com'], + }); + await transport.start(); + + const mockReq = createMockRequest({ + origin: 'http://localhost:3000', + 'content-type': 'application/json', + }); + const mockHandleRes = createMockResponse(); + + await transport.handlePostMessage(mockReq, mockHandleRes, { jsonrpc: '2.0', method: 'test' }); + + expect(mockHandleRes.writeHead).toHaveBeenCalledWith(202); + expect(mockHandleRes.end).toHaveBeenCalledWith('Accepted'); + }); + + it('should reject requests with disallowed origin headers', async () => { + const mockRes = createMockResponse(); + const transport = new SSEServerTransport('/messages', mockRes, { + allowedOrigins: ['http://localhost:3000'], + }); + await transport.start(); + + const mockReq = createMockRequest({ + origin: 'http://evil.com', + 'content-type': 'application/json', + }); + const mockHandleRes = createMockResponse(); + + await transport.handlePostMessage(mockReq, mockHandleRes, { jsonrpc: '2.0', method: 'test' }); + + expect(mockHandleRes.writeHead).toHaveBeenCalledWith(403); + expect(mockHandleRes.end).toHaveBeenCalledWith('Invalid Origin header: http://evil.com'); + }); + }); + + describe('Content-Type validation', () => { + it('should accept requests with application/json content-type', async () => { + const mockRes = createMockResponse(); + const transport = new SSEServerTransport('/messages', mockRes); + await transport.start(); + + const mockReq = createMockRequest({ + 'content-type': 'application/json', + }); + const mockHandleRes = createMockResponse(); + + await transport.handlePostMessage(mockReq, mockHandleRes, { jsonrpc: '2.0', method: 'test' }); + + expect(mockHandleRes.writeHead).toHaveBeenCalledWith(202); + expect(mockHandleRes.end).toHaveBeenCalledWith('Accepted'); + }); + + it('should accept requests with application/json with charset', async () => { + const mockRes = createMockResponse(); + const transport = new SSEServerTransport('/messages', mockRes); + await transport.start(); + + const mockReq = createMockRequest({ + 'content-type': 'application/json; charset=utf-8', + }); + const mockHandleRes = createMockResponse(); + + await transport.handlePostMessage(mockReq, mockHandleRes, { jsonrpc: '2.0', method: 'test' }); + + expect(mockHandleRes.writeHead).toHaveBeenCalledWith(202); + expect(mockHandleRes.end).toHaveBeenCalledWith('Accepted'); + }); + + it('should reject requests with non-application/json content-type when protection is enabled', async () => { + const mockRes = createMockResponse(); + const transport = new SSEServerTransport('/messages', mockRes); + await transport.start(); + + const mockReq = createMockRequest({ + 'content-type': 'text/plain', + }); + const mockHandleRes = createMockResponse(); + + await transport.handlePostMessage(mockReq, mockHandleRes, { jsonrpc: '2.0', method: 'test' }); + + expect(mockHandleRes.writeHead).toHaveBeenCalledWith(400); + expect(mockHandleRes.end).toHaveBeenCalledWith('Error: Content-Type must start with application/json, got: text/plain'); + }); + }); + + describe('disableDnsRebindingProtection option', () => { + it('should skip all validations when disableDnsRebindingProtection is true', async () => { + const mockRes = createMockResponse(); + const transport = new SSEServerTransport('/messages', mockRes, { + allowedHosts: ['localhost:3000'], + allowedOrigins: ['http://localhost:3000'], + disableDnsRebindingProtection: true, + }); + await transport.start(); + + const mockReq = createMockRequest({ + host: 'evil.com', + origin: 'http://evil.com', + 'content-type': 'text/plain', + }); + const mockHandleRes = createMockResponse(); + + await transport.handlePostMessage(mockReq, mockHandleRes, { jsonrpc: '2.0', method: 'test' }); + + // Should pass even with invalid headers because protection is disabled + expect(mockHandleRes.writeHead).toHaveBeenCalledWith(400); + // The error should be from content-type parsing, not DNS rebinding protection + expect(mockHandleRes.end).toHaveBeenCalledWith('Error: Unsupported content-type: text/plain'); + }); + }); + + describe('Combined validations', () => { + it('should validate both host and origin when both are configured', async () => { + const mockRes = createMockResponse(); + const transport = new SSEServerTransport('/messages', mockRes, { + allowedHosts: ['localhost:3000'], + allowedOrigins: ['http://localhost:3000'], + }); + await transport.start(); + + // Valid host, invalid origin + const mockReq1 = createMockRequest({ + host: 'localhost:3000', + origin: 'http://evil.com', + 'content-type': 'application/json', + }); + const mockHandleRes1 = createMockResponse(); + + await transport.handlePostMessage(mockReq1, mockHandleRes1, { jsonrpc: '2.0', method: 'test' }); + + expect(mockHandleRes1.writeHead).toHaveBeenCalledWith(403); + expect(mockHandleRes1.end).toHaveBeenCalledWith('Invalid Origin header: http://evil.com'); + + // Invalid host, valid origin + const mockReq2 = createMockRequest({ + host: 'evil.com', + origin: 'http://localhost:3000', + 'content-type': 'application/json', + }); + const mockHandleRes2 = createMockResponse(); + + await transport.handlePostMessage(mockReq2, mockHandleRes2, { jsonrpc: '2.0', method: 'test' }); + + expect(mockHandleRes2.writeHead).toHaveBeenCalledWith(403); + expect(mockHandleRes2.end).toHaveBeenCalledWith('Invalid Host header: evil.com'); + + // Both valid + const mockReq3 = createMockRequest({ + host: 'localhost:3000', + origin: 'http://localhost:3000', + 'content-type': 'application/json', + }); + const mockHandleRes3 = createMockResponse(); + + await transport.handlePostMessage(mockReq3, mockHandleRes3, { jsonrpc: '2.0', method: 'test' }); + + expect(mockHandleRes3.writeHead).toHaveBeenCalledWith(202); + expect(mockHandleRes3.end).toHaveBeenCalledWith('Accepted'); + }); + }); + }); }); diff --git a/src/server/sse.ts b/src/server/sse.ts index 03f6fefc..6ef2baef 100644 --- a/src/server/sse.ts +++ b/src/server/sse.ts @@ -9,6 +9,29 @@ import { URL } from 'url'; const MAXIMUM_MESSAGE_SIZE = "4mb"; +/** + * Configuration options for SSEServerTransport. + */ +export interface SSEServerTransportOptions { + /** + * List of allowed host header values for DNS rebinding protection. + * If not specified, host validation is disabled. + */ + allowedHosts?: string[]; + + /** + * List of allowed origin header values for DNS rebinding protection. + * If not specified, origin validation is disabled. + */ + allowedOrigins?: string[]; + + /** + * Disable DNS rebinding protection entirely (overrides allowedHosts and allowedOrigins). + * Default is false. + */ + disableDnsRebindingProtection?: boolean; +} + /** * Server transport for SSE: this will send messages over an SSE connection and receive messages from HTTP POST requests. * @@ -17,6 +40,7 @@ const MAXIMUM_MESSAGE_SIZE = "4mb"; export class SSEServerTransport implements Transport { private _sseResponse?: ServerResponse; private _sessionId: string; + private _options: SSEServerTransportOptions; onclose?: () => void; onerror?: (error: Error) => void; @@ -28,8 +52,39 @@ export class SSEServerTransport implements Transport { constructor( private _endpoint: string, private res: ServerResponse, + options?: SSEServerTransportOptions, ) { this._sessionId = randomUUID(); + this._options = options || {disableDnsRebindingProtection: true}; + } + + /** + * Validates request headers for DNS rebinding protection. + * @returns Error message if validation fails, undefined if validation passes. + */ + private validateRequestHeaders(req: IncomingMessage): string | undefined { + // Skip validation if protection is disabled + if (this._options.disableDnsRebindingProtection) { + return undefined; + } + + // Validate Host header if allowedHosts is configured + if (this._options.allowedHosts && this._options.allowedHosts.length > 0) { + const hostHeader = req.headers.host; + if (!hostHeader || !this._options.allowedHosts.includes(hostHeader)) { + return `Invalid Host header: ${hostHeader}`; + } + } + + // Validate Origin header if allowedOrigins is configured + if (this._options.allowedOrigins && this._options.allowedOrigins.length > 0) { + const originHeader = req.headers.origin; + if (!originHeader || !this._options.allowedOrigins.includes(originHeader)) { + return `Invalid Origin header: ${originHeader}`; + } + } + + return undefined; } /** @@ -86,13 +141,22 @@ export class SSEServerTransport implements Transport { res.writeHead(500).end(message); throw new Error(message); } + + // Validate request headers for DNS rebinding protection + const validationError = this.validateRequestHeaders(req); + if (validationError) { + res.writeHead(403).end(validationError); + this.onerror?.(new Error(validationError)); + return; + } + const authInfo: AuthInfo | undefined = req.auth; let body: string | unknown; try { const ct = contentType.parse(req.headers["content-type"] ?? ""); if (ct.type !== "application/json") { - throw new Error(`Unsupported content-type: ${ct}`); + throw new Error(`Unsupported content-type: ${ct.type}`); } body = parsedBody ?? await getRawBody(req, { From ebf2535b90ce4c81553f237f49450b6ea4f9ce71 Mon Sep 17 00:00:00 2001 From: David Dworken Date: Thu, 29 May 2025 13:52:30 -0700 Subject: [PATCH 2/6] Add protections for streamable HTTP too --- README.md | 22 ++- src/server/streamableHttp.test.ts | 263 +++++++++++++++++++++++++++++- src/server/streamableHttp.ts | 68 ++++++++ 3 files changed, 351 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index c9e27c27..ccc627b2 100644 --- a/README.md +++ b/README.md @@ -251,7 +251,11 @@ app.post('/mcp', async (req, res) => { onsessioninitialized: (sessionId) => { // Store the transport by session ID transports[sessionId] = transport; - } + }, + // DNS rebinding protection is disabled by default for backwards compatibility. If you are running this server + // locally, make sure to set: + // disableDnsRebindingProtection: true, + // allowedHosts: ['127.0.0.1'], }); // Clean up transport when closed @@ -386,6 +390,22 @@ This stateless approach is useful for: - RESTful scenarios where each request is independent - Horizontally scaled deployments without shared session state +#### DNS Rebinding Protection + +The Streamable HTTP transport includes DNS rebinding protection to prevent security vulnerabilities. By default, this protection is **disabled** for backwards compatibility. + +**Important**: If you are running this server locally, enable DNS rebinding protection: + +```typescript +const transport = new StreamableHTTPServerTransport({ + sessionIdGenerator: () => randomUUID(), + disableDnsRebindingProtection: false, + + allowedHosts: ['127.0.0.1', ...], + allowedOrigins: ['https://yourdomain.com', 'https://www.yourdomain.com'] +}); +``` + ### Testing and Debugging To test your server, you can use the [MCP Inspector](https://github.com/modelcontextprotocol/inspector). See its README for more information. diff --git a/src/server/streamableHttp.test.ts b/src/server/streamableHttp.test.ts index b961f6c4..68fe8ee7 100644 --- a/src/server/streamableHttp.test.ts +++ b/src/server/streamableHttp.test.ts @@ -1293,4 +1293,265 @@ describe("StreamableHTTPServerTransport in stateless mode", () => { }); expect(stream2.status).toBe(409); // Conflict - only one stream allowed }); -}); \ No newline at end of file +}); + +// Test DNS rebinding protection +describe("StreamableHTTPServerTransport DNS rebinding protection", () => { + let server: Server; + let transport: StreamableHTTPServerTransport; + let baseUrl: URL; + + afterEach(async () => { + if (server && transport) { + await stopTestServer({ server, transport }); + } + }); + + describe("Host header validation", () => { + it("should accept requests with allowed host headers", async () => { + const result = await createTestServerWithDnsProtection({ + sessionIdGenerator: undefined, + allowedHosts: ['localhost:3001'], + disableDnsRebindingProtection: false, + }); + server = result.server; + transport = result.transport; + baseUrl = result.baseUrl; + + // Note: fetch() automatically sets Host header to match the URL + // Since we're connecting to localhost:3001 and that's in allowedHosts, this should work + const response = await fetch(baseUrl, { + method: "POST", + headers: { + "Content-Type": "application/json", + Accept: "application/json, text/event-stream", + }, + body: JSON.stringify(TEST_MESSAGES.initialize), + }); + + expect(response.status).toBe(200); + }); + + it("should reject requests with disallowed host headers", async () => { + // Test DNS rebinding protection by creating a server that only allows example.com + // but we're connecting via localhost, so it should be rejected + const result = await createTestServerWithDnsProtection({ + sessionIdGenerator: undefined, + allowedHosts: ['example.com:3001'], + disableDnsRebindingProtection: false, + }); + server = result.server; + transport = result.transport; + baseUrl = result.baseUrl; + + const response = await fetch(baseUrl, { + method: "POST", + headers: { + "Content-Type": "application/json", + Accept: "application/json, text/event-stream", + }, + body: JSON.stringify(TEST_MESSAGES.initialize), + }); + + expect(response.status).toBe(403); + const body = await response.json(); + expect(body.error.message).toContain("Invalid Host header:"); + }); + + it("should reject GET requests with disallowed host headers", async () => { + const result = await createTestServerWithDnsProtection({ + sessionIdGenerator: undefined, + allowedHosts: ['example.com:3001'], + disableDnsRebindingProtection: false, + }); + server = result.server; + transport = result.transport; + baseUrl = result.baseUrl; + + const response = await fetch(baseUrl, { + method: "GET", + headers: { + Accept: "text/event-stream", + }, + }); + + expect(response.status).toBe(403); + }); + }); + + describe("Origin header validation", () => { + it("should accept requests with allowed origin headers", async () => { + const result = await createTestServerWithDnsProtection({ + sessionIdGenerator: undefined, + allowedOrigins: ['http://localhost:3000', 'https://example.com'], + disableDnsRebindingProtection: false, + }); + server = result.server; + transport = result.transport; + baseUrl = result.baseUrl; + + const response = await fetch(baseUrl, { + method: "POST", + headers: { + "Content-Type": "application/json", + Accept: "application/json, text/event-stream", + Origin: "http://localhost:3000", + }, + body: JSON.stringify(TEST_MESSAGES.initialize), + }); + + expect(response.status).toBe(200); + }); + + it("should reject requests with disallowed origin headers", async () => { + const result = await createTestServerWithDnsProtection({ + sessionIdGenerator: undefined, + allowedOrigins: ['http://localhost:3000'], + disableDnsRebindingProtection: false, + }); + server = result.server; + transport = result.transport; + baseUrl = result.baseUrl; + + const response = await fetch(baseUrl, { + method: "POST", + headers: { + "Content-Type": "application/json", + Accept: "application/json, text/event-stream", + Origin: "http://evil.com", + }, + body: JSON.stringify(TEST_MESSAGES.initialize), + }); + + expect(response.status).toBe(403); + const body = await response.json(); + expect(body.error.message).toBe("Invalid Origin header: http://evil.com"); + }); + }); + + describe("disableDnsRebindingProtection option", () => { + it("should skip all validations when disableDnsRebindingProtection is true", async () => { + const result = await createTestServerWithDnsProtection({ + sessionIdGenerator: undefined, + allowedHosts: ['localhost:3001'], + allowedOrigins: ['http://localhost:3000'], + disableDnsRebindingProtection: true, + }); + server = result.server; + transport = result.transport; + baseUrl = result.baseUrl; + + const response = await fetch(baseUrl, { + method: "POST", + headers: { + "Content-Type": "application/json", + Accept: "application/json, text/event-stream", + Host: "evil.com", + Origin: "http://evil.com", + }, + body: JSON.stringify(TEST_MESSAGES.initialize), + }); + + // Should pass even with invalid headers because protection is disabled + expect(response.status).toBe(200); + }); + }); + + describe("Combined validations", () => { + it("should validate both host and origin when both are configured", async () => { + const result = await createTestServerWithDnsProtection({ + sessionIdGenerator: undefined, + allowedHosts: ['localhost:3001'], + allowedOrigins: ['http://localhost:3001'], + disableDnsRebindingProtection: false, + }); + server = result.server; + transport = result.transport; + baseUrl = result.baseUrl; + + // Test with invalid origin (host will be automatically correct via fetch) + const response1 = await fetch(baseUrl, { + method: "POST", + headers: { + "Content-Type": "application/json", + Accept: "application/json, text/event-stream", + Origin: "http://evil.com", + }, + body: JSON.stringify(TEST_MESSAGES.initialize), + }); + + expect(response1.status).toBe(403); + const body1 = await response1.json(); + expect(body1.error.message).toBe("Invalid Origin header: http://evil.com"); + + // Test with valid origin + const response2 = await fetch(baseUrl, { + method: "POST", + headers: { + "Content-Type": "application/json", + Accept: "application/json, text/event-stream", + Origin: "http://localhost:3001", + }, + body: JSON.stringify(TEST_MESSAGES.initialize), + }); + + expect(response2.status).toBe(200); + }); + }); +}); + +/** + * Helper to create test server with DNS rebinding protection options + */ +async function createTestServerWithDnsProtection(config: { + sessionIdGenerator: (() => string) | undefined; + allowedHosts?: string[]; + allowedOrigins?: string[]; + disableDnsRebindingProtection?: boolean; +}): Promise<{ + server: Server; + transport: StreamableHTTPServerTransport; + mcpServer: McpServer; + baseUrl: URL; +}> { + const mcpServer = new McpServer( + { name: "test-server", version: "1.0.0" }, + { capabilities: { logging: {} } } + ); + + const transport = new StreamableHTTPServerTransport({ + sessionIdGenerator: config.sessionIdGenerator, + allowedHosts: config.allowedHosts, + allowedOrigins: config.allowedOrigins, + disableDnsRebindingProtection: config.disableDnsRebindingProtection, + }); + + await mcpServer.connect(transport); + + const httpServer = createServer(async (req, res) => { + if (req.method === "POST") { + let body = ""; + req.on("data", (chunk) => (body += chunk)); + req.on("end", async () => { + const parsedBody = JSON.parse(body); + await transport.handleRequest(req as IncomingMessage & { auth?: AuthInfo }, res, parsedBody); + }); + } else { + await transport.handleRequest(req as IncomingMessage & { auth?: AuthInfo }, res); + } + }); + + await new Promise((resolve) => { + httpServer.listen(3001, () => resolve()); + }); + + const port = (httpServer.address() as AddressInfo).port; + const serverUrl = new URL(`http://localhost:${port}/`); + + return { + server: httpServer, + transport, + mcpServer, + baseUrl: serverUrl, + }; +} \ No newline at end of file diff --git a/src/server/streamableHttp.ts b/src/server/streamableHttp.ts index dc99c306..8feeac55 100644 --- a/src/server/streamableHttp.ts +++ b/src/server/streamableHttp.ts @@ -61,6 +61,24 @@ export interface StreamableHTTPServerTransportOptions { * If provided, resumability will be enabled, allowing clients to reconnect and resume messages */ eventStore?: EventStore; + + /** + * List of allowed host header values for DNS rebinding protection. + * If not specified, host validation is disabled. + */ + allowedHosts?: string[]; + + /** + * List of allowed origin header values for DNS rebinding protection. + * If not specified, origin validation is disabled. + */ + allowedOrigins?: string[]; + + /** + * Disable DNS rebinding protection entirely (overrides allowedHosts and allowedOrigins). + * Default is true for backwards compatibility. + */ + disableDnsRebindingProtection?: boolean; } /** @@ -109,6 +127,9 @@ export class StreamableHTTPServerTransport implements Transport { private _standaloneSseStreamId: string = '_GET_stream'; private _eventStore?: EventStore; private _onsessioninitialized?: (sessionId: string) => void; + private _allowedHosts?: string[]; + private _allowedOrigins?: string[]; + private _disableDnsRebindingProtection: boolean; sessionId?: string | undefined; onclose?: () => void; @@ -120,6 +141,9 @@ export class StreamableHTTPServerTransport implements Transport { this._enableJsonResponse = options.enableJsonResponse ?? false; this._eventStore = options.eventStore; this._onsessioninitialized = options.onsessioninitialized; + this._allowedHosts = options.allowedHosts; + this._allowedOrigins = options.allowedOrigins; + this._disableDnsRebindingProtection = options.disableDnsRebindingProtection ?? true; } /** @@ -133,10 +157,54 @@ export class StreamableHTTPServerTransport implements Transport { this._started = true; } + /** + * Validates request headers for DNS rebinding protection. + * @returns Error message if validation fails, undefined if validation passes. + */ + private validateRequestHeaders(req: IncomingMessage): string | undefined { + // Skip validation if protection is disabled + if (this._disableDnsRebindingProtection) { + return undefined; + } + + // Validate Host header if allowedHosts is configured + if (this._allowedHosts && this._allowedHosts.length > 0) { + const hostHeader = req.headers.host; + if (!hostHeader || !this._allowedHosts.includes(hostHeader)) { + return `Invalid Host header: ${hostHeader}`; + } + } + + // Validate Origin header if allowedOrigins is configured + if (this._allowedOrigins && this._allowedOrigins.length > 0) { + const originHeader = req.headers.origin; + if (!originHeader || !this._allowedOrigins.includes(originHeader)) { + return `Invalid Origin header: ${originHeader}`; + } + } + + return undefined; + } + /** * Handles an incoming HTTP request, whether GET or POST */ async handleRequest(req: IncomingMessage & { auth?: AuthInfo }, res: ServerResponse, parsedBody?: unknown): Promise { + // Validate request headers for DNS rebinding protection + const validationError = this.validateRequestHeaders(req); + if (validationError) { + res.writeHead(403).end(JSON.stringify({ + jsonrpc: "2.0", + error: { + code: -32000, + message: validationError + }, + id: null + })); + this.onerror?.(new Error(validationError)); + return; + } + if (req.method === "POST") { await this.handlePostRequest(req, res, parsedBody); } else if (req.method === "GET") { From 41c7ed09954d63151c8f96ec8af96a827c399edf Mon Sep 17 00:00:00 2001 From: David Dworken Date: Thu, 29 May 2025 13:53:31 -0700 Subject: [PATCH 3/6] Revert package-lock change --- package-lock.json | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/package-lock.json b/package-lock.json index ef539382..40bad9fe 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "@modelcontextprotocol/sdk", - "version": "1.12.1", + "version": "1.11.4", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "@modelcontextprotocol/sdk", - "version": "1.12.1", + "version": "1.11.4", "license": "MIT", "dependencies": { "ajv": "^6.12.6", From 88c6098495522fc34d9554076998363862617df0 Mon Sep 17 00:00:00 2001 From: David Dworken Date: Thu, 29 May 2025 13:54:49 -0700 Subject: [PATCH 4/6] Clean up --- src/server/sse.ts | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/server/sse.ts b/src/server/sse.ts index 6ef2baef..65e8c7c8 100644 --- a/src/server/sse.ts +++ b/src/server/sse.ts @@ -27,7 +27,6 @@ export interface SSEServerTransportOptions { /** * Disable DNS rebinding protection entirely (overrides allowedHosts and allowedOrigins). - * Default is false. */ disableDnsRebindingProtection?: boolean; } @@ -156,7 +155,7 @@ export class SSEServerTransport implements Transport { try { const ct = contentType.parse(req.headers["content-type"] ?? ""); if (ct.type !== "application/json") { - throw new Error(`Unsupported content-type: ${ct.type}`); + throw new Error(`Unsupported content-type: ${ct}`); } body = parsedBody ?? await getRawBody(req, { From 970905c3131276d1f6b392f060faac91bddbf3fa Mon Sep 17 00:00:00 2001 From: David Dworken Date: Thu, 29 May 2025 14:01:36 -0700 Subject: [PATCH 5/6] Fix SSE content-type error message format --- src/server/sse.test.ts | 2 +- src/server/sse.ts | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/server/sse.test.ts b/src/server/sse.test.ts index 38ba9e59..9fb1f30c 100644 --- a/src/server/sse.test.ts +++ b/src/server/sse.test.ts @@ -264,7 +264,7 @@ describe('SSEServerTransport', () => { await transport.handlePostMessage(mockReq, mockHandleRes, { jsonrpc: '2.0', method: 'test' }); expect(mockHandleRes.writeHead).toHaveBeenCalledWith(400); - expect(mockHandleRes.end).toHaveBeenCalledWith('Error: Content-Type must start with application/json, got: text/plain'); + expect(mockHandleRes.end).toHaveBeenCalledWith('Error: Unsupported content-type: text/plain'); }); }); diff --git a/src/server/sse.ts b/src/server/sse.ts index 65e8c7c8..73c74bb3 100644 --- a/src/server/sse.ts +++ b/src/server/sse.ts @@ -155,7 +155,7 @@ export class SSEServerTransport implements Transport { try { const ct = contentType.parse(req.headers["content-type"] ?? ""); if (ct.type !== "application/json") { - throw new Error(`Unsupported content-type: ${ct}`); + throw new Error(`Unsupported content-type: ${ct.type}`); } body = parsedBody ?? await getRawBody(req, { From ab900839fbd450fbd86a1fb0034803a892048c29 Mon Sep 17 00:00:00 2001 From: David Dworken Date: Fri, 30 May 2025 09:06:01 -0700 Subject: [PATCH 6/6] Invert variable to improve code readability --- README.md | 4 ++-- src/server/sse.test.ts | 12 +++++++++--- src/server/sse.ts | 11 ++++++----- src/server/streamableHttp.test.ts | 22 +++++++++++----------- src/server/streamableHttp.ts | 14 +++++++------- 5 files changed, 35 insertions(+), 28 deletions(-) diff --git a/README.md b/README.md index ccc627b2..32037f90 100644 --- a/README.md +++ b/README.md @@ -254,7 +254,7 @@ app.post('/mcp', async (req, res) => { }, // DNS rebinding protection is disabled by default for backwards compatibility. If you are running this server // locally, make sure to set: - // disableDnsRebindingProtection: true, + // enableDnsRebindingProtection: true, // allowedHosts: ['127.0.0.1'], }); @@ -399,7 +399,7 @@ The Streamable HTTP transport includes DNS rebinding protection to prevent secur ```typescript const transport = new StreamableHTTPServerTransport({ sessionIdGenerator: () => randomUUID(), - disableDnsRebindingProtection: false, + enableDnsRebindingProtection: true, allowedHosts: ['127.0.0.1', ...], allowedOrigins: ['https://yourdomain.com', 'https://www.yourdomain.com'] diff --git a/src/server/sse.test.ts b/src/server/sse.test.ts index 9fb1f30c..aee6eaf6 100644 --- a/src/server/sse.test.ts +++ b/src/server/sse.test.ts @@ -125,6 +125,7 @@ describe('SSEServerTransport', () => { const mockRes = createMockResponse(); const transport = new SSEServerTransport('/messages', mockRes, { allowedHosts: ['localhost:3000', 'example.com'], + enableDnsRebindingProtection: true, }); await transport.start(); @@ -144,6 +145,7 @@ describe('SSEServerTransport', () => { const mockRes = createMockResponse(); const transport = new SSEServerTransport('/messages', mockRes, { allowedHosts: ['localhost:3000'], + enableDnsRebindingProtection: true, }); await transport.start(); @@ -163,6 +165,7 @@ describe('SSEServerTransport', () => { const mockRes = createMockResponse(); const transport = new SSEServerTransport('/messages', mockRes, { allowedHosts: ['localhost:3000'], + enableDnsRebindingProtection: true, }); await transport.start(); @@ -183,6 +186,7 @@ describe('SSEServerTransport', () => { const mockRes = createMockResponse(); const transport = new SSEServerTransport('/messages', mockRes, { allowedOrigins: ['http://localhost:3000', 'https://example.com'], + enableDnsRebindingProtection: true, }); await transport.start(); @@ -202,6 +206,7 @@ describe('SSEServerTransport', () => { const mockRes = createMockResponse(); const transport = new SSEServerTransport('/messages', mockRes, { allowedOrigins: ['http://localhost:3000'], + enableDnsRebindingProtection: true, }); await transport.start(); @@ -268,13 +273,13 @@ describe('SSEServerTransport', () => { }); }); - describe('disableDnsRebindingProtection option', () => { - it('should skip all validations when disableDnsRebindingProtection is true', async () => { + describe('enableDnsRebindingProtection option', () => { + it('should skip all validations when enableDnsRebindingProtection is false', async () => { const mockRes = createMockResponse(); const transport = new SSEServerTransport('/messages', mockRes, { allowedHosts: ['localhost:3000'], allowedOrigins: ['http://localhost:3000'], - disableDnsRebindingProtection: true, + enableDnsRebindingProtection: false, }); await transport.start(); @@ -300,6 +305,7 @@ describe('SSEServerTransport', () => { const transport = new SSEServerTransport('/messages', mockRes, { allowedHosts: ['localhost:3000'], allowedOrigins: ['http://localhost:3000'], + enableDnsRebindingProtection: true, }); await transport.start(); diff --git a/src/server/sse.ts b/src/server/sse.ts index 73c74bb3..bd5d80b9 100644 --- a/src/server/sse.ts +++ b/src/server/sse.ts @@ -26,9 +26,10 @@ export interface SSEServerTransportOptions { allowedOrigins?: string[]; /** - * Disable DNS rebinding protection entirely (overrides allowedHosts and allowedOrigins). + * Enable DNS rebinding protection (requires allowedHosts and/or allowedOrigins to be configured). + * Default is false for backwards compatibility. */ - disableDnsRebindingProtection?: boolean; + enableDnsRebindingProtection?: boolean; } /** @@ -54,7 +55,7 @@ export class SSEServerTransport implements Transport { options?: SSEServerTransportOptions, ) { this._sessionId = randomUUID(); - this._options = options || {disableDnsRebindingProtection: true}; + this._options = options || {enableDnsRebindingProtection: false}; } /** @@ -62,8 +63,8 @@ export class SSEServerTransport implements Transport { * @returns Error message if validation fails, undefined if validation passes. */ private validateRequestHeaders(req: IncomingMessage): string | undefined { - // Skip validation if protection is disabled - if (this._options.disableDnsRebindingProtection) { + // Skip validation if protection is not enabled + if (!this._options.enableDnsRebindingProtection) { return undefined; } diff --git a/src/server/streamableHttp.test.ts b/src/server/streamableHttp.test.ts index 68fe8ee7..4683024b 100644 --- a/src/server/streamableHttp.test.ts +++ b/src/server/streamableHttp.test.ts @@ -1312,7 +1312,7 @@ describe("StreamableHTTPServerTransport DNS rebinding protection", () => { const result = await createTestServerWithDnsProtection({ sessionIdGenerator: undefined, allowedHosts: ['localhost:3001'], - disableDnsRebindingProtection: false, + enableDnsRebindingProtection: true, }); server = result.server; transport = result.transport; @@ -1338,7 +1338,7 @@ describe("StreamableHTTPServerTransport DNS rebinding protection", () => { const result = await createTestServerWithDnsProtection({ sessionIdGenerator: undefined, allowedHosts: ['example.com:3001'], - disableDnsRebindingProtection: false, + enableDnsRebindingProtection: true, }); server = result.server; transport = result.transport; @@ -1362,7 +1362,7 @@ describe("StreamableHTTPServerTransport DNS rebinding protection", () => { const result = await createTestServerWithDnsProtection({ sessionIdGenerator: undefined, allowedHosts: ['example.com:3001'], - disableDnsRebindingProtection: false, + enableDnsRebindingProtection: true, }); server = result.server; transport = result.transport; @@ -1384,7 +1384,7 @@ describe("StreamableHTTPServerTransport DNS rebinding protection", () => { const result = await createTestServerWithDnsProtection({ sessionIdGenerator: undefined, allowedOrigins: ['http://localhost:3000', 'https://example.com'], - disableDnsRebindingProtection: false, + enableDnsRebindingProtection: true, }); server = result.server; transport = result.transport; @@ -1407,7 +1407,7 @@ describe("StreamableHTTPServerTransport DNS rebinding protection", () => { const result = await createTestServerWithDnsProtection({ sessionIdGenerator: undefined, allowedOrigins: ['http://localhost:3000'], - disableDnsRebindingProtection: false, + enableDnsRebindingProtection: true, }); server = result.server; transport = result.transport; @@ -1429,13 +1429,13 @@ describe("StreamableHTTPServerTransport DNS rebinding protection", () => { }); }); - describe("disableDnsRebindingProtection option", () => { - it("should skip all validations when disableDnsRebindingProtection is true", async () => { + describe("enableDnsRebindingProtection option", () => { + it("should skip all validations when enableDnsRebindingProtection is false", async () => { const result = await createTestServerWithDnsProtection({ sessionIdGenerator: undefined, allowedHosts: ['localhost:3001'], allowedOrigins: ['http://localhost:3000'], - disableDnsRebindingProtection: true, + enableDnsRebindingProtection: false, }); server = result.server; transport = result.transport; @@ -1463,7 +1463,7 @@ describe("StreamableHTTPServerTransport DNS rebinding protection", () => { sessionIdGenerator: undefined, allowedHosts: ['localhost:3001'], allowedOrigins: ['http://localhost:3001'], - disableDnsRebindingProtection: false, + enableDnsRebindingProtection: true, }); server = result.server; transport = result.transport; @@ -1507,7 +1507,7 @@ async function createTestServerWithDnsProtection(config: { sessionIdGenerator: (() => string) | undefined; allowedHosts?: string[]; allowedOrigins?: string[]; - disableDnsRebindingProtection?: boolean; + enableDnsRebindingProtection?: boolean; }): Promise<{ server: Server; transport: StreamableHTTPServerTransport; @@ -1523,7 +1523,7 @@ async function createTestServerWithDnsProtection(config: { sessionIdGenerator: config.sessionIdGenerator, allowedHosts: config.allowedHosts, allowedOrigins: config.allowedOrigins, - disableDnsRebindingProtection: config.disableDnsRebindingProtection, + enableDnsRebindingProtection: config.enableDnsRebindingProtection, }); await mcpServer.connect(transport); diff --git a/src/server/streamableHttp.ts b/src/server/streamableHttp.ts index 8feeac55..084147dc 100644 --- a/src/server/streamableHttp.ts +++ b/src/server/streamableHttp.ts @@ -75,10 +75,10 @@ export interface StreamableHTTPServerTransportOptions { allowedOrigins?: string[]; /** - * Disable DNS rebinding protection entirely (overrides allowedHosts and allowedOrigins). - * Default is true for backwards compatibility. + * Enable DNS rebinding protection (requires allowedHosts and/or allowedOrigins to be configured). + * Default is false for backwards compatibility. */ - disableDnsRebindingProtection?: boolean; + enableDnsRebindingProtection?: boolean; } /** @@ -129,7 +129,7 @@ export class StreamableHTTPServerTransport implements Transport { private _onsessioninitialized?: (sessionId: string) => void; private _allowedHosts?: string[]; private _allowedOrigins?: string[]; - private _disableDnsRebindingProtection: boolean; + private _enableDnsRebindingProtection: boolean; sessionId?: string | undefined; onclose?: () => void; @@ -143,7 +143,7 @@ export class StreamableHTTPServerTransport implements Transport { this._onsessioninitialized = options.onsessioninitialized; this._allowedHosts = options.allowedHosts; this._allowedOrigins = options.allowedOrigins; - this._disableDnsRebindingProtection = options.disableDnsRebindingProtection ?? true; + this._enableDnsRebindingProtection = options.enableDnsRebindingProtection ?? false; } /** @@ -162,8 +162,8 @@ export class StreamableHTTPServerTransport implements Transport { * @returns Error message if validation fails, undefined if validation passes. */ private validateRequestHeaders(req: IncomingMessage): string | undefined { - // Skip validation if protection is disabled - if (this._disableDnsRebindingProtection) { + // Skip validation if protection is not enabled + if (!this._enableDnsRebindingProtection) { return undefined; }