Skip to content

Commit eef15bc

Browse files
committed
Simplify population of OAuth2AuthorizationContext
1 parent 7303821 commit eef15bc

File tree

6 files changed

+117
-104
lines changed

6 files changed

+117
-104
lines changed

oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ClientCredentialsOAuth2AuthorizedClientProvider.java

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import org.springframework.security.oauth2.client.endpoint.DefaultClientCredentialsTokenResponseClient;
2020
import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient;
2121
import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest;
22+
import org.springframework.security.oauth2.client.registration.ClientRegistration;
2223
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
2324
import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
2425
import org.springframework.security.oauth2.core.AuthorizationGrantType;
@@ -38,8 +39,8 @@
3839
* @see DefaultClientCredentialsTokenResponseClient
3940
*/
4041
public final class ClientCredentialsOAuth2AuthorizedClientProvider implements OAuth2AuthorizedClientProvider {
41-
private static final String HTTP_SERVLET_REQUEST_ATTR_NAME = HttpServletRequest.class.getName();
42-
private static final String HTTP_SERVLET_RESPONSE_ATTR_NAME = HttpServletResponse.class.getName();
42+
private static final String HTTP_SERVLET_REQUEST_ATTRIBUTE_NAME = HttpServletRequest.class.getName();
43+
private static final String HTTP_SERVLET_RESPONSE_ATTRIBUTE_NAME = HttpServletResponse.class.getName();
4344
private final ClientRegistrationRepository clientRegistrationRepository;
4445
private final OAuth2AuthorizedClientRepository authorizedClientRepository;
4546
private OAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> accessTokenResponseClient =
@@ -59,6 +60,22 @@ public ClientCredentialsOAuth2AuthorizedClientProvider(ClientRegistrationReposit
5960
this.authorizedClientRepository = authorizedClientRepository;
6061
}
6162

63+
/**
64+
* Attempt to authorize (or re-authorize) the {@link OAuth2AuthorizationContext#getClientRegistration() client} in the provided {@code context}.
65+
* Returns {@code null} if authorization (or re-authorization) is not supported,
66+
* e.g. the client's {@link ClientRegistration#getAuthorizationGrantType() authorization grant type}
67+
* is not {@link AuthorizationGrantType#CLIENT_CREDENTIALS client_credentials}.
68+
*
69+
* <p>
70+
* The following {@link OAuth2AuthorizationContext#getAttributes() context attributes} are supported:
71+
* <ol>
72+
* <li>{@code "javax.servlet.http.HttpServletRequest"} (required) - the {@code HttpServletRequest}</li>
73+
* <li>{@code "javax.servlet.http.HttpServletResponse"} (required) - the {@code HttpServletResponse}</li>
74+
* </ol>
75+
*
76+
* @param context the context that holds authorization-specific state for the client
77+
* @return the {@link OAuth2AuthorizedClient} or {@code null} if authorization (or re-authorization) is not supported
78+
*/
6279
@Override
6380
@Nullable
6481
public OAuth2AuthorizedClient authorize(OAuth2AuthorizationContext context) {
@@ -67,10 +84,10 @@ public OAuth2AuthorizedClient authorize(OAuth2AuthorizationContext context) {
6784
return null;
6885
}
6986

70-
HttpServletRequest request = context.getAttribute(HTTP_SERVLET_REQUEST_ATTR_NAME);
71-
HttpServletResponse response = context.getAttribute(HTTP_SERVLET_RESPONSE_ATTR_NAME);
72-
Assert.notNull(request, "context.HttpServletRequest cannot be null");
73-
Assert.notNull(response, "context.HttpServletResponse cannot be null");
87+
HttpServletRequest request = context.getAttribute(HTTP_SERVLET_REQUEST_ATTRIBUTE_NAME);
88+
HttpServletResponse response = context.getAttribute(HTTP_SERVLET_RESPONSE_ATTRIBUTE_NAME);
89+
Assert.notNull(request, "The context attribute cannot be null '" + HTTP_SERVLET_REQUEST_ATTRIBUTE_NAME + "'");
90+
Assert.notNull(response, "The context attribute cannot be null '" + HTTP_SERVLET_RESPONSE_ATTRIBUTE_NAME + "'");
7491

7592
// As per spec, in section 4.4.3 Access Token Response
7693
// https://tools.ietf.org/html/rfc6749#section-4.4.3

oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProvider.java

Lines changed: 39 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,13 @@
2424
import org.springframework.security.oauth2.core.AuthorizationGrantType;
2525
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
2626
import org.springframework.util.Assert;
27+
import org.springframework.util.StringUtils;
2728

2829
import javax.servlet.http.HttpServletRequest;
2930
import javax.servlet.http.HttpServletResponse;
31+
import java.util.Arrays;
3032
import java.util.Set;
33+
import java.util.stream.Collectors;
3134

3235
/**
3336
* An implementation of an {@link OAuth2AuthorizedClientProvider}
@@ -39,9 +42,17 @@
3942
* @see DefaultRefreshTokenTokenResponseClient
4043
*/
4144
public final class RefreshTokenOAuth2AuthorizedClientProvider implements OAuth2AuthorizedClientProvider {
42-
private static final String HTTP_SERVLET_REQUEST_ATTR_NAME = HttpServletRequest.class.getName();
43-
private static final String HTTP_SERVLET_RESPONSE_ATTR_NAME = HttpServletResponse.class.getName();
44-
private static final String SCOPE_ATTR_NAME = "SCOPE";
45+
/**
46+
* The name of the {@link OAuth2AuthorizationContext#getAttribute(String) attribute}
47+
* in the {@link OAuth2AuthorizationContext context} associated to the value for the "requested scope(s)".
48+
* The value of the attribute is a space-delimited or comma-delimited {@code String} of scope(s)
49+
* to be requested by the {@link OAuth2AuthorizationContext#getClientRegistration() client}.
50+
*/
51+
public static final String REQUEST_SCOPE_ATTRIBUTE_NAME = "org.springframework.security.oauth2.client.REQUEST_SCOPE";
52+
53+
private static final String HTTP_SERVLET_REQUEST_ATTRIBUTE_NAME = HttpServletRequest.class.getName();
54+
private static final String HTTP_SERVLET_RESPONSE_ATTRIBUTE_NAME = HttpServletResponse.class.getName();
55+
4556
private final ClientRegistrationRepository clientRegistrationRepository;
4657
private final OAuth2AuthorizedClientRepository authorizedClientRepository;
4758
private OAuth2AccessTokenResponseClient<OAuth2RefreshTokenGrantRequest> accessTokenResponseClient =
@@ -61,6 +72,23 @@ public RefreshTokenOAuth2AuthorizedClientProvider(ClientRegistrationRepository c
6172
this.authorizedClientRepository = authorizedClientRepository;
6273
}
6374

75+
/**
76+
* Attempt to re-authorize the {@link OAuth2AuthorizationContext#getClientRegistration() client} in the provided {@code context}.
77+
* Returns {@code null} if re-authorization is not supported,
78+
* e.g. the {@link OAuth2AuthorizedClient#getRefreshToken() refresh token} is not available for the
79+
* {@link OAuth2AuthorizationContext#getAuthorizedClient() authorized client}.
80+
*
81+
* <p>
82+
* The following {@link OAuth2AuthorizationContext#getAttributes() context attributes} are supported:
83+
* <ol>
84+
* <li>{@code "javax.servlet.http.HttpServletRequest"} (required) - the {@code HttpServletRequest}</li>
85+
* <li>{@code "javax.servlet.http.HttpServletResponse"} (required) - the {@code HttpServletResponse}</li>
86+
* <li>{@code "org.springframework.security.oauth2.client.REQUEST_SCOPE"} (optional) - a space-delimited or comma-delimited {@code String} of scope(s) to be requested by the {@link OAuth2AuthorizationContext#getClientRegistration() client}</li>
87+
* </ol>
88+
*
89+
* @param context the context that holds authorization-specific state for the client
90+
* @return the {@link OAuth2AuthorizedClient} or {@code null} if re-authorization is not supported
91+
*/
6492
@Override
6593
@Nullable
6694
public OAuth2AuthorizedClient authorize(OAuth2AuthorizationContext context) {
@@ -69,16 +97,16 @@ public OAuth2AuthorizedClient authorize(OAuth2AuthorizationContext context) {
6997
return null;
7098
}
7199

72-
HttpServletRequest request = context.getAttribute(HTTP_SERVLET_REQUEST_ATTR_NAME);
73-
HttpServletResponse response = context.getAttribute(HTTP_SERVLET_RESPONSE_ATTR_NAME);
74-
Assert.notNull(request, "context.HttpServletRequest cannot be null");
75-
Assert.notNull(response, "context.HttpServletResponse cannot be null");
100+
HttpServletRequest request = context.getAttribute(HTTP_SERVLET_REQUEST_ATTRIBUTE_NAME);
101+
HttpServletResponse response = context.getAttribute(HTTP_SERVLET_RESPONSE_ATTRIBUTE_NAME);
102+
Assert.notNull(request, "The context attribute cannot be null '" + HTTP_SERVLET_REQUEST_ATTRIBUTE_NAME + "'");
103+
Assert.notNull(response, "The context attribute cannot be null '" + HTTP_SERVLET_RESPONSE_ATTRIBUTE_NAME + "'");
76104

77-
Object scopesObj = context.getAttribute(SCOPE_ATTR_NAME);
105+
String requestScope = context.getAttribute(REQUEST_SCOPE_ATTRIBUTE_NAME);
78106
Set<String> scopes = null;
79-
if (scopesObj != null) {
80-
Assert.isTrue(scopesObj instanceof Set, "The '" + SCOPE_ATTR_NAME + "' attribute must be of type " + Set.class.getName());
81-
scopes = (Set<String>) scopesObj;
107+
if (!StringUtils.isEmpty(requestScope)) {
108+
String delimiter = requestScope.indexOf(',') != -1 ? "," : " ";
109+
scopes = Arrays.stream(StringUtils.delimitedListToStringArray(requestScope, delimiter, " ")).collect(Collectors.toSet());
82110
}
83111

84112
OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest =

oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolver.java

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -123,17 +123,16 @@ public Object resolveArgument(MethodParameter parameter,
123123

124124
HttpServletResponse servletResponse = webRequest.getNativeResponse(HttpServletResponse.class);
125125

126-
OAuth2AuthorizationContext.Builder authorizationContextBuilder = OAuth2AuthorizationContext.authorize(clientRegistration);
127-
if (principal == null) {
128-
authorizationContextBuilder.principal("anonymousUser");
126+
OAuth2AuthorizationContext.Builder contextBuilder = OAuth2AuthorizationContext.authorize(clientRegistration);
127+
if (principal != null) {
128+
contextBuilder.principal(principal);
129129
} else {
130-
authorizationContextBuilder.principal(principal);
130+
contextBuilder.principal("anonymousUser");
131131
}
132-
OAuth2AuthorizationContext authorizationContext = authorizationContextBuilder
132+
OAuth2AuthorizationContext authorizationContext = contextBuilder
133133
.attribute(HttpServletRequest.class.getName(), servletRequest)
134134
.attribute(HttpServletResponse.class.getName(), servletResponse)
135135
.build();
136-
137136
return this.authorizedClientProvider.authorize(authorizationContext);
138137
}
139138

oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java

Lines changed: 29 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
import org.springframework.beans.factory.DisposableBean;
2121
import org.springframework.beans.factory.InitializingBean;
2222
import org.springframework.security.core.Authentication;
23-
import org.springframework.security.core.GrantedAuthority;
2423
import org.springframework.security.core.context.SecurityContextHolder;
2524
import org.springframework.security.oauth2.client.AuthorizationCodeOAuth2AuthorizedClientProvider;
2625
import org.springframework.security.oauth2.client.ClientCredentialsOAuth2AuthorizedClientProvider;
@@ -54,7 +53,7 @@
5453
import java.time.Clock;
5554
import java.time.Duration;
5655
import java.time.Instant;
57-
import java.util.Collection;
56+
import java.util.HashMap;
5857
import java.util.Map;
5958
import java.util.function.Consumer;
6059

@@ -317,7 +316,7 @@ public Mono<ClientResponse> filter(ClientRequest request, ExchangeFunction next)
317316
.filter(req -> req.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME).isPresent())
318317
.switchIfEmpty(mergeRequestAttributesFromContext(request))
319318
.filter(req -> req.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME).isPresent())
320-
.flatMap(req -> reauthorizeClientIfNecessary(req, next, getOAuth2AuthorizedClient(req.attributes())))
319+
.flatMap(req -> reauthorizeClientIfNecessary(getOAuth2AuthorizedClient(req.attributes()), req))
321320
.map(authorizedClient -> bearer(request, authorizedClient))
322321
.flatMap(next::exchange)
323322
.switchIfEmpty(next.exchange(request));
@@ -388,48 +387,59 @@ private void populateDefaultOAuth2AuthorizedClient(Map<String, Object> attrs) {
388387
OAuth2AuthorizedClient authorizedClient = this.authorizedClientRepository.loadAuthorizedClient(
389388
clientRegistrationId, authentication, request);
390389
if (authorizedClient == null) {
391-
authorizedClient = getAuthorizedClient(clientRegistrationId, attrs);
390+
authorizedClient = authorizeClient(clientRegistrationId, attrs);
392391
}
393392
oauth2AuthorizedClient(authorizedClient).accept(attrs);
394393
}
395394
}
396395

397-
private OAuth2AuthorizedClient getAuthorizedClient(String clientRegistrationId, Map<String, Object> attrs) {
396+
private OAuth2AuthorizedClient authorizeClient(String clientRegistrationId, Map<String, Object> attributes) {
398397
ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId);
399398
if (clientRegistration == null) {
400399
throw new IllegalArgumentException("Could not find ClientRegistration with id " + clientRegistrationId);
401400
}
402-
Authentication authentication = getAuthentication(attrs);
403-
if (authentication == null) {
404-
authentication = new PrincipalNameAuthentication("anonymousUser");
401+
402+
OAuth2AuthorizationContext.Builder contextBuilder = OAuth2AuthorizationContext.authorize(clientRegistration);
403+
Authentication authentication = getAuthentication(attributes);
404+
if (authentication != null) {
405+
contextBuilder.principal(authentication);
406+
} else {
407+
contextBuilder.principal("anonymousUser");
405408
}
406-
OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext.authorize(clientRegistration)
407-
.principal(authentication)
408-
.attribute(HttpServletRequest.class.getName(), getRequest(attrs))
409-
.attribute(HttpServletResponse.class.getName(), getResponse(attrs))
409+
OAuth2AuthorizationContext authorizationContext = contextBuilder
410+
.attributes(defaultContextAttributes(attributes))
410411
.build();
411412
return this.authorizedClientProvider.authorize(authorizationContext);
412413
}
413414

414415
private Mono<OAuth2AuthorizedClient> reauthorizeClientIfNecessary(
415-
ClientRequest request, ExchangeFunction next, OAuth2AuthorizedClient authorizedClient) {
416+
OAuth2AuthorizedClient authorizedClient, ClientRequest request) {
416417
if (this.authorizedClientProvider == null || !hasTokenExpired(authorizedClient)) {
417418
return Mono.just(authorizedClient);
418419
}
419420

420421
Map<String, Object> attributes = request.attributes();
422+
423+
OAuth2AuthorizationContext.Builder contextBuilder = OAuth2AuthorizationContext.reauthorize(authorizedClient);
421424
Authentication authentication = getAuthentication(attributes);
422-
if (authentication == null) {
423-
authentication = new PrincipalNameAuthentication(authorizedClient.getPrincipalName());
425+
if (authentication != null) {
426+
contextBuilder.principal(authentication);
427+
} else {
428+
contextBuilder.principal(authorizedClient.getPrincipalName());
424429
}
425-
OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext.reauthorize(authorizedClient)
426-
.principal(authentication)
427-
.attribute(HttpServletRequest.class.getName(), getRequest(attributes))
428-
.attribute(HttpServletResponse.class.getName(), getResponse(attributes))
430+
OAuth2AuthorizationContext authorizationContext = contextBuilder
431+
.attributes(defaultContextAttributes(attributes))
429432
.build();
430433
return Mono.fromSupplier(() -> this.authorizedClientProvider.authorize(authorizationContext));
431434
}
432435

436+
private Map<String, Object> defaultContextAttributes(Map<String, Object> attributes) {
437+
Map<String, Object> contextAttributes = new HashMap<>();
438+
contextAttributes.put(HttpServletRequest.class.getName(), getRequest(attributes));
439+
contextAttributes.put(HttpServletResponse.class.getName(), getResponse(attributes));
440+
return contextAttributes;
441+
}
442+
433443
private boolean hasTokenExpired(OAuth2AuthorizedClient authorizedClient) {
434444
Instant now = this.clock.instant();
435445
Instant expiresAt = authorizedClient.getAccessToken().getExpiresAt();
@@ -478,53 +488,6 @@ static HttpServletResponse getResponse(Map<String, Object> attrs) {
478488
return (HttpServletResponse) attrs.get(HTTP_SERVLET_RESPONSE_ATTR_NAME);
479489
}
480490

481-
private static class PrincipalNameAuthentication implements Authentication {
482-
private final String username;
483-
484-
private PrincipalNameAuthentication(String username) {
485-
this.username = username;
486-
}
487-
488-
@Override
489-
public Collection<? extends GrantedAuthority> getAuthorities() {
490-
throw unsupported();
491-
}
492-
493-
@Override
494-
public Object getCredentials() {
495-
throw unsupported();
496-
}
497-
498-
@Override
499-
public Object getDetails() {
500-
throw unsupported();
501-
}
502-
503-
@Override
504-
public Object getPrincipal() {
505-
throw unsupported();
506-
}
507-
508-
@Override
509-
public boolean isAuthenticated() {
510-
throw unsupported();
511-
}
512-
513-
@Override
514-
public void setAuthenticated(boolean isAuthenticated) throws IllegalArgumentException {
515-
throw unsupported();
516-
}
517-
518-
@Override
519-
public String getName() {
520-
return this.username;
521-
}
522-
523-
private UnsupportedOperationException unsupported() {
524-
return new UnsupportedOperationException("Not Supported");
525-
}
526-
}
527-
528491
private static class RequestContextSubscriber<T> implements CoreSubscriber<T> {
529492
private static final String CONTEXT_DEFAULTED_ATTR_NAME = RequestContextSubscriber.class.getName().concat(".CONTEXT_DEFAULTED_ATTR_NAME");
530493
private final CoreSubscriber<T> delegate;

oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/ClientCredentialsOAuth2AuthorizedClientProviderTests.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ public void authorizeWhenHttpServletRequestIsNullThenThrowIllegalArgumentExcepti
106106
OAuth2AuthorizationContext.authorize(this.clientRegistration).principal(this.principal).build();
107107
assertThatThrownBy(() -> this.authorizedClientProvider.authorize(authorizationContext))
108108
.isInstanceOf(IllegalArgumentException.class)
109-
.hasMessage("context.HttpServletRequest cannot be null");
109+
.hasMessage("The context attribute cannot be null 'javax.servlet.http.HttpServletRequest'");
110110
}
111111

112112
@Test
@@ -118,7 +118,7 @@ public void authorizeWhenHttpServletResponseIsNullThenThrowIllegalArgumentExcept
118118
.build();
119119
assertThatThrownBy(() -> this.authorizedClientProvider.authorize(authorizationContext))
120120
.isInstanceOf(IllegalArgumentException.class)
121-
.hasMessage("context.HttpServletResponse cannot be null");
121+
.hasMessage("The context attribute cannot be null 'javax.servlet.http.HttpServletResponse'");
122122
}
123123

124124
@Test

0 commit comments

Comments
 (0)