diff --git a/.github/dependabot.yml b/.github/dependabot.yml index 6dc7ac86e78..bd91e456f38 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -94,6 +94,15 @@ updates: - dependency-name: "*" update-types: [ "version-update:semver-major", "version-update:semver-minor" ] + - package-ecosystem: "gradle" + target-branch: "docs-build" + directory: "/" + schedule: + interval: "daily" + time: "03:00" + timezone: "Etc/UTC" + labels: [ "type: dependency-upgrade" ] + # GitHub Actions - package-ecosystem: github-actions diff --git a/.github/workflows/continuous-integration-workflow.yml b/.github/workflows/continuous-integration-workflow.yml index 39a4dfbfc91..9b5960b47ac 100644 --- a/.github/workflows/continuous-integration-workflow.yml +++ b/.github/workflows/continuous-integration-workflow.yml @@ -60,7 +60,7 @@ jobs: distribution: 'temurin' cache: 'gradle' - name: Set up Gradle - uses: gradle/gradle-build-action@v2 + uses: gradle/gradle-build-action@v3 - name: Set up gradle user name run: echo 'systemProp.user.name=spring-builds+github' >> gradle.properties - name: Build with Gradle @@ -254,7 +254,7 @@ jobs: ./gradlew createGitHubRelease -PnextVersion=$VERSION -Pbranch=$BRANCH -PcreateRelease=true -PgitHubAccessToken=$TOKEN - name: Announce Release on Slack id: spring-security-announcing - uses: slackapi/slack-github-action@v1.24.0 + uses: slackapi/slack-github-action@v1.25.0 with: payload: | { diff --git a/.github/workflows/edit-dependabot-pr.yml b/.github/workflows/edit-dependabot-pr.yml new file mode 100644 index 00000000000..a273c61d5e4 --- /dev/null +++ b/.github/workflows/edit-dependabot-pr.yml @@ -0,0 +1,55 @@ +# This workflow is an adaptation from https://github.com/spring-projects/spring-integration/blob/main/.github/workflows/merge-dependabot-pr.yml +# and https://github.com/spring-io/spring-github-workflows/blob/main/.github/workflows/spring-merge-dependabot-pr.yml + +name: Edit Dependabot PR + +on: + pull_request: + +run-name: Edit Dependabot PR ${{ github.ref_name }} + +env: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + +jobs: + edit-dependabot-pr: + runs-on: ubuntu-latest + if: github.actor == 'dependabot[bot]' + permissions: write-all + steps: + + - uses: actions/checkout@v4 + with: + show-progress: false + + - uses: actions/setup-java@v4 + with: + distribution: temurin + java-version: 17 + + - name: Dependabot metadata + id: metadata + uses: dependabot/fetch-metadata@v1 + with: + github-token: ${{ env.GH_TOKEN }} + + - name: Set Milestone to Dependabot pull request + id: set-milestone + run: | + if test -f pom.xml + then + CURRENT_VERSION=$(mvn help:evaluate -Dexpression="project.version" -q -DforceStdout) + else + CURRENT_VERSION=$(cat gradle.properties | sed -n '/^version=/ { s/^version=//;p }') + fi + export CANDIDATE_VERSION=${CURRENT_VERSION/-SNAPSHOT} + MILESTONE=$(gh api repos/$GITHUB_REPOSITORY/milestones --jq 'map(select(.due_on != null and (.title | startswith(env.CANDIDATE_VERSION)))) | .[0] | .title') + + if [ -z $MILESTONE ] + then + gh run cancel ${{ github.run_id }} + echo "::warning title=Cannot merge::No scheduled milestone for $CURRENT_VERSION version" + else + gh pr edit ${{ github.event.pull_request.number }} --milestone $MILESTONE + echo mergeEnabled=true >> $GITHUB_OUTPUT + fi diff --git a/CONTRIBUTING.adoc b/CONTRIBUTING.adoc index ad08a4116b1..f9bc67c9efb 100644 --- a/CONTRIBUTING.adoc +++ b/CONTRIBUTING.adoc @@ -79,7 +79,7 @@ See https://github.com/spring-projects/spring-security/tree/main#building-from-s The wiki pages https://github.com/spring-projects/spring-framework/wiki/Code-Style[Code Style] and https://github.com/spring-projects/spring-framework/wiki/IntelliJ-IDEA-Editor-Settings[IntelliJ IDEA Editor Settings] define the source file coding standards we use along with some IDEA editor settings we customize. -To format the code as well as check the style, run `./gradle format check`. +To format the code as well as check the style, run `./gradlew format check`. [[submit-a-pull-request]] === Submit a Pull Request diff --git a/README.adoc b/README.adoc index 969125ab470..71b556d21e0 100644 --- a/README.adoc +++ b/README.adoc @@ -2,7 +2,7 @@ image::https://badges.gitter.im/Join%20Chat.svg[Gitter,link=https://gitter.im/sp image:https://github.com/spring-projects/spring-security/actions/workflows/continuous-integration-workflow.yml/badge.svg?branch=main["Build Status", link="https://github.com/spring-projects/spring-security/actions/workflows/continuous-integration-workflow.yml"] -image:https://img.shields.io/badge/Revved%20up%20by-Gradle%20Enterprise-06A0CE?logo=Gradle&labelColor=02303A["Revved up by Gradle Enterprise", link="https://ge.spring.io/scans?search.rootProjectNames=spring-security"] +image:https://img.shields.io/badge/Revved%20up%20by-Develocity-06A0CE?logo=Gradle&labelColor=02303A["Revved up by Develocity", link="https://ge.spring.io/scans?search.rootProjectNames=spring-security"] = Spring Security diff --git a/config/src/main/java/org/springframework/security/config/web/server/OidcBackChannelLogoutReactiveAuthenticationManager.java b/config/src/main/java/org/springframework/security/config/web/server/OidcBackChannelLogoutReactiveAuthenticationManager.java index 1cd87fc830e..46e4c442062 100644 --- a/config/src/main/java/org/springframework/security/config/web/server/OidcBackChannelLogoutReactiveAuthenticationManager.java +++ b/config/src/main/java/org/springframework/security/config/web/server/OidcBackChannelLogoutReactiveAuthenticationManager.java @@ -85,17 +85,14 @@ public Mono authenticate(Authentication authentication) throws A private Mono decode(ClientRegistration registration, String token) { ReactiveJwtDecoder logoutTokenDecoder = this.logoutTokenDecoderFactory.createDecoder(registration); - try { - return logoutTokenDecoder.decode(token); - } - catch (BadJwtException failed) { - OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST, failed.getMessage(), - "https://openid.net/specs/openid-connect-backchannel-1_0.html#Validation"); - return Mono.error(new OAuth2AuthenticationException(error, failed)); - } - catch (Exception failed) { - return Mono.error(new AuthenticationServiceException(failed.getMessage(), failed)); - } + return logoutTokenDecoder.decode(token).onErrorResume(Exception.class, (ex) -> { + if (ex instanceof BadJwtException) { + OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST, ex.getMessage(), + "https://openid.net/specs/openid-connect-backchannel-1_0.html#Validation"); + return Mono.error(new OAuth2AuthenticationException(error, ex)); + } + return Mono.error(new AuthenticationServiceException(ex.getMessage(), ex)); + }); } /** diff --git a/config/src/main/kotlin/org/springframework/security/config/annotation/web/FormLoginDsl.kt b/config/src/main/kotlin/org/springframework/security/config/annotation/web/FormLoginDsl.kt index 3a03ddf170a..0da773f570a 100644 --- a/config/src/main/kotlin/org/springframework/security/config/annotation/web/FormLoginDsl.kt +++ b/config/src/main/kotlin/org/springframework/security/config/annotation/web/FormLoginDsl.kt @@ -1,5 +1,5 @@ /* - * Copyright 2002-2023 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -38,6 +38,8 @@ import jakarta.servlet.http.HttpServletRequest * @property loginProcessingUrl the URL to validate the credentials * @property permitAll whether to grant access to the urls for [failureUrl] as well as * for the [HttpSecurityBuilder], the [loginPage] and [loginProcessingUrl] for every user + * @property usernameParameter the HTTP parameter to look for the username when performing authentication + * @property passwordParameter the HTTP parameter to look for the password when performing authentication */ @SecurityMarker class FormLoginDsl { @@ -48,6 +50,8 @@ class FormLoginDsl { var loginProcessingUrl: String? = null var permitAll: Boolean? = null var authenticationDetailsSource: AuthenticationDetailsSource? = null + var usernameParameter: String? = null + var passwordParameter: String? = null private var defaultSuccessUrlOption: Pair? = null @@ -95,6 +99,8 @@ class FormLoginDsl { authenticationSuccessHandler?.also { login.successHandler(authenticationSuccessHandler) } authenticationFailureHandler?.also { login.failureHandler(authenticationFailureHandler) } authenticationDetailsSource?.also { login.authenticationDetailsSource(authenticationDetailsSource) } + usernameParameter?.also { login.usernameParameter(usernameParameter) } + passwordParameter?.also { login.passwordParameter(passwordParameter) } if (disabled) { login.disable() } diff --git a/config/src/test/kotlin/org/springframework/security/config/annotation/web/FormLoginDslTests.kt b/config/src/test/kotlin/org/springframework/security/config/annotation/web/FormLoginDslTests.kt index 8c44ef8524a..965c361b4a3 100644 --- a/config/src/test/kotlin/org/springframework/security/config/annotation/web/FormLoginDslTests.kt +++ b/config/src/test/kotlin/org/springframework/security/config/annotation/web/FormLoginDslTests.kt @@ -1,5 +1,5 @@ /* - * Copyright 2002-2023 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -33,6 +33,7 @@ import org.springframework.security.config.test.SpringTestContextExtension import org.springframework.security.core.userdetails.User import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestBuilders.formLogin import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf +import org.springframework.security.test.web.servlet.response.SecurityMockMvcResultMatchers.authenticated import org.springframework.security.web.SecurityFilterChain import org.springframework.security.web.authentication.SimpleUrlAuthenticationFailureHandler import org.springframework.security.web.authentication.SimpleUrlAuthenticationSuccessHandler @@ -367,6 +368,50 @@ class FormLoginDslTests { verify(exactly = 1) { CustomAuthenticationDetailsSourceConfig.AUTHENTICATION_DETAILS_SOURCE.buildDetails(any()) } } + @Configuration + @EnableWebSecurity + open class CustomUsernameParameterConfig { + @Bean + open fun securityFilterChain(http: HttpSecurity): SecurityFilterChain { + http { + formLogin { + usernameParameter = "custom-username" + } + } + return http.build() + } + } + + @Test + fun `form login when custom username parameter then used`() { + this.spring.register(CustomUsernameParameterConfig::class.java, UserConfig::class.java).autowire() + + this.mockMvc.perform(formLogin().userParameter("custom-username")) + .andExpect(authenticated()) + } + + @Configuration + @EnableWebSecurity + open class CustomPasswordParameterConfig { + @Bean + open fun securityFilterChain(http: HttpSecurity): SecurityFilterChain { + http { + formLogin { + passwordParameter = "custom-password" + } + } + return http.build() + } + } + + @Test + fun `form login when custom password parameter then used`() { + this.spring.register(CustomPasswordParameterConfig::class.java, UserConfig::class.java).autowire() + + this.mockMvc.perform(formLogin().passwordParam("custom-password")) + .andExpect(authenticated()) + } + @Configuration @EnableWebSecurity open class CustomAuthenticationDetailsSourceConfig { diff --git a/core/src/main/java/org/springframework/security/authorization/method/AbstractExpressionAttributeRegistry.java b/core/src/main/java/org/springframework/security/authorization/method/AbstractExpressionAttributeRegistry.java index 42b7cd92c03..23693f7c93d 100644 --- a/core/src/main/java/org/springframework/security/authorization/method/AbstractExpressionAttributeRegistry.java +++ b/core/src/main/java/org/springframework/security/authorization/method/AbstractExpressionAttributeRegistry.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2021 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -29,6 +29,7 @@ * For internal use only, as this contract is likely to change * * @author Evgeniy Cheban + * @author DingHao */ abstract class AbstractExpressionAttributeRegistry { @@ -67,4 +68,8 @@ final T getAttribute(Method method, Class targetClass) { @NonNull abstract T resolveAttribute(Method method, Class targetClass); + Class targetClass(Method method, Class targetClass) { + return (targetClass != null) ? targetClass : method.getDeclaringClass(); + } + } diff --git a/core/src/main/java/org/springframework/security/authorization/method/Jsr250AuthorizationManager.java b/core/src/main/java/org/springframework/security/authorization/method/Jsr250AuthorizationManager.java index beb318ed15f..f913db85f68 100644 --- a/core/src/main/java/org/springframework/security/authorization/method/Jsr250AuthorizationManager.java +++ b/core/src/main/java/org/springframework/security/authorization/method/Jsr250AuthorizationManager.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2023 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -44,6 +44,7 @@ * * @author Evgeniy Cheban * @author Josh Cummings + * @author DingHao * @since 5.6 */ public final class Jsr250AuthorizationManager implements AuthorizationManager { @@ -121,7 +122,8 @@ AuthorizationManager resolveManager(Method method, Class ta private Annotation findJsr250Annotation(Method method, Class targetClass) { Method specificMethod = AopUtils.getMostSpecificMethod(method, targetClass); Annotation annotation = findAnnotation(specificMethod); - return (annotation != null) ? annotation : findAnnotation(specificMethod.getDeclaringClass()); + return (annotation != null) ? annotation + : findAnnotation((targetClass != null) ? targetClass : specificMethod.getDeclaringClass()); } private Annotation findAnnotation(Method method) { diff --git a/core/src/main/java/org/springframework/security/authorization/method/PostAuthorizeExpressionAttributeRegistry.java b/core/src/main/java/org/springframework/security/authorization/method/PostAuthorizeExpressionAttributeRegistry.java index c89bbc3e312..300fd08f7ee 100644 --- a/core/src/main/java/org/springframework/security/authorization/method/PostAuthorizeExpressionAttributeRegistry.java +++ b/core/src/main/java/org/springframework/security/authorization/method/PostAuthorizeExpressionAttributeRegistry.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -31,6 +31,7 @@ * For internal use only, as this contract is likely to change. * * @author Evgeniy Cheban + * @author DingHao * @since 5.8 */ final class PostAuthorizeExpressionAttributeRegistry extends AbstractExpressionAttributeRegistry { @@ -54,7 +55,7 @@ MethodSecurityExpressionHandler getExpressionHandler() { @Override ExpressionAttribute resolveAttribute(Method method, Class targetClass) { Method specificMethod = AopUtils.getMostSpecificMethod(method, targetClass); - PostAuthorize postAuthorize = findPostAuthorizeAnnotation(specificMethod); + PostAuthorize postAuthorize = findPostAuthorizeAnnotation(specificMethod, targetClass); if (postAuthorize == null) { return ExpressionAttribute.NULL_ATTRIBUTE; } @@ -63,10 +64,10 @@ ExpressionAttribute resolveAttribute(Method method, Class targetClass) { return new ExpressionAttribute(postAuthorizeExpression); } - private PostAuthorize findPostAuthorizeAnnotation(Method method) { + private PostAuthorize findPostAuthorizeAnnotation(Method method, Class targetClass) { PostAuthorize postAuthorize = AuthorizationAnnotationUtils.findUniqueAnnotation(method, PostAuthorize.class); - return (postAuthorize != null) ? postAuthorize - : AuthorizationAnnotationUtils.findUniqueAnnotation(method.getDeclaringClass(), PostAuthorize.class); + return (postAuthorize != null) ? postAuthorize : AuthorizationAnnotationUtils + .findUniqueAnnotation(targetClass(method, targetClass), PostAuthorize.class); } } diff --git a/core/src/main/java/org/springframework/security/authorization/method/PostFilterExpressionAttributeRegistry.java b/core/src/main/java/org/springframework/security/authorization/method/PostFilterExpressionAttributeRegistry.java index 4bc33bc493d..541d2caae54 100644 --- a/core/src/main/java/org/springframework/security/authorization/method/PostFilterExpressionAttributeRegistry.java +++ b/core/src/main/java/org/springframework/security/authorization/method/PostFilterExpressionAttributeRegistry.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -30,6 +30,7 @@ * For internal use only, as this contract is likely to change. * * @author Evgeniy Cheban + * @author DingHao * @since 5.8 */ final class PostFilterExpressionAttributeRegistry extends AbstractExpressionAttributeRegistry { @@ -53,7 +54,7 @@ MethodSecurityExpressionHandler getExpressionHandler() { @Override ExpressionAttribute resolveAttribute(Method method, Class targetClass) { Method specificMethod = AopUtils.getMostSpecificMethod(method, targetClass); - PostFilter postFilter = findPostFilterAnnotation(specificMethod); + PostFilter postFilter = findPostFilterAnnotation(specificMethod, targetClass); if (postFilter == null) { return ExpressionAttribute.NULL_ATTRIBUTE; } @@ -62,10 +63,10 @@ ExpressionAttribute resolveAttribute(Method method, Class targetClass) { return new ExpressionAttribute(postFilterExpression); } - private PostFilter findPostFilterAnnotation(Method method) { + private PostFilter findPostFilterAnnotation(Method method, Class targetClass) { PostFilter postFilter = AuthorizationAnnotationUtils.findUniqueAnnotation(method, PostFilter.class); return (postFilter != null) ? postFilter - : AuthorizationAnnotationUtils.findUniqueAnnotation(method.getDeclaringClass(), PostFilter.class); + : AuthorizationAnnotationUtils.findUniqueAnnotation(targetClass(method, targetClass), PostFilter.class); } } diff --git a/core/src/main/java/org/springframework/security/authorization/method/PreAuthorizeExpressionAttributeRegistry.java b/core/src/main/java/org/springframework/security/authorization/method/PreAuthorizeExpressionAttributeRegistry.java index dcae13eb205..0c445a37d1f 100644 --- a/core/src/main/java/org/springframework/security/authorization/method/PreAuthorizeExpressionAttributeRegistry.java +++ b/core/src/main/java/org/springframework/security/authorization/method/PreAuthorizeExpressionAttributeRegistry.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -31,6 +31,7 @@ * For internal use only, as this contract is likely to change. * * @author Evgeniy Cheban + * @author DingHao * @since 5.8 */ final class PreAuthorizeExpressionAttributeRegistry extends AbstractExpressionAttributeRegistry { @@ -58,7 +59,7 @@ MethodSecurityExpressionHandler getExpressionHandler() { @Override ExpressionAttribute resolveAttribute(Method method, Class targetClass) { Method specificMethod = AopUtils.getMostSpecificMethod(method, targetClass); - PreAuthorize preAuthorize = findPreAuthorizeAnnotation(specificMethod); + PreAuthorize preAuthorize = findPreAuthorizeAnnotation(specificMethod, targetClass); if (preAuthorize == null) { return ExpressionAttribute.NULL_ATTRIBUTE; } @@ -67,10 +68,10 @@ ExpressionAttribute resolveAttribute(Method method, Class targetClass) { return new ExpressionAttribute(preAuthorizeExpression); } - private PreAuthorize findPreAuthorizeAnnotation(Method method) { + private PreAuthorize findPreAuthorizeAnnotation(Method method, Class targetClass) { PreAuthorize preAuthorize = AuthorizationAnnotationUtils.findUniqueAnnotation(method, PreAuthorize.class); - return (preAuthorize != null) ? preAuthorize - : AuthorizationAnnotationUtils.findUniqueAnnotation(method.getDeclaringClass(), PreAuthorize.class); + return (preAuthorize != null) ? preAuthorize : AuthorizationAnnotationUtils + .findUniqueAnnotation(targetClass(method, targetClass), PreAuthorize.class); } } diff --git a/core/src/main/java/org/springframework/security/authorization/method/PreFilterExpressionAttributeRegistry.java b/core/src/main/java/org/springframework/security/authorization/method/PreFilterExpressionAttributeRegistry.java index 6fa8448355a..67bab2c7ff7 100644 --- a/core/src/main/java/org/springframework/security/authorization/method/PreFilterExpressionAttributeRegistry.java +++ b/core/src/main/java/org/springframework/security/authorization/method/PreFilterExpressionAttributeRegistry.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -30,6 +30,7 @@ * For internal use only, as this contract is likely to change. * * @author Evgeniy Cheban + * @author DingHao * @since 5.8 */ final class PreFilterExpressionAttributeRegistry @@ -54,7 +55,7 @@ MethodSecurityExpressionHandler getExpressionHandler() { @Override PreFilterExpressionAttribute resolveAttribute(Method method, Class targetClass) { Method specificMethod = AopUtils.getMostSpecificMethod(method, targetClass); - PreFilter preFilter = findPreFilterAnnotation(specificMethod); + PreFilter preFilter = findPreFilterAnnotation(specificMethod, targetClass); if (preFilter == null) { return PreFilterExpressionAttribute.NULL_ATTRIBUTE; } @@ -63,10 +64,10 @@ PreFilterExpressionAttribute resolveAttribute(Method method, Class targetClas return new PreFilterExpressionAttribute(preFilterExpression, preFilter.filterTarget()); } - private PreFilter findPreFilterAnnotation(Method method) { + private PreFilter findPreFilterAnnotation(Method method, Class targetClass) { PreFilter preFilter = AuthorizationAnnotationUtils.findUniqueAnnotation(method, PreFilter.class); return (preFilter != null) ? preFilter - : AuthorizationAnnotationUtils.findUniqueAnnotation(method.getDeclaringClass(), PreFilter.class); + : AuthorizationAnnotationUtils.findUniqueAnnotation(targetClass(method, targetClass), PreFilter.class); } static final class PreFilterExpressionAttribute extends ExpressionAttribute { diff --git a/core/src/main/java/org/springframework/security/authorization/method/SecuredAuthorizationManager.java b/core/src/main/java/org/springframework/security/authorization/method/SecuredAuthorizationManager.java index dcfc8a8511f..63553503d21 100644 --- a/core/src/main/java/org/springframework/security/authorization/method/SecuredAuthorizationManager.java +++ b/core/src/main/java/org/springframework/security/authorization/method/SecuredAuthorizationManager.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2023 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -41,6 +41,7 @@ * contains a specified authority from the Spring Security's {@link Secured} annotation. * * @author Evgeniy Cheban + * @author DingHao * @since 5.6 */ public final class SecuredAuthorizationManager implements AuthorizationManager { @@ -86,14 +87,14 @@ private Set getAuthorities(MethodInvocation methodInvocation) { private Set resolveAuthorities(Method method, Class targetClass) { Method specificMethod = AopUtils.getMostSpecificMethod(method, targetClass); - Secured secured = findSecuredAnnotation(specificMethod); + Secured secured = findSecuredAnnotation(specificMethod, targetClass); return (secured != null) ? Set.of(secured.value()) : Collections.emptySet(); } - private Secured findSecuredAnnotation(Method method) { + private Secured findSecuredAnnotation(Method method, Class targetClass) { Secured secured = AuthorizationAnnotationUtils.findUniqueAnnotation(method, Secured.class); - return (secured != null) ? secured - : AuthorizationAnnotationUtils.findUniqueAnnotation(method.getDeclaringClass(), Secured.class); + return (secured != null) ? secured : AuthorizationAnnotationUtils + .findUniqueAnnotation((targetClass != null) ? targetClass : method.getDeclaringClass(), Secured.class); } } diff --git a/core/src/test/java/org/springframework/security/authorization/method/Jsr250AuthorizationManagerTests.java b/core/src/test/java/org/springframework/security/authorization/method/Jsr250AuthorizationManagerTests.java index 5a35b884895..fec14a5d2c8 100644 --- a/core/src/test/java/org/springframework/security/authorization/method/Jsr250AuthorizationManagerTests.java +++ b/core/src/test/java/org/springframework/security/authorization/method/Jsr250AuthorizationManagerTests.java @@ -225,6 +225,48 @@ public void checkInheritedAnnotationsWhenConflictingThenAnnotationConfigurationE .isThrownBy(() -> manager.check(authentication, methodInvocation)); } + @Test + public void checkRequiresUserWhenMethodsFromInheritThenApplies() throws Exception { + MockMethodInvocation methodInvocation = new MockMethodInvocation(new RolesAllowedClass(), + RolesAllowedClass.class, "securedUser"); + Jsr250AuthorizationManager manager = new Jsr250AuthorizationManager(); + AuthorizationDecision decision = manager.check(TestAuthentication::authenticatedUser, methodInvocation); + assertThat(decision.isGranted()).isTrue(); + } + + @Test + public void checkPermitAllWhenMethodsFromInheritThenApplies() throws Exception { + MockMethodInvocation methodInvocation = new MockMethodInvocation(new PermitAllClass(), PermitAllClass.class, + "securedUser"); + Jsr250AuthorizationManager manager = new Jsr250AuthorizationManager(); + AuthorizationDecision decision = manager.check(TestAuthentication::authenticatedUser, methodInvocation); + assertThat(decision.isGranted()).isTrue(); + } + + @Test + public void checkDenyAllWhenMethodsFromInheritThenApplies() throws Exception { + MockMethodInvocation methodInvocation = new MockMethodInvocation(new DenyAllClass(), DenyAllClass.class, + "securedUser"); + Jsr250AuthorizationManager manager = new Jsr250AuthorizationManager(); + AuthorizationDecision decision = manager.check(TestAuthentication::authenticatedUser, methodInvocation); + assertThat(decision.isGranted()).isFalse(); + } + + @RolesAllowed("USER") + public static class RolesAllowedClass extends SecuredAuthorizationManagerTests.ParentClass { + + } + + @PermitAll + public static class PermitAllClass extends SecuredAuthorizationManagerTests.ParentClass { + + } + + @DenyAll + public static class DenyAllClass extends SecuredAuthorizationManagerTests.ParentClass { + + } + public static class TestClass implements InterfaceAnnotationsOne, InterfaceAnnotationsTwo { public void doSomething() { diff --git a/core/src/test/java/org/springframework/security/authorization/method/PostAuthorizeAuthorizationManagerTests.java b/core/src/test/java/org/springframework/security/authorization/method/PostAuthorizeAuthorizationManagerTests.java index 37383b40d57..e345eb97da1 100644 --- a/core/src/test/java/org/springframework/security/authorization/method/PostAuthorizeAuthorizationManagerTests.java +++ b/core/src/test/java/org/springframework/security/authorization/method/PostAuthorizeAuthorizationManagerTests.java @@ -167,6 +167,21 @@ public void checkInheritedAnnotationsWhenConflictingThenAnnotationConfigurationE .isThrownBy(() -> manager.check(authentication, result)); } + @Test + public void checkRequiresUserWhenMethodsFromInheritThenApplies() throws Exception { + MockMethodInvocation methodInvocation = new MockMethodInvocation(new PostAuthorizeClass(), + PostAuthorizeClass.class, "securedUser"); + MethodInvocationResult result = new MethodInvocationResult(methodInvocation, null); + PostAuthorizeAuthorizationManager manager = new PostAuthorizeAuthorizationManager(); + AuthorizationDecision decision = manager.check(TestAuthentication::authenticatedUser, result); + assertThat(decision.isGranted()).isTrue(); + } + + @PostAuthorize("hasRole('USER')") + public static class PostAuthorizeClass extends SecuredAuthorizationManagerTests.ParentClass { + + } + public static class TestClass implements InterfaceAnnotationsOne, InterfaceAnnotationsTwo { public void doSomething() { diff --git a/core/src/test/java/org/springframework/security/authorization/method/PostFilterAuthorizationMethodInterceptorTests.java b/core/src/test/java/org/springframework/security/authorization/method/PostFilterAuthorizationMethodInterceptorTests.java index 00f2aed42dd..fe1f1803125 100644 --- a/core/src/test/java/org/springframework/security/authorization/method/PostFilterAuthorizationMethodInterceptorTests.java +++ b/core/src/test/java/org/springframework/security/authorization/method/PostFilterAuthorizationMethodInterceptorTests.java @@ -170,6 +170,34 @@ public Object proceed() { SecurityContextHolder.setContextHolderStrategy(saved); } + @Test + public void checkPostFilterWhenMethodsFromInheritThenApplies() throws Throwable { + String[] array = { "john", "bob" }; + MockMethodInvocation methodInvocation = new MockMethodInvocation(new PostFilterClass(), PostFilterClass.class, + "inheritMethod", new Class[] { String[].class }, new Object[] { array }) { + @Override + public Object proceed() { + return array; + } + }; + PostFilterAuthorizationMethodInterceptor advice = new PostFilterAuthorizationMethodInterceptor(); + Object result = advice.invoke(methodInvocation); + assertThat(result).asInstanceOf(InstanceOfAssertFactories.array(String[].class)).containsOnly("john"); + } + + @PostFilter("filterObject == 'john'") + public static class PostFilterClass extends ParentClass { + + } + + public static class ParentClass { + + public String[] inheritMethod(String[] array) { + return array; + } + + } + @PostFilter("filterObject == 'john'") public static class TestClass implements InterfaceAnnotationsOne, InterfaceAnnotationsTwo { diff --git a/core/src/test/java/org/springframework/security/authorization/method/PreAuthorizeAuthorizationManagerTests.java b/core/src/test/java/org/springframework/security/authorization/method/PreAuthorizeAuthorizationManagerTests.java index cb43868dbf0..ff191b845f9 100644 --- a/core/src/test/java/org/springframework/security/authorization/method/PreAuthorizeAuthorizationManagerTests.java +++ b/core/src/test/java/org/springframework/security/authorization/method/PreAuthorizeAuthorizationManagerTests.java @@ -147,6 +147,20 @@ public void checkTargetClassAwareWhenInterfaceLevelAnnotationsThenApplies() thro assertThat(decision.isGranted()).isTrue(); } + @Test + public void checkRequiresUserWhenMethodsFromInheritThenApplies() throws Exception { + MockMethodInvocation methodInvocation = new MockMethodInvocation(new PreAuthorizeClass(), + PreAuthorizeClass.class, "securedUser"); + PreAuthorizeAuthorizationManager manager = new PreAuthorizeAuthorizationManager(); + AuthorizationDecision decision = manager.check(TestAuthentication::authenticatedUser, methodInvocation); + assertThat(decision.isGranted()).isTrue(); + } + + @PreAuthorize("hasRole('USER')") + public static class PreAuthorizeClass extends SecuredAuthorizationManagerTests.ParentClass { + + } + public static class TestClass implements InterfaceAnnotationsOne, InterfaceAnnotationsTwo { public void doSomething() { diff --git a/core/src/test/java/org/springframework/security/authorization/method/PreFilterAuthorizationMethodInterceptorTests.java b/core/src/test/java/org/springframework/security/authorization/method/PreFilterAuthorizationMethodInterceptorTests.java index 4f1d56fb146..b750c72c512 100644 --- a/core/src/test/java/org/springframework/security/authorization/method/PreFilterAuthorizationMethodInterceptorTests.java +++ b/core/src/test/java/org/springframework/security/authorization/method/PreFilterAuthorizationMethodInterceptorTests.java @@ -224,6 +224,32 @@ public void preFilterWhenStaticSecurityContextHolderStrategyAfterConstructorThen SecurityContextHolder.setContextHolderStrategy(saved); } + @Test + public void checkPreFilterWhenMethodsFromInheritThenApplies() throws Throwable { + List list = new ArrayList<>(); + list.add("john"); + list.add("bob"); + MockMethodInvocation invocation = new MockMethodInvocation(new PreFilterClass(), PreFilterClass.class, + "inheritMethod", new Class[] { List.class }, new Object[] { list }); + PreFilterAuthorizationMethodInterceptor advice = new PreFilterAuthorizationMethodInterceptor(); + advice.invoke(invocation); + assertThat(list).hasSize(1); + assertThat(list.get(0)).isEqualTo("john"); + } + + @PreFilter("filterObject == 'john'") + public static class PreFilterClass extends ParentClass { + + } + + public static class ParentClass { + + public void inheritMethod(List list) { + + } + + } + @PreFilter("filterObject == 'john'") public static class TestClass implements InterfaceAnnotationsOne, InterfaceAnnotationsTwo { diff --git a/core/src/test/java/org/springframework/security/authorization/method/SecuredAuthorizationManagerTests.java b/core/src/test/java/org/springframework/security/authorization/method/SecuredAuthorizationManagerTests.java index 5d8651df9e8..117d9935569 100644 --- a/core/src/test/java/org/springframework/security/authorization/method/SecuredAuthorizationManagerTests.java +++ b/core/src/test/java/org/springframework/security/authorization/method/SecuredAuthorizationManagerTests.java @@ -167,6 +167,28 @@ public void checkTargetClassAwareWhenInterfaceLevelAnnotationsThenApplies() thro assertThat(decision.isGranted()).isTrue(); } + @Test + public void checkRequiresUserWhenMethodsFromInheritThenApplies() throws Exception { + MockMethodInvocation methodInvocation = new MockMethodInvocation(new SecuredSonClass(), SecuredSonClass.class, + "securedUser"); + SecuredAuthorizationManager manager = new SecuredAuthorizationManager(); + AuthorizationDecision decision = manager.check(TestAuthentication::authenticatedUser, methodInvocation); + assertThat(decision.isGranted()).isTrue(); + } + + @Secured("ROLE_USER") + public static class SecuredSonClass extends ParentClass { + + } + + public static class ParentClass { + + public void securedUser() { + + } + + } + public static class TestClass implements InterfaceAnnotationsOne, InterfaceAnnotationsTwo { public void doSomething() { diff --git a/crypto/spring-security-crypto.gradle b/crypto/spring-security-crypto.gradle index b45a20b0d87..0894300dbcd 100644 --- a/crypto/spring-security-crypto.gradle +++ b/crypto/spring-security-crypto.gradle @@ -3,8 +3,9 @@ apply plugin: 'io.spring.convention.spring-module' dependencies { management platform(project(":spring-security-dependencies")) optional 'org.springframework:spring-jcl' + optional 'org.springframework:spring-core' optional 'org.bouncycastle:bcpkix-jdk15on' - + testImplementation "org.assertj:assertj-core" testImplementation "org.junit.jupiter:junit-jupiter-api" testImplementation "org.junit.jupiter:junit-jupiter-params" diff --git a/crypto/src/main/java/org/springframework/security/crypto/encrypt/KeyStoreKeyFactory.java b/crypto/src/main/java/org/springframework/security/crypto/encrypt/KeyStoreKeyFactory.java new file mode 100644 index 00000000000..9c226042f2b --- /dev/null +++ b/crypto/src/main/java/org/springframework/security/crypto/encrypt/KeyStoreKeyFactory.java @@ -0,0 +1,96 @@ +/* + * Copyright 2013-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.crypto.encrypt; + +import java.io.InputStream; +import java.security.KeyFactory; +import java.security.KeyPair; +import java.security.KeyStore; +import java.security.PublicKey; +import java.security.cert.Certificate; +import java.security.interfaces.RSAPrivateCrtKey; +import java.security.spec.RSAPublicKeySpec; + +import org.springframework.core.io.Resource; +import org.springframework.util.StringUtils; + +/** + * @author Dave Syer + * @author Tim Ysewyn + * @since 6.3 + */ +public class KeyStoreKeyFactory { + + private final Resource resource; + + private final char[] password; + + private KeyStore store; + + private final Object lock = new Object(); + + private final String type; + + public KeyStoreKeyFactory(Resource resource, char[] password) { + this(resource, password, type(resource)); + } + + private static String type(Resource resource) { + String ext = StringUtils.getFilenameExtension(resource.getFilename()); + return (ext != null) ? ext : "jks"; + } + + public KeyStoreKeyFactory(Resource resource, char[] password, String type) { + this.resource = resource; + this.password = password; + this.type = type; + } + + public KeyPair getKeyPair(String alias) { + return getKeyPair(alias, this.password); + } + + public KeyPair getKeyPair(String alias, char[] password) { + try { + synchronized (this.lock) { + if (this.store == null) { + synchronized (this.lock) { + this.store = KeyStore.getInstance(this.type); + try (InputStream stream = this.resource.getInputStream()) { + this.store.load(stream, this.password); + } + } + } + } + RSAPrivateCrtKey key = (RSAPrivateCrtKey) this.store.getKey(alias, password); + Certificate certificate = this.store.getCertificate(alias); + PublicKey publicKey = null; + if (certificate != null) { + publicKey = certificate.getPublicKey(); + } + else if (key != null) { + RSAPublicKeySpec spec = new RSAPublicKeySpec(key.getModulus(), key.getPublicExponent()); + publicKey = KeyFactory.getInstance("RSA").generatePublic(spec); + } + return new KeyPair(publicKey, key); + } + catch (Exception ex) { + throw new IllegalStateException("Cannot load keys from store: " + this.resource, ex); + } + } + +} diff --git a/crypto/src/main/java/org/springframework/security/crypto/encrypt/RsaAlgorithm.java b/crypto/src/main/java/org/springframework/security/crypto/encrypt/RsaAlgorithm.java new file mode 100644 index 00000000000..c22a173df46 --- /dev/null +++ b/crypto/src/main/java/org/springframework/security/crypto/encrypt/RsaAlgorithm.java @@ -0,0 +1,44 @@ +/* + * Copyright 2013-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.crypto.encrypt; + +/** + * @author Dave Syer + * @since 6.3 + */ +public enum RsaAlgorithm { + + DEFAULT("RSA", 117), OAEP("RSA/ECB/OAEPPadding", 86); + + private final String name; + + private final int maxLength; + + RsaAlgorithm(String name, int maxLength) { + this.name = name; + this.maxLength = maxLength; + } + + public String getJceName() { + return this.name; + } + + public int getMaxLength() { + return this.maxLength; + } + +} diff --git a/crypto/src/main/java/org/springframework/security/crypto/encrypt/RsaKeyHelper.java b/crypto/src/main/java/org/springframework/security/crypto/encrypt/RsaKeyHelper.java new file mode 100644 index 00000000000..fd3e058a119 --- /dev/null +++ b/crypto/src/main/java/org/springframework/security/crypto/encrypt/RsaKeyHelper.java @@ -0,0 +1,284 @@ +/* + * Copyright 2013-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.crypto.encrypt; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.StringWriter; +import java.math.BigInteger; +import java.nio.ByteBuffer; +import java.nio.CharBuffer; +import java.nio.charset.CharacterCodingException; +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; +import java.security.KeyFactory; +import java.security.KeyPair; +import java.security.KeyPairGenerator; +import java.security.NoSuchAlgorithmException; +import java.security.PrivateKey; +import java.security.PublicKey; +import java.security.interfaces.RSAPublicKey; +import java.security.spec.InvalidKeySpecException; +import java.security.spec.KeySpec; +import java.security.spec.RSAPrivateCrtKeySpec; +import java.security.spec.RSAPublicKeySpec; +import java.security.spec.X509EncodedKeySpec; +import java.util.Arrays; +import java.util.Base64; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import org.bouncycastle.asn1.ASN1Sequence; + +/** + * Reads RSA key pairs using BC provider classes but without the need to specify a crypto + * provider or have BC added as one. + * + * @author Luke Taylor + * @author Dave Syer + */ +final class RsaKeyHelper { + + private static final Charset UTF8 = StandardCharsets.UTF_8; + + private static final String BEGIN = "-----BEGIN"; + + private static final Pattern PEM_DATA = Pattern.compile(".*-----BEGIN (.*)-----(.*)-----END (.*)-----", + Pattern.DOTALL); + + private static final byte[] PREFIX = new byte[] { 0, 0, 0, 7, 's', 's', 'h', '-', 'r', 's', 'a' }; + + private RsaKeyHelper() { + } + + static KeyPair parseKeyPair(String pemData) { + Matcher m = PEM_DATA.matcher(pemData.replaceAll("\n *", "").trim()); + + if (!m.matches()) { + try { + RSAPublicKey publicValue = extractPublicKey(pemData); + if (publicValue != null) { + return new KeyPair(publicValue, null); + } + } + catch (Exception ex) { + // Ignore + } + throw new IllegalArgumentException("String is not PEM encoded data, nor a public key encoded for ssh"); + } + + String type = m.group(1); + final byte[] content = base64Decode(m.group(2)); + + PublicKey publicKey; + PrivateKey privateKey = null; + + try { + KeyFactory fact = KeyFactory.getInstance("RSA"); + switch (type) { + case "RSA PRIVATE KEY" -> { + ASN1Sequence seq = ASN1Sequence.getInstance(content); + if (seq.size() != 9) { + throw new IllegalArgumentException("Invalid RSA Private Key ASN1 sequence."); + } + org.bouncycastle.asn1.pkcs.RSAPrivateKey key = org.bouncycastle.asn1.pkcs.RSAPrivateKey + .getInstance(seq); + RSAPublicKeySpec pubSpec = new RSAPublicKeySpec(key.getModulus(), key.getPublicExponent()); + RSAPrivateCrtKeySpec privSpec = new RSAPrivateCrtKeySpec(key.getModulus(), key.getPublicExponent(), + key.getPrivateExponent(), key.getPrime1(), key.getPrime2(), key.getExponent1(), + key.getExponent2(), key.getCoefficient()); + publicKey = fact.generatePublic(pubSpec); + privateKey = fact.generatePrivate(privSpec); + } + case "PUBLIC KEY" -> { + KeySpec keySpec = new X509EncodedKeySpec(content); + publicKey = fact.generatePublic(keySpec); + } + case "RSA PUBLIC KEY" -> { + ASN1Sequence seq = ASN1Sequence.getInstance(content); + org.bouncycastle.asn1.pkcs.RSAPublicKey key = org.bouncycastle.asn1.pkcs.RSAPublicKey + .getInstance(seq); + RSAPublicKeySpec pubSpec = new RSAPublicKeySpec(key.getModulus(), key.getPublicExponent()); + publicKey = fact.generatePublic(pubSpec); + } + default -> throw new IllegalArgumentException(type + " is not a supported format"); + } + + return new KeyPair(publicKey, privateKey); + } + catch (InvalidKeySpecException ex) { + throw new RuntimeException(ex); + } + catch (NoSuchAlgorithmException ex) { + throw new IllegalStateException(ex); + } + } + + private static byte[] base64Decode(String string) { + try { + ByteBuffer bytes = UTF8.newEncoder().encode(CharBuffer.wrap(string)); + byte[] bytesCopy = new byte[bytes.limit()]; + System.arraycopy(bytes.array(), 0, bytesCopy, 0, bytes.limit()); + return Base64.getDecoder().decode(bytesCopy); + } + catch (CharacterCodingException ex) { + throw new RuntimeException(ex); + } + } + + static String base64Encode(byte[] bytes) { + try { + return UTF8.newDecoder().decode(ByteBuffer.wrap(Base64.getEncoder().encode(bytes))).toString(); + } + catch (CharacterCodingException ex) { + throw new RuntimeException(ex); + } + } + + static KeyPair generateKeyPair() { + try { + final KeyPairGenerator keyGen = KeyPairGenerator.getInstance("RSA"); + keyGen.initialize(1024); + return keyGen.generateKeyPair(); + } + catch (NoSuchAlgorithmException ex) { + throw new IllegalStateException(ex); + } + + } + + private static final Pattern SSH_PUB_KEY = Pattern.compile("ssh-(rsa|dsa) ([A-Za-z0-9/+]+=*) (.*)"); + + private static RSAPublicKey extractPublicKey(String key) { + + Matcher m = SSH_PUB_KEY.matcher(key); + + if (m.matches()) { + String alg = m.group(1); + String encKey = m.group(2); + // String id = m.group(3); + + if (!"rsa".equalsIgnoreCase(alg)) { + throw new IllegalArgumentException("Only RSA is currently supported, but algorithm was " + alg); + } + + return parseSSHPublicKey(encKey); + } + else if (!key.startsWith(BEGIN)) { + // Assume it's the plain Base64 encoded ssh key without the + // "ssh-rsa" at the start + return parseSSHPublicKey(key); + } + + return null; + } + + static RSAPublicKey parsePublicKey(String key) { + + RSAPublicKey publicKey = extractPublicKey(key); + + if (publicKey != null) { + return publicKey; + } + + KeyPair kp = parseKeyPair(key); + + if (kp.getPublic() == null) { + throw new IllegalArgumentException("Key data does not contain a public key"); + } + + return (RSAPublicKey) kp.getPublic(); + + } + + static String encodePublicKey(RSAPublicKey key, String id) { + StringWriter output = new StringWriter(); + output.append("ssh-rsa "); + ByteArrayOutputStream stream = new ByteArrayOutputStream(); + try { + stream.write(PREFIX); + writeBigInteger(stream, key.getPublicExponent()); + writeBigInteger(stream, key.getModulus()); + } + catch (IOException ex) { + throw new IllegalStateException("Cannot encode key", ex); + } + output.append(base64Encode(stream.toByteArray())); + output.append(" " + id); + return output.toString(); + } + + private static RSAPublicKey parseSSHPublicKey(String encKey) { + ByteArrayInputStream in = new ByteArrayInputStream(base64Decode(encKey)); + + byte[] prefix = new byte[11]; + + try { + if (in.read(prefix) != 11 || !Arrays.equals(PREFIX, prefix)) { + throw new IllegalArgumentException("SSH key prefix not found"); + } + + BigInteger e = new BigInteger(readBigInteger(in)); + BigInteger n = new BigInteger(readBigInteger(in)); + + return createPublicKey(n, e); + } + catch (IOException ex) { + throw new RuntimeException(ex); + } + } + + static RSAPublicKey createPublicKey(BigInteger n, BigInteger e) { + try { + return (RSAPublicKey) KeyFactory.getInstance("RSA").generatePublic(new RSAPublicKeySpec(n, e)); + } + catch (Exception ex) { + throw new RuntimeException(ex); + } + } + + private static void writeBigInteger(ByteArrayOutputStream stream, BigInteger num) throws IOException { + int length = num.toByteArray().length; + byte[] data = new byte[4]; + data[0] = (byte) ((length >> 24) & 0xFF); + data[1] = (byte) ((length >> 16) & 0xFF); + data[2] = (byte) ((length >> 8) & 0xFF); + data[3] = (byte) (length & 0xFF); + stream.write(data); + stream.write(num.toByteArray()); + } + + private static byte[] readBigInteger(ByteArrayInputStream in) throws IOException { + byte[] b = new byte[4]; + + if (in.read(b) != 4) { + throw new IOException("Expected length data as 4 bytes"); + } + + int l = ((b[0] & 0xFF) << 24) | ((b[1] & 0xFF) << 16) | ((b[2] & 0xFF) << 8) | (b[3] & 0xFF); + + b = new byte[l]; + + if (in.read(b) != l) { + throw new IOException("Expected " + l + " key bytes"); + } + + return b; + } + +} diff --git a/crypto/src/main/java/org/springframework/security/crypto/encrypt/RsaKeyHolder.java b/crypto/src/main/java/org/springframework/security/crypto/encrypt/RsaKeyHolder.java new file mode 100644 index 00000000000..49ae22e62f8 --- /dev/null +++ b/crypto/src/main/java/org/springframework/security/crypto/encrypt/RsaKeyHolder.java @@ -0,0 +1,27 @@ +/* + * Copyright 2013-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.crypto.encrypt; + +/** + * @author Dave Syer + * @since 6.3 + */ +public interface RsaKeyHolder { + + String getPublicKey(); + +} diff --git a/crypto/src/main/java/org/springframework/security/crypto/encrypt/RsaRawEncryptor.java b/crypto/src/main/java/org/springframework/security/crypto/encrypt/RsaRawEncryptor.java new file mode 100644 index 00000000000..655ea45b08e --- /dev/null +++ b/crypto/src/main/java/org/springframework/security/crypto/encrypt/RsaRawEncryptor.java @@ -0,0 +1,168 @@ +/* + * Copyright 2013-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.crypto.encrypt; + +import java.io.ByteArrayOutputStream; +import java.nio.charset.Charset; +import java.security.KeyPair; +import java.security.PrivateKey; +import java.security.PublicKey; +import java.security.interfaces.RSAKey; +import java.security.interfaces.RSAPrivateKey; +import java.security.interfaces.RSAPublicKey; +import java.util.Base64; + +import javax.crypto.Cipher; + +/** + * @author Dave Syer + * @since 6.3 + */ +public class RsaRawEncryptor implements BytesEncryptor, TextEncryptor, RsaKeyHolder { + + private static final String DEFAULT_ENCODING = "UTF-8"; + + private RsaAlgorithm algorithm = RsaAlgorithm.DEFAULT; + + private Charset charset; + + private RSAPublicKey publicKey; + + private RSAPrivateKey privateKey; + + private Charset defaultCharset; + + public RsaRawEncryptor(RsaAlgorithm algorithm) { + this(RsaKeyHelper.generateKeyPair(), algorithm); + } + + public RsaRawEncryptor() { + this(RsaKeyHelper.generateKeyPair()); + } + + public RsaRawEncryptor(KeyPair keyPair, RsaAlgorithm algorithm) { + this(DEFAULT_ENCODING, keyPair.getPublic(), keyPair.getPrivate(), algorithm); + } + + public RsaRawEncryptor(KeyPair keyPair) { + this(DEFAULT_ENCODING, keyPair.getPublic(), keyPair.getPrivate()); + } + + public RsaRawEncryptor(String pemData) { + this(RsaKeyHelper.parseKeyPair(pemData)); + } + + public RsaRawEncryptor(PublicKey publicKey) { + this(DEFAULT_ENCODING, publicKey, null); + } + + public RsaRawEncryptor(String encoding, PublicKey publicKey, PrivateKey privateKey) { + this(encoding, publicKey, privateKey, RsaAlgorithm.DEFAULT); + } + + public RsaRawEncryptor(String encoding, PublicKey publicKey, PrivateKey privateKey, RsaAlgorithm algorithm) { + this.charset = Charset.forName(encoding); + this.publicKey = (RSAPublicKey) publicKey; + this.privateKey = (RSAPrivateKey) privateKey; + this.defaultCharset = Charset.forName(DEFAULT_ENCODING); + this.algorithm = algorithm; + } + + @Override + public String getPublicKey() { + return RsaKeyHelper.encodePublicKey(this.publicKey, "application"); + } + + @Override + public String encrypt(String text) { + return new String(Base64.getEncoder().encode(encrypt(text.getBytes(this.charset))), this.defaultCharset); + } + + @Override + public String decrypt(String encryptedText) { + if (this.privateKey == null) { + throw new IllegalStateException("Private key must be provided for decryption"); + } + return new String(decrypt(Base64.getDecoder().decode(encryptedText.getBytes(this.defaultCharset))), + this.charset); + } + + @Override + public byte[] encrypt(byte[] byteArray) { + return encrypt(byteArray, this.publicKey, this.algorithm); + } + + @Override + public byte[] decrypt(byte[] encryptedByteArray) { + return decrypt(encryptedByteArray, this.privateKey, this.algorithm); + } + + private static byte[] encrypt(byte[] text, PublicKey key, RsaAlgorithm alg) { + ByteArrayOutputStream output = new ByteArrayOutputStream(text.length); + try { + final Cipher cipher = Cipher.getInstance(alg.getJceName()); + int limit = Math.min(text.length, alg.getMaxLength()); + int pos = 0; + while (pos < text.length) { + cipher.init(Cipher.ENCRYPT_MODE, key); + cipher.update(text, pos, limit); + pos += limit; + limit = Math.min(text.length - pos, alg.getMaxLength()); + byte[] buffer = cipher.doFinal(); + output.write(buffer, 0, buffer.length); + } + return output.toByteArray(); + } + catch (RuntimeException ex) { + throw ex; + } + catch (Exception ex) { + throw new IllegalStateException("Cannot encrypt", ex); + } + } + + private static byte[] decrypt(byte[] text, RSAPrivateKey key, RsaAlgorithm alg) { + ByteArrayOutputStream output = new ByteArrayOutputStream(text.length); + try { + final Cipher cipher = Cipher.getInstance(alg.getJceName()); + int maxLength = getByteLength(key); + int pos = 0; + while (pos < text.length) { + int limit = Math.min(text.length - pos, maxLength); + cipher.init(Cipher.DECRYPT_MODE, key); + cipher.update(text, pos, limit); + pos += limit; + byte[] buffer = cipher.doFinal(); + output.write(buffer, 0, buffer.length); + } + return output.toByteArray(); + } + catch (RuntimeException ex) { + throw ex; + } + catch (Exception ex) { + throw new IllegalStateException("Cannot decrypt", ex); + } + } + + // copied from sun.security.rsa.RSACore.getByteLength(java.math.BigInteger) + public static int getByteLength(RSAKey key) { + int n = key.getModulus().bitLength(); + return (n + 7) >> 3; + } + +} diff --git a/crypto/src/main/java/org/springframework/security/crypto/encrypt/RsaSecretEncryptor.java b/crypto/src/main/java/org/springframework/security/crypto/encrypt/RsaSecretEncryptor.java new file mode 100644 index 00000000000..ad8b76d4fb4 --- /dev/null +++ b/crypto/src/main/java/org/springframework/security/crypto/encrypt/RsaSecretEncryptor.java @@ -0,0 +1,247 @@ +/* + * Copyright 2013-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.crypto.encrypt; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.nio.charset.Charset; +import java.security.KeyPair; +import java.security.PrivateKey; +import java.security.PublicKey; +import java.security.interfaces.RSAPublicKey; +import java.util.Base64; + +import javax.crypto.Cipher; + +import org.springframework.security.crypto.codec.Hex; +import org.springframework.security.crypto.keygen.KeyGenerators; + +/** + * @author Dave Syer + * @since 6.3 + */ +public class RsaSecretEncryptor implements BytesEncryptor, TextEncryptor, RsaKeyHolder { + + private static final String DEFAULT_ENCODING = "UTF-8"; + + // The secret for encryption is random (so dictionary attack is not a danger) + private static final String DEFAULT_SALT = "deadbeef"; + + private final String salt; + + private RsaAlgorithm algorithm = RsaAlgorithm.DEFAULT; + + private final Charset charset; + + private final PublicKey publicKey; + + private final PrivateKey privateKey; + + private final Charset defaultCharset; + + private final boolean gcm; + + public RsaSecretEncryptor(RsaAlgorithm algorithm, String salt, boolean gcm) { + this(RsaKeyHelper.generateKeyPair(), algorithm, salt, gcm); + } + + public RsaSecretEncryptor(RsaAlgorithm algorithm, String salt) { + this(RsaKeyHelper.generateKeyPair(), algorithm, salt); + } + + public RsaSecretEncryptor(RsaAlgorithm algorithm, boolean gcm) { + this(RsaKeyHelper.generateKeyPair(), algorithm, DEFAULT_SALT, gcm); + } + + public RsaSecretEncryptor(RsaAlgorithm algorithm) { + this(RsaKeyHelper.generateKeyPair(), algorithm); + } + + public RsaSecretEncryptor() { + this(RsaKeyHelper.generateKeyPair()); + } + + public RsaSecretEncryptor(KeyPair keyPair, RsaAlgorithm algorithm, String salt, boolean gcm) { + this(DEFAULT_ENCODING, keyPair.getPublic(), keyPair.getPrivate(), algorithm, salt, gcm); + } + + public RsaSecretEncryptor(KeyPair keyPair, RsaAlgorithm algorithm, String salt) { + this(DEFAULT_ENCODING, keyPair.getPublic(), keyPair.getPrivate(), algorithm, salt, false); + } + + public RsaSecretEncryptor(KeyPair keyPair, RsaAlgorithm algorithm) { + this(DEFAULT_ENCODING, keyPair.getPublic(), keyPair.getPrivate(), algorithm); + } + + public RsaSecretEncryptor(KeyPair keyPair) { + this(DEFAULT_ENCODING, keyPair.getPublic(), keyPair.getPrivate()); + } + + public RsaSecretEncryptor(String pemData, RsaAlgorithm algorithm, String salt) { + this(RsaKeyHelper.parseKeyPair(pemData), algorithm, salt); + } + + public RsaSecretEncryptor(String pemData, RsaAlgorithm algorithm) { + this(RsaKeyHelper.parseKeyPair(pemData), algorithm); + } + + public RsaSecretEncryptor(String pemData) { + this(RsaKeyHelper.parseKeyPair(pemData)); + } + + public RsaSecretEncryptor(PublicKey publicKey, RsaAlgorithm algorithm, String salt, boolean gcm) { + this(DEFAULT_ENCODING, publicKey, null, algorithm, salt, gcm); + } + + public RsaSecretEncryptor(PublicKey publicKey, RsaAlgorithm algorithm, String salt) { + this(DEFAULT_ENCODING, publicKey, null, algorithm, salt, false); + } + + public RsaSecretEncryptor(PublicKey publicKey, RsaAlgorithm algorithm) { + this(DEFAULT_ENCODING, publicKey, null, algorithm); + } + + public RsaSecretEncryptor(PublicKey publicKey) { + this(DEFAULT_ENCODING, publicKey, null); + } + + public RsaSecretEncryptor(String encoding, PublicKey publicKey, PrivateKey privateKey) { + this(encoding, publicKey, privateKey, RsaAlgorithm.DEFAULT); + } + + public RsaSecretEncryptor(String encoding, PublicKey publicKey, PrivateKey privateKey, RsaAlgorithm algorithm) { + this(encoding, publicKey, privateKey, algorithm, DEFAULT_SALT, false); + } + + public RsaSecretEncryptor(String encoding, PublicKey publicKey, PrivateKey privateKey, RsaAlgorithm algorithm, + String salt, boolean gcm) { + this.charset = Charset.forName(encoding); + this.publicKey = publicKey; + this.privateKey = privateKey; + this.defaultCharset = Charset.forName(DEFAULT_ENCODING); + this.algorithm = algorithm; + this.salt = isHex(salt) ? salt : new String(Hex.encode(salt.getBytes(this.defaultCharset))); + this.gcm = gcm; + } + + @Override + public String getPublicKey() { + return RsaKeyHelper.encodePublicKey((RSAPublicKey) this.publicKey, "application"); + } + + @Override + public String encrypt(String text) { + return new String(Base64.getEncoder().encode(encrypt(text.getBytes(this.charset))), this.defaultCharset); + } + + @Override + public String decrypt(String encryptedText) { + if (!canDecrypt()) { + throw new IllegalStateException("Encryptor is not configured for decryption"); + } + return new String(decrypt(Base64.getDecoder().decode(encryptedText.getBytes(this.defaultCharset))), + this.charset); + } + + @Override + public byte[] encrypt(byte[] byteArray) { + return encrypt(byteArray, this.publicKey, this.algorithm, this.salt, this.gcm); + } + + @Override + public byte[] decrypt(byte[] encryptedByteArray) { + if (!canDecrypt()) { + throw new IllegalStateException("Encryptor is not configured for decryption"); + } + return decrypt(encryptedByteArray, this.privateKey, this.algorithm, this.salt, this.gcm); + } + + private static byte[] encrypt(byte[] text, PublicKey key, RsaAlgorithm alg, String salt, boolean gcm) { + byte[] random = KeyGenerators.secureRandom(16).generateKey(); + BytesEncryptor aes = gcm ? Encryptors.stronger(new String(Hex.encode(random)), salt) + : Encryptors.standard(new String(Hex.encode(random)), salt); + try { + final Cipher cipher = Cipher.getInstance(alg.getJceName()); + cipher.init(Cipher.ENCRYPT_MODE, key); + byte[] secret = cipher.doFinal(random); + ByteArrayOutputStream result = new ByteArrayOutputStream(text.length + 20); + writeInt(result, secret.length); + result.write(secret); + result.write(aes.encrypt(text)); + return result.toByteArray(); + } + catch (RuntimeException ex) { + throw ex; + } + catch (Exception ex) { + throw new IllegalStateException("Cannot encrypt", ex); + } + } + + private static void writeInt(ByteArrayOutputStream result, int length) throws IOException { + byte[] data = new byte[2]; + data[0] = (byte) ((length >> 8) & 0xFF); + data[1] = (byte) (length & 0xFF); + result.write(data); + } + + private static int readInt(ByteArrayInputStream result) throws IOException { + byte[] b = new byte[2]; + result.read(b); + return ((b[0] & 0xFF) << 8) | (b[1] & 0xFF); + } + + private static byte[] decrypt(byte[] text, PrivateKey key, RsaAlgorithm alg, String salt, boolean gcm) { + ByteArrayInputStream input = new ByteArrayInputStream(text); + ByteArrayOutputStream output = new ByteArrayOutputStream(text.length); + try { + int length = readInt(input); + byte[] random = new byte[length]; + input.read(random); + final Cipher cipher = Cipher.getInstance(alg.getJceName()); + cipher.init(Cipher.DECRYPT_MODE, key); + String secret = new String(Hex.encode(cipher.doFinal(random))); + byte[] buffer = new byte[text.length - random.length - 2]; + input.read(buffer); + BytesEncryptor aes = gcm ? Encryptors.stronger(secret, salt) : Encryptors.standard(secret, salt); + output.write(aes.decrypt(buffer)); + return output.toByteArray(); + } + catch (RuntimeException ex) { + throw ex; + } + catch (Exception ex) { + throw new IllegalStateException("Cannot decrypt", ex); + } + } + + private static boolean isHex(String input) { + try { + Hex.decode(input); + return true; + } + catch (Exception ex) { + return false; + } + } + + public boolean canDecrypt() { + return this.privateKey != null; + } + +} diff --git a/crypto/src/test/java/org/springframework/security/crypto/encrypt/KeyStoreKeyFactoryTests.java b/crypto/src/test/java/org/springframework/security/crypto/encrypt/KeyStoreKeyFactoryTests.java new file mode 100644 index 00000000000..79868264db8 --- /dev/null +++ b/crypto/src/test/java/org/springframework/security/crypto/encrypt/KeyStoreKeyFactoryTests.java @@ -0,0 +1,70 @@ +/* + * Copyright 2013-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.crypto.encrypt; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.DisabledOnOs; +import org.junit.jupiter.api.condition.OS; + +import org.springframework.core.io.ClassPathResource; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Dave Syer + * + */ +@DisabledOnOs(OS.WINDOWS) +public class KeyStoreKeyFactoryTests { + + @Test + public void initializeEncryptorFromKeyStore() { + char[] password = "foobar".toCharArray(); + KeyStoreKeyFactory factory = new KeyStoreKeyFactory(new ClassPathResource("keystore.jks"), password); + RsaSecretEncryptor encryptor = new RsaSecretEncryptor(factory.getKeyPair("test")); + assertThat(encryptor.canDecrypt()).as("Should be able to decrypt").isTrue(); + assertThat(encryptor.decrypt(encryptor.encrypt("foo"))).isEqualTo("foo"); + } + + @Test + public void initializeEncryptorFromPkcs12KeyStore() { + char[] password = "letmein".toCharArray(); + KeyStoreKeyFactory factory = new KeyStoreKeyFactory(new ClassPathResource("keystore.pkcs12"), password); + RsaSecretEncryptor encryptor = new RsaSecretEncryptor(factory.getKeyPair("mytestkey")); + assertThat(encryptor.canDecrypt()).as("Should be able to decrypt").isTrue(); + assertThat(encryptor.decrypt(encryptor.encrypt("foo"))).isEqualTo("foo"); + } + + @Test + public void initializeEncryptorFromTrustedCertificateInKeyStore() { + char[] password = "foobar".toCharArray(); + KeyStoreKeyFactory factory = new KeyStoreKeyFactory(new ClassPathResource("keystore.jks"), password); + RsaSecretEncryptor encryptor = new RsaSecretEncryptor(factory.getKeyPair("testcertificate")); + assertThat(encryptor.canDecrypt()).as("Should not be able to decrypt").isFalse(); + assertThat(encryptor.encrypt("foo")).isNotEqualTo("foo"); + } + + @Test + public void initializeEncryptorFromTrustedCertificateInPkcs12KeyStore() { + char[] password = "letmein".toCharArray(); + KeyStoreKeyFactory factory = new KeyStoreKeyFactory(new ClassPathResource("keystore.pkcs12"), password); + RsaSecretEncryptor encryptor = new RsaSecretEncryptor(factory.getKeyPair("mytestcertificate")); + assertThat(encryptor.canDecrypt()).as("Should not be able to decrypt").isFalse(); + assertThat(encryptor.encrypt("foo")).isNotEqualTo("foo"); + } + +} diff --git a/crypto/src/test/java/org/springframework/security/crypto/encrypt/RsaKeyHelperTests.java b/crypto/src/test/java/org/springframework/security/crypto/encrypt/RsaKeyHelperTests.java new file mode 100644 index 00000000000..593681fe1ef --- /dev/null +++ b/crypto/src/test/java/org/springframework/security/crypto/encrypt/RsaKeyHelperTests.java @@ -0,0 +1,67 @@ +/* + * Copyright 2013-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.crypto.encrypt; + +import java.nio.charset.StandardCharsets; +import java.security.KeyPair; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.DisabledOnOs; +import org.junit.jupiter.api.condition.OS; + +import org.springframework.core.io.ClassPathResource; +import org.springframework.util.StreamUtils; + +import static org.assertj.core.api.Assertions.assertThat; + +@DisabledOnOs(OS.WINDOWS) +public class RsaKeyHelperTests { + + @Test + public void parsePrivateKey() throws Exception { + // ssh-keygen -m pem -b 1024 -f src/test/resources/fake.pem + String pem = StreamUtils.copyToString(new ClassPathResource("/fake.pem", getClass()).getInputStream(), + StandardCharsets.UTF_8); + KeyPair result = RsaKeyHelper.parseKeyPair(pem); + assertThat(result.getPrivate().getEncoded().length > 0).isTrue(); + assertThat(result.getPrivate().getAlgorithm()).isEqualTo("RSA"); + } + + @Test + public void parseSpaceyKey() throws Exception { + String pem = StreamUtils.copyToString(new ClassPathResource("/spacey.pem", getClass()).getInputStream(), + StandardCharsets.UTF_8); + KeyPair result = RsaKeyHelper.parseKeyPair(pem); + assertThat(result.getPrivate().getEncoded().length > 0).isTrue(); + assertThat(result.getPrivate().getAlgorithm()).isEqualTo("RSA"); + } + + @Test + public void parseBadKey() throws Exception { + // ssh-keygen -m pem -b 1024 -f src/test/resources/fake.pem + String pem = StreamUtils.copyToString(new ClassPathResource("/bad.pem", getClass()).getInputStream(), + StandardCharsets.UTF_8); + try { + RsaKeyHelper.parseKeyPair(pem); + throw new IllegalStateException("Expected IllegalArgumentException"); + } + catch (IllegalArgumentException ex) { + assertThat(ex.getMessage().contains("PEM")).isTrue(); + } + } + +} diff --git a/crypto/src/test/java/org/springframework/security/crypto/encrypt/RsaRawEncryptorTests.java b/crypto/src/test/java/org/springframework/security/crypto/encrypt/RsaRawEncryptorTests.java new file mode 100644 index 00000000000..edcc4ee420a --- /dev/null +++ b/crypto/src/test/java/org/springframework/security/crypto/encrypt/RsaRawEncryptorTests.java @@ -0,0 +1,154 @@ +/* + * Copyright 2013-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.crypto.encrypt; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Dave Syer + * + */ +public class RsaRawEncryptorTests { + + private RsaRawEncryptor encryptor = new RsaRawEncryptor(); + + @BeforeEach + public void init() { + LONG_STRING = SHORT_STRING + SHORT_STRING + SHORT_STRING + SHORT_STRING; + for (int i = 0; i < 4; i++) { + LONG_STRING = LONG_STRING + LONG_STRING; + } + } + + @Test + public void roundTrip() { + assertThat(this.encryptor.decrypt(this.encryptor.encrypt("encryptor"))).isEqualTo("encryptor"); + } + + @Test + public void roundTripOeap() { + this.encryptor = new RsaRawEncryptor(RsaAlgorithm.OAEP); + assertThat(this.encryptor.decrypt(this.encryptor.encrypt("encryptor"))).isEqualTo("encryptor"); + } + + @Test + public void roundTripLongString() { + assertThat(this.encryptor.decrypt(this.encryptor.encrypt(LONG_STRING))).isEqualTo(LONG_STRING); + } + + @Test + public void roundTripLongStringOeap() { + this.encryptor = new RsaRawEncryptor(RsaAlgorithm.OAEP); + assertThat(this.encryptor.decrypt(this.encryptor.encrypt(LONG_STRING))).isEqualTo(LONG_STRING); + } + + @Test + public void roundTrip2048Key() { + String pemData = "-----BEGIN RSA PRIVATE KEY-----" + + "MIIEpQIBAAKCAQEA5KHEkCudAHCKIUHKyW6Z8dMyQsKrLbpDe0wDzx9MBARcOoS9" + + "ZUjzXwK6p/0RM6aCp+b9kkr37QKQ9K/Am13sr0z8Mkn1Q2cvXiL5gbnY1nYGk8/m" + + "CBX3QEhH2UII4yJsDVx1xmcSorZaWmeNKor7Zl3SZaQpWTvlkMgQKwY8DZL6PPxt" + + "JRPeKmuUY6B59u5okh1G6Y9OnT2dVxAkqT8WgLHu6StxBmueJ272x2sUWUzoDhnP" + + "7JRqa7h7t6fml3o3Op1iCywCOFzCIcK6G/oG/WZ7tbBYkwQdDjn/9VMdKkkPufwq" + + "zt4S75NJygXDwDnNPiTVoaOwrRrL8ahgw6bFCQIDAQABAoIBAECIMHUI+l2fZj2Q" + + "1m4Ym7cYB320eKCFjHqGsCSMDuarXGTgBp1KA/dzS8ASvAI6I3LEzhm2s1fge420" + + "9cZksmOgdSa0nVeTDlmhwY8OJ9gQpDagXas2l/066Zy2+M8zbhAvYsbHXQk0MziF" + + "NeEmLWNtY+9wcINRVrCQ549dSSIDK6UX21oU6d1mrlnF5/bbbdDIM3dKok355jwx" + + "0HFY0tJIs1zArsBVoz3Ccu1MQEfnxEFM1LLPi5rE6cuHIOBinbD1OQ2R/HM2aukG" + + "Rk2m6F3wAieJ7zpt5yaHuuIedn8p8m2NVulXAjgkY2oQl3GGiDH/H7eZlrvQRg6E" + + "D8Bq+ykCgYEA+AfPXVeeVg3Qu0KsNrACek/o92BMY9g3GyPVGULGvq9seoNB86hj" + + "nXasqngBfTlOfJFiahoEzRBB9hIyo1zMw4x99pR8nGxhR3aU+v8EGftMABGHWsB9" + + "Jxj4YQH4fhi57iBa72QmNPbu/1o7y3SEe68E5PJ8KY3jc4xos8Vl658CgYEA6/pk" + + "t6WZII+9lpxQfePQDIlBWAphiQceh995bGXfDmX3vOVmPozix9/fUtF1TeKS/ypw" + + "u++Qmvj5oMsBVrjCyoOYfHKE2vGrLoEzkX/sPO65IsV00geZZoyCEKEE3USJfY46" + + "u0hs61oP8HJyLhLiYiGcFTzZ4nEvvEbiM4E/DlcCgYEA6S0OecZhiK08SpAHrvIR" + + "okN11PqnVkZyqAUr1a+9gI8TAKpdWmA4JlTnRuvDGqLBcsKLLwx+7voVyOyaxpH7" + + "vutZkHNQIw6Q9co5jS4qAPMLJBVWlq7X+eWzvB9KKeG9Cm1IkD4q3Sg4z79Y75D+" + + "6/hCNarxp29JIdwior81bikCgYEApp1P+b7pxGzZPvs1df2hCwjqY0BJJ5goPWVT" + + "dW7kNGVYqz4JmAafpOJz6yTLP2fHxHRxzrBSmKlMj/RmCJZBqv2Jb+zn0zMpW5eM" + + "EqKQ6WDgxSVH23fUHuz8dMNMDPL0ZPtEirGTfgVEFdCov9FDmGgErZYefVzPiI8/" + + "7X/HRtcCgYEApQ2YS+0DLPqaM0cC6/6hDr/jmHLFhHaV6DZR7M9HHDnMN2uMlOEa" + + "RYvXRMBjyQ7LQkwOj6K5k8MVrsDDM5dbekTBgcJMHfM9uViDkB0VPYULORmDJ20N" + + "MLowIAiSon2B2/isatY80YtFq+bRyvPOzjGvinHN3MU1GH/gFuS0fiw=" + "-----END RSA PRIVATE KEY-----"; + RsaRawEncryptor encryptor_2048 = new RsaRawEncryptor(pemData); + assertThat(encryptor_2048.decrypt(encryptor_2048.encrypt("encryptor"))).isEqualTo("encryptor"); + } + + @Test + public void roundTrip4096Key() { + String pemData = "-----BEGIN RSA PRIVATE KEY-----" + + "MIIJKAIBAAKCAgEAw/OIcO1pv8t/lhXwzc+CqCqAE8+2+BTWd6fHy8P2oGKZK0s3" + + "jxPWdZEbp1soGZobCIjEIuYuuPeinrTFOxtnf/JVfmzGnixRjWzQK0UiM/4z8GW6" + + "7+dzB0+QZlU+PGCL6xra4d3+5EsPQwTDjPJ4OhcA66hWACd3UJpvE2C14YdFkCP/" + + "CUxubz1l+8rFwEtMcw2bVUL/Mt+Sx1CHPFer17VK/sT4urwNG7y9R8WWvNQXgEwg" + + "0im+iJ0zf1u0SdUVj+Q1LwgNRoIx4vec2xAJ6xdqSx3Y3g2twWqUXUBb5K09ajIW" + + "Vuko5kWJVyx1x8LazU+0wQRLVJRYAiUOPLg7PdPAJWaAWmagnkAvl5bqCKi6sIc8" + + "+vKyrPx4VJH5KLsHx8020Wgch/LfHl/vvoHE7Oa81hnyMVsApvNCJdFbiMJ6r2z/" + + "eHqzjY8lzBQHNxh1XJys5teTJsi6N06gCc+OQRyw1FQ8KLgFlLPHNamfMnP5Ju0d" + + "Jv8GzQiMFjudjEYhkh2GPmRus1VYWDwDWhXwp28koWAanfih+Ujc2ZqNUS23hGWz" + + "KbCxRaAwSLqn3vkoYBeDyWWs1r0HnB6gACFaZIk38aiGyg7GjF0286Aq7USqNwKu" + + "Izm4kzIPFrHIbywKq7804J7wXUlaAgf0pNSndMD5OnwudzD+JHLTuOGFNdUCAwEA" + + "AQKCAgBYh2mIY6rYTS9adpUx1uPX6EOvL7QhhwCSVMoupF2Dfqhm5/e0+6hzu1h8" + + "FvIaBwbZpzi977MCPFdLTq6hErODGdBIawqdIbbCp3uxYO2gAeQjY0K+6pmMnwTF" + + "RxP0IUZ1tM9ZJnvnVoYRqFBVGKL607PFxGr+bNY6I1u1rIbf2sax5aFu6Qon1dyC" + + "ks0fIKXsgSRBtCAqMtpUlGxU9eMcdLrqOcGKVDWz52S4zWtZ6pSnkT1u1g9QF33R" + + "t3PPu6afOOJSWlftGBtDyM0kJ63jedO7FkQJprJu5SEctFwQB7jshq6TG4ov5xCy" + + "wtJ/quhBxBYM8ky6bL8KUQWKp02Tyfq0Fo+iwuLxM4N6LxVPFZ6R6jwvazm+ka4S" + + "sZAW/hnH3FdJEAyFcxzhelLdLUrjwrsWjmJBk0pMP5cEleYR8PQh2sHM8ZOX1T5f" + + "4zfyR66+tl1O81T7anbma8l1Wm/QSNZz+8QAM1iNuV+uLsWvmxLAc7NRgjDmiAMn" + + "8VhfUtl0ooOZYkDexqSNaWvIQG+S8Pl28gNxVXkXrXqBGPJn2ptROEJ1/AN1h4cv" + + "2CktVylRFpEI/hxXvKMaAu/tXtvoakvaTA8msl8Otrldsy3EGhgHrDTYIJUg/rRT" + + "TlbRkN/ycaOhA0d4HAewOGul3ss+EtBz+SQBzaWm2Inr8XOJoQKCAQEA4LwW7eGm" + + "MOYspFUbn2tMlnJAng9HKK42o2m6ShYAaQAoLX7LIkQYVS++9CiGCPpoSlwIJWE3" + + "N/qGx0i7REDm+wNu0/4acaMFI+qYtvjKiWwtMOBH3bw1C4/Isc60tFPkI7FEFCiF" + + "SiW3c+Z8B0/IRMb/YF5tZeuWUlAl7PQJ1rMcPUE4O4LXM4BG29hghVGGnp39YsOY" + + "b/6oBApTgdxCaSZhmhDwTMu97n75CK0xzA2vDtHn2Gu3zf4j6bsNot6/7wRtQBMg" + + "1e3kXuwGUZ08QZ7OqATUIZdCeK1PfxypontVh+0LeNjiDU8pW3Q8IMlDT96Fd5U+" + + "BgtjfHmwHXeBmQKCAQEA3zZS619O/IUoWN3rWT4hUSJE3S+FXXcaBaJ7H6r897cl" + + "ju+HSS2CLp/C9ftcQ9ef+pG2arLRZpONd5KhfRyjo0pNp3SwxklnIhNS9abBBCnN" + + "ojeYcVHOcSfmWGlUCQAvv5LeBPSS02pbCE5t/qadglvgKhHqSb2u+FgkdKrV0Mme" + + "sbVy+tyd4F1oBIS0wg1p3mHKvKfb4MEnUDvIvG8rCBUMvAWQmTiuyqFUiuqSwEMy" + + "LANFFV/ZoJ5194ruTXdelcoZjXhd128JJFNp6Jh4eg5OWoBS7e08QHbvUYBppDYO" + + "Iz0N1TipVK9uCqHHtbwIqqxyPVev3QJUYkpl5/tznQKCAQB9izV38F2J5Zu8tbq3" + + "pRZk2TCV280RwbjOMysZZg8WmTrYp4NNAiNhu0l+VgEClPibyavXTeauA+s0+sF6" + + "kJM4WKOaE9Kr9rjRZqWnWXazrFXWfwRGr3QmoE0qX2H9dvv0oHt6k2RalpVUTsas" + + "wvoKyewx5q5QiHoyQ4ncRDwWz3oQEhYa0K3tnFR5TfglofSFOZcqjD/lGKq9jxM1" + + "cVk8Km/NxHapQAw7Zn0yRqaR6ncH3WUaNpq4nadsU817Vdp86MkrSURHnhy8lje1" + + "chQOSGwD2qaymTBN/+twBBATr7iJNXf6K5akfruI1nccjbJntNR0iE/cypHqIISt" + + "AWzJAoIBAFDV5ZWkAIDm4EO+qpq5K2usk2/e49eDaIMd4qUHUXGMfCeVi1LvDjRA" + + "W2Sl0TYogqFF3+AoPjl9uj/RdHZQxto98H1yfwpwTs9CXErmRwRw9y2GIMj5LWBB" + + "aOQf0PUpgiFI2OrGf93cqHcLoD4WrPgmubnCnyxxa0o48Yrmy2Q/gB8vbSJ4fxxf" + + "92mbfbLBFNQaakeEKtbsXIZsADhtshHNPb1h7onuwy5S2sEsTlUegK77yCsDeVb3" + + "zBUH1WFsl257sGFRc/qvFYp4QuSfQxJA2BNiYaYUwjs+V1EWxitYACq206miSYCH" + + "v7xN9ntUS3cz2HNqrB/H1jN6aglnQOkCggEBAJb5FYvQCvw5PJM44nR6/U1cSlr4" + + "lRWcuFp7Xv5kWxSwM5115qic14fByh7DbaTHxxoPEhEA4aJ2QcDa7YWvabVc/VEV" + + "VacAAdg44+WSw6FNni18K53oOKAONgzSQlYUm/jgENIXi+5L0Yq7qAbnldiC6jXr" + + "yqbEwZjmpt8xsBLnl37k/LSLG1GUaYV8AK3s9UDs9/jv5RUrV96jiXed+7pYrjmj" + + "o1yJ4WAqouYHmOQCI3SeFCLT8GCdQ+uE74G5q+Yte6YT9jqSiGDjrst0bjtN640v" + + "YKRG3XK4AE9i4Oinnv/Ua95ql0syphn+CPW2ksmGon5/0mbK5qYsg47Hdls=" + "-----END RSA PRIVATE KEY-----"; + RsaRawEncryptor encryptor_4096 = new RsaRawEncryptor(pemData); + assertThat(encryptor_4096.decrypt(encryptor_4096.encrypt("encryptor"))).isEqualTo("encryptor"); + } + + private static final String SHORT_STRING = "Bacon ipsum dolor sit amet tail pork loin pork chop filet mignon flank fatback tenderloin boudin shankle corned beef t-bone short ribs. Meatball capicola ball tip short loin beef ribs shoulder, kielbasa pork chop meatloaf biltong porchetta bresaola t-bone spare ribs. Andouille t-bone sausage ground round frankfurter venison. Ground round meatball chicken ribeye doner tongue porchetta."; + + private static String LONG_STRING; + +} diff --git a/crypto/src/test/java/org/springframework/security/crypto/encrypt/RsaSecretEncryptorTests.java b/crypto/src/test/java/org/springframework/security/crypto/encrypt/RsaSecretEncryptorTests.java new file mode 100644 index 00000000000..48087a4e8d3 --- /dev/null +++ b/crypto/src/test/java/org/springframework/security/crypto/encrypt/RsaSecretEncryptorTests.java @@ -0,0 +1,121 @@ +/* + * Copyright 2013-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.crypto.encrypt; + +import java.security.PublicKey; +import java.security.interfaces.RSAPublicKey; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIllegalStateException; + +/** + * @author Dave Syer + * + */ +public class RsaSecretEncryptorTests { + + private RsaSecretEncryptor encryptor = new RsaSecretEncryptor(); + + @BeforeEach + public void init() { + LONG_STRING = SHORT_STRING + SHORT_STRING + SHORT_STRING + SHORT_STRING; + for (int i = 0; i < 4; i++) { + LONG_STRING = LONG_STRING + LONG_STRING; + } + } + + @Test + public void roundTripKey() { + PublicKey key = RsaKeyHelper.generateKeyPair().getPublic(); + String encoded = RsaKeyHelper.encodePublicKey((RSAPublicKey) key, "application"); + assertThat(RsaKeyHelper.parsePublicKey(encoded)).isEqualTo(key); + } + + @Test + public void roundTrip() { + assertThat(this.encryptor.decrypt(this.encryptor.encrypt("encryptor"))).isEqualTo("encryptor"); + } + + @Test + public void roundTripWithSalt() { + this.encryptor = new RsaSecretEncryptor(RsaAlgorithm.OAEP, "somesalt"); + assertThat(this.encryptor.decrypt(this.encryptor.encrypt("encryptor"))).isEqualTo("encryptor"); + } + + @Test + public void roundTripWithHexSalt() { + this.encryptor = new RsaSecretEncryptor(RsaAlgorithm.OAEP, "beefea"); + assertThat(this.encryptor.decrypt(this.encryptor.encrypt("encryptor"))).isEqualTo("encryptor"); + } + + @Test + public void roundTripWithLongSalt() { + this.encryptor = new RsaSecretEncryptor(RsaAlgorithm.OAEP, "somesaltsomesaltsomesaltsomesaltsomesalt"); + assertThat(this.encryptor.decrypt(this.encryptor.encrypt("encryptor"))).isEqualTo("encryptor"); + } + + @Test + public void roundTripOaep() { + this.encryptor = new RsaSecretEncryptor(RsaAlgorithm.OAEP); + assertThat(this.encryptor.decrypt(this.encryptor.encrypt("encryptor"))).isEqualTo("encryptor"); + } + + @Test + public void roundTripOaepGcm() { + this.encryptor = new RsaSecretEncryptor(RsaAlgorithm.OAEP, true); + assertThat(this.encryptor.decrypt(this.encryptor.encrypt("encryptor"))).isEqualTo("encryptor"); + } + + @Test + public void roundTripWithMixedAlgorithm() { + RsaSecretEncryptor oaep = new RsaSecretEncryptor(RsaAlgorithm.OAEP); + assertThatIllegalStateException().isThrownBy(() -> oaep.decrypt(this.encryptor.encrypt("encryptor"))); + } + + @Test + public void roundTripWithMixedSalt() { + RsaSecretEncryptor other = new RsaSecretEncryptor(this.encryptor.getPublicKey(), RsaAlgorithm.DEFAULT, "salt"); + assertThatIllegalStateException().isThrownBy(() -> this.encryptor.decrypt(other.encrypt("encryptor"))); + } + + @Test + public void roundTripWithPublicKeyEncryption() { + RsaSecretEncryptor encryptor = new RsaSecretEncryptor(this.encryptor.getPublicKey()); + RsaSecretEncryptor decryptor = this.encryptor; + assertThat(decryptor.decrypt(encryptor.encrypt("encryptor"))).isEqualTo("encryptor"); + } + + @Test + public void publicKeyCannotDecrypt() { + RsaSecretEncryptor encryptor = new RsaSecretEncryptor(this.encryptor.getPublicKey()); + assertThat(encryptor.canDecrypt()).as("Encryptor schould not be able to decrypt").isFalse(); + assertThatIllegalStateException().isThrownBy(() -> encryptor.decrypt(encryptor.encrypt("encryptor"))); + } + + @Test + public void roundTripLongString() { + assertThat(this.encryptor.decrypt(this.encryptor.encrypt(LONG_STRING))).isEqualTo(LONG_STRING); + } + + private static final String SHORT_STRING = "Bacon ipsum dolor sit amet tail pork loin pork chop filet mignon flank fatback tenderloin boudin shankle corned beef t-bone short ribs. Meatball capicola ball tip short loin beef ribs shoulder, kielbasa pork chop meatloaf biltong porchetta bresaola t-bone spare ribs. Andouille t-bone sausage ground round frankfurter venison. Ground round meatball chicken ribeye doner tongue porchetta."; + + private static String LONG_STRING; + +} diff --git a/crypto/src/test/resources/bad.pem b/crypto/src/test/resources/bad.pem new file mode 100644 index 00000000000..653a6eaea7d --- /dev/null +++ b/crypto/src/test/resources/bad.pem @@ -0,0 +1,2 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEowIBAAKCAQEAwClFgrRa/PUHPIJr9gvIPL6g6Rjp/TVZmVNOf2fL96DYbkj5 \ No newline at end of file diff --git a/crypto/src/test/resources/fake.pem b/crypto/src/test/resources/fake.pem new file mode 100644 index 00000000000..931e57889d8 --- /dev/null +++ b/crypto/src/test/resources/fake.pem @@ -0,0 +1,15 @@ +-----BEGIN RSA PRIVATE KEY----- +MIICWwIBAAKBgQDMWnfaQ0yLFXelprq2S8UurnaGvxFNUdbmTyJeycem5vGLycEY +T4KcdVCTU5491cjbk5GcHjoj2efRSO0y0aXIlUJpLofDdML/SuGLZWp/GbEv978M +pZIztK8iaIm7D/D7by8aws1RJyD9T+lZDAGY7eFfMp0EQyHOcEL0NGFLuwIDAQAB +AoGAWwC6uO8ZaiKwOouqQD4z3FsDG3SA/v7ABaYd9zpCd9gGnyrEm8/kqUoxDLrD +EGRg4y+vO2fWmlqSuoeQYf4spf+vi2di+mGIb6nGe7TpMLPa7lFLOSQHZRx5M5H6 +JDhfhAHlKmF9gLGvDHbpyErzn5YXjcu0PoFiNC1y445D8iECQQDvJzkGbJ9l9vb0 +oRyGXRDpddUcVMECLLB9NKmTl/zKy/qVPD+zYNoi87ePBJFbgmAXRjhhTk2uSBRP +NtVaMoXLAkEA2r+ugzjsLZQIYz/9gxdzdbKWDgpSPbhKCR4bOmrDgJMcOVjtwW+n ++liaX6zwI0QEgCAWLzCbbYDmj3kJrRwT0QJAaowg/dm7EmR7FfYJjVs9Q6X5skuY +Se27G60wt88JExjZpU9YWgSWaugGKbOxRwHI6dWhHMkUFseKNNiLKUpFDQJALIGP +ahdsxiE2S6s7Uy60SSAas6SZ8wDJ320GsS4DtOc5eNmFFjQ3gxH/5rNy8FnoaIEe +wl8rYG43er1voI7z4QJAB4qaqBo7eeiRgnUVIccaSZkNIMSrZ9QUjVFRgfLwAXDO +Ae+t6V+eB0oaIXczA+BLj3Oe6D3iHRGHrxGlcvDdHw== +-----END RSA PRIVATE KEY----- diff --git a/crypto/src/test/resources/keystore.jks b/crypto/src/test/resources/keystore.jks new file mode 100644 index 00000000000..c13b189d0b8 Binary files /dev/null and b/crypto/src/test/resources/keystore.jks differ diff --git a/crypto/src/test/resources/keystore.pkcs12 b/crypto/src/test/resources/keystore.pkcs12 new file mode 100644 index 00000000000..a6d7c0e7598 Binary files /dev/null and b/crypto/src/test/resources/keystore.pkcs12 differ diff --git a/crypto/src/test/resources/spacey.pem b/crypto/src/test/resources/spacey.pem new file mode 100644 index 00000000000..1050b1dc1aa --- /dev/null +++ b/crypto/src/test/resources/spacey.pem @@ -0,0 +1,25 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEowIBAAKCAQEAwClFgrRa/PUHPIJr9gvIPL6g6Rjp/TVZmVNOf2fL96DYbkj5 +4YbrwfKwjoTjk1M6gLQpOA4Blocx6zN5OnICnVGlVM9xymWxTxxCfc2tE2Fai9I1wchULCChhwm/UU5ZNi3KpXinlyamSYw+lMQkZ8gTXCgOEvs2j9E1quF4pvy1BZKvbD8tUnUQlyiKRnI6gOxQL8B6OAYPRdaa9FVNmrs1B4eDPG918L2f1pT090P1n+tw +iejNgQvtSD78/A88qt89OhzscsufALTrBjycn89kkfBd0zbVLF0W6+ZVLZrf97/y +LCoGSCcZL9LFPNvNqxOnleviDco7aOs4stQ9jQIDAQABAoIBAQC1TbthyN0YUe+T + 7dIDAbbZaVrU00biOtXgzjMADmTprP7Hf18UpIIIKfzfWw6FUD+gc1t4oe5pogE9 +UwGMXUmOORxu2pMYTb5vT9CEdexYnsAZsCo8PdD9GYSNrmquQef2MFpEqYQmHrdC + KWpaXn2i1ak+iCRPUGp4YwHpynZVxfE8z/AIsPn6NPDh6SnCXb1rTgQe2UCfXm93 +UJe5F/OR2kQi5KFO+dxLmCOBCwr6SGCLH+VotGpuxCVRUd9sJ/d4QpDZEgjuf7Ug +eQHfgMDS/tc09B9rl0dwKnEa31kcQ9X9KLkKP+w0Pqhh0Emny20eg9jS6XNayg61 +p/LQtW9BAoGBAO5veKMIcXfZmuh11WIIhdhLKkNtYyt5NDmrV8/IVScLFvjB0ftt +8PAtXo/ekOHkyITyIumQ9l4VCvacNw7DyV9FYk4WvrvVYOCL8aZi+O5+12NT67eO +Rr/voGlRoV05X7+inc90qbbYJ8lRmLSqvzmsm98mkuhw/FKGRhVZIfAJAoGBAM5R + I5vK6cJxOwXQOEGOd5/8B9JMFXyuendXo/N2/NxSQsbx4pc3v2rv/eGJYaY7Nx/y +2M/vdWYkpG59PAS3k2TrCA/0SGmyVqY+c8BomKisU5VaBlIPfGuec9tDPgWCp8Ur +3Jjt/2sVoa0vMkqymUqMb9HyH9tdI9oyh7EOOrplAoGAR6DlNNUMgVy11K/Rcqns +y5WJFMh/ykeXENwQfTNJoXkLZZ+UXVwhzYVTqxTJoZMBSi8TnecWnBzmNj+nqp/W + lvBZH+xlUDhB6jMgXUPOVJd2TTigz3vGdVKfdgQ33bGmugM4NWJuuacmDKyem2fQ + GptoGBmWeI24v3HnC/LC50ECgYAz0iN8hRnz0db+Xc9TgAJB997LDnszJuvxv9yZ + UWCvwiWtrKG6U7FLnd4J4STayPLOnoOgrsexETEP43rIwIdQCMysnTH3AmlLNlKC + mIMHksknsUX3JJaevVziTOBuJ+QV3S96ZgUKk5NZWYprQrLIC8AmXodr5NgVfS2h + 5i4QFQKBgFfbYHiMw5AAUQrBNkrAjLd1wIaO/6qS3w4OsCWKowhfaJLEXAbIRV7s +vAtgtlCovdasVj4RRLXFf+73naVTQjBZI+3jWHHyFk3+Zy86mQCSGv9WuDVV1IhS +h8InTVvK8wgdgX7qiw3pvU0roqNW4/j4j8OqJO3Zt4KO2iX8htsO +-----END RSA PRIVATE KEY----- diff --git a/docs/modules/ROOT/pages/servlet/authentication/passwords/ldap.adoc b/docs/modules/ROOT/pages/servlet/authentication/passwords/ldap.adoc index fed7b325b6d..33661f72d63 100644 --- a/docs/modules/ROOT/pages/servlet/authentication/passwords/ldap.adoc +++ b/docs/modules/ROOT/pages/servlet/authentication/passwords/ldap.adoc @@ -614,7 +614,7 @@ You need only supply the domain name and an LDAP URL that supplies the address o [NOTE] ==== -It is also possible to obtain the server's IP address byusing a DNS lookup. +It is also possible to obtain the server's IP address by using a DNS lookup. This is not currently supported, but hopefully will be in a future version. ==== diff --git a/docs/modules/ROOT/pages/servlet/authentication/session-management.adoc b/docs/modules/ROOT/pages/servlet/authentication/session-management.adoc index f9434d38264..f269c960bc6 100644 --- a/docs/modules/ROOT/pages/servlet/authentication/session-management.adoc +++ b/docs/modules/ROOT/pages/servlet/authentication/session-management.adoc @@ -788,7 +788,7 @@ Java:: @Bean public SecurityFilterChain filterChain(HttpSecurity http) { http - .sessionManagement((session) - session + .sessionManagement((session) -> session .sessionFixation((sessionFixation) -> sessionFixation .newSession() ) diff --git a/docs/modules/ROOT/pages/servlet/authorization/method-security.adoc b/docs/modules/ROOT/pages/servlet/authorization/method-security.adoc index 920558fc275..4d3f55122fc 100644 --- a/docs/modules/ROOT/pages/servlet/authorization/method-security.adoc +++ b/docs/modules/ROOT/pages/servlet/authorization/method-security.adoc @@ -160,12 +160,12 @@ For example, if needed, you can disable the Spring Security defaults and <<_enab The method interceptors are as follows: -* For <>, Spring Security uses {security-api-url}org/springframework/security/authorization/method/AuthorizationManagerBeforeMethodInterceptor.html[`AuthenticationManagerBeforeMethodInterceptor#preAuthorize`], which in turn uses {security-api-url}org/springframework/security/authorization/method/PreAuthorizeAuthorizationManager.html[`PreAuthorizeAuthorizationManager`] -* For <>, Spring Security uses {security-api-url}org/springframework/security/authorization/method/AuthorizationManagerAfterMethodInterceptor.html[`AuthenticationManagerAfterMethodInterceptor#postAuthorize`], which in turn uses {security-api-url}org/springframework/security/authorization/method/PostAuthorizeAuthorizationManager.html[`PostAuthorizeAuthorizationManager`] +* For <>, Spring Security uses {security-api-url}org/springframework/security/authorization/method/AuthorizationManagerBeforeMethodInterceptor.html[`AuthorizationManagerBeforeMethodInterceptor#preAuthorize`], which in turn uses {security-api-url}org/springframework/security/authorization/method/PreAuthorizeAuthorizationManager.html[`PreAuthorizeAuthorizationManager`] +* For <>, Spring Security uses {security-api-url}org/springframework/security/authorization/method/AuthorizationManagerAfterMethodInterceptor.html[`AuthorizationManagerBeforeMethodInterceptor#postAuthorize`], which in turn uses {security-api-url}org/springframework/security/authorization/method/PostAuthorizeAuthorizationManager.html[`PostAuthorizeAuthorizationManager`] * For <>, Spring Security uses {security-api-url}org/springframework/security/authorization/method/PreFilterAuthorizationMethodInterceptor.html[`PreFilterAuthorizationMethodInterceptor`] * For <>, Spring Security uses {security-api-url}org/springframework/security/authorization/method/PostFilterAuthorizationMethodInterceptor.html[`PostFilterAuthorizationMethodInterceptor`] -* For <>, Spring Security uses {security-api-url}org/springframework/security/authorization/method/AuthorizationManagerBeforeMethodInterceptor.html[`AuthenticationManagerBeforeMethodInterceptor#secured`], which in turn uses {security-api-url}org/springframework/security/authorization/method/SecuredAuthorizationManager.html[`SecuredAuthorizationManager`] -* For JSR-250 annotations, Spring Security uses {security-api-url}org/springframework/security/authorization/method/AuthorizationManagerBeforeMethodInterceptor.html[`AuthenticationManagerBeforeMethodInterceptor#jsr250`], which in turn uses {security-api-url}org/springframework/security/authorization/method/Jsr250AuthorizationManager.html[`Jsr250AuthorizationManager`] +* For <>, Spring Security uses {security-api-url}org/springframework/security/authorization/method/AuthorizationManagerBeforeMethodInterceptor.html[`AuthorizationManagerBeforeMethodInterceptor#secured`], which in turn uses {security-api-url}org/springframework/security/authorization/method/SecuredAuthorizationManager.html[`SecuredAuthorizationManager`] +* For JSR-250 annotations, Spring Security uses {security-api-url}org/springframework/security/authorization/method/AuthorizationManagerBeforeMethodInterceptor.html[`AuthorizationManagerBeforeMethodInterceptor#jsr250`], which in turn uses {security-api-url}org/springframework/security/authorization/method/Jsr250AuthorizationManager.html[`Jsr250AuthorizationManager`] Generally speaking, you can consider the following listing as representative of what interceptors Spring Security publishes when you add `@EnableMethodSecurity`: diff --git a/docs/modules/ROOT/pages/whats-new.adoc b/docs/modules/ROOT/pages/whats-new.adoc index 9fa19e003d4..06b919de643 100644 --- a/docs/modules/ROOT/pages/whats-new.adoc +++ b/docs/modules/ROOT/pages/whats-new.adoc @@ -4,6 +4,10 @@ Spring Security 6.3 provides a number of new features. Below are the highlights of the release. +== General + +- https://spring.io/blog/2024/01/19/spring-security-6-3-adds-passive-jdk-serialization-deserialization-for[blog post] - Added Passive JDK Serialization/Deserialization for Seamless Upgrades + == Configuration - https://github.com/spring-projects/spring-security/issues/6192[gh-6192] - xref:reactive/authentication/concurrent-sessions-control.adoc[docs] Add Concurrent Sessions Control on WebFlux @@ -11,3 +15,7 @@ Below are the highlights of the release. == CAS - https://github.com/spring-projects/spring-security/pull/14193[gh-14193] - Added support for CAS Gateway Authentication + +== Crypto + +- https://github.com/spring-projects/spring-security/issues/14202[gh-14202] - Migrated spring-security-rsa into spring-security-crypto diff --git a/messaging/src/main/java/org/springframework/security/messaging/handler/invocation/reactive/CurrentSecurityContextArgumentResolver.java b/messaging/src/main/java/org/springframework/security/messaging/handler/invocation/reactive/CurrentSecurityContextArgumentResolver.java index 3b67e85a437..ec69e1d389f 100644 --- a/messaging/src/main/java/org/springframework/security/messaging/handler/invocation/reactive/CurrentSecurityContextArgumentResolver.java +++ b/messaging/src/main/java/org/springframework/security/messaging/handler/invocation/reactive/CurrentSecurityContextArgumentResolver.java @@ -118,7 +118,21 @@ public void setAdapterRegistry(ReactiveAdapterRegistry adapterRegistry) { @Override public boolean supportsParameter(MethodParameter parameter) { - return findMethodAnnotation(CurrentSecurityContext.class, parameter) != null; + return isMonoSecurityContext(parameter) + || findMethodAnnotation(CurrentSecurityContext.class, parameter) != null; + } + + private boolean isMonoSecurityContext(MethodParameter parameter) { + boolean isParameterPublisher = Publisher.class.isAssignableFrom(parameter.getParameterType()); + if (isParameterPublisher) { + ResolvableType resolvableType = ResolvableType.forMethodParameter(parameter); + Class genericType = resolvableType.resolveGeneric(0); + if (genericType == null) { + return false; + } + return SecurityContext.class.isAssignableFrom(genericType); + } + return false; } @Override @@ -136,6 +150,14 @@ public Mono resolveArgument(MethodParameter parameter, Message messag private Object resolveSecurityContext(MethodParameter parameter, Object securityContext) { CurrentSecurityContext contextAnno = findMethodAnnotation(CurrentSecurityContext.class, parameter); + if (contextAnno != null) { + return resolveSecurityContextFromAnnotation(contextAnno, parameter, securityContext); + } + return securityContext; + } + + private Object resolveSecurityContextFromAnnotation(CurrentSecurityContext contextAnno, MethodParameter parameter, + Object securityContext) { String expressionToParse = contextAnno.expression(); if (StringUtils.hasLength(expressionToParse)) { StandardEvaluationContext context = new StandardEvaluationContext(); diff --git a/messaging/src/test/java/org/springframework/security/messaging/handler/invocation/reactive/CurrentSecurityContextArgumentResolverTests.java b/messaging/src/test/java/org/springframework/security/messaging/handler/invocation/reactive/CurrentSecurityContextArgumentResolverTests.java index 9b715b65451..22876bde63e 100644 --- a/messaging/src/test/java/org/springframework/security/messaging/handler/invocation/reactive/CurrentSecurityContextArgumentResolverTests.java +++ b/messaging/src/test/java/org/springframework/security/messaging/handler/invocation/reactive/CurrentSecurityContextArgumentResolverTests.java @@ -46,6 +46,24 @@ public void supportsParameterWhenAuthenticationPrincipalThenTrue() { assertThat(this.resolver.supportsParameter(arg0("currentSecurityContextOnMonoSecurityContext"))).isTrue(); } + @Test + public void supportsParameterWhenMonoSecurityContextNoAnnotationThenTrue() { + assertThat(this.resolver.supportsParameter(arg0("currentSecurityContextOnMonoSecurityContextNoAnnotation"))) + .isTrue(); + } + + @Test + public void supportsParameterWhenMonoCustomSecurityContextNoAnnotationThenTrue() { + assertThat( + this.resolver.supportsParameter(arg0("currentCustomSecurityContextOnMonoSecurityContextNoAnnotation"))) + .isTrue(); + } + + @Test + public void supportsParameterWhenNoSecurityContextNoAnnotationThenFalse() { + assertThat(this.resolver.supportsParameter(arg0("currentSecurityContextOnMonoStringNoAnnotation"))).isFalse(); + } + @Test public void resolveArgumentWhenAuthenticationPrincipalAndEmptyContextThenNull() { Object result = this.resolver.resolveArgument(arg0("currentSecurityContextOnMonoSecurityContext"), null) @@ -67,6 +85,18 @@ public void resolveArgumentWhenAuthenticationPrincipalThenFound() { private void currentSecurityContextOnMonoSecurityContext(@CurrentSecurityContext Mono context) { } + @SuppressWarnings("unused") + private void currentSecurityContextOnMonoSecurityContextNoAnnotation(Mono context) { + } + + @SuppressWarnings("unused") + private void currentCustomSecurityContextOnMonoSecurityContextNoAnnotation(Mono context) { + } + + @SuppressWarnings("unused") + private void currentSecurityContextOnMonoStringNoAnnotation(Mono context) { + } + @Test public void supportsParameterWhenCurrentUserThenTrue() { assertThat(this.resolver.supportsParameter(arg0("currentUserOnMonoUserDetails"))).isTrue(); @@ -110,6 +140,41 @@ public void supportsParameterWhenNotAnnotatedThenFalse() { private void monoUserDetails(Mono user) { } + @Test + public void supportsParameterWhenSecurityContextNotAnnotatedThenTrue() { + assertThat(this.resolver.supportsParameter(arg0("monoSecurityContext"))).isTrue(); + } + + @Test + public void resolveArgumentWhenMonoSecurityContextNoAnnotationThenFound() { + Authentication authentication = TestAuthentication.authenticatedUser(); + Mono result = (Mono) this.resolver + .resolveArgument(arg0("monoSecurityContext"), null) + .contextWrite(ReactiveSecurityContextHolder.withAuthentication(authentication)) + .block(); + assertThat(result.block().getAuthentication().getPrincipal()).isEqualTo(authentication.getPrincipal()); + } + + @SuppressWarnings("unused") + private void monoSecurityContext(Mono securityContext) { + } + + @Test + public void resolveArgumentWhenMonoCustomSecurityContextNoAnnotationThenFound() { + Authentication authentication = TestAuthentication.authenticatedUser(); + CustomSecurityContext securityContext = new CustomSecurityContext(); + securityContext.setAuthentication(authentication); + Mono result = (Mono) this.resolver + .resolveArgument(arg0("monoCustomSecurityContext"), null) + .contextWrite(ReactiveSecurityContextHolder.withSecurityContext(Mono.just(securityContext))) + .block(); + assertThat(result.block().getAuthentication().getPrincipal()).isEqualTo(authentication.getPrincipal()); + } + + @SuppressWarnings("unused") + private void monoCustomSecurityContext(Mono securityContext) { + } + private MethodParameter arg0(String methodName) { ResolvableMethod method = ResolvableMethod.on(getClass()).named(methodName).method(); return new SynthesizingMethodParameter(method.method(), 0); @@ -121,4 +186,20 @@ private MethodParameter arg0(String methodName) { } + static class CustomSecurityContext implements SecurityContext { + + private Authentication authentication; + + @Override + public Authentication getAuthentication() { + return this.authentication; + } + + @Override + public void setAuthentication(Authentication authentication) { + this.authentication = authentication; + } + + } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenDecoderFactory.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenDecoderFactory.java index cf3e5fb206b..fe6996aedf1 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenDecoderFactory.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenDecoderFactory.java @@ -171,7 +171,7 @@ private NimbusJwtDecoder buildDecoder(ClientRegistration clientRegistration) { } return NimbusJwtDecoder.withJwkSetUri(jwkSetUri) .jwsAlgorithm((SignatureAlgorithm) jwsAlgorithm) - .restOperations(restOperationsFactory.apply(clientRegistration)) + .restOperations(this.restOperationsFactory.apply(clientRegistration)) .build(); } if (jwsAlgorithm != null && MacAlgorithm.class.isAssignableFrom(jwsAlgorithm.getClass())) { diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/ReactiveOidcIdTokenDecoderFactory.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/ReactiveOidcIdTokenDecoderFactory.java index 0e2b2b2ddd9..1cc8e5a6e44 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/ReactiveOidcIdTokenDecoderFactory.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/ReactiveOidcIdTokenDecoderFactory.java @@ -168,7 +168,7 @@ private NimbusReactiveJwtDecoder buildDecoder(ClientRegistration clientRegistrat } return NimbusReactiveJwtDecoder.withJwkSetUri(jwkSetUri) .jwsAlgorithm((SignatureAlgorithm) jwsAlgorithm) - .webClient(webClientFactory.apply(clientRegistration)) + .webClient(this.webClientFactory.apply(clientRegistration)) .build(); } if (jwsAlgorithm != null && MacAlgorithm.class.isAssignableFrom(jwsAlgorithm.getClass())) { diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/DefaultOAuth2UserService.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/DefaultOAuth2UserService.java index 0851006de3d..95336084b7b 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/DefaultOAuth2UserService.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/DefaultOAuth2UserService.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2020 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,9 +16,9 @@ package org.springframework.security.oauth2.client.userinfo; +import java.util.Collection; import java.util.LinkedHashSet; import java.util.Map; -import java.util.Set; import org.springframework.core.ParameterizedTypeReference; import org.springframework.core.convert.converter.Converter; @@ -76,6 +76,9 @@ public class DefaultOAuth2UserService implements OAuth2UserService> requestEntityConverter = new OAuth2UserRequestEntityConverter(); + private Converter, Map>> attributesConverter = ( + request) -> (attributes) -> attributes; + private RestOperations restOperations; public DefaultOAuth2UserService() { @@ -87,35 +90,39 @@ public DefaultOAuth2UserService() { @Override public OAuth2User loadUser(OAuth2UserRequest userRequest) throws OAuth2AuthenticationException { Assert.notNull(userRequest, "userRequest cannot be null"); - if (!StringUtils - .hasText(userRequest.getClientRegistration().getProviderDetails().getUserInfoEndpoint().getUri())) { - OAuth2Error oauth2Error = new OAuth2Error(MISSING_USER_INFO_URI_ERROR_CODE, - "Missing required UserInfo Uri in UserInfoEndpoint for Client Registration: " - + userRequest.getClientRegistration().getRegistrationId(), - null); - throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); - } - String userNameAttributeName = userRequest.getClientRegistration() - .getProviderDetails() - .getUserInfoEndpoint() - .getUserNameAttributeName(); - if (!StringUtils.hasText(userNameAttributeName)) { - OAuth2Error oauth2Error = new OAuth2Error(MISSING_USER_NAME_ATTRIBUTE_ERROR_CODE, - "Missing required \"user name\" attribute name in UserInfoEndpoint for Client Registration: " - + userRequest.getClientRegistration().getRegistrationId(), - null); - throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); - } + String userNameAttributeName = getUserNameAttributeName(userRequest); RequestEntity request = this.requestEntityConverter.convert(userRequest); ResponseEntity> response = getResponse(userRequest, request); - Map userAttributes = response.getBody(); - Set authorities = new LinkedHashSet<>(); - authorities.add(new OAuth2UserAuthority(userAttributes)); OAuth2AccessToken token = userRequest.getAccessToken(); - for (String authority : token.getScopes()) { - authorities.add(new SimpleGrantedAuthority("SCOPE_" + authority)); - } - return new DefaultOAuth2User(authorities, userAttributes, userNameAttributeName); + Map attributes = this.attributesConverter.convert(userRequest).convert(response.getBody()); + Collection authorities = getAuthorities(token, attributes); + return new DefaultOAuth2User(authorities, attributes, userNameAttributeName); + } + + /** + * Use this strategy to adapt user attributes into a format understood by Spring + * Security; by default, the original attributes are preserved. + * + *

+ * This can be helpful, for example, if the user attribute is nested. Since Spring + * Security needs the username attribute to be at the top level, you can use this + * method to do: + * + *

+	 *     DefaultOAuth2UserService userService = new DefaultOAuth2UserService();
+	 *     userService.setAttributesConverter((userRequest) -> (attributes) ->
+	 *         Map<String, Object> userObject = (Map<String, Object>) attributes.get("user");
+	 *         attributes.put("user-name", userObject.get("user-name"));
+	 *         return attributes;
+	 *     });
+	 * 
+ * @param attributesConverter the attribute adaptation strategy to use + * @since 6.3 + */ + public void setAttributesConverter( + Converter, Map>> attributesConverter) { + Assert.notNull(attributesConverter, "attributesConverter cannot be null"); + this.attributesConverter = attributesConverter; } private ResponseEntity> getResponse(OAuth2UserRequest userRequest, RequestEntity request) { @@ -157,6 +164,38 @@ private ResponseEntity> getResponse(OAuth2UserRequest userRe } } + private String getUserNameAttributeName(OAuth2UserRequest userRequest) { + if (!StringUtils + .hasText(userRequest.getClientRegistration().getProviderDetails().getUserInfoEndpoint().getUri())) { + OAuth2Error oauth2Error = new OAuth2Error(MISSING_USER_INFO_URI_ERROR_CODE, + "Missing required UserInfo Uri in UserInfoEndpoint for Client Registration: " + + userRequest.getClientRegistration().getRegistrationId(), + null); + throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); + } + String userNameAttributeName = userRequest.getClientRegistration() + .getProviderDetails() + .getUserInfoEndpoint() + .getUserNameAttributeName(); + if (!StringUtils.hasText(userNameAttributeName)) { + OAuth2Error oauth2Error = new OAuth2Error(MISSING_USER_NAME_ATTRIBUTE_ERROR_CODE, + "Missing required \"user name\" attribute name in UserInfoEndpoint for Client Registration: " + + userRequest.getClientRegistration().getRegistrationId(), + null); + throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); + } + return userNameAttributeName; + } + + private Collection getAuthorities(OAuth2AccessToken token, Map attributes) { + Collection authorities = new LinkedHashSet<>(); + authorities.add(new OAuth2UserAuthority(attributes)); + for (String authority : token.getScopes()) { + authorities.add(new SimpleGrantedAuthority("SCOPE_" + authority)); + } + return authorities; + } + /** * Sets the {@link Converter} used for converting the {@link OAuth2UserRequest} to a * {@link RequestEntity} representation of the UserInfo Request. diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/DefaultReactiveOAuth2UserService.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/DefaultReactiveOAuth2UserService.java index 90c33fd41b1..920119baab3 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/DefaultReactiveOAuth2UserService.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/userinfo/DefaultReactiveOAuth2UserService.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,6 +26,7 @@ import reactor.core.publisher.Mono; import org.springframework.core.ParameterizedTypeReference; +import org.springframework.core.convert.converter.Converter; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpStatusCode; import org.springframework.http.MediaType; @@ -78,6 +79,9 @@ public class DefaultReactiveOAuth2UserService implements ReactiveOAuth2UserServi private static final ParameterizedTypeReference> STRING_STRING_MAP = new ParameterizedTypeReference>() { }; + private Converter, Map>> attributesConverter = ( + request) -> (attributes) -> attributes; + private WebClient webClient = WebClient.create(); @Override @@ -123,7 +127,8 @@ public Mono loadUser(OAuth2UserRequest userRequest) throws OAuth2Aut throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); }) ) - .bodyToMono(DefaultReactiveOAuth2UserService.STRING_OBJECT_MAP); + .bodyToMono(DefaultReactiveOAuth2UserService.STRING_OBJECT_MAP) + .mapNotNull((attributes) -> this.attributesConverter.convert(userRequest).convert(attributes)); return userAttributes.map((attrs) -> { GrantedAuthority authority = new OAuth2UserAuthority(attrs); Set authorities = new HashSet<>(); @@ -184,6 +189,32 @@ private WebClient.RequestHeadersSpec getRequestHeaderSpec(OAuth2UserRequest u // @formatter:on } + /** + * Use this strategy to adapt user attributes into a format understood by Spring + * Security; by default, the original attributes are preserved. + * + *

+ * This can be helpful, for example, if the user attribute is nested. Since Spring + * Security needs the username attribute to be at the top level, you can use this + * method to do: + * + *

+	 *     DefaultReactiveOAuth2UserService userService = new DefaultReactiveOAuth2UserService();
+	 *     userService.setAttributesConverter((userRequest) -> (attributes) ->
+	 *         Map<String, Object> userObject = (Map<String, Object>) attributes.get("user");
+	 *         attributes.put("user-name", userObject.get("user-name"));
+	 *         return attributes;
+	 *     });
+	 * 
+ * @param attributesConverter the attribute adaptation strategy to use + * @since 6.3 + */ + public void setAttributesConverter( + Converter, Map>> attributesConverter) { + Assert.notNull(attributesConverter, "attributesConverter cannot be null"); + this.attributesConverter = attributesConverter; + } + /** * Sets the {@link WebClient} used for retrieving the user endpoint * @param webClient the client to use diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenDecoderFactoryTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenDecoderFactoryTests.java index b417d17b676..52db517f5c9 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenDecoderFactoryTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenDecoderFactoryTests.java @@ -100,7 +100,7 @@ public void setClaimTypeConverterFactoryWhenNullThenThrowIllegalArgumentExceptio @Test public void setRestOperationsFactoryWhenNullThenThrowIllegalArgumentException() { assertThatIllegalArgumentException() - .isThrownBy(() -> this.idTokenDecoderFactory.setRestOperationsFactory(null)); + .isThrownBy(() -> this.idTokenDecoderFactory.setRestOperationsFactory(null)); } @Test @@ -187,13 +187,12 @@ public void createDecoderWhenCustomClaimTypeConverterFactorySetThenApplied() { @Test public void createDecoderWhenCustomRestOperationsFactorySetThenApplied() { - Function customRestOperationsFactory = mock( - Function.class); + Function customRestOperationsFactory = mock(Function.class); this.idTokenDecoderFactory.setRestOperationsFactory(customRestOperationsFactory); ClientRegistration clientRegistration = this.registration.build(); - given(customRestOperationsFactory.apply(same(clientRegistration))) - .willReturn(new RestTemplate()); + given(customRestOperationsFactory.apply(same(clientRegistration))).willReturn(new RestTemplate()); this.idTokenDecoderFactory.createDecoder(clientRegistration); verify(customRestOperationsFactory).apply(same(clientRegistration)); } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/ReactiveOidcIdTokenDecoderFactoryTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/ReactiveOidcIdTokenDecoderFactoryTests.java index 87f7dd67f23..8f304e098d2 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/ReactiveOidcIdTokenDecoderFactoryTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/ReactiveOidcIdTokenDecoderFactoryTests.java @@ -97,8 +97,7 @@ public void setClaimTypeConverterFactoryWhenNullThenThrowIllegalArgumentExceptio @Test public void setWebClientFactoryWhenNullThenThrowIllegalArgumentException() { - assertThatIllegalArgumentException() - .isThrownBy(() -> this.idTokenDecoderFactory.setWebClientFactory(null)); + assertThatIllegalArgumentException().isThrownBy(() -> this.idTokenDecoderFactory.setWebClientFactory(null)); } @Test @@ -185,13 +184,12 @@ public void createDecoderWhenCustomClaimTypeConverterFactorySetThenApplied() { @Test public void createDecoderWhenCustomWebClientFactorySetThenApplied() { - Function customWebClientFactory = mock( - Function.class); + Function customWebClientFactory = mock(Function.class); this.idTokenDecoderFactory.setWebClientFactory(customWebClientFactory); ClientRegistration clientRegistration = this.registration.build(); - given(customWebClientFactory.apply(same(clientRegistration))) - .willReturn(WebClient.create()); + given(customWebClientFactory.apply(same(clientRegistration))).willReturn(WebClient.create()); this.idTokenDecoderFactory.createDecoder(clientRegistration); verify(customWebClientFactory).apply(same(clientRegistration)); } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcReactiveOAuth2UserServiceTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcReactiveOAuth2UserServiceTests.java index 2b8e6180dfc..14acfdea160 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcReactiveOAuth2UserServiceTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcReactiveOAuth2UserServiceTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ package org.springframework.security.oauth2.client.oidc.userinfo; +import java.io.IOException; import java.time.Duration; import java.time.Instant; import java.util.Collections; @@ -24,6 +25,8 @@ import java.util.Map; import java.util.function.Function; +import okhttp3.mockwebserver.MockResponse; +import okhttp3.mockwebserver.MockWebServer; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; @@ -32,13 +35,17 @@ import reactor.core.publisher.Mono; import org.springframework.core.convert.converter.Converter; +import org.springframework.http.HttpHeaders; +import org.springframework.http.MediaType; import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.authority.SimpleGrantedAuthority; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.client.userinfo.DefaultReactiveOAuth2UserService; import org.springframework.security.oauth2.client.userinfo.OAuth2UserRequest; import org.springframework.security.oauth2.client.userinfo.ReactiveOAuth2UserService; +import org.springframework.security.oauth2.core.AuthenticationMethod; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; @@ -203,8 +210,62 @@ public void loadUserWhenTokenDoesNotContainScopesThenNoScopeAuthorities() { assertThat(userAuthority.getAttributes()).isEqualTo(user.getAttributes()); } + @Test + public void loadUserWhenNestedUserInfoSuccessThenReturnUser() throws IOException { + // @formatter:off + String userInfoResponse = "{\n" + + " \"user\": {\"user-name\": \"user1\"},\n" + + " \"sub\" : \"" + this.idToken.getSubject() + "\",\n" + + " \"first-name\": \"first\",\n" + + " \"last-name\": \"last\",\n" + + " \"middle-name\": \"middle\",\n" + + " \"address\": \"address\",\n" + + " \"email\": \"user1@example.com\"\n" + + "}\n"; + // @formatter:on + try (MockWebServer server = new MockWebServer()) { + server.start(); + enqueueApplicationJsonBody(server, userInfoResponse); + String userInfoUri = server.url("/user").toString(); + ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration() + .userInfoUri(userInfoUri) + .userInfoAuthenticationMethod(AuthenticationMethod.HEADER) + .userNameAttributeName("user-name") + .build(); + OidcReactiveOAuth2UserService userService = new OidcReactiveOAuth2UserService(); + DefaultReactiveOAuth2UserService oAuth2UserService = new DefaultReactiveOAuth2UserService(); + oAuth2UserService.setAttributesConverter((request) -> (attributes) -> { + Map user = (Map) attributes.get("user"); + attributes.put("user-name", user.get("user-name")); + return attributes; + }); + userService.setOauth2UserService(oAuth2UserService); + OAuth2User user = userService + .loadUser(new OidcUserRequest(clientRegistration, this.accessToken, this.idToken)) + .block(); + assertThat(user.getName()).isEqualTo("user1"); + assertThat(user.getAttributes()).hasSize(13); + assertThat(((Map) user.getAttribute("user")).get("user-name")).isEqualTo("user1"); + assertThat((String) user.getAttribute("first-name")).isEqualTo("first"); + assertThat((String) user.getAttribute("last-name")).isEqualTo("last"); + assertThat((String) user.getAttribute("middle-name")).isEqualTo("middle"); + assertThat((String) user.getAttribute("address")).isEqualTo("address"); + assertThat((String) user.getAttribute("email")).isEqualTo("user1@example.com"); + assertThat(user.getAuthorities()).hasSize(2); + assertThat(user.getAuthorities().iterator().next()).isInstanceOf(OAuth2UserAuthority.class); + OAuth2UserAuthority userAuthority = (OAuth2UserAuthority) user.getAuthorities().iterator().next(); + assertThat(userAuthority.getAuthority()).isEqualTo("OIDC_USER"); + assertThat(userAuthority.getAttributes()).isEqualTo(user.getAttributes()); + } + } + private OidcUserRequest userRequest() { return new OidcUserRequest(this.registration.build(), this.accessToken, this.idToken); } + private void enqueueApplicationJsonBody(MockWebServer server, String json) { + server.enqueue( + new MockResponse().setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE).setBody(json)); + } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserServiceTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserServiceTests.java index 310667e2fff..6d63e8a5c3f 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserServiceTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/userinfo/OidcUserServiceTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -52,6 +52,8 @@ import org.springframework.security.oauth2.core.oidc.TestOidcIdTokens; import org.springframework.security.oauth2.core.oidc.user.OidcUser; import org.springframework.security.oauth2.core.oidc.user.OidcUserAuthority; +import org.springframework.security.oauth2.core.user.OAuth2User; +import org.springframework.security.oauth2.core.user.OAuth2UserAuthority; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; @@ -492,6 +494,49 @@ public void loadUserWhenTokenDoesNotContainScopesAndUserInfoUriThenUserInfoReque assertThat(user.getUserInfo()).isNotNull(); } + @Test + public void loadUserWhenNestedUserInfoSuccessThenReturnUser() { + // @formatter:off + String userInfoResponse = "{\n" + + " \"user\": {\"user-name\": \"user1\"},\n" + + " \"sub\" : \"subject1\",\n" + + " \"first-name\": \"first\",\n" + + " \"last-name\": \"last\",\n" + + " \"middle-name\": \"middle\",\n" + + " \"address\": \"address\",\n" + + " \"email\": \"user1@example.com\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(userInfoResponse)); + String userInfoUri = this.server.url("/user").toString(); + ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri) + .userInfoAuthenticationMethod(AuthenticationMethod.HEADER) + .userNameAttributeName("user-name") + .build(); + OidcUserService userService = new OidcUserService(); + DefaultOAuth2UserService oAuth2UserService = new DefaultOAuth2UserService(); + oAuth2UserService.setAttributesConverter((request) -> (attributes) -> { + Map user = (Map) attributes.get("user"); + attributes.put("user-name", user.get("user-name")); + return attributes; + }); + userService.setOauth2UserService(oAuth2UserService); + OAuth2User user = userService.loadUser(new OidcUserRequest(clientRegistration, this.accessToken, this.idToken)); + assertThat(user.getName()).isEqualTo("user1"); + assertThat(user.getAttributes()).hasSize(9); + assertThat(((Map) user.getAttribute("user")).get("user-name")).isEqualTo("user1"); + assertThat((String) user.getAttribute("first-name")).isEqualTo("first"); + assertThat((String) user.getAttribute("last-name")).isEqualTo("last"); + assertThat((String) user.getAttribute("middle-name")).isEqualTo("middle"); + assertThat((String) user.getAttribute("address")).isEqualTo("address"); + assertThat((String) user.getAttribute("email")).isEqualTo("user1@example.com"); + assertThat(user.getAuthorities()).hasSize(3); + assertThat(user.getAuthorities().iterator().next()).isInstanceOf(OAuth2UserAuthority.class); + OAuth2UserAuthority userAuthority = (OAuth2UserAuthority) user.getAuthorities().iterator().next(); + assertThat(userAuthority.getAuthority()).isEqualTo("OIDC_USER"); + assertThat(userAuthority.getAttributes()).isEqualTo(user.getAttributes()); + } + private MockResponse jsonResponse(String json) { // @formatter:off return new MockResponse() diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/DefaultOAuth2UserServiceTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/DefaultOAuth2UserServiceTests.java index 361100ec6f3..99457d4e3e7 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/DefaultOAuth2UserServiceTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/DefaultOAuth2UserServiceTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -158,6 +158,46 @@ public void loadUserWhenUserInfoSuccessResponseThenReturnUser() { assertThat(userAuthority.getAttributes()).isEqualTo(user.getAttributes()); } + @Test + public void loadUserWhenNestedUserInfoSuccessThenReturnUser() { + // @formatter:off + String userInfoResponse = "{\n" + + " \"user\": {\"user-name\": \"user1\"},\n" + + " \"first-name\": \"first\",\n" + + " \"last-name\": \"last\",\n" + + " \"middle-name\": \"middle\",\n" + + " \"address\": \"address\",\n" + + " \"email\": \"user1@example.com\"\n" + + "}\n"; + // @formatter:on + this.server.enqueue(jsonResponse(userInfoResponse)); + String userInfoUri = this.server.url("/user").toString(); + ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri) + .userInfoAuthenticationMethod(AuthenticationMethod.HEADER) + .userNameAttributeName("user-name") + .build(); + DefaultOAuth2UserService userService = new DefaultOAuth2UserService(); + userService.setAttributesConverter((request) -> (attributes) -> { + Map user = (Map) attributes.get("user"); + attributes.put("user-name", user.get("user-name")); + return attributes; + }); + OAuth2User user = userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken)); + assertThat(user.getName()).isEqualTo("user1"); + assertThat(user.getAttributes()).hasSize(7); + assertThat(((Map) user.getAttribute("user")).get("user-name")).isEqualTo("user1"); + assertThat((String) user.getAttribute("first-name")).isEqualTo("first"); + assertThat((String) user.getAttribute("last-name")).isEqualTo("last"); + assertThat((String) user.getAttribute("middle-name")).isEqualTo("middle"); + assertThat((String) user.getAttribute("address")).isEqualTo("address"); + assertThat((String) user.getAttribute("email")).isEqualTo("user1@example.com"); + assertThat(user.getAuthorities()).hasSize(1); + assertThat(user.getAuthorities().iterator().next()).isInstanceOf(OAuth2UserAuthority.class); + OAuth2UserAuthority userAuthority = (OAuth2UserAuthority) user.getAuthorities().iterator().next(); + assertThat(userAuthority.getAuthority()).isEqualTo("OAUTH2_USER"); + assertThat(userAuthority.getAttributes()).isEqualTo(user.getAttributes()); + } + @Test public void loadUserWhenUserInfoSuccessResponseInvalidThenThrowOAuth2AuthenticationException() { // @formatter:off @@ -373,6 +413,12 @@ public void loadUserWhenUserInfoSuccessResponseInvalidContentTypeThenThrowOAuth2 + "from '" + userInfoUri + "': response contains invalid content type 'text/plain'."); } + @Test + public void setAttributesConverterWhenNullThenException() { + assertThatExceptionOfType(IllegalArgumentException.class) + .isThrownBy(() -> this.userService.setAttributesConverter(null)); + } + private DefaultOAuth2UserService withMockResponse(Map response) { ResponseEntity> responseEntity = new ResponseEntity<>(response, HttpStatus.OK); Converter> requestEntityConverter = mock(Converter.class); diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/DefaultReactiveOAuth2UserServiceTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/DefaultReactiveOAuth2UserServiceTests.java index c9989ae3202..68aa1a31d3a 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/DefaultReactiveOAuth2UserServiceTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/userinfo/DefaultReactiveOAuth2UserServiceTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -165,6 +165,46 @@ public void loadUserWhenUserInfo201CreatedResponseThenReturnUser() { assertThatNoException().isThrownBy(() -> this.userService.loadUser(oauth2UserRequest()).block()); } + @Test + public void loadUserWhenNestedUserInfoSuccessThenReturnUser() { + // @formatter:off + String userInfoResponse = "{\n" + + " \"user\": {\"user-name\": \"user1\"},\n" + + " \"first-name\": \"first\",\n" + + " \"last-name\": \"last\",\n" + + " \"middle-name\": \"middle\",\n" + + " \"address\": \"address\",\n" + + " \"email\": \"user1@example.com\"\n" + + "}\n"; + // @formatter:on + enqueueApplicationJsonBody(userInfoResponse); + String userInfoUri = this.server.url("/user").toString(); + ClientRegistration clientRegistration = this.clientRegistration.userInfoUri(userInfoUri) + .userInfoAuthenticationMethod(AuthenticationMethod.HEADER) + .userNameAttributeName("user-name") + .build(); + DefaultReactiveOAuth2UserService userService = new DefaultReactiveOAuth2UserService(); + userService.setAttributesConverter((request) -> (attributes) -> { + Map user = (Map) attributes.get("user"); + attributes.put("user-name", user.get("user-name")); + return attributes; + }); + OAuth2User user = userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken)).block(); + assertThat(user.getName()).isEqualTo("user1"); + assertThat(user.getAttributes()).hasSize(7); + assertThat(((Map) user.getAttribute("user")).get("user-name")).isEqualTo("user1"); + assertThat((String) user.getAttribute("first-name")).isEqualTo("first"); + assertThat((String) user.getAttribute("last-name")).isEqualTo("last"); + assertThat((String) user.getAttribute("middle-name")).isEqualTo("middle"); + assertThat((String) user.getAttribute("address")).isEqualTo("address"); + assertThat((String) user.getAttribute("email")).isEqualTo("user1@example.com"); + assertThat(user.getAuthorities()).hasSize(1); + assertThat(user.getAuthorities().iterator().next()).isInstanceOf(OAuth2UserAuthority.class); + OAuth2UserAuthority userAuthority = (OAuth2UserAuthority) user.getAuthorities().iterator().next(); + assertThat(userAuthority.getAuthority()).isEqualTo("OAUTH2_USER"); + assertThat(userAuthority.getAttributes()).isEqualTo(user.getAttributes()); + } + // gh-5500 @Test public void loadUserWhenAuthenticationMethodHeaderSuccessResponseThenHttpMethodGet() throws Exception { @@ -290,6 +330,12 @@ public void loadUserWhenUserInfoSuccessResponseInvalidContentTypeThenThrowOAuth2 + "response contains invalid content type 'text/plain'"); } + @Test + public void setAttributesConverterWhenNullThenException() { + assertThatExceptionOfType(IllegalArgumentException.class) + .isThrownBy(() -> this.userService.setAttributesConverter(null)); + } + private DefaultReactiveOAuth2UserService withMockResponse(Map body) { WebClient real = WebClient.builder().build(); WebClient.RequestHeadersUriSpec spec = spy(real.post()); diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoder.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoder.java index 2da97ab96b6..ffbed62c29d 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoder.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoder.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2023 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -144,20 +144,17 @@ public void setClaimSetConverter(Converter, Map decode(String token) throws JwtException { - JWT jwt = parse(token); - if (jwt instanceof PlainJWT) { - throw new BadJwtException("Unsupported algorithm of " + jwt.getHeader().getAlgorithm()); - } - return this.decode(jwt); - } - - private JWT parse(String token) { + public Mono decode(String token) { try { - return JWTParser.parse(token); + JWT jwt = JWTParser.parse(token); + if (jwt instanceof PlainJWT) { + return Mono.error(new BadJwtException("Unsupported algorithm of " + jwt.getHeader().getAlgorithm())); + } + return this.decode(jwt); } catch (Exception ex) { - throw new BadJwtException("An error occurred while attempting to decode the Jwt: " + ex.getMessage(), ex); + return Mono.error(new BadJwtException( + "An error occurred while attempting to decode the Jwt: " + ex.getMessage(), ex)); } } diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/Saml2MetadataFilter.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/Saml2MetadataFilter.java index b195e4945a5..723acfeba76 100644 --- a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/Saml2MetadataFilter.java +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/Saml2MetadataFilter.java @@ -106,7 +106,7 @@ private void writeMetadataToResponse(HttpServletResponse response, Saml2Metadata response.setContentType(MediaType.APPLICATION_XML_VALUE); String format = "attachment; filename=\"%s\"; filename*=UTF-8''%s"; String fileName = metadata.getFileName(); - String encodedFileName = URLEncoder.encode(fileName, StandardCharsets.UTF_8.name()); + String encodedFileName = URLEncoder.encode(fileName, StandardCharsets.UTF_8); response.setHeader(HttpHeaders.CONTENT_DISPOSITION, String.format(format, fileName, encodedFileName)); response.setContentLength(metadata.getMetadata().getBytes(StandardCharsets.UTF_8).length); response.setCharacterEncoding(StandardCharsets.UTF_8.name()); diff --git a/web/src/main/java/org/springframework/security/web/method/annotation/CurrentSecurityContextArgumentResolver.java b/web/src/main/java/org/springframework/security/web/method/annotation/CurrentSecurityContextArgumentResolver.java index 05ef53a59ca..d1a6ba12a7d 100644 --- a/web/src/main/java/org/springframework/security/web/method/annotation/CurrentSecurityContextArgumentResolver.java +++ b/web/src/main/java/org/springframework/security/web/method/annotation/CurrentSecurityContextArgumentResolver.java @@ -85,7 +85,8 @@ public final class CurrentSecurityContextArgumentResolver implements HandlerMeth @Override public boolean supportsParameter(MethodParameter parameter) { - return findMethodAnnotation(CurrentSecurityContext.class, parameter) != null; + return SecurityContext.class.isAssignableFrom(parameter.getParameterType()) + || findMethodAnnotation(CurrentSecurityContext.class, parameter) != null; } @Override @@ -95,26 +96,12 @@ public Object resolveArgument(MethodParameter parameter, ModelAndViewContainer m if (securityContext == null) { return null; } - Object securityContextResult = securityContext; CurrentSecurityContext annotation = findMethodAnnotation(CurrentSecurityContext.class, parameter); - String expressionToParse = annotation.expression(); - if (StringUtils.hasLength(expressionToParse)) { - StandardEvaluationContext context = new StandardEvaluationContext(); - context.setRootObject(securityContext); - context.setVariable("this", securityContext); - context.setBeanResolver(this.beanResolver); - Expression expression = this.parser.parseExpression(expressionToParse); - securityContextResult = expression.getValue(context); - } - if (securityContextResult != null - && !parameter.getParameterType().isAssignableFrom(securityContextResult.getClass())) { - if (annotation.errorOnInvalidType()) { - throw new ClassCastException( - securityContextResult + " is not assignable to " + parameter.getParameterType()); - } - return null; + if (annotation != null) { + return resolveSecurityContextFromAnnotation(parameter, annotation, securityContext); } - return securityContextResult; + + return securityContext; } /** @@ -137,6 +124,29 @@ public void setBeanResolver(BeanResolver beanResolver) { this.beanResolver = beanResolver; } + private Object resolveSecurityContextFromAnnotation(MethodParameter parameter, CurrentSecurityContext annotation, + SecurityContext securityContext) { + Object securityContextResult = securityContext; + String expressionToParse = annotation.expression(); + if (StringUtils.hasLength(expressionToParse)) { + StandardEvaluationContext context = new StandardEvaluationContext(); + context.setRootObject(securityContext); + context.setVariable("this", securityContext); + context.setBeanResolver(this.beanResolver); + Expression expression = this.parser.parseExpression(expressionToParse); + securityContextResult = expression.getValue(context); + } + if (securityContextResult != null + && !parameter.getParameterType().isAssignableFrom(securityContextResult.getClass())) { + if (annotation.errorOnInvalidType()) { + throw new ClassCastException( + securityContextResult + " is not assignable to " + parameter.getParameterType()); + } + return null; + } + return securityContextResult; + } + /** * Obtain the specified {@link Annotation} on the specified {@link MethodParameter}. * @param annotationClass the class of the {@link Annotation} to find on the diff --git a/web/src/main/java/org/springframework/security/web/reactive/result/method/annotation/CurrentSecurityContextArgumentResolver.java b/web/src/main/java/org/springframework/security/web/reactive/result/method/annotation/CurrentSecurityContextArgumentResolver.java index a02a9c30b9f..fd51d8ac533 100644 --- a/web/src/main/java/org/springframework/security/web/reactive/result/method/annotation/CurrentSecurityContextArgumentResolver.java +++ b/web/src/main/java/org/springframework/security/web/reactive/result/method/annotation/CurrentSecurityContextArgumentResolver.java @@ -67,7 +67,21 @@ public void setBeanResolver(BeanResolver beanResolver) { @Override public boolean supportsParameter(MethodParameter parameter) { - return findMethodAnnotation(CurrentSecurityContext.class, parameter) != null; + return isMonoSecurityContext(parameter) + || findMethodAnnotation(CurrentSecurityContext.class, parameter) != null; + } + + private boolean isMonoSecurityContext(MethodParameter parameter) { + boolean isParameterPublisher = Publisher.class.isAssignableFrom(parameter.getParameterType()); + if (isParameterPublisher) { + ResolvableType resolvableType = ResolvableType.forMethodParameter(parameter); + Class genericType = resolvableType.resolveGeneric(0); + if (genericType == null) { + return false; + } + return SecurityContext.class.isAssignableFrom(genericType); + } + return false; } @Override @@ -95,6 +109,14 @@ public Mono resolveArgument(MethodParameter parameter, BindingContext bi */ private Object resolveSecurityContext(MethodParameter parameter, SecurityContext securityContext) { CurrentSecurityContext annotation = findMethodAnnotation(CurrentSecurityContext.class, parameter); + if (annotation != null) { + return resolveSecurityContextFromAnnotation(annotation, parameter, securityContext); + } + return securityContext; + } + + private Object resolveSecurityContextFromAnnotation(CurrentSecurityContext annotation, MethodParameter parameter, + Object securityContext) { Object securityContextResult = securityContext; String expressionToParse = annotation.expression(); if (StringUtils.hasLength(expressionToParse)) { diff --git a/web/src/main/java/org/springframework/security/web/util/matcher/IpAddressMatcher.java b/web/src/main/java/org/springframework/security/web/util/matcher/IpAddressMatcher.java index 3be7851094d..80c344fb275 100644 --- a/web/src/main/java/org/springframework/security/web/util/matcher/IpAddressMatcher.java +++ b/web/src/main/java/org/springframework/security/web/util/matcher/IpAddressMatcher.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -47,6 +47,7 @@ public final class IpAddressMatcher implements RequestMatcher { * come. */ public IpAddressMatcher(String ipAddress) { + assertStartsWithHexa(ipAddress); if (ipAddress.indexOf('/') > 0) { String[] addressAndMask = StringUtils.split(ipAddress, "/"); ipAddress = addressAndMask[0]; @@ -56,8 +57,9 @@ public IpAddressMatcher(String ipAddress) { this.nMaskBits = -1; } this.requiredAddress = parseAddress(ipAddress); - Assert.isTrue(this.requiredAddress.getAddress().length * 8 >= this.nMaskBits, - String.format("IP address %s is too short for bitmask of length %d", ipAddress, this.nMaskBits)); + String finalIpAddress = ipAddress; + Assert.isTrue(this.requiredAddress.getAddress().length * 8 >= this.nMaskBits, () -> String + .format("IP address %s is too short for bitmask of length %d", finalIpAddress, this.nMaskBits)); } @Override @@ -66,6 +68,7 @@ public boolean matches(HttpServletRequest request) { } public boolean matches(String address) { + assertStartsWithHexa(address); InetAddress remoteAddress = parseAddress(address); if (!this.requiredAddress.getClass().equals(remoteAddress.getClass())) { return false; @@ -88,6 +91,13 @@ public boolean matches(String address) { return true; } + private void assertStartsWithHexa(String ipAddress) { + Assert.isTrue( + ipAddress.charAt(0) == '[' || ipAddress.charAt(0) == ':' + || Character.digit(ipAddress.charAt(0), 16) != -1, + "ipAddress must start with a [, :, or a hexadecimal digit"); + } + private InetAddress parseAddress(String address) { try { return InetAddress.getByName(address); diff --git a/web/src/test/java/org/springframework/security/web/method/annotation/CurrentSecurityContextArgumentResolverTests.java b/web/src/test/java/org/springframework/security/web/method/annotation/CurrentSecurityContextArgumentResolverTests.java index f04b5269e5e..80c33ac9eca 100644 --- a/web/src/test/java/org/springframework/security/web/method/annotation/CurrentSecurityContextArgumentResolverTests.java +++ b/web/src/test/java/org/springframework/security/web/method/annotation/CurrentSecurityContextArgumentResolverTests.java @@ -69,9 +69,26 @@ public void cleanup() { SecurityContextHolder.clearContext(); } + @Test + public void supportsParameterNoAnnotationWrongType() { + assertThat(this.resolver.supportsParameter(showSecurityContextNoAnnotationTypeMismatch())).isFalse(); + } + @Test public void supportsParameterNoAnnotation() { - assertThat(this.resolver.supportsParameter(showSecurityContextNoAnnotation())).isFalse(); + assertThat(this.resolver.supportsParameter(showSecurityContextNoAnnotation())).isTrue(); + } + + @Test + public void supportsParameterCustomSecurityContextNoAnnotation() { + assertThat(this.resolver.supportsParameter(showSecurityContextWithCustomSecurityContextNoAnnotation())) + .isTrue(); + } + + @Test + public void supportsParameterNoAnnotationCustomType() { + assertThat(this.resolver.supportsParameter(showSecurityContextWithCustomSecurityContextNoAnnotation())) + .isTrue(); } @Test @@ -88,6 +105,24 @@ public void resolveArgumentWithCustomSecurityContext() { assertThat(customSecurityContext.getAuthentication().getPrincipal()).isEqualTo(principal); } + @Test + public void resolveArgumentWithCustomSecurityContextNoAnnotation() { + String principal = "custom_security_context"; + setAuthenticationPrincipalWithCustomSecurityContext(principal); + CustomSecurityContext customSecurityContext = (CustomSecurityContext) this.resolver + .resolveArgument(showSecurityContextWithCustomSecurityContextNoAnnotation(), null, null, null); + assertThat(customSecurityContext.getAuthentication().getPrincipal()).isEqualTo(principal); + } + + @Test + public void resolveArgumentWithNoAnnotation() { + String principal = "custom_security_context"; + setAuthenticationPrincipal(principal); + SecurityContext securityContext = (SecurityContext) this.resolver + .resolveArgument(showSecurityContextNoAnnotation(), null, null, null); + assertThat(securityContext.getAuthentication().getPrincipal()).isEqualTo(principal); + } + @Test public void resolveArgumentWithCustomSecurityContextTypeMatch() { String principal = "custom_security_context_type_match"; @@ -212,10 +247,14 @@ public void metaAnnotationWhenCurrentSecurityWithErrorOnInvalidTypeThenMisMatch( .resolveArgument(showCurrentSecurityWithErrorOnInvalidTypeMisMatch(), null, null, null)); } - private MethodParameter showSecurityContextNoAnnotation() { + private MethodParameter showSecurityContextNoAnnotationTypeMismatch() { return getMethodParameter("showSecurityContextNoAnnotation", String.class); } + private MethodParameter showSecurityContextNoAnnotation() { + return getMethodParameter("showSecurityContextNoAnnotation", SecurityContext.class); + } + private MethodParameter showSecurityContextAnnotation() { return getMethodParameter("showSecurityContextAnnotation", SecurityContext.class); } @@ -276,6 +315,11 @@ public MethodParameter showCurrentSecurityWithErrorOnInvalidTypeMisMatch() { return getMethodParameter("showCurrentSecurityWithErrorOnInvalidTypeMisMatch", String.class); } + public MethodParameter showSecurityContextWithCustomSecurityContextNoAnnotation() { + return getMethodParameter("showSecurityContextWithCustomSecurityContextNoAnnotation", + CustomSecurityContext.class); + } + private MethodParameter getMethodParameter(String methodName, Class... paramTypes) { Method method = ReflectionUtils.findMethod(TestController.class, methodName, paramTypes); return new MethodParameter(method, 0); @@ -358,6 +402,12 @@ public void showCurrentSecurityWithErrorOnInvalidTypeMisMatch( @CurrentSecurityWithErrorOnInvalidType String typeMisMatch) { } + public void showSecurityContextNoAnnotation(SecurityContext context) { + } + + public void showSecurityContextWithCustomSecurityContextNoAnnotation(CustomSecurityContext context) { + } + } static class CustomSecurityContext implements SecurityContext { diff --git a/web/src/test/java/org/springframework/security/web/reactive/result/method/annotation/CurrentSecurityContextArgumentResolverTests.java b/web/src/test/java/org/springframework/security/web/reactive/result/method/annotation/CurrentSecurityContextArgumentResolverTests.java index 06cc9282bff..5556a25ed1b 100644 --- a/web/src/test/java/org/springframework/security/web/reactive/result/method/annotation/CurrentSecurityContextArgumentResolverTests.java +++ b/web/src/test/java/org/springframework/security/web/reactive/result/method/annotation/CurrentSecurityContextArgumentResolverTests.java @@ -69,6 +69,14 @@ public class CurrentSecurityContextArgumentResolverTests { ResolvableMethod securityContextMethod = ResolvableMethod.on(getClass()).named("securityContext").build(); + ResolvableMethod securityContextNoAnnotationMethod = ResolvableMethod.on(getClass()) + .named("securityContextNoAnnotation") + .build(); + + ResolvableMethod customSecurityContextNoAnnotationMethod = ResolvableMethod.on(getClass()) + .named("customSecurityContextNoAnnotation") + .build(); + ResolvableMethod securityContextWithAuthentication = ResolvableMethod.on(getClass()) .named("securityContextWithAuthentication") .build(); @@ -87,6 +95,19 @@ public void supportsParameterCurrentSecurityContext() { .isTrue(); } + @Test + public void supportsParameterCurrentSecurityContextNoAnnotation() { + assertThat(this.resolver + .supportsParameter(this.securityContextNoAnnotationMethod.arg(Mono.class, SecurityContext.class))).isTrue(); + } + + @Test + public void supportsParameterCurrentCustomSecurityContextNoAnnotation() { + assertThat(this.resolver.supportsParameter( + this.customSecurityContextNoAnnotationMethod.arg(Mono.class, CustomSecurityContext.class))) + .isTrue(); + } + @Test public void supportsParameterWithAuthentication() { assertThat(this.resolver @@ -123,6 +144,40 @@ public void resolveArgumentWithSecurityContext() { ReactiveSecurityContextHolder.clearContext(); } + @Test + public void resolveArgumentWithSecurityContextNoAnnotation() { + MethodParameter parameter = ResolvableMethod.on(getClass()) + .named("securityContextNoAnnotation") + .build() + .arg(Mono.class, SecurityContext.class); + Authentication auth = buildAuthenticationWithPrincipal("hello"); + Context context = ReactiveSecurityContextHolder.withAuthentication(auth); + Mono argument = this.resolver.resolveArgument(parameter, this.bindingContext, this.exchange); + SecurityContext securityContext = (SecurityContext) argument.contextWrite(context) + .cast(Mono.class) + .block() + .block(); + assertThat(securityContext.getAuthentication()).isSameAs(auth); + ReactiveSecurityContextHolder.clearContext(); + } + + @Test + public void resolveArgumentWithCustomSecurityContextNoAnnotation() { + MethodParameter parameter = ResolvableMethod.on(getClass()) + .named("customSecurityContextNoAnnotation") + .build() + .arg(Mono.class, CustomSecurityContext.class); + Authentication auth = buildAuthenticationWithPrincipal("hello"); + Context context = ReactiveSecurityContextHolder.withSecurityContext(Mono.just(new CustomSecurityContext(auth))); + Mono argument = this.resolver.resolveArgument(parameter, this.bindingContext, this.exchange); + CustomSecurityContext securityContext = (CustomSecurityContext) argument.contextWrite(context) + .cast(Mono.class) + .block() + .block(); + assertThat(securityContext.getAuthentication()).isSameAs(auth); + ReactiveSecurityContextHolder.clearContext(); + } + @Test public void resolveArgumentWithCustomSecurityContext() { MethodParameter parameter = ResolvableMethod.on(getClass()) @@ -350,6 +405,12 @@ public void metaAnnotationWhenCurrentSecurityWithErrorOnInvalidTypeThenMisMatch( void securityContext(@CurrentSecurityContext Mono monoSecurityContext) { } + void securityContextNoAnnotation(Mono securityContextMono) { + } + + void customSecurityContextNoAnnotation(Mono securityContextMono) { + } + void customSecurityContext(@CurrentSecurityContext Mono monoSecurityContext) { } diff --git a/web/src/test/java/org/springframework/security/web/util/matcher/IpAddressMatcherTests.java b/web/src/test/java/org/springframework/security/web/util/matcher/IpAddressMatcherTests.java index 0362917be13..17c2bbadb3a 100644 --- a/web/src/test/java/org/springframework/security/web/util/matcher/IpAddressMatcherTests.java +++ b/web/src/test/java/org/springframework/security/web/util/matcher/IpAddressMatcherTests.java @@ -105,4 +105,10 @@ public void ipv6RequiredAddressMaskTooLongThenIllegalArgumentException() { "fe80::21f:5bff:fe33:bd68", 129)); } + @Test + public void invalidAddressThenIllegalArgumentException() { + assertThatIllegalArgumentException().isThrownBy(() -> new IpAddressMatcher("invalid-ip")) + .withMessage("ipAddress must start with a [, :, or a hexadecimal digit"); + } + }