Skip to content

Adding invalidateCredentials() to OAuthClientProvider #570

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 29 additions & 24 deletions src/client/auth.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import {
auth,
type OAuthClientProvider,
} from "./auth.js";
import {ServerError} from "../server/auth/errors.js";

// Mock fetch globally
const mockFetch = jest.fn();
Expand Down Expand Up @@ -275,25 +276,23 @@ describe("OAuth Authorization", () => {
});

it("throws on non-404 errors", async () => {
mockFetch.mockResolvedValueOnce({
ok: false,
status: 500,
});
mockFetch.mockResolvedValueOnce(new Response(null, { status: 500 }));

await expect(
discoverOAuthMetadata("https://auth.example.com")
).rejects.toThrow("HTTP 500");
});

it("validates metadata schema", async () => {
mockFetch.mockResolvedValueOnce({
ok: true,
status: 200,
json: async () => ({
// Missing required fields
issuer: "https://auth.example.com",
}),
});
mockFetch.mockResolvedValueOnce(
Response.json(
{
// Missing required fields
issuer: "https://auth.example.com",
},
{ status: 200 }
)
);

await expect(
discoverOAuthMetadata("https://auth.example.com")
Expand Down Expand Up @@ -510,10 +509,12 @@ describe("OAuth Authorization", () => {
});

it("throws on error response", async () => {
mockFetch.mockResolvedValueOnce({
ok: false,
status: 400,
});
mockFetch.mockResolvedValueOnce(
Response.json(
new ServerError("Token exchange failed").toResponseObject(),
{ status: 400 }
)
);

await expect(
exchangeAuthorization("https://auth.example.com", {
Expand Down Expand Up @@ -611,10 +612,12 @@ describe("OAuth Authorization", () => {
});

it("throws on error response", async () => {
mockFetch.mockResolvedValueOnce({
ok: false,
status: 400,
});
mockFetch.mockResolvedValueOnce(
Response.json(
new ServerError("Token refresh failed").toResponseObject(),
{ status: 400 }
)
);

await expect(
refreshAuthorization("https://auth.example.com", {
Expand Down Expand Up @@ -699,10 +702,12 @@ describe("OAuth Authorization", () => {
});

it("throws on error response", async () => {
mockFetch.mockResolvedValueOnce({
ok: false,
status: 400,
});
mockFetch.mockResolvedValueOnce(
Response.json(
new ServerError("Dynamic client registration failed").toResponseObject(),
{ status: 400 }
)
);

await expect(
registerClient("https://auth.example.com", {
Expand Down
97 changes: 89 additions & 8 deletions src/client/auth.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,23 @@
import pkceChallenge from "pkce-challenge";
import { LATEST_PROTOCOL_VERSION } from "../types.js";
import type { OAuthClientMetadata, OAuthClientInformation, OAuthTokens, OAuthMetadata, OAuthClientInformationFull, OAuthProtectedResourceMetadata } from "../shared/auth.js";
import {
OAuthClientMetadata,
OAuthClientInformation,
OAuthTokens,
OAuthMetadata,
OAuthClientInformationFull,
OAuthProtectedResourceMetadata,
OAuthErrorResponseSchema
} from "../shared/auth.js";
import { OAuthClientInformationFullSchema, OAuthMetadataSchema, OAuthProtectedResourceMetadataSchema, OAuthTokensSchema } from "../shared/auth.js";
import {
InvalidClientError,
InvalidGrantError,
OAUTH_ERRORS,
OAuthError,
ServerError,
UnauthorizedClientError
} from "../server/auth/errors.js";

/**
* Implements an end-to-end OAuth client to be used with one MCP server.
Expand Down Expand Up @@ -71,6 +87,13 @@ export interface OAuthClientProvider {
* the authorization result.
*/
codeVerifier(): string | Promise<string>;

/**
* If implemented, provides a way for the client to invalidate (e.g. delete) the specified
* credentials, in the case where the server has indicated that they are no longer valid.
* This avoids requiring the user to intervene manually.
*/
invalidateCredentials?(scope: 'all' | 'client' | 'tokens' | 'verifier'): void | Promise<void>;
}

export type AuthResult = "AUTHORIZED" | "REDIRECT";
Expand All @@ -81,13 +104,65 @@ export class UnauthorizedError extends Error {
}
}

/**
* Parses an OAuth error response from a string or Response object.
*
* If the input is a standard OAuth2.0 error response, it will be parsed according to the spec
* and an instance of the appropriate OAuthError subclass will be returned.
* If parsing fails, it falls back to a generic ServerError that includes
* the response status (if available) and original content.
*
* @param input - A Response object or string containing the error response
* @returns A Promise that resolves to an OAuthError instance
*/
export async function parseErrorResponse(input: Response | string): Promise<OAuthError> {
const statusCode = input instanceof Response ? input.status : undefined;
const body = input instanceof Response ? await input.text() : input;

try {
const result = OAuthErrorResponseSchema.parse(JSON.parse(body));
const { error, error_description, error_uri } = result;
const errorClass = OAUTH_ERRORS[error] || ServerError;
return new errorClass(error_description || '', error_uri);
} catch (error) {
// Not a valid OAuth error response, but try to inform the user of the raw data anyway
const errorMessage = `${statusCode ? `HTTP ${statusCode}: ` : ''}Invalid OAuth error response: ${error}. Raw body: ${body}`;
return new ServerError(errorMessage);
}
}

/**
* Orchestrates the full auth flow with a server.
*
* This can be used as a single entry point for all authorization functionality,
* instead of linking together the other lower-level functions in this module.
*/
export async function auth(
provider: OAuthClientProvider,
options: {
serverUrl: string | URL;
authorizationCode?: string;
scope?: string;
resourceMetadataUrl?: URL }): Promise<AuthResult> {

try {
return await authInternal(provider, options);
} catch (error) {
// Handle recoverable error types by invalidating credentials and retrying
if (error instanceof InvalidClientError || error instanceof UnauthorizedClientError) {
await provider.invalidateCredentials?.('all');
return await authInternal(provider, options);
} else if (error instanceof InvalidGrantError) {
await provider.invalidateCredentials?.('tokens');
return await authInternal(provider, options);
}

// Throw otherwise
throw error
}
}

async function authInternal(
provider: OAuthClientProvider,
{ serverUrl,
authorizationCode,
Expand Down Expand Up @@ -145,7 +220,7 @@ export async function auth(
});

await provider.saveTokens(tokens);
return "AUTHORIZED";
return "AUTHORIZED"
}

const tokens = await provider.tokens();
Expand All @@ -161,9 +236,15 @@ export async function auth(
});

await provider.saveTokens(newTokens);
return "AUTHORIZED";
return "AUTHORIZED"
} catch (error) {
console.error("Could not refresh OAuth tokens:", error);
// If this is a ServerError, or an unknown type, log it out and try to continue. Otherwise, escalate so we can fix things and retry.
if (!(error instanceof OAuthError) || error instanceof ServerError) {
console.error("Could not refresh OAuth tokens:", error);
} else {
console.warn(`OAuth token refresh failed: ${JSON.stringify(error.toResponseObject())}`);
throw error;
}
}
}

Expand All @@ -180,7 +261,7 @@ export async function auth(

await provider.saveCodeVerifier(codeVerifier);
await provider.redirectToAuthorization(authorizationUrl);
return "REDIRECT";
return "REDIRECT"
}

/**
Expand Down Expand Up @@ -427,7 +508,7 @@ export async function exchangeAuthorization(
});

if (!response.ok) {
throw new Error(`Token exchange failed: HTTP ${response.status}`);
throw await parseErrorResponse(response);
}

return OAuthTokensSchema.parse(await response.json());
Expand Down Expand Up @@ -485,7 +566,7 @@ export async function refreshAuthorization(
body: params,
});
if (!response.ok) {
throw new Error(`Token refresh failed: HTTP ${response.status}`);
throw await parseErrorResponse(response);
}

return OAuthTokensSchema.parse({ refresh_token: refreshToken, ...(await response.json()) });
Expand Down Expand Up @@ -525,7 +606,7 @@ export async function registerClient(
});

if (!response.ok) {
throw new Error(`Dynamic client registration failed: HTTP ${response.status}`);
throw await parseErrorResponse(response);
}

return OAuthClientInformationFullSchema.parse(await response.json());
Expand Down
Loading