Skip to content

Commit ecb41f5

Browse files
committed
Refactoring OAuthErrors
This makes it possible to parse them from JSON, using OAUTH_ERRORS Invalidating credentials & retrying when server OAuth errors occur Updated existing tests Added some initial test coverage refactored to avoid recursion as recommended
1 parent 1878143 commit ecb41f5

File tree

6 files changed

+502
-82
lines changed

6 files changed

+502
-82
lines changed

src/client/auth.test.ts

Lines changed: 29 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import {
99
auth,
1010
type OAuthClientProvider,
1111
} from "./auth.js";
12+
import {ServerError} from "../server/auth/errors.js";
1213

1314
// Mock fetch globally
1415
const mockFetch = jest.fn();
@@ -275,25 +276,23 @@ describe("OAuth Authorization", () => {
275276
});
276277

277278
it("throws on non-404 errors", async () => {
278-
mockFetch.mockResolvedValueOnce({
279-
ok: false,
280-
status: 500,
281-
});
279+
mockFetch.mockResolvedValueOnce(new Response(null, { status: 500 }));
282280

283281
await expect(
284282
discoverOAuthMetadata("https://auth.example.com")
285283
).rejects.toThrow("HTTP 500");
286284
});
287285

288286
it("validates metadata schema", async () => {
289-
mockFetch.mockResolvedValueOnce({
290-
ok: true,
291-
status: 200,
292-
json: async () => ({
293-
// Missing required fields
294-
issuer: "https://auth.example.com",
295-
}),
296-
});
287+
mockFetch.mockResolvedValueOnce(
288+
Response.json(
289+
{
290+
// Missing required fields
291+
issuer: "https://auth.example.com",
292+
},
293+
{ status: 200 }
294+
)
295+
);
297296

298297
await expect(
299298
discoverOAuthMetadata("https://auth.example.com")
@@ -510,10 +509,12 @@ describe("OAuth Authorization", () => {
510509
});
511510

512511
it("throws on error response", async () => {
513-
mockFetch.mockResolvedValueOnce({
514-
ok: false,
515-
status: 400,
516-
});
512+
mockFetch.mockResolvedValueOnce(
513+
Response.json(
514+
new ServerError("Token exchange failed").toResponseObject(),
515+
{ status: 400 }
516+
)
517+
);
517518

518519
await expect(
519520
exchangeAuthorization("https://auth.example.com", {
@@ -611,10 +612,12 @@ describe("OAuth Authorization", () => {
611612
});
612613

613614
it("throws on error response", async () => {
614-
mockFetch.mockResolvedValueOnce({
615-
ok: false,
616-
status: 400,
617-
});
615+
mockFetch.mockResolvedValueOnce(
616+
Response.json(
617+
new ServerError("Token refresh failed").toResponseObject(),
618+
{ status: 400 }
619+
)
620+
);
618621

619622
await expect(
620623
refreshAuthorization("https://auth.example.com", {
@@ -699,10 +702,12 @@ describe("OAuth Authorization", () => {
699702
});
700703

701704
it("throws on error response", async () => {
702-
mockFetch.mockResolvedValueOnce({
703-
ok: false,
704-
status: 400,
705-
});
705+
mockFetch.mockResolvedValueOnce(
706+
Response.json(
707+
new ServerError("Dynamic client registration failed").toResponseObject(),
708+
{ status: 400 }
709+
)
710+
);
706711

707712
await expect(
708713
registerClient("https://auth.example.com", {

src/client/auth.ts

Lines changed: 89 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,23 @@
11
import pkceChallenge from "pkce-challenge";
22
import { LATEST_PROTOCOL_VERSION } from "../types.js";
3-
import type { OAuthClientMetadata, OAuthClientInformation, OAuthTokens, OAuthMetadata, OAuthClientInformationFull, OAuthProtectedResourceMetadata } from "../shared/auth.js";
3+
import {
4+
OAuthClientMetadata,
5+
OAuthClientInformation,
6+
OAuthTokens,
7+
OAuthMetadata,
8+
OAuthClientInformationFull,
9+
OAuthProtectedResourceMetadata,
10+
OAuthErrorResponseSchema
11+
} from "../shared/auth.js";
412
import { OAuthClientInformationFullSchema, OAuthMetadataSchema, OAuthProtectedResourceMetadataSchema, OAuthTokensSchema } from "../shared/auth.js";
13+
import {
14+
InvalidClientError,
15+
InvalidGrantError,
16+
OAUTH_ERRORS,
17+
OAuthError,
18+
ServerError,
19+
UnauthorizedClientError
20+
} from "../server/auth/errors.js";
521

622
/**
723
* Implements an end-to-end OAuth client to be used with one MCP server.
@@ -71,6 +87,13 @@ export interface OAuthClientProvider {
7187
* the authorization result.
7288
*/
7389
codeVerifier(): string | Promise<string>;
90+
91+
/**
92+
* If implemented, provides a way for the client to invalidate (e.g. delete) the specified
93+
* credentials, in the case where the server has indicated that they are no longer valid.
94+
* This avoids requiring the user to intervene manually.
95+
*/
96+
invalidateCredentials?(scope: 'all' | 'client' | 'tokens' | 'verifier'): void | Promise<void>;
7497
}
7598

7699
export type AuthResult = "AUTHORIZED" | "REDIRECT";
@@ -81,13 +104,65 @@ export class UnauthorizedError extends Error {
81104
}
82105
}
83106

107+
/**
108+
* Parses an OAuth error response from a string or Response object.
109+
*
110+
* If the input is a standard OAuth2.0 error response, it will be parsed according to the spec
111+
* and an instance of the appropriate OAuthError subclass will be returned.
112+
* If parsing fails, it falls back to a generic ServerError that includes
113+
* the response status (if available) and original content.
114+
*
115+
* @param input - A Response object or string containing the error response
116+
* @returns A Promise that resolves to an OAuthError instance
117+
*/
118+
export async function parseErrorResponse(input: Response | string): Promise<OAuthError> {
119+
const statusCode = input instanceof Response ? input.status : undefined;
120+
const body = input instanceof Response ? await input.text() : input;
121+
122+
try {
123+
const result = OAuthErrorResponseSchema.parse(JSON.parse(body));
124+
const { error, error_description, error_uri } = result;
125+
const errorClass = OAUTH_ERRORS[error] || ServerError;
126+
return new errorClass(error_description || '', error_uri);
127+
} catch (error) {
128+
// Not a valid OAuth error response, but try to inform the user of the raw data anyway
129+
const errorMessage = `${statusCode ? `HTTP ${statusCode}: ` : ''}Invalid OAuth error response: ${error}. Raw body: ${body}`;
130+
return new ServerError(errorMessage);
131+
}
132+
}
133+
84134
/**
85135
* Orchestrates the full auth flow with a server.
86136
*
87137
* This can be used as a single entry point for all authorization functionality,
88138
* instead of linking together the other lower-level functions in this module.
89139
*/
90140
export async function auth(
141+
provider: OAuthClientProvider,
142+
options: {
143+
serverUrl: string | URL;
144+
authorizationCode?: string;
145+
scope?: string;
146+
resourceMetadataUrl?: URL }): Promise<AuthResult> {
147+
148+
try {
149+
return await authInternal(provider, options);
150+
} catch (error) {
151+
// Handle recoverable error types by invalidating credentials and retrying
152+
if (error instanceof InvalidClientError || error instanceof UnauthorizedClientError) {
153+
await provider.invalidateCredentials?.('all');
154+
return await authInternal(provider, options);
155+
} else if (error instanceof InvalidGrantError) {
156+
await provider.invalidateCredentials?.('tokens');
157+
return await authInternal(provider, options);
158+
}
159+
160+
// Throw otherwise
161+
throw error
162+
}
163+
}
164+
165+
async function authInternal(
91166
provider: OAuthClientProvider,
92167
{ serverUrl,
93168
authorizationCode,
@@ -145,7 +220,7 @@ export async function auth(
145220
});
146221

147222
await provider.saveTokens(tokens);
148-
return "AUTHORIZED";
223+
return "AUTHORIZED"
149224
}
150225

151226
const tokens = await provider.tokens();
@@ -161,9 +236,15 @@ export async function auth(
161236
});
162237

163238
await provider.saveTokens(newTokens);
164-
return "AUTHORIZED";
239+
return "AUTHORIZED"
165240
} catch (error) {
166-
console.error("Could not refresh OAuth tokens:", error);
241+
// If this is a ServerError, or an unknown type, log it out and try to continue. Otherwise, escalate so we can fix things and retry.
242+
if (!(error instanceof OAuthError) || error instanceof ServerError) {
243+
console.error("Could not refresh OAuth tokens:", error);
244+
} else {
245+
console.warn(`OAuth token refresh failed: ${JSON.stringify(error.toResponseObject())}`);
246+
throw error;
247+
}
167248
}
168249
}
169250

@@ -180,7 +261,7 @@ export async function auth(
180261

181262
await provider.saveCodeVerifier(codeVerifier);
182263
await provider.redirectToAuthorization(authorizationUrl);
183-
return "REDIRECT";
264+
return "REDIRECT"
184265
}
185266

186267
/**
@@ -427,7 +508,7 @@ export async function exchangeAuthorization(
427508
});
428509

429510
if (!response.ok) {
430-
throw new Error(`Token exchange failed: HTTP ${response.status}`);
511+
throw await parseErrorResponse(response);
431512
}
432513

433514
return OAuthTokensSchema.parse(await response.json());
@@ -485,7 +566,7 @@ export async function refreshAuthorization(
485566
body: params,
486567
});
487568
if (!response.ok) {
488-
throw new Error(`Token refresh failed: HTTP ${response.status}`);
569+
throw await parseErrorResponse(response);
489570
}
490571

491572
return OAuthTokensSchema.parse({ refresh_token: refreshToken, ...(await response.json()) });
@@ -525,7 +606,7 @@ export async function registerClient(
525606
});
526607

527608
if (!response.ok) {
528-
throw new Error(`Dynamic client registration failed: HTTP ${response.status}`);
609+
throw await parseErrorResponse(response);
529610
}
530611

531612
return OAuthClientInformationFullSchema.parse(await response.json());

0 commit comments

Comments
 (0)