diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager.java new file mode 100644 index 00000000000..38bf53487c7 --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager.java @@ -0,0 +1,147 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client; + +import org.springframework.lang.Nullable; +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import org.springframework.util.StringUtils; +import org.springframework.web.server.ServerWebExchange; +import reactor.core.publisher.Mono; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; +import java.util.function.Function; + +/** + * An implementation of an {@link ReactiveOAuth2AuthorizedClientManager} + * that is capable of operating outside of a {@code ServerHttpRequest} context, + * e.g. in a scheduled/background thread and/or in the service-tier. + * + * @author Ankur Pathak + * @see ReactiveOAuth2AuthorizedClientManager + * @see ReactiveOAuth2AuthorizedClientProvider + * @see ReactiveOAuth2AuthorizedClientService + * @since 5.3 + */ +public final class OAuth2AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager implements ReactiveOAuth2AuthorizedClientManager { + private final ReactiveClientRegistrationRepository clientRegistrationRepository; + private final ReactiveOAuth2AuthorizedClientService authorizedClientService; + private ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = context -> Mono.empty(); + private Function>> contextAttributesMapper = new DefaultContextAttributesMapper(); + + /** + * Constructs an {@code OAuth2AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager} using the provided parameters. + * + * @param clientRegistrationRepository the repository of client registrations + * @param authorizedClientService the authorized client service + */ + public OAuth2AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager(ReactiveClientRegistrationRepository clientRegistrationRepository, + ReactiveOAuth2AuthorizedClientService authorizedClientService) { + Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null"); + Assert.notNull(authorizedClientService, "authorizedClientService cannot be null"); + this.clientRegistrationRepository = clientRegistrationRepository; + this.authorizedClientService = authorizedClientService; + } + + @Nullable + @Override + public Mono authorize(OAuth2AuthorizeRequest authorizeRequest) { + Assert.notNull(authorizeRequest, "authorizeRequest cannot be null"); + String clientRegistrationId = authorizeRequest.getClientRegistrationId(); + OAuth2AuthorizedClient authorizedClient = authorizeRequest.getAuthorizedClient(); + Authentication principal = authorizeRequest.getPrincipal(); + // @formatter:off + return Mono.justOrEmpty(authorizedClient) + .map(OAuth2AuthorizationContext::withAuthorizedClient) + .switchIfEmpty(Mono.defer(() -> this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId) + .flatMap(clientRegistration -> this.authorizedClientService.loadAuthorizedClient(clientRegistrationId, principal.getName()) + .map(OAuth2AuthorizationContext::withAuthorizedClient) + .switchIfEmpty(Mono.fromSupplier(() -> OAuth2AuthorizationContext.withClientRegistration(clientRegistration))) + ) + .switchIfEmpty(Mono.error(new IllegalArgumentException("Could not find ClientRegistration with id '" + clientRegistrationId + "'"))) + ) + ) + .flatMap(contextBuilder -> this.contextAttributesMapper.apply(authorizeRequest) + .filter(contextAttributes-> !CollectionUtils.isEmpty(contextAttributes)) + .map(contextAttributes -> contextBuilder.principal(principal) + .attributes(attributes -> { + attributes.putAll(contextAttributes); + }).build()) + ).flatMap(authorizationContext -> this.authorizedClientProvider.authorize(authorizationContext) + .doOnNext(_authorizedClient -> authorizedClientService.saveAuthorizedClient(_authorizedClient, principal)) + .switchIfEmpty(Mono.defer(()-> Mono.justOrEmpty(Optional.ofNullable(authorizationContext.getAuthorizedClient())))) + ); + // @formatter:on + } + + /** + * Sets the {@link ReactiveOAuth2AuthorizedClientProvider} used for authorizing (or re-authorizing) an OAuth 2.0 Client. + * + * @param authorizedClientProvider the {@link ReactiveOAuth2AuthorizedClientProvider} used for authorizing (or re-authorizing) an OAuth 2.0 Client + */ + public void setAuthorizedClientProvider(ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider) { + Assert.notNull(authorizedClientProvider, "authorizedClientProvider cannot be null"); + this.authorizedClientProvider = authorizedClientProvider; + } + + /** + * Sets the {@code Function} used for mapping attribute(s) from the {@link OAuth2AuthorizeRequest} to a {@code Map} of attributes + * to be associated to the {@link OAuth2AuthorizationContext#getAttributes() authorization context}. + * + * @param contextAttributesMapper the {@code Function} used for supplying the {@code Map} of attributes + * to the {@link OAuth2AuthorizationContext#getAttributes() authorization context} + */ + public void setContextAttributesMapper(Function>> contextAttributesMapper) { + Assert.notNull(contextAttributesMapper, "contextAttributesMapper cannot be null"); + this.contextAttributesMapper = contextAttributesMapper; + } + + private static Mono currentServerWebExchange() { + return Mono.subscriberContext() + .filter(c -> c.hasKey(ServerWebExchange.class)) + .map(c -> c.get(ServerWebExchange.class)); + } + + /** + * The default implementation of the {@link #setContextAttributesMapper(Function) contextAttributesMapper}. + */ + public static class DefaultContextAttributesMapper implements Function>> { + + @Override + public Mono> apply(OAuth2AuthorizeRequest authorizeRequest) { + ServerWebExchange serverWebExchange = authorizeRequest.getAttribute(ServerWebExchange.class.getName()); + return Mono.justOrEmpty(serverWebExchange) + .switchIfEmpty(Mono.defer(() -> currentServerWebExchange())) + .flatMap(exchange -> { + Map contextAttributes = Collections.emptyMap(); + String scope = exchange.getRequest().getQueryParams().getFirst(OAuth2ParameterNames.SCOPE); + if (StringUtils.hasText(scope)) { + contextAttributes = new HashMap<>(); + contextAttributes.put(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME, + StringUtils.delimitedListToStringArray(scope, " ")); + } + return Mono.just(contextAttributes); + }) + .defaultIfEmpty(Collections.emptyMap()); + } + } +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientServiceReactiveOAuth2AuthorizedClientManagerTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientServiceReactiveOAuth2AuthorizedClientManagerTests.java new file mode 100644 index 00000000000..2be9e4139d2 --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientServiceReactiveOAuth2AuthorizedClientManagerTests.java @@ -0,0 +1,300 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; +import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import java.util.function.Function; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.*; + +/** + * Tests for {@link OAuth2AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager}. + * + * @author Ankur Pathak + */ +public class OAuth2AuthorizedClientServiceReactiveOAuth2AuthorizedClientManagerTests { + private ReactiveClientRegistrationRepository clientRegistrationRepository; + private ReactiveOAuth2AuthorizedClientService authorizedClientService; + private ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider; + private Function contextAttributesMapper; + private OAuth2AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager authorizedClientManager; + private ClientRegistration clientRegistration; + private Authentication principal; + private OAuth2AuthorizedClient authorizedClient; + private ArgumentCaptor authorizationContextCaptor; + + @SuppressWarnings("unchecked") + @Before + public void setup() { + this.clientRegistrationRepository = mock(ReactiveClientRegistrationRepository.class); + this.authorizedClientService = mock(ReactiveOAuth2AuthorizedClientService.class); + this.authorizedClientProvider = mock(ReactiveOAuth2AuthorizedClientProvider.class); + this.contextAttributesMapper = mock(Function.class); + this.authorizedClientManager = new OAuth2AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager( + this.clientRegistrationRepository, this.authorizedClientService); + this.authorizedClientManager.setAuthorizedClientProvider(this.authorizedClientProvider); + this.authorizedClientManager.setContextAttributesMapper(this.contextAttributesMapper); + this.clientRegistration = TestClientRegistrations.clientRegistration().build(); + this.principal = new TestingAuthenticationToken("principal", "password"); + this.authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, this.principal.getName(), + TestOAuth2AccessTokens.scopes("read", "write"), TestOAuth2RefreshTokens.refreshToken()); + this.authorizationContextCaptor = ArgumentCaptor.forClass(OAuth2AuthorizationContext.class); + } + + @Test + public void constructorWhenClientRegistrationRepositoryIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new OAuth2AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager(null, this.authorizedClientService)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("reactiveClientRegistrationRepository cannot be null"); + } + + @Test + public void constructorWhenOAuth2AuthorizedClientServiceIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new OAuth2AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager(this.clientRegistrationRepository, null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("reactiveAuthorizedClientService cannot be null"); + } + + @Test + public void setAuthorizedClientProviderWhenNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.authorizedClientManager.setAuthorizedClientProvider(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("reactiveAuthorizedClientProvider cannot be null"); + } + + @Test + public void setContextAttributesMapperWhenNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.authorizedClientManager.setContextAttributesMapper(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("contextAttributesMapper cannot be null"); + } + + @Test + public void authorizeWhenRequestIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.authorizedClientManager.authorize(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("authorizeRequest cannot be null"); + } + + @Test + public void authorizeWhenClientRegistrationNotFoundThenThrowIllegalArgumentException() { + String clientRegistrationId = "invalid-registration-id"; + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(clientRegistrationId) + .principal(this.principal) + .build(); + when(this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId)).thenReturn(Mono.empty()); + StepVerifier.create(this.authorizedClientManager.authorize(authorizeRequest)) + .verifyError(IllegalArgumentException.class); + + } + + @SuppressWarnings("unchecked") + @Test + public void authorizeWhenNotAuthorizedAndUnsupportedProviderThenNotAuthorized() { + when(this.clientRegistrationRepository.findByRegistrationId( + eq(this.clientRegistration.getRegistrationId()))).thenReturn(Mono.just(this.clientRegistration)); + when(this.authorizedClientService.loadAuthorizedClient( + any(), any())).thenReturn(Mono.empty()); + + when(authorizedClientProvider.authorize(any())).thenReturn(Mono.empty()); + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId()) + .principal(this.principal) + .build(); + Mono authorizedClient = this.authorizedClientManager.authorize(authorizeRequest); + + authorizedClient.subscribe(); + + verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); + verify(this.contextAttributesMapper).apply(eq(authorizeRequest)); + + OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); + assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration); + assertThat(authorizationContext.getAuthorizedClient()).isNull(); + assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); + + StepVerifier.create(authorizedClient).expectComplete(); + verify(this.authorizedClientService, never()).saveAuthorizedClient( + any(OAuth2AuthorizedClient.class), eq(this.principal)); + } + + @SuppressWarnings("unchecked") + @Test + public void authorizeWhenNotAuthorizedAndSupportedProviderThenAuthorized() { + when(this.clientRegistrationRepository.findByRegistrationId( + eq(this.clientRegistration.getRegistrationId()))).thenReturn(Mono.just(this.clientRegistration)); + + when(this.authorizedClientService.loadAuthorizedClient( + any(), any())).thenReturn(Mono.empty()); + + when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(Mono.just(this.authorizedClient)); + + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId()) + .principal(this.principal) + .build(); + Mono authorizedClient = this.authorizedClientManager.authorize(authorizeRequest); + + authorizedClient.subscribe(); + + verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); + verify(this.contextAttributesMapper).apply(eq(authorizeRequest)); + + OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); + assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration); + assertThat(authorizationContext.getAuthorizedClient()).isNull(); + assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); + + StepVerifier.create(authorizedClient).expectNextCount(1).assertNext(x -> assertThat(x).isSameAs(this.authorizedClient)); + verify(this.authorizedClientService).saveAuthorizedClient( + eq(this.authorizedClient), eq(this.principal)); + } + + @SuppressWarnings("unchecked") + @Test + public void authorizeWhenAuthorizedAndSupportedProviderThenReauthorized() { + when(this.clientRegistrationRepository.findByRegistrationId( + eq(this.clientRegistration.getRegistrationId()))).thenReturn(Mono.just(this.clientRegistration)); + when(this.authorizedClientService.loadAuthorizedClient( + eq(this.clientRegistration.getRegistrationId()), eq(this.principal.getName()))).thenReturn(Mono.just(this.authorizedClient)); + + OAuth2AuthorizedClient reauthorizedClient = new OAuth2AuthorizedClient( + this.clientRegistration, this.principal.getName(), + TestOAuth2AccessTokens.noScopes(), TestOAuth2RefreshTokens.refreshToken()); + + when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(Mono.just(reauthorizedClient)); + + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId()) + .principal(this.principal) + .build(); + Mono authorizedClient = this.authorizedClientManager.authorize(authorizeRequest); + + authorizedClient.subscribe(); + + verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); + verify(this.contextAttributesMapper).apply(eq(authorizeRequest)); + + OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); + assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration); + assertThat(authorizationContext.getAuthorizedClient()).isSameAs(this.authorizedClient); + assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); + + StepVerifier.create(authorizedClient).expectNextCount(1).assertNext(x -> assertThat(x).isSameAs(this.authorizedClient)); + verify(this.authorizedClientService).saveAuthorizedClient( + eq(reauthorizedClient), eq(this.principal)); + } + + @SuppressWarnings("unchecked") + @Test + public void reauthorizeWhenUnsupportedProviderThenNotReauthorized() { + when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(Mono.empty()); + OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest.withAuthorizedClient(this.authorizedClient) + .principal(this.principal) + .build(); + Mono authorizedClient = this.authorizedClientManager.authorize(reauthorizeRequest); + + authorizedClient.subscribe(); + + verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); + verify(this.contextAttributesMapper).apply(eq(reauthorizeRequest)); + + OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); + assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration); + assertThat(authorizationContext.getAuthorizedClient()).isSameAs(this.authorizedClient); + assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); + + StepVerifier.create(authorizedClient).expectNextCount(1).assertNext(x -> assertThat(x).isSameAs(this.authorizedClient)); + verify(this.authorizedClientService, never()).saveAuthorizedClient( + any(OAuth2AuthorizedClient.class), eq(this.principal)); + } + + @SuppressWarnings("unchecked") + @Test + public void reauthorizeWhenSupportedProviderThenReauthorized() { + OAuth2AuthorizedClient reauthorizedClient = new OAuth2AuthorizedClient( + this.clientRegistration, this.principal.getName(), + TestOAuth2AccessTokens.noScopes(), TestOAuth2RefreshTokens.refreshToken()); + + when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(Mono.just(reauthorizedClient)); + + OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest.withAuthorizedClient(this.authorizedClient) + .principal(this.principal) + .build(); + Mono authorizedClient = this.authorizedClientManager.authorize(reauthorizeRequest); + + authorizedClient.subscribe(); + + verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); + verify(this.contextAttributesMapper).apply(eq(reauthorizeRequest)); + + OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); + assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration); + assertThat(authorizationContext.getAuthorizedClient()).isSameAs(this.authorizedClient); + assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); + + StepVerifier.create(authorizedClient).expectNextCount(1).assertNext(x -> assertThat(x).isSameAs(this.authorizedClient)); + verify(this.authorizedClientService).saveAuthorizedClient( + eq(reauthorizedClient), eq(this.principal)); + } + + @SuppressWarnings("unchecked") + @Test + public void reauthorizeWhenRequestAttributeScopeThenMappedToContext() { + OAuth2AuthorizedClient reauthorizedClient = new OAuth2AuthorizedClient( + this.clientRegistration, this.principal.getName(), + TestOAuth2AccessTokens.noScopes(), TestOAuth2RefreshTokens.refreshToken()); + + when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(Mono.just(reauthorizedClient)); + + + OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest.withAuthorizedClient(this.authorizedClient) + .principal(this.principal) + .attribute(OAuth2ParameterNames.SCOPE, "read write") + .build(); + Mono authorizedClient = this.authorizedClientManager.authorize(reauthorizeRequest); + + authorizedClient.subscribe(); + + verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); + + OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); + assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration); + assertThat(authorizationContext.getAuthorizedClient()).isSameAs(this.authorizedClient); + assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); + assertThat(authorizationContext.getAttributes()).containsKey(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME); + String[] requestScopeAttribute = authorizationContext.getAttribute(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME); + assertThat(requestScopeAttribute).contains("read", "write"); + + StepVerifier.create(authorizedClient).expectNextCount(1).assertNext(x -> assertThat(x).isSameAs(this.authorizedClient)); + verify(this.authorizedClientService).saveAuthorizedClient( + eq(reauthorizedClient), eq(this.principal)); + } +}