|
20 | 20 | import org.springframework.beans.factory.DisposableBean;
|
21 | 21 | import org.springframework.beans.factory.InitializingBean;
|
22 | 22 | import org.springframework.security.core.Authentication;
|
23 |
| -import org.springframework.security.core.GrantedAuthority; |
24 | 23 | import org.springframework.security.core.context.SecurityContextHolder;
|
25 | 24 | import org.springframework.security.oauth2.client.AuthorizationCodeOAuth2AuthorizedClientProvider;
|
26 | 25 | import org.springframework.security.oauth2.client.ClientCredentialsOAuth2AuthorizedClientProvider;
|
|
54 | 53 | import java.time.Clock;
|
55 | 54 | import java.time.Duration;
|
56 | 55 | import java.time.Instant;
|
57 |
| -import java.util.Collection; |
| 56 | +import java.util.HashMap; |
58 | 57 | import java.util.Map;
|
59 | 58 | import java.util.function.Consumer;
|
60 | 59 |
|
@@ -317,7 +316,7 @@ public Mono<ClientResponse> filter(ClientRequest request, ExchangeFunction next)
|
317 | 316 | .filter(req -> req.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME).isPresent())
|
318 | 317 | .switchIfEmpty(mergeRequestAttributesFromContext(request))
|
319 | 318 | .filter(req -> req.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME).isPresent())
|
320 |
| - .flatMap(req -> reauthorizeClientIfNecessary(req, next, getOAuth2AuthorizedClient(req.attributes()))) |
| 319 | + .flatMap(req -> reauthorizeClientIfNecessary(getOAuth2AuthorizedClient(req.attributes()), req)) |
321 | 320 | .map(authorizedClient -> bearer(request, authorizedClient))
|
322 | 321 | .flatMap(next::exchange)
|
323 | 322 | .switchIfEmpty(next.exchange(request));
|
@@ -388,48 +387,59 @@ private void populateDefaultOAuth2AuthorizedClient(Map<String, Object> attrs) {
|
388 | 387 | OAuth2AuthorizedClient authorizedClient = this.authorizedClientRepository.loadAuthorizedClient(
|
389 | 388 | clientRegistrationId, authentication, request);
|
390 | 389 | if (authorizedClient == null) {
|
391 |
| - authorizedClient = getAuthorizedClient(clientRegistrationId, attrs); |
| 390 | + authorizedClient = authorizeClient(clientRegistrationId, attrs); |
392 | 391 | }
|
393 | 392 | oauth2AuthorizedClient(authorizedClient).accept(attrs);
|
394 | 393 | }
|
395 | 394 | }
|
396 | 395 |
|
397 |
| - private OAuth2AuthorizedClient getAuthorizedClient(String clientRegistrationId, Map<String, Object> attrs) { |
| 396 | + private OAuth2AuthorizedClient authorizeClient(String clientRegistrationId, Map<String, Object> attributes) { |
398 | 397 | ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId);
|
399 | 398 | if (clientRegistration == null) {
|
400 | 399 | throw new IllegalArgumentException("Could not find ClientRegistration with id " + clientRegistrationId);
|
401 | 400 | }
|
402 |
| - Authentication authentication = getAuthentication(attrs); |
403 |
| - if (authentication == null) { |
404 |
| - authentication = new PrincipalNameAuthentication("anonymousUser"); |
| 401 | + |
| 402 | + OAuth2AuthorizationContext.Builder contextBuilder = OAuth2AuthorizationContext.authorize(clientRegistration); |
| 403 | + Authentication authentication = getAuthentication(attributes); |
| 404 | + if (authentication != null) { |
| 405 | + contextBuilder.principal(authentication); |
| 406 | + } else { |
| 407 | + contextBuilder.principal("anonymousUser"); |
405 | 408 | }
|
406 |
| - OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext.authorize(clientRegistration) |
407 |
| - .principal(authentication) |
408 |
| - .attribute(HttpServletRequest.class.getName(), getRequest(attrs)) |
409 |
| - .attribute(HttpServletResponse.class.getName(), getResponse(attrs)) |
| 409 | + OAuth2AuthorizationContext authorizationContext = contextBuilder |
| 410 | + .attributes(defaultContextAttributes(attributes)) |
410 | 411 | .build();
|
411 | 412 | return this.authorizedClientProvider.authorize(authorizationContext);
|
412 | 413 | }
|
413 | 414 |
|
414 | 415 | private Mono<OAuth2AuthorizedClient> reauthorizeClientIfNecessary(
|
415 |
| - ClientRequest request, ExchangeFunction next, OAuth2AuthorizedClient authorizedClient) { |
| 416 | + OAuth2AuthorizedClient authorizedClient, ClientRequest request) { |
416 | 417 | if (this.authorizedClientProvider == null || !hasTokenExpired(authorizedClient)) {
|
417 | 418 | return Mono.just(authorizedClient);
|
418 | 419 | }
|
419 | 420 |
|
420 | 421 | Map<String, Object> attributes = request.attributes();
|
| 422 | + |
| 423 | + OAuth2AuthorizationContext.Builder contextBuilder = OAuth2AuthorizationContext.reauthorize(authorizedClient); |
421 | 424 | Authentication authentication = getAuthentication(attributes);
|
422 |
| - if (authentication == null) { |
423 |
| - authentication = new PrincipalNameAuthentication(authorizedClient.getPrincipalName()); |
| 425 | + if (authentication != null) { |
| 426 | + contextBuilder.principal(authentication); |
| 427 | + } else { |
| 428 | + contextBuilder.principal(authorizedClient.getPrincipalName()); |
424 | 429 | }
|
425 |
| - OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext.reauthorize(authorizedClient) |
426 |
| - .principal(authentication) |
427 |
| - .attribute(HttpServletRequest.class.getName(), getRequest(attributes)) |
428 |
| - .attribute(HttpServletResponse.class.getName(), getResponse(attributes)) |
| 430 | + OAuth2AuthorizationContext authorizationContext = contextBuilder |
| 431 | + .attributes(defaultContextAttributes(attributes)) |
429 | 432 | .build();
|
430 | 433 | return Mono.fromSupplier(() -> this.authorizedClientProvider.authorize(authorizationContext));
|
431 | 434 | }
|
432 | 435 |
|
| 436 | + private Map<String, Object> defaultContextAttributes(Map<String, Object> attributes) { |
| 437 | + Map<String, Object> contextAttributes = new HashMap<>(); |
| 438 | + contextAttributes.put(HttpServletRequest.class.getName(), getRequest(attributes)); |
| 439 | + contextAttributes.put(HttpServletResponse.class.getName(), getResponse(attributes)); |
| 440 | + return contextAttributes; |
| 441 | + } |
| 442 | + |
433 | 443 | private boolean hasTokenExpired(OAuth2AuthorizedClient authorizedClient) {
|
434 | 444 | Instant now = this.clock.instant();
|
435 | 445 | Instant expiresAt = authorizedClient.getAccessToken().getExpiresAt();
|
@@ -478,53 +488,6 @@ static HttpServletResponse getResponse(Map<String, Object> attrs) {
|
478 | 488 | return (HttpServletResponse) attrs.get(HTTP_SERVLET_RESPONSE_ATTR_NAME);
|
479 | 489 | }
|
480 | 490 |
|
481 |
| - private static class PrincipalNameAuthentication implements Authentication { |
482 |
| - private final String username; |
483 |
| - |
484 |
| - private PrincipalNameAuthentication(String username) { |
485 |
| - this.username = username; |
486 |
| - } |
487 |
| - |
488 |
| - @Override |
489 |
| - public Collection<? extends GrantedAuthority> getAuthorities() { |
490 |
| - throw unsupported(); |
491 |
| - } |
492 |
| - |
493 |
| - @Override |
494 |
| - public Object getCredentials() { |
495 |
| - throw unsupported(); |
496 |
| - } |
497 |
| - |
498 |
| - @Override |
499 |
| - public Object getDetails() { |
500 |
| - throw unsupported(); |
501 |
| - } |
502 |
| - |
503 |
| - @Override |
504 |
| - public Object getPrincipal() { |
505 |
| - throw unsupported(); |
506 |
| - } |
507 |
| - |
508 |
| - @Override |
509 |
| - public boolean isAuthenticated() { |
510 |
| - throw unsupported(); |
511 |
| - } |
512 |
| - |
513 |
| - @Override |
514 |
| - public void setAuthenticated(boolean isAuthenticated) throws IllegalArgumentException { |
515 |
| - throw unsupported(); |
516 |
| - } |
517 |
| - |
518 |
| - @Override |
519 |
| - public String getName() { |
520 |
| - return this.username; |
521 |
| - } |
522 |
| - |
523 |
| - private UnsupportedOperationException unsupported() { |
524 |
| - return new UnsupportedOperationException("Not Supported"); |
525 |
| - } |
526 |
| - } |
527 |
| - |
528 | 491 | private static class RequestContextSubscriber<T> implements CoreSubscriber<T> {
|
529 | 492 | private static final String CONTEXT_DEFAULTED_ATTR_NAME = RequestContextSubscriber.class.getName().concat(".CONTEXT_DEFAULTED_ATTR_NAME");
|
530 | 493 | private final CoreSubscriber<T> delegate;
|
|
0 commit comments