diff --git a/src/client/streamableHttp.test.ts b/src/client/streamableHttp.test.ts index 9e8efa52..f748a2be 100644 --- a/src/client/streamableHttp.test.ts +++ b/src/client/streamableHttp.test.ts @@ -1,12 +1,24 @@ import { StreamableHTTPClientTransport, StreamableHTTPReconnectionOptions } from "./streamableHttp.js"; +import { OAuthClientProvider, UnauthorizedError } from "./auth.js"; import { JSONRPCMessage } from "../types.js"; describe("StreamableHTTPClientTransport", () => { let transport: StreamableHTTPClientTransport; + let mockAuthProvider: jest.Mocked; beforeEach(() => { - transport = new StreamableHTTPClientTransport(new URL("http://localhost:1234/mcp")); + mockAuthProvider = { + get redirectUrl() { return "http://localhost/callback"; }, + get clientMetadata() { return { redirect_uris: ["http://localhost/callback"] }; }, + clientInformation: jest.fn(() => ({ client_id: "test-client-id", client_secret: "test-client-secret" })), + tokens: jest.fn(), + saveTokens: jest.fn(), + redirectToAuthorization: jest.fn(), + saveCodeVerifier: jest.fn(), + codeVerifier: jest.fn(), + }; + transport = new StreamableHTTPClientTransport(new URL("http://localhost:1234/mcp"), { authProvider: mockAuthProvider }); jest.spyOn(global, "fetch"); }); @@ -497,4 +509,27 @@ describe("StreamableHTTPClientTransport", () => { expect(getDelay(10)).toBe(5000); }); + it("attempts auth flow on 401 during POST request", async () => { + const message: JSONRPCMessage = { + jsonrpc: "2.0", + method: "test", + params: {}, + id: "test-id" + }; + + (global.fetch as jest.Mock) + .mockResolvedValueOnce({ + ok: false, + status: 401, + statusText: "Unauthorized", + headers: new Headers() + }) + .mockResolvedValue({ + ok: false, + status: 404 + }); + + await expect(transport.send(message)).rejects.toThrow(UnauthorizedError); + expect(mockAuthProvider.redirectToAuthorization.mock.calls).toHaveLength(1); + }); });