diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenDecoderFactory.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenDecoderFactory.java index d00490bf235..c66994df16d 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenDecoderFactory.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenDecoderFactory.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -49,6 +49,8 @@ import org.springframework.security.oauth2.jwt.NimbusJwtDecoder; import org.springframework.util.Assert; import org.springframework.util.StringUtils; +import org.springframework.web.client.RestOperations; +import org.springframework.web.client.RestTemplate; /** * A {@link JwtDecoderFactory factory} that provides a {@link JwtDecoder} used for @@ -89,6 +91,9 @@ public final class OidcIdTokenDecoderFactory implements JwtDecoderFactory, Map>> claimTypeConverterFactory = ( clientRegistration) -> DEFAULT_CLAIM_TYPE_CONVERTER; + private Function restOperationsFactory = ( + clientRegistration) -> new RestTemplate(); + /** * Returns the default {@link Converter}'s used for type conversion of claim values * for an {@link OidcIdToken}. @@ -164,7 +169,10 @@ private NimbusJwtDecoder buildDecoder(ClientRegistration clientRegistration) { null); throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); } - return NimbusJwtDecoder.withJwkSetUri(jwkSetUri).jwsAlgorithm((SignatureAlgorithm) jwsAlgorithm).build(); + return NimbusJwtDecoder.withJwkSetUri(jwkSetUri) + .jwsAlgorithm((SignatureAlgorithm) jwsAlgorithm) + .restOperations(restOperationsFactory.apply(clientRegistration)) + .build(); } if (jwsAlgorithm != null && MacAlgorithm.class.isAssignableFrom(jwsAlgorithm.getClass())) { // https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation @@ -237,4 +245,18 @@ public void setClaimTypeConverterFactory( this.claimTypeConverterFactory = claimTypeConverterFactory; } + /** + * Sets the factory that provides a {@link RestOperations} used by + * {@link NimbusJwtDecoder} to coordinate with the authorization servers indicated in + * the JWK Set uri. + * @param restOperationsFactory the factory that provides a {@link RestOperations} + * used by {@link NimbusJwtDecoder} + * + * @since 6.3 + */ + public void setRestOperationsFactory(Function restOperationsFactory) { + Assert.notNull(restOperationsFactory, "restOperationsFactory cannot be null"); + this.restOperationsFactory = restOperationsFactory; + } + } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/ReactiveOidcIdTokenDecoderFactory.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/ReactiveOidcIdTokenDecoderFactory.java index 84850bba6a2..58f3d5a3592 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/ReactiveOidcIdTokenDecoderFactory.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/ReactiveOidcIdTokenDecoderFactory.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -49,6 +49,7 @@ import org.springframework.security.oauth2.jwt.ReactiveJwtDecoderFactory; import org.springframework.util.Assert; import org.springframework.util.StringUtils; +import org.springframework.web.reactive.function.client.WebClient; /** * A {@link ReactiveJwtDecoderFactory factory} that provides a {@link ReactiveJwtDecoder} @@ -89,6 +90,8 @@ public final class ReactiveOidcIdTokenDecoderFactory implements ReactiveJwtDecod private Function, Map>> claimTypeConverterFactory = ( clientRegistration) -> DEFAULT_CLAIM_TYPE_CONVERTER; + private Function webClientFactory = (clientRegistration) -> WebClient.create(); + /** * Returns the default {@link Converter}'s used for type conversion of claim values * for an {@link OidcIdToken}. @@ -165,6 +168,7 @@ private NimbusReactiveJwtDecoder buildDecoder(ClientRegistration clientRegistrat } return NimbusReactiveJwtDecoder.withJwkSetUri(jwkSetUri) .jwsAlgorithm((SignatureAlgorithm) jwsAlgorithm) + .webClient(webClientFactory.apply(clientRegistration)) .build(); } if (jwsAlgorithm != null && MacAlgorithm.class.isAssignableFrom(jwsAlgorithm.getClass())) { @@ -241,4 +245,19 @@ public void setClaimTypeConverterFactory( this.claimTypeConverterFactory = claimTypeConverterFactory; } + /** + * Sets the factory that provides a {@link WebClient} used by + * {@link NimbusReactiveJwtDecoder} to coordinate with the authorization servers + * indicated in the JWK + * Set uri. + * @param webClientFactory the factory that provides a {@link WebClient} used by + * {@link NimbusReactiveJwtDecoder} + * + * @since 6.3 + */ + public void setWebClientFactory(Function webClientFactory) { + Assert.notNull(webClientFactory, "webClientFactory cannot be null"); + this.webClientFactory = webClientFactory; + } + } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenDecoderFactoryTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenDecoderFactoryTests.java index 33663bac650..0f6015e4107 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenDecoderFactoryTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenDecoderFactoryTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -34,6 +34,8 @@ import org.springframework.security.oauth2.jose.jws.MacAlgorithm; import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.web.client.RestOperations; +import org.springframework.web.client.RestTemplate; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; @@ -95,6 +97,12 @@ public void setClaimTypeConverterFactoryWhenNullThenThrowIllegalArgumentExceptio .isThrownBy(() -> this.idTokenDecoderFactory.setClaimTypeConverterFactory(null)); } + @Test + public void setRestOperationsFactoryWhenNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException() + .isThrownBy(() -> this.idTokenDecoderFactory.setRestOperationsFactory(null)); + } + @Test public void createDecoderWhenClientRegistrationNullThenThrowIllegalArgumentException() { assertThatIllegalArgumentException().isThrownBy(() -> this.idTokenDecoderFactory.createDecoder(null)); @@ -177,4 +185,15 @@ public void createDecoderWhenCustomClaimTypeConverterFactorySetThenApplied() { verify(customClaimTypeConverterFactory).apply(same(clientRegistration)); } + @Test + public void createDecoderWhenCustomRestOperationsFactorySetThenApplied() { + Function customRestOperationsFactory = mock( + Function.class); + this.idTokenDecoderFactory.setRestOperationsFactory(customRestOperationsFactory); + ClientRegistration clientRegistration = this.registration.build(); + given(customRestOperationsFactory.apply(same(clientRegistration))) + .willReturn(new RestTemplate()); + this.idTokenDecoderFactory.createDecoder(clientRegistration); + verify(customRestOperationsFactory).apply(same(clientRegistration)); + } } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/ReactiveOidcIdTokenDecoderFactoryTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/ReactiveOidcIdTokenDecoderFactoryTests.java index 3a3f668d7a3..e2764e81cde 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/ReactiveOidcIdTokenDecoderFactoryTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/ReactiveOidcIdTokenDecoderFactoryTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -34,6 +34,7 @@ import org.springframework.security.oauth2.jose.jws.MacAlgorithm; import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.web.reactive.function.client.WebClient; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; @@ -94,6 +95,12 @@ public void setClaimTypeConverterFactoryWhenNullThenThrowIllegalArgumentExceptio .isThrownBy(() -> this.idTokenDecoderFactory.setClaimTypeConverterFactory(null)); } + @Test + public void setWebClientFactoryWhenNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException() + .isThrownBy(() -> this.idTokenDecoderFactory.setWebClientFactory(null)); + } + @Test public void createDecoderWhenClientRegistrationNullThenThrowIllegalArgumentException() { assertThatIllegalArgumentException().isThrownBy(() -> this.idTokenDecoderFactory.createDecoder(null)); @@ -176,4 +183,15 @@ public void createDecoderWhenCustomClaimTypeConverterFactorySetThenApplied() { verify(customClaimTypeConverterFactory).apply(same(clientRegistration)); } + @Test + public void createDecoderWhenCustomWebClientFactorySetThenApplied() { + Function customWebClientFactory = mock( + Function.class); + this.idTokenDecoderFactory.setWebClientFactory(customWebClientFactory); + ClientRegistration clientRegistration = this.registration.build(); + given(customWebClientFactory.apply(same(clientRegistration))) + .willReturn(WebClient.create()); + this.idTokenDecoderFactory.createDecoder(clientRegistration); + verify(customWebClientFactory).apply(same(clientRegistration)); + } }