Skip to content

Implement DNS Rebinding Protections per spec #565

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,11 @@ app.post('/mcp', async (req, res) => {
onsessioninitialized: (sessionId) => {
// Store the transport by session ID
transports[sessionId] = transport;
}
},
// DNS rebinding protection is disabled by default for backwards compatibility. If you are running this server
// locally, make sure to set:
// enableDnsRebindingProtection: true,
// allowedHosts: ['127.0.0.1'],
});

// Clean up transport when closed
Expand Down Expand Up @@ -386,6 +390,22 @@ This stateless approach is useful for:
- RESTful scenarios where each request is independent
- Horizontally scaled deployments without shared session state

#### DNS Rebinding Protection

The Streamable HTTP transport includes DNS rebinding protection to prevent security vulnerabilities. By default, this protection is **disabled** for backwards compatibility.

**Important**: If you are running this server locally, enable DNS rebinding protection:

```typescript
const transport = new StreamableHTTPServerTransport({
sessionIdGenerator: () => randomUUID(),
enableDnsRebindingProtection: true,

allowedHosts: ['127.0.0.1', ...],
allowedOrigins: ['https://yourdomain.com', 'https://www.yourdomain.com']
});
```

### Testing and Debugging

To test your server, you can use the [MCP Inspector](https://github.com/modelcontextprotocol/inspector). See its README for more information.
Expand Down
246 changes: 245 additions & 1 deletion src/server/sse.test.ts
Original file line number Diff line number Diff line change
@@ -1,19 +1,27 @@
import http from 'http';
import { jest } from '@jest/globals';
import { SSEServerTransport } from './sse.js';
import { SSEServerTransport } from './sse.js';
import { AuthInfo } from './auth/types.js';

const createMockResponse = () => {
const res = {
writeHead: jest.fn<http.ServerResponse['writeHead']>(),
write: jest.fn<http.ServerResponse['write']>().mockReturnValue(true),
on: jest.fn<http.ServerResponse['on']>(),
end: jest.fn<http.ServerResponse['end']>().mockReturnThis(),
};
res.writeHead.mockReturnThis();
res.on.mockReturnThis();

return res as unknown as http.ServerResponse;
};

const createMockRequest = (headers: Record<string, string> = {}) => {
return {
headers,
} as unknown as http.IncomingMessage & { auth?: AuthInfo };
};

describe('SSEServerTransport', () => {
describe('start method', () => {
it('should correctly append sessionId to a simple relative endpoint', async () => {
Expand Down Expand Up @@ -106,4 +114,240 @@ describe('SSEServerTransport', () => {
);
});
});

describe('DNS rebinding protection', () => {
beforeEach(() => {
jest.clearAllMocks();
});

describe('Host header validation', () => {
it('should accept requests with allowed host headers', async () => {
const mockRes = createMockResponse();
const transport = new SSEServerTransport('/messages', mockRes, {
allowedHosts: ['localhost:3000', 'example.com'],
enableDnsRebindingProtection: true,
});
await transport.start();

const mockReq = createMockRequest({
host: 'localhost:3000',
'content-type': 'application/json',
});
const mockHandleRes = createMockResponse();

await transport.handlePostMessage(mockReq, mockHandleRes, { jsonrpc: '2.0', method: 'test' });

expect(mockHandleRes.writeHead).toHaveBeenCalledWith(202);
expect(mockHandleRes.end).toHaveBeenCalledWith('Accepted');
});

it('should reject requests with disallowed host headers', async () => {
const mockRes = createMockResponse();
const transport = new SSEServerTransport('/messages', mockRes, {
allowedHosts: ['localhost:3000'],
enableDnsRebindingProtection: true,
});
await transport.start();

const mockReq = createMockRequest({
host: 'evil.com',
'content-type': 'application/json',
});
const mockHandleRes = createMockResponse();

await transport.handlePostMessage(mockReq, mockHandleRes, { jsonrpc: '2.0', method: 'test' });

expect(mockHandleRes.writeHead).toHaveBeenCalledWith(403);
expect(mockHandleRes.end).toHaveBeenCalledWith('Invalid Host header: evil.com');
});

it('should reject requests without host header when allowedHosts is configured', async () => {
const mockRes = createMockResponse();
const transport = new SSEServerTransport('/messages', mockRes, {
allowedHosts: ['localhost:3000'],
enableDnsRebindingProtection: true,
});
await transport.start();

const mockReq = createMockRequest({
'content-type': 'application/json',
});
const mockHandleRes = createMockResponse();

await transport.handlePostMessage(mockReq, mockHandleRes, { jsonrpc: '2.0', method: 'test' });

expect(mockHandleRes.writeHead).toHaveBeenCalledWith(403);
expect(mockHandleRes.end).toHaveBeenCalledWith('Invalid Host header: undefined');
});
});

describe('Origin header validation', () => {
it('should accept requests with allowed origin headers', async () => {
const mockRes = createMockResponse();
const transport = new SSEServerTransport('/messages', mockRes, {
allowedOrigins: ['http://localhost:3000', 'https://example.com'],
enableDnsRebindingProtection: true,
});
await transport.start();

const mockReq = createMockRequest({
origin: 'http://localhost:3000',
'content-type': 'application/json',
});
const mockHandleRes = createMockResponse();

await transport.handlePostMessage(mockReq, mockHandleRes, { jsonrpc: '2.0', method: 'test' });

expect(mockHandleRes.writeHead).toHaveBeenCalledWith(202);
expect(mockHandleRes.end).toHaveBeenCalledWith('Accepted');
});

it('should reject requests with disallowed origin headers', async () => {
const mockRes = createMockResponse();
const transport = new SSEServerTransport('/messages', mockRes, {
allowedOrigins: ['http://localhost:3000'],
enableDnsRebindingProtection: true,
});
await transport.start();

const mockReq = createMockRequest({
origin: 'http://evil.com',
'content-type': 'application/json',
});
const mockHandleRes = createMockResponse();

await transport.handlePostMessage(mockReq, mockHandleRes, { jsonrpc: '2.0', method: 'test' });

expect(mockHandleRes.writeHead).toHaveBeenCalledWith(403);
expect(mockHandleRes.end).toHaveBeenCalledWith('Invalid Origin header: http://evil.com');
});
});

describe('Content-Type validation', () => {
it('should accept requests with application/json content-type', async () => {
const mockRes = createMockResponse();
const transport = new SSEServerTransport('/messages', mockRes);
await transport.start();

const mockReq = createMockRequest({
'content-type': 'application/json',
});
const mockHandleRes = createMockResponse();

await transport.handlePostMessage(mockReq, mockHandleRes, { jsonrpc: '2.0', method: 'test' });

expect(mockHandleRes.writeHead).toHaveBeenCalledWith(202);
expect(mockHandleRes.end).toHaveBeenCalledWith('Accepted');
});

it('should accept requests with application/json with charset', async () => {
const mockRes = createMockResponse();
const transport = new SSEServerTransport('/messages', mockRes);
await transport.start();

const mockReq = createMockRequest({
'content-type': 'application/json; charset=utf-8',
});
const mockHandleRes = createMockResponse();

await transport.handlePostMessage(mockReq, mockHandleRes, { jsonrpc: '2.0', method: 'test' });

expect(mockHandleRes.writeHead).toHaveBeenCalledWith(202);
expect(mockHandleRes.end).toHaveBeenCalledWith('Accepted');
});

it('should reject requests with non-application/json content-type when protection is enabled', async () => {
const mockRes = createMockResponse();
const transport = new SSEServerTransport('/messages', mockRes);
await transport.start();

const mockReq = createMockRequest({
'content-type': 'text/plain',
});
const mockHandleRes = createMockResponse();

await transport.handlePostMessage(mockReq, mockHandleRes, { jsonrpc: '2.0', method: 'test' });

expect(mockHandleRes.writeHead).toHaveBeenCalledWith(400);
expect(mockHandleRes.end).toHaveBeenCalledWith('Error: Unsupported content-type: text/plain');
});
});

describe('enableDnsRebindingProtection option', () => {
it('should skip all validations when enableDnsRebindingProtection is false', async () => {
const mockRes = createMockResponse();
const transport = new SSEServerTransport('/messages', mockRes, {
allowedHosts: ['localhost:3000'],
allowedOrigins: ['http://localhost:3000'],
enableDnsRebindingProtection: false,
});
await transport.start();

const mockReq = createMockRequest({
host: 'evil.com',
origin: 'http://evil.com',
'content-type': 'text/plain',
});
const mockHandleRes = createMockResponse();

await transport.handlePostMessage(mockReq, mockHandleRes, { jsonrpc: '2.0', method: 'test' });

// Should pass even with invalid headers because protection is disabled
expect(mockHandleRes.writeHead).toHaveBeenCalledWith(400);
// The error should be from content-type parsing, not DNS rebinding protection
expect(mockHandleRes.end).toHaveBeenCalledWith('Error: Unsupported content-type: text/plain');
});
});

describe('Combined validations', () => {
it('should validate both host and origin when both are configured', async () => {
const mockRes = createMockResponse();
const transport = new SSEServerTransport('/messages', mockRes, {
allowedHosts: ['localhost:3000'],
allowedOrigins: ['http://localhost:3000'],
enableDnsRebindingProtection: true,
});
await transport.start();

// Valid host, invalid origin
const mockReq1 = createMockRequest({
host: 'localhost:3000',
origin: 'http://evil.com',
'content-type': 'application/json',
});
const mockHandleRes1 = createMockResponse();

await transport.handlePostMessage(mockReq1, mockHandleRes1, { jsonrpc: '2.0', method: 'test' });

expect(mockHandleRes1.writeHead).toHaveBeenCalledWith(403);
expect(mockHandleRes1.end).toHaveBeenCalledWith('Invalid Origin header: http://evil.com');

// Invalid host, valid origin
const mockReq2 = createMockRequest({
host: 'evil.com',
origin: 'http://localhost:3000',
'content-type': 'application/json',
});
const mockHandleRes2 = createMockResponse();

await transport.handlePostMessage(mockReq2, mockHandleRes2, { jsonrpc: '2.0', method: 'test' });

expect(mockHandleRes2.writeHead).toHaveBeenCalledWith(403);
expect(mockHandleRes2.end).toHaveBeenCalledWith('Invalid Host header: evil.com');

// Both valid
const mockReq3 = createMockRequest({
host: 'localhost:3000',
origin: 'http://localhost:3000',
'content-type': 'application/json',
});
const mockHandleRes3 = createMockResponse();

await transport.handlePostMessage(mockReq3, mockHandleRes3, { jsonrpc: '2.0', method: 'test' });

expect(mockHandleRes3.writeHead).toHaveBeenCalledWith(202);
expect(mockHandleRes3.end).toHaveBeenCalledWith('Accepted');
});
});
});
});
Loading