diff --git a/src/client/auth.test.ts b/src/client/auth.test.ts index 1b9fb071..6452dbee 100644 --- a/src/client/auth.test.ts +++ b/src/client/auth.test.ts @@ -9,6 +9,7 @@ import { auth, type OAuthClientProvider, } from "./auth.js"; +import {ServerError} from "../server/auth/errors.js"; // Mock fetch globally const mockFetch = jest.fn(); @@ -275,10 +276,7 @@ describe("OAuth Authorization", () => { }); it("throws on non-404 errors", async () => { - mockFetch.mockResolvedValueOnce({ - ok: false, - status: 500, - }); + mockFetch.mockResolvedValueOnce(new Response(null, { status: 500 })); await expect( discoverOAuthMetadata("https://auth.example.com") @@ -286,14 +284,15 @@ describe("OAuth Authorization", () => { }); it("validates metadata schema", async () => { - mockFetch.mockResolvedValueOnce({ - ok: true, - status: 200, - json: async () => ({ - // Missing required fields - issuer: "https://auth.example.com", - }), - }); + mockFetch.mockResolvedValueOnce( + Response.json( + { + // Missing required fields + issuer: "https://auth.example.com", + }, + { status: 200 } + ) + ); await expect( discoverOAuthMetadata("https://auth.example.com") @@ -510,10 +509,12 @@ describe("OAuth Authorization", () => { }); it("throws on error response", async () => { - mockFetch.mockResolvedValueOnce({ - ok: false, - status: 400, - }); + mockFetch.mockResolvedValueOnce( + Response.json( + new ServerError("Token exchange failed").toResponseObject(), + { status: 400 } + ) + ); await expect( exchangeAuthorization("https://auth.example.com", { @@ -611,10 +612,12 @@ describe("OAuth Authorization", () => { }); it("throws on error response", async () => { - mockFetch.mockResolvedValueOnce({ - ok: false, - status: 400, - }); + mockFetch.mockResolvedValueOnce( + Response.json( + new ServerError("Token refresh failed").toResponseObject(), + { status: 400 } + ) + ); await expect( refreshAuthorization("https://auth.example.com", { @@ -699,10 +702,12 @@ describe("OAuth Authorization", () => { }); it("throws on error response", async () => { - mockFetch.mockResolvedValueOnce({ - ok: false, - status: 400, - }); + mockFetch.mockResolvedValueOnce( + Response.json( + new ServerError("Dynamic client registration failed").toResponseObject(), + { status: 400 } + ) + ); await expect( registerClient("https://auth.example.com", { diff --git a/src/client/auth.ts b/src/client/auth.ts index 7a91eb25..8b9d9b9a 100644 --- a/src/client/auth.ts +++ b/src/client/auth.ts @@ -1,7 +1,23 @@ import pkceChallenge from "pkce-challenge"; import { LATEST_PROTOCOL_VERSION } from "../types.js"; -import type { OAuthClientMetadata, OAuthClientInformation, OAuthTokens, OAuthMetadata, OAuthClientInformationFull, OAuthProtectedResourceMetadata } from "../shared/auth.js"; +import { + OAuthClientMetadata, + OAuthClientInformation, + OAuthTokens, + OAuthMetadata, + OAuthClientInformationFull, + OAuthProtectedResourceMetadata, + OAuthErrorResponseSchema +} from "../shared/auth.js"; import { OAuthClientInformationFullSchema, OAuthMetadataSchema, OAuthProtectedResourceMetadataSchema, OAuthTokensSchema } from "../shared/auth.js"; +import { + InvalidClientError, + InvalidGrantError, + OAUTH_ERRORS, + OAuthError, + ServerError, + UnauthorizedClientError +} from "../server/auth/errors.js"; /** * Implements an end-to-end OAuth client to be used with one MCP server. @@ -71,6 +87,13 @@ export interface OAuthClientProvider { * the authorization result. */ codeVerifier(): string | Promise; + + /** + * If implemented, provides a way for the client to invalidate (e.g. delete) the specified + * credentials, in the case where the server has indicated that they are no longer valid. + * This avoids requiring the user to intervene manually. + */ + invalidateCredentials?(scope: 'all' | 'client' | 'tokens' | 'verifier'): void | Promise; } export type AuthResult = "AUTHORIZED" | "REDIRECT"; @@ -81,6 +104,33 @@ export class UnauthorizedError extends Error { } } +/** + * Parses an OAuth error response from a string or Response object. + * + * If the input is a standard OAuth2.0 error response, it will be parsed according to the spec + * and an instance of the appropriate OAuthError subclass will be returned. + * If parsing fails, it falls back to a generic ServerError that includes + * the response status (if available) and original content. + * + * @param input - A Response object or string containing the error response + * @returns A Promise that resolves to an OAuthError instance + */ +export async function parseErrorResponse(input: Response | string): Promise { + const statusCode = input instanceof Response ? input.status : undefined; + const body = input instanceof Response ? await input.text() : input; + + try { + const result = OAuthErrorResponseSchema.parse(JSON.parse(body)); + const { error, error_description, error_uri } = result; + const errorClass = OAUTH_ERRORS[error] || ServerError; + return new errorClass(error_description || '', error_uri); + } catch (error) { + // Not a valid OAuth error response, but try to inform the user of the raw data anyway + const errorMessage = `${statusCode ? `HTTP ${statusCode}: ` : ''}Invalid OAuth error response: ${error}. Raw body: ${body}`; + return new ServerError(errorMessage); + } +} + /** * Orchestrates the full auth flow with a server. * @@ -88,6 +138,31 @@ export class UnauthorizedError extends Error { * instead of linking together the other lower-level functions in this module. */ export async function auth( + provider: OAuthClientProvider, + options: { + serverUrl: string | URL; + authorizationCode?: string; + scope?: string; + resourceMetadataUrl?: URL }): Promise { + + try { + return await authInternal(provider, options); + } catch (error) { + // Handle recoverable error types by invalidating credentials and retrying + if (error instanceof InvalidClientError || error instanceof UnauthorizedClientError) { + await provider.invalidateCredentials?.('all'); + return await authInternal(provider, options); + } else if (error instanceof InvalidGrantError) { + await provider.invalidateCredentials?.('tokens'); + return await authInternal(provider, options); + } + + // Throw otherwise + throw error + } +} + +async function authInternal( provider: OAuthClientProvider, { serverUrl, authorizationCode, @@ -145,7 +220,7 @@ export async function auth( }); await provider.saveTokens(tokens); - return "AUTHORIZED"; + return "AUTHORIZED" } const tokens = await provider.tokens(); @@ -161,9 +236,15 @@ export async function auth( }); await provider.saveTokens(newTokens); - return "AUTHORIZED"; + return "AUTHORIZED" } catch (error) { - console.error("Could not refresh OAuth tokens:", error); + // If this is a ServerError, or an unknown type, log it out and try to continue. Otherwise, escalate so we can fix things and retry. + if (!(error instanceof OAuthError) || error instanceof ServerError) { + console.error("Could not refresh OAuth tokens:", error); + } else { + console.warn(`OAuth token refresh failed: ${JSON.stringify(error.toResponseObject())}`); + throw error; + } } } @@ -180,7 +261,7 @@ export async function auth( await provider.saveCodeVerifier(codeVerifier); await provider.redirectToAuthorization(authorizationUrl); - return "REDIRECT"; + return "REDIRECT" } /** @@ -427,7 +508,7 @@ export async function exchangeAuthorization( }); if (!response.ok) { - throw new Error(`Token exchange failed: HTTP ${response.status}`); + throw await parseErrorResponse(response); } return OAuthTokensSchema.parse(await response.json()); @@ -485,7 +566,7 @@ export async function refreshAuthorization( body: params, }); if (!response.ok) { - throw new Error(`Token refresh failed: HTTP ${response.status}`); + throw await parseErrorResponse(response); } return OAuthTokensSchema.parse({ refresh_token: refreshToken, ...(await response.json()) }); @@ -525,7 +606,7 @@ export async function registerClient( }); if (!response.ok) { - throw new Error(`Dynamic client registration failed: HTTP ${response.status}`); + throw await parseErrorResponse(response); } return OAuthClientInformationFullSchema.parse(await response.json()); diff --git a/src/client/sse.test.ts b/src/client/sse.test.ts index 714e1fdd..1b832937 100644 --- a/src/client/sse.test.ts +++ b/src/client/sse.test.ts @@ -4,6 +4,7 @@ import { JSONRPCMessage } from "../types.js"; import { SSEClientTransport } from "./sse.js"; import { OAuthClientProvider, UnauthorizedError } from "./auth.js"; import { OAuthTokens } from "../shared/auth.js"; +import { InvalidClientError, InvalidGrantError, UnauthorizedClientError } from "../server/auth/errors.js"; describe("SSEClientTransport", () => { let resourceServer: Server; @@ -331,6 +332,7 @@ describe("SSEClientTransport", () => { redirectToAuthorization: jest.fn(), saveCodeVerifier: jest.fn(), codeVerifier: jest.fn(), + invalidateCredentials: jest.fn(), }; }); @@ -879,5 +881,176 @@ describe("SSEClientTransport", () => { await expect(() => transport.start()).rejects.toThrow(UnauthorizedError); expect(mockAuthProvider.redirectToAuthorization).toHaveBeenCalled(); }); + + it("invalidates all credentials on InvalidClientError during token refresh", async () => { + // Mock tokens() to return token with refresh token + mockAuthProvider.tokens.mockResolvedValue({ + access_token: "expired-token", + token_type: "Bearer", + refresh_token: "refresh-token" + }); + + let baseUrl = resourceBaseUrl; + + // Create server that returns InvalidClientError on token refresh + const server = createServer((req, res) => { + lastServerRequest = req; + + // Handle OAuth metadata discovery + if (req.url === "/.well-known/oauth-authorization-server" && req.method === "GET") { + res.writeHead(200, { 'Content-Type': 'application/json' }); + res.end(JSON.stringify({ + issuer: baseUrl.href, + authorization_endpoint: `${baseUrl.href}authorize`, + token_endpoint: `${baseUrl.href}token`, + response_types_supported: ["code"], + code_challenge_methods_supported: ["S256"], + })); + return; + } + + if (req.url === "/token" && req.method === "POST") { + // Handle token refresh request - return InvalidClientError + const error = new InvalidClientError("Client authentication failed"); + res.writeHead(400, { 'Content-Type': 'application/json' }) + .end(JSON.stringify(error.toResponseObject())); + return; + } + + if (req.url !== "/") { + res.writeHead(404).end(); + return; + } + res.writeHead(401).end(); + }); + + await new Promise(resolve => { + server.listen(0, "127.0.0.1", () => { + const addr = server.address() as AddressInfo; + baseUrl = new URL(`http://127.0.0.1:${addr.port}`); + resolve(); + }); + }); + + transport = new SSEClientTransport(baseUrl, { + authProvider: mockAuthProvider, + }); + + await expect(() => transport.start()).rejects.toThrow(InvalidClientError); + expect(mockAuthProvider.invalidateCredentials).toHaveBeenCalledWith('all'); + }); + + it("invalidates all credentials on UnauthorizedClientError during token refresh", async () => { + // Mock tokens() to return token with refresh token + mockAuthProvider.tokens.mockResolvedValue({ + access_token: "expired-token", + token_type: "Bearer", + refresh_token: "refresh-token" + }); + + let baseUrl = resourceBaseUrl; + + const server = createServer((req, res) => { + lastServerRequest = req; + + // Handle OAuth metadata discovery + if (req.url === "/.well-known/oauth-authorization-server" && req.method === "GET") { + res.writeHead(200, { 'Content-Type': 'application/json' }); + res.end(JSON.stringify({ + issuer: baseUrl.href, + authorization_endpoint: `${baseUrl.href}authorize`, + token_endpoint: `${baseUrl.href}token`, + response_types_supported: ["code"], + code_challenge_methods_supported: ["S256"], + })); + return; + } + + if (req.url === "/token" && req.method === "POST") { + // Handle token refresh request - return UnauthorizedClientError + const error = new UnauthorizedClientError("Client not authorized"); + res.writeHead(400, { 'Content-Type': 'application/json' }) + .end(JSON.stringify(error.toResponseObject())); + return; + } + + if (req.url !== "/") { + res.writeHead(404).end(); + return; + } + res.writeHead(401).end(); + }); + + await new Promise(resolve => { + server.listen(0, "127.0.0.1", () => { + const addr = server.address() as AddressInfo; + baseUrl = new URL(`http://127.0.0.1:${addr.port}`); + resolve(); + }); + }); + + transport = new SSEClientTransport(baseUrl, { + authProvider: mockAuthProvider, + }); + + await expect(() => transport.start()).rejects.toThrow(UnauthorizedClientError); + expect(mockAuthProvider.invalidateCredentials).toHaveBeenCalledWith('all'); + }); + + it("invalidates tokens on InvalidGrantError during token refresh", async () => { + // Mock tokens() to return token with refresh token + mockAuthProvider.tokens.mockResolvedValue({ + access_token: "expired-token", + token_type: "Bearer", + refresh_token: "refresh-token" + }); + let baseUrl = resourceBaseUrl; + + const server = createServer((req, res) => { + lastServerRequest = req; + + // Handle OAuth metadata discovery + if (req.url === "/.well-known/oauth-authorization-server" && req.method === "GET") { + res.writeHead(200, { 'Content-Type': 'application/json' }); + res.end(JSON.stringify({ + issuer: baseUrl.href, + authorization_endpoint: `${baseUrl.href}authorize`, + token_endpoint: `${baseUrl.href}token`, + response_types_supported: ["code"], + code_challenge_methods_supported: ["S256"], + })); + return; + } + + if (req.url === "/token" && req.method === "POST") { + // Handle token refresh request - return InvalidGrantError + const error = new InvalidGrantError("Invalid refresh token"); + res.writeHead(400, { 'Content-Type': 'application/json' }) + .end(JSON.stringify(error.toResponseObject())); + return; + } + + if (req.url !== "/") { + res.writeHead(404).end(); + return; + } + res.writeHead(401).end(); + }); + + await new Promise(resolve => { + server.listen(0, "127.0.0.1", () => { + const addr = server.address() as AddressInfo; + baseUrl = new URL(`http://127.0.0.1:${addr.port}`); + resolve(); + }); + }); + + transport = new SSEClientTransport(baseUrl, { + authProvider: mockAuthProvider, + }); + + await expect(() => transport.start()).rejects.toThrow(InvalidGrantError); + expect(mockAuthProvider.invalidateCredentials).toHaveBeenCalledWith('tokens'); + }); }); }); diff --git a/src/client/streamableHttp.test.ts b/src/client/streamableHttp.test.ts index f748a2be..f2b55f8e 100644 --- a/src/client/streamableHttp.test.ts +++ b/src/client/streamableHttp.test.ts @@ -1,6 +1,7 @@ import { StreamableHTTPClientTransport, StreamableHTTPReconnectionOptions } from "./streamableHttp.js"; import { OAuthClientProvider, UnauthorizedError } from "./auth.js"; import { JSONRPCMessage } from "../types.js"; +import { InvalidClientError, InvalidGrantError, UnauthorizedClientError } from "../server/auth/errors.js"; describe("StreamableHTTPClientTransport", () => { @@ -17,6 +18,7 @@ describe("StreamableHTTPClientTransport", () => { redirectToAuthorization: jest.fn(), saveCodeVerifier: jest.fn(), codeVerifier: jest.fn(), + invalidateCredentials: jest.fn(), }; transport = new StreamableHTTPClientTransport(new URL("http://localhost:1234/mcp"), { authProvider: mockAuthProvider }); jest.spyOn(global, "fetch"); @@ -532,4 +534,160 @@ describe("StreamableHTTPClientTransport", () => { await expect(transport.send(message)).rejects.toThrow(UnauthorizedError); expect(mockAuthProvider.redirectToAuthorization.mock.calls).toHaveLength(1); }); + + it("invalidates all credentials on InvalidClientError during auth", async () => { + const message: JSONRPCMessage = { + jsonrpc: "2.0", + method: "test", + params: {}, + id: "test-id" + }; + + mockAuthProvider.tokens.mockResolvedValue({ + access_token: "test-token", + token_type: "Bearer", + refresh_token: "test-refresh" + }); + + const unauthedResponse = { + ok: false, + status: 401, + statusText: "Unauthorized", + headers: new Headers() + }; + (global.fetch as jest.Mock) + // Initial connection + .mockResolvedValueOnce(unauthedResponse) + // Resource discovery + .mockResolvedValueOnce(unauthedResponse) + // OAuth metadata discovery + .mockResolvedValueOnce({ + ok: true, + status: 200, + json: async () => ({ + issuer: "http://localhost:1234", + authorization_endpoint: "http://localhost:1234/authorize", + token_endpoint: "http://localhost:1234/token", + response_types_supported: ["code"], + code_challenge_methods_supported: ["S256"], + }), + }) + // Token refresh fails with InvalidClientError + .mockResolvedValueOnce(Response.json( + new InvalidClientError("Client authentication failed").toResponseObject(), + { status: 400 } + )) + // Fallback should fail to complete the flow + .mockResolvedValue({ + ok: false, + status: 404 + }); + + await expect(transport.send(message)).rejects.toThrow(UnauthorizedError); + expect(mockAuthProvider.invalidateCredentials).toHaveBeenCalledWith('all'); + }); + + it("invalidates all credentials on UnauthorizedClientError during auth", async () => { + const message: JSONRPCMessage = { + jsonrpc: "2.0", + method: "test", + params: {}, + id: "test-id" + }; + + mockAuthProvider.tokens.mockResolvedValue({ + access_token: "test-token", + token_type: "Bearer", + refresh_token: "test-refresh" + }); + + const unauthedResponse = { + ok: false, + status: 401, + statusText: "Unauthorized", + headers: new Headers() + }; + (global.fetch as jest.Mock) + // Initial connection + .mockResolvedValueOnce(unauthedResponse) + // Resource discovery + .mockResolvedValueOnce(unauthedResponse) + // OAuth metadata discovery + .mockResolvedValueOnce({ + ok: true, + status: 200, + json: async () => ({ + issuer: "http://localhost:1234", + authorization_endpoint: "http://localhost:1234/authorize", + token_endpoint: "http://localhost:1234/token", + response_types_supported: ["code"], + code_challenge_methods_supported: ["S256"], + }), + }) + // Token refresh fails with UnauthorizedClientError + .mockResolvedValueOnce(Response.json( + new UnauthorizedClientError("Client not authorized").toResponseObject(), + { status: 400 } + )) + // Fallback should fail to complete the flow + .mockResolvedValue({ + ok: false, + status: 404 + }); + + await expect(transport.send(message)).rejects.toThrow(UnauthorizedError); + expect(mockAuthProvider.invalidateCredentials).toHaveBeenCalledWith('all'); + }); + + it("invalidates tokens on InvalidGrantError during auth", async () => { + const message: JSONRPCMessage = { + jsonrpc: "2.0", + method: "test", + params: {}, + id: "test-id" + }; + + mockAuthProvider.tokens.mockResolvedValue({ + access_token: "test-token", + token_type: "Bearer", + refresh_token: "test-refresh" + }); + + const unauthedResponse = { + ok: false, + status: 401, + statusText: "Unauthorized", + headers: new Headers() + }; + (global.fetch as jest.Mock) + // Initial connection + .mockResolvedValueOnce(unauthedResponse) + // Resource discovery + .mockResolvedValueOnce(unauthedResponse) + // OAuth metadata discovery + .mockResolvedValueOnce({ + ok: true, + status: 200, + json: async () => ({ + issuer: "http://localhost:1234", + authorization_endpoint: "http://localhost:1234/authorize", + token_endpoint: "http://localhost:1234/token", + response_types_supported: ["code"], + code_challenge_methods_supported: ["S256"], + }), + }) + // Token refresh fails with InvalidGrantError + .mockResolvedValueOnce(Response.json( + new InvalidGrantError("Invalid refresh token").toResponseObject(), + { status: 400 } + )) + // Fallback should fail to complete the flow + .mockResolvedValue({ + ok: false, + status: 404 + }); + + await expect(transport.send(message)).rejects.toThrow(UnauthorizedError); + expect(mockAuthProvider.invalidateCredentials).toHaveBeenCalledWith('tokens'); + }); }); diff --git a/src/server/auth/errors.ts b/src/server/auth/errors.ts index 428199ce..83bb7fd3 100644 --- a/src/server/auth/errors.ts +++ b/src/server/auth/errors.ts @@ -4,12 +4,15 @@ import { OAuthErrorResponse } from "../../shared/auth.js"; * Base class for all OAuth errors */ export class OAuthError extends Error { + static errorCode: string; + public errorCode: string; + constructor( - public readonly errorCode: string, message: string, public readonly errorUri?: string ) { super(message); + this.errorCode = (this.constructor as typeof OAuthError).errorCode this.name = this.constructor.name; } @@ -36,9 +39,7 @@ export class OAuthError extends Error { * or is otherwise malformed. */ export class InvalidRequestError extends OAuthError { - constructor(message: string, errorUri?: string) { - super("invalid_request", message, errorUri); - } + static errorCode = "invalid_request"; } /** @@ -46,9 +47,7 @@ export class InvalidRequestError extends OAuthError { * authentication included, or unsupported authentication method). */ export class InvalidClientError extends OAuthError { - constructor(message: string, errorUri?: string) { - super("invalid_client", message, errorUri); - } + static errorCode = "invalid_client"; } /** @@ -57,9 +56,7 @@ export class InvalidClientError extends OAuthError { * authorization request, or was issued to another client. */ export class InvalidGrantError extends OAuthError { - constructor(message: string, errorUri?: string) { - super("invalid_grant", message, errorUri); - } + static errorCode = "invalid_grant"; } /** @@ -67,9 +64,7 @@ export class InvalidGrantError extends OAuthError { * this authorization grant type. */ export class UnauthorizedClientError extends OAuthError { - constructor(message: string, errorUri?: string) { - super("unauthorized_client", message, errorUri); - } + static errorCode = "unauthorized_client"; } /** @@ -77,9 +72,7 @@ export class UnauthorizedClientError extends OAuthError { * by the authorization server. */ export class UnsupportedGrantTypeError extends OAuthError { - constructor(message: string, errorUri?: string) { - super("unsupported_grant_type", message, errorUri); - } + static errorCode = "unsupported_grant_type"; } /** @@ -87,18 +80,14 @@ export class UnsupportedGrantTypeError extends OAuthError { * exceeds the scope granted by the resource owner. */ export class InvalidScopeError extends OAuthError { - constructor(message: string, errorUri?: string) { - super("invalid_scope", message, errorUri); - } + static errorCode = "invalid_scope"; } /** * Access denied error - The resource owner or authorization server denied the request. */ export class AccessDeniedError extends OAuthError { - constructor(message: string, errorUri?: string) { - super("access_denied", message, errorUri); - } + static errorCode = "access_denied"; } /** @@ -106,9 +95,7 @@ export class AccessDeniedError extends OAuthError { * that prevented it from fulfilling the request. */ export class ServerError extends OAuthError { - constructor(message: string, errorUri?: string) { - super("server_error", message, errorUri); - } + static errorCode = "server_error"; } /** @@ -116,9 +103,7 @@ export class ServerError extends OAuthError { * handle the request due to a temporary overloading or maintenance of the server. */ export class TemporarilyUnavailableError extends OAuthError { - constructor(message: string, errorUri?: string) { - super("temporarily_unavailable", message, errorUri); - } + static errorCode = "temporarily_unavailable"; } /** @@ -126,9 +111,7 @@ export class TemporarilyUnavailableError extends OAuthError { * obtaining an authorization code using this method. */ export class UnsupportedResponseTypeError extends OAuthError { - constructor(message: string, errorUri?: string) { - super("unsupported_response_type", message, errorUri); - } + static errorCode = "unsupported_response_type"; } /** @@ -136,9 +119,7 @@ export class UnsupportedResponseTypeError extends OAuthError { * the requested token type. */ export class UnsupportedTokenTypeError extends OAuthError { - constructor(message: string, errorUri?: string) { - super("unsupported_token_type", message, errorUri); - } + static errorCode = "unsupported_token_type"; } /** @@ -146,9 +127,7 @@ export class UnsupportedTokenTypeError extends OAuthError { * or invalid for other reasons. */ export class InvalidTokenError extends OAuthError { - constructor(message: string, errorUri?: string) { - super("invalid_token", message, errorUri); - } + static errorCode = "invalid_token"; } /** @@ -156,9 +135,7 @@ export class InvalidTokenError extends OAuthError { * (Custom, non-standard error) */ export class MethodNotAllowedError extends OAuthError { - constructor(message: string, errorUri?: string) { - super("method_not_allowed", message, errorUri); - } + static errorCode = "method_not_allowed"; } /** @@ -166,9 +143,7 @@ export class MethodNotAllowedError extends OAuthError { * (Custom, non-standard error based on RFC 6585) */ export class TooManyRequestsError extends OAuthError { - constructor(message: string, errorUri?: string) { - super("too_many_requests", message, errorUri); - } + static errorCode = "too_many_requests"; } /** @@ -176,16 +151,44 @@ export class TooManyRequestsError extends OAuthError { * (Custom error for dynamic client registration - RFC 7591) */ export class InvalidClientMetadataError extends OAuthError { - constructor(message: string, errorUri?: string) { - super("invalid_client_metadata", message, errorUri); - } + static errorCode = "invalid_client_metadata"; } /** * Insufficient scope error - The request requires higher privileges than provided by the access token. */ export class InsufficientScopeError extends OAuthError { - constructor(message: string, errorUri?: string) { - super("insufficient_scope", message, errorUri); + static errorCode = "insufficient_scope"; +} + +/** + * A utility class for defining one-off error codes + */ +export class CustomOAuthError extends OAuthError { + constructor(errorCode: string, message: string, errorUri?: string) { + super(message, errorUri); + this.errorCode = errorCode } } + +/** + * A full list of all OAuthErrors, enabling parsing from error responses + */ +export const OAUTH_ERRORS = { + [InvalidRequestError.errorCode]: InvalidRequestError, + [InvalidClientError.errorCode]: InvalidClientError, + [InvalidGrantError.errorCode]: InvalidGrantError, + [UnauthorizedClientError.errorCode]: UnauthorizedClientError, + [UnsupportedGrantTypeError.errorCode]: UnsupportedGrantTypeError, + [InvalidScopeError.errorCode]: InvalidScopeError, + [AccessDeniedError.errorCode]: AccessDeniedError, + [ServerError.errorCode]: ServerError, + [TemporarilyUnavailableError.errorCode]: TemporarilyUnavailableError, + [UnsupportedResponseTypeError.errorCode]: UnsupportedResponseTypeError, + [UnsupportedTokenTypeError.errorCode]: UnsupportedTokenTypeError, + [InvalidTokenError.errorCode]: InvalidTokenError, + [MethodNotAllowedError.errorCode]: MethodNotAllowedError, + [TooManyRequestsError.errorCode]: TooManyRequestsError, + [InvalidClientMetadataError.errorCode]: InvalidClientMetadataError, + [InsufficientScopeError.errorCode]: InsufficientScopeError, +} as const; diff --git a/src/server/auth/middleware/bearerAuth.test.ts b/src/server/auth/middleware/bearerAuth.test.ts index b8953e5c..4e5e1429 100644 --- a/src/server/auth/middleware/bearerAuth.test.ts +++ b/src/server/auth/middleware/bearerAuth.test.ts @@ -1,7 +1,7 @@ import { Request, Response } from "express"; import { requireBearerAuth } from "./bearerAuth.js"; import { AuthInfo } from "../types.js"; -import { InsufficientScopeError, InvalidTokenError, OAuthError, ServerError } from "../errors.js"; +import { InsufficientScopeError, InvalidTokenError, CustomOAuthError, ServerError } from "../errors.js"; import { OAuthTokenVerifier } from "../provider.js"; // Mock verifier @@ -268,7 +268,7 @@ describe("requireBearerAuth middleware", () => { authorization: "Bearer valid-token", }; - mockVerifyAccessToken.mockRejectedValue(new OAuthError("custom_error", "Some OAuth error")); + mockVerifyAccessToken.mockRejectedValue(new CustomOAuthError("custom_error", "Some OAuth error")); const middleware = requireBearerAuth({ verifier: mockVerifier }); await middleware(mockRequest as Request, mockResponse as Response, nextFunction);