diff --git a/src/server/auth/handlers/token.test.ts b/src/server/auth/handlers/token.test.ts index bf41b5eb..f3a91a39 100644 --- a/src/server/auth/handlers/token.test.ts +++ b/src/server/auth/handlers/token.test.ts @@ -21,7 +21,7 @@ describe('Token Handler', () => { const validClient: OAuthClientInformationFull = { client_id: 'valid-client', client_secret: 'valid-secret', - redirect_uris: ['https://example.com/callback'] + redirect_uris: ['https://valid.com/callback'] }; // Mock client store @@ -44,7 +44,7 @@ describe('Token Handler', () => { clientsStore: mockClientStore, async authorize(client: OAuthClientInformationFull, params: AuthorizationParams, res: Response): Promise { - res.redirect('https://example.com/callback?code=mock_auth_code'); + res.redirect('https://valid.com/callback?code=mock_auth_code'); }, async challengeForAuthorizationCode(client: OAuthClientInformationFull, authorizationCode: string): Promise { @@ -56,8 +56,8 @@ describe('Token Handler', () => { throw new InvalidGrantError('The authorization code is invalid'); }, - async exchangeAuthorizationCode(client: OAuthClientInformationFull, authorizationCode: string): Promise { - if (authorizationCode === 'valid_code') { + async exchangeAuthorizationCode(client: OAuthClientInformationFull, authorizationCode: string, redirectUri: string): Promise { + if (authorizationCode === 'valid_code' && redirectUri === 'https://valid.com/callback') { return { access_token: 'mock_access_token', token_type: 'bearer', @@ -123,7 +123,10 @@ describe('Token Handler', () => { .send({ client_id: 'valid-client', client_secret: 'valid-secret', - grant_type: 'authorization_code' + grant_type: 'authorization_code', + code: 'valid_code', + code_verifier: 'valid_verifier', + redirect_uri: 'https://valid.com/callback' }); expect(response.status).toBe(405); @@ -171,7 +174,10 @@ describe('Token Handler', () => { .send({ client_id: 'invalid-client', client_secret: 'wrong-secret', - grant_type: 'authorization_code' + grant_type: 'authorization_code', + code: 'valid_code', + code_verifier: 'valid_verifier', + redirect_uri: 'https://valid.com/callback' }); expect(response.status).toBe(400); @@ -187,7 +193,8 @@ describe('Token Handler', () => { client_secret: 'valid-secret', grant_type: 'authorization_code', code: 'valid_code', - code_verifier: 'valid_verifier' + code_verifier: 'valid_verifier', + redirect_uri: 'https://valid.com/callback' }); expect(response.status).toBe(200); @@ -204,7 +211,8 @@ describe('Token Handler', () => { client_secret: 'valid-secret', grant_type: 'authorization_code', // Missing code - code_verifier: 'valid_verifier' + code_verifier: 'valid_verifier', + redirect_uri: 'https://valid.com/callback' }); expect(response.status).toBe(400); @@ -219,8 +227,26 @@ describe('Token Handler', () => { client_id: 'valid-client', client_secret: 'valid-secret', grant_type: 'authorization_code', - code: 'valid_code' + code: 'valid_code', // Missing code_verifier + redirect_uri: 'https://valid.com/callback' + }); + + expect(response.status).toBe(400); + expect(response.body.error).toBe('invalid_request'); + }); + + it('requires redirect_uri parameter', async () => { + const response = await supertest(app) + .post('/token') + .type('form') + .send({ + client_id: 'valid-client', + client_secret: 'valid-secret', + grant_type: 'authorization_code', + code: 'valid_code', + code_verifier: 'valid_verifier' + // Missing redirect_uri }); expect(response.status).toBe(400); @@ -239,7 +265,8 @@ describe('Token Handler', () => { client_secret: 'valid-secret', grant_type: 'authorization_code', code: 'valid_code', - code_verifier: 'invalid_verifier' + code_verifier: 'invalid_verifier', + redirect_uri: 'https://valid.com/callback' }); expect(response.status).toBe(400); @@ -256,13 +283,31 @@ describe('Token Handler', () => { client_secret: 'valid-secret', grant_type: 'authorization_code', code: 'expired_code', - code_verifier: 'valid_verifier' + code_verifier: 'valid_verifier', + redirect_uri: 'https://valid.com/callback' }); expect(response.status).toBe(400); expect(response.body.error).toBe('invalid_grant'); }); + it('rejects unregistered redirect_uri', async () => { + const response = await supertest(app) + .post('/token') + .type('form') + .send({ + client_id: 'valid-client', + client_secret: 'valid-secret', + grant_type: 'authorization_code', + code: 'valid_code', + code_verifier: 'valid_verifier', + redirect_uri: 'https://wrong.com/callback' + }); + + expect(response.status).toBe(400); + expect(response.body.error).toBe('invalid_request'); + }); + it('returns tokens for valid code exchange', async () => { const response = await supertest(app) .post('/token') @@ -272,7 +317,8 @@ describe('Token Handler', () => { client_secret: 'valid-secret', grant_type: 'authorization_code', code: 'valid_code', - code_verifier: 'valid_verifier' + code_verifier: 'valid_verifier', + redirect_uri: 'https://valid.com/callback' }); expect(response.status).toBe(200); @@ -322,7 +368,8 @@ describe('Token Handler', () => { client_secret: 'valid-secret', grant_type: 'authorization_code', code: 'valid_code', - code_verifier: 'any_verifier' + code_verifier: 'any_verifier', + redirect_uri: 'https://valid.com/callback' }); expect(response.status).toBe(200); diff --git a/src/server/auth/handlers/token.ts b/src/server/auth/handlers/token.ts index 28412a01..f73eaa50 100644 --- a/src/server/auth/handlers/token.ts +++ b/src/server/auth/handlers/token.ts @@ -31,6 +31,7 @@ const TokenRequestSchema = z.object({ const AuthorizationCodeGrantSchema = z.object({ code: z.string(), code_verifier: z.string(), + redirect_uri: z.string(), }); const RefreshTokenGrantSchema = z.object({ @@ -88,7 +89,10 @@ export function tokenHandler({ provider, rateLimit: rateLimitConfig }: TokenHand throw new InvalidRequestError(parseResult.error.message); } - const { code, code_verifier } = parseResult.data; + const { code, code_verifier, redirect_uri } = parseResult.data; + if (!client.redirect_uris.includes(redirect_uri)) { + throw new InvalidRequestError("Unregistered redirect_uri"); + } const skipLocalPkceValidation = provider.skipLocalPkceValidation; @@ -102,7 +106,7 @@ export function tokenHandler({ provider, rateLimit: rateLimitConfig }: TokenHand } // Passes the code_verifier to the provider if PKCE validation didn't occur locally - const tokens = await provider.exchangeAuthorizationCode(client, code, skipLocalPkceValidation ? code_verifier : undefined); + const tokens = await provider.exchangeAuthorizationCode(client, code, redirect_uri, skipLocalPkceValidation ? code_verifier : undefined); res.status(200).json(tokens); break; } diff --git a/src/server/auth/provider.ts b/src/server/auth/provider.ts index dc186bca..5b308d18 100644 --- a/src/server/auth/provider.ts +++ b/src/server/auth/provider.ts @@ -36,7 +36,7 @@ export interface OAuthServerProvider { /** * Exchanges an authorization code for an access token. */ - exchangeAuthorizationCode(client: OAuthClientInformationFull, authorizationCode: string, codeVerifier?: string): Promise; + exchangeAuthorizationCode(client: OAuthClientInformationFull, authorizationCode: string, redirectUri: string, codeVerifier?: string): Promise; /** * Exchanges a refresh token for an access token. diff --git a/src/server/auth/providers/proxyProvider.test.ts b/src/server/auth/providers/proxyProvider.test.ts index 6e842ea3..e36762fb 100644 --- a/src/server/auth/providers/proxyProvider.test.ts +++ b/src/server/auth/providers/proxyProvider.test.ts @@ -126,6 +126,7 @@ describe("Proxy OAuth Server Provider", () => { const tokens = await provider.exchangeAuthorizationCode( validClient, "test-code", + "https://example.com/callback", "test-verifier" ); diff --git a/src/server/auth/providers/proxyProvider.ts b/src/server/auth/providers/proxyProvider.ts index be450305..fa62c4fe 100644 --- a/src/server/auth/providers/proxyProvider.ts +++ b/src/server/auth/providers/proxyProvider.ts @@ -151,12 +151,14 @@ export class ProxyOAuthServerProvider implements OAuthServerProvider { async exchangeAuthorizationCode( client: OAuthClientInformationFull, authorizationCode: string, + redirectUri: string, codeVerifier?: string ): Promise { const params = new URLSearchParams({ grant_type: "authorization_code", client_id: client.client_id, code: authorizationCode, + redirect_uri: redirectUri, }); if (client.client_secret) {