Skip to content

Commit 9c352c4

Browse files
krezovicjzheaux
authored andcommitted
Support overriding RestOperations in OidcIdTokenDecoderFactory
Closes gh-14178
1 parent 0041c65 commit 9c352c4

File tree

2 files changed

+44
-3
lines changed

2 files changed

+44
-3
lines changed

oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenDecoderFactory.java

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2019 the original author or authors.
2+
* Copyright 2002-2023 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -49,6 +49,8 @@
4949
import org.springframework.security.oauth2.jwt.NimbusJwtDecoder;
5050
import org.springframework.util.Assert;
5151
import org.springframework.util.StringUtils;
52+
import org.springframework.web.client.RestOperations;
53+
import org.springframework.web.client.RestTemplate;
5254

5355
/**
5456
* A {@link JwtDecoderFactory factory} that provides a {@link JwtDecoder} used for
@@ -89,6 +91,9 @@ public final class OidcIdTokenDecoderFactory implements JwtDecoderFactory<Client
8991
private Function<ClientRegistration, Converter<Map<String, Object>, Map<String, Object>>> claimTypeConverterFactory = (
9092
clientRegistration) -> DEFAULT_CLAIM_TYPE_CONVERTER;
9193

94+
private Function<ClientRegistration, RestOperations> restOperationsFactory = (
95+
clientRegistration) -> new RestTemplate();
96+
9297
/**
9398
* Returns the default {@link Converter}'s used for type conversion of claim values
9499
* for an {@link OidcIdToken}.
@@ -164,7 +169,10 @@ private NimbusJwtDecoder buildDecoder(ClientRegistration clientRegistration) {
164169
null);
165170
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
166171
}
167-
return NimbusJwtDecoder.withJwkSetUri(jwkSetUri).jwsAlgorithm((SignatureAlgorithm) jwsAlgorithm).build();
172+
return NimbusJwtDecoder.withJwkSetUri(jwkSetUri)
173+
.jwsAlgorithm((SignatureAlgorithm) jwsAlgorithm)
174+
.restOperations(restOperationsFactory.apply(clientRegistration))
175+
.build();
168176
}
169177
if (jwsAlgorithm != null && MacAlgorithm.class.isAssignableFrom(jwsAlgorithm.getClass())) {
170178
// https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
@@ -237,4 +245,18 @@ public void setClaimTypeConverterFactory(
237245
this.claimTypeConverterFactory = claimTypeConverterFactory;
238246
}
239247

248+
/**
249+
* Sets the factory that provides a {@link RestOperations} used by
250+
* {@link NimbusJwtDecoder} to coordinate with the authorization servers indicated in
251+
* the <a href="https://tools.ietf.org/html/rfc7517#section-5">JWK Set</a> uri.
252+
* @param restOperationsFactory the factory that provides a {@link RestOperations}
253+
* used by {@link NimbusJwtDecoder}
254+
*
255+
* @since 6.3
256+
*/
257+
public void setRestOperationsFactory(Function<ClientRegistration, RestOperations> restOperationsFactory) {
258+
Assert.notNull(restOperationsFactory, "restOperationsFactory cannot be null");
259+
this.restOperationsFactory = restOperationsFactory;
260+
}
261+
240262
}

oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/OidcIdTokenDecoderFactoryTests.java

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2019 the original author or authors.
2+
* Copyright 2002-2023 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -34,6 +34,8 @@
3434
import org.springframework.security.oauth2.jose.jws.MacAlgorithm;
3535
import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
3636
import org.springframework.security.oauth2.jwt.Jwt;
37+
import org.springframework.web.client.RestOperations;
38+
import org.springframework.web.client.RestTemplate;
3739

3840
import static org.assertj.core.api.Assertions.assertThat;
3941
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
@@ -95,6 +97,12 @@ public void setClaimTypeConverterFactoryWhenNullThenThrowIllegalArgumentExceptio
9597
.isThrownBy(() -> this.idTokenDecoderFactory.setClaimTypeConverterFactory(null));
9698
}
9799

100+
@Test
101+
public void setRestOperationsFactoryWhenNullThenThrowIllegalArgumentException() {
102+
assertThatIllegalArgumentException()
103+
.isThrownBy(() -> this.idTokenDecoderFactory.setRestOperationsFactory(null));
104+
}
105+
98106
@Test
99107
public void createDecoderWhenClientRegistrationNullThenThrowIllegalArgumentException() {
100108
assertThatIllegalArgumentException().isThrownBy(() -> this.idTokenDecoderFactory.createDecoder(null));
@@ -177,4 +185,15 @@ public void createDecoderWhenCustomClaimTypeConverterFactorySetThenApplied() {
177185
verify(customClaimTypeConverterFactory).apply(same(clientRegistration));
178186
}
179187

188+
@Test
189+
public void createDecoderWhenCustomRestOperationsFactorySetThenApplied() {
190+
Function<ClientRegistration, RestOperations> customRestOperationsFactory = mock(
191+
Function.class);
192+
this.idTokenDecoderFactory.setRestOperationsFactory(customRestOperationsFactory);
193+
ClientRegistration clientRegistration = this.registration.build();
194+
given(customRestOperationsFactory.apply(same(clientRegistration)))
195+
.willReturn(new RestTemplate());
196+
this.idTokenDecoderFactory.createDecoder(clientRegistration);
197+
verify(customRestOperationsFactory).apply(same(clientRegistration));
198+
}
180199
}

0 commit comments

Comments
 (0)