diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthenticationProviderUtils.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthenticationProviderUtils.java index 97bc402f4..99224cd73 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthenticationProviderUtils.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthenticationProviderUtils.java @@ -17,6 +17,7 @@ import org.springframework.security.authentication.AuthenticationProvider; import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2ErrorCodes; import org.springframework.security.oauth2.core.OAuth2RefreshToken; @@ -55,6 +56,25 @@ static OAuth2Authorization invalidate( (metadata) -> metadata.put(OAuth2Authorization.Token.INVALIDATED_METADATA_NAME, true)); + if (OAuth2AuthorizationCode.class.isAssignableFrom(token.getClass())) { + OAuth2Authorization.Token accessToken = authorization.getAccessToken(); + if (accessToken != null && !accessToken.isInvalidated()) { + authorizationBuilder.token( + accessToken.getToken(), + (metadata) -> + metadata.put(OAuth2Authorization.Token.INVALIDATED_METADATA_NAME, true)); + } + + OAuth2Authorization.Token refreshToken = authorization.getRefreshToken(); + if (refreshToken != null && !refreshToken.isInvalidated()) { + authorizationBuilder.token( + refreshToken.getToken(), + (metadata) -> + metadata.put(OAuth2Authorization.Token.INVALIDATED_METADATA_NAME, true)); + } + + } + if (OAuth2RefreshToken.class.isAssignableFrom(token.getClass())) { authorizationBuilder.token( authorization.getAccessToken().getToken(), diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java index cb8fb9021..51c8af241 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java @@ -149,6 +149,13 @@ public Authentication authenticate(Authentication authentication) throws Authent } if (!authorizationCode.isActive()) { + if (authorizationCode.isInvalidated()) { + authorization = OAuth2AuthenticationProviderUtils.invalidate(authorization, authorizationCode.getToken()); + this.authorizationService.save(authorization); + if (this.logger.isWarnEnabled()) { + this.logger.warn(LogMessage.format("Invalidated authorization tokens previously issued based on the authorization code")); + } + } throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_GRANT); } @@ -169,7 +176,12 @@ public Authentication authenticate(Authentication authentication) throws Authent .authorizationGrant(authorizationCodeAuthentication); // @formatter:on - OAuth2Authorization.Builder authorizationBuilder = OAuth2Authorization.from(authorization); + // @formatter:off + OAuth2Authorization.Builder authorizationBuilder = OAuth2Authorization.from(authorization) + // Invalidate the authorization code as it can only be used once + .token(authorizationCode.getToken(), metadata -> + metadata.put(OAuth2Authorization.Token.INVALIDATED_METADATA_NAME, true)); + // @formatter:on // ----- Access token ----- OAuth2TokenContext tokenContext = tokenContextBuilder.tokenType(OAuth2TokenType.ACCESS_TOKEN).build(); @@ -250,9 +262,6 @@ public Authentication authenticate(Authentication authentication) throws Authent authorization = authorizationBuilder.build(); - // Invalidate the authorization code as it can only be used once - authorization = OAuth2AuthenticationProviderUtils.invalidate(authorization, authorizationCode.getToken()); - this.authorizationService.save(authorization); if (this.logger.isTraceEnabled()) { @@ -305,5 +314,4 @@ private SessionInformation getSessionInformation(Authentication principal) { } return sessionInformation; } - } diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java index d0bc3613e..9e82fd53a 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java @@ -271,6 +271,12 @@ public void authenticateWhenInvalidatedCodeThenThrowOAuth2AuthenticationExceptio .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) .extracting("errorCode") .isEqualTo(OAuth2ErrorCodes.INVALID_GRANT); + + ArgumentCaptor authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class); + verify(this.authorizationService).save(authorizationCaptor.capture()); + OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue(); + assertThat(updatedAuthorization.getAccessToken().isInvalidated()).isTrue(); + assertThat(updatedAuthorization.getRefreshToken().isInvalidated()).isTrue(); } // gh-290