diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager.java new file mode 100644 index 00000000000..16400dc922c --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager.java @@ -0,0 +1,137 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client; + +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; +import org.springframework.util.Assert; +import reactor.core.publisher.Mono; + +import java.util.Collections; +import java.util.Map; +import java.util.function.Function; + +/** + * An implementation of an {@link ReactiveOAuth2AuthorizedClientManager} + * that is capable of operating outside of a {@code ServerHttpRequest} context, + * e.g. in a scheduled/background thread and/or in the service-tier. + * + *

This is a reactive equivalent of {@link org.springframework.security.oauth2.client.AuthorizedClientServiceOAuth2AuthorizedClientManager}

+ * + * @author Ankur Pathak + * @author Phil Clay + * @see ReactiveOAuth2AuthorizedClientManager + * @see ReactiveOAuth2AuthorizedClientProvider + * @see ReactiveOAuth2AuthorizedClientService + * @since 5.2.2 + */ +public final class AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager + implements ReactiveOAuth2AuthorizedClientManager { + + private final ReactiveClientRegistrationRepository clientRegistrationRepository; + private final ReactiveOAuth2AuthorizedClientService authorizedClientService; + private ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = context -> Mono.empty(); + private Function>> contextAttributesMapper = new DefaultContextAttributesMapper(); + + /** + * Constructs an {@code AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager} using the provided parameters. + * + * @param clientRegistrationRepository the repository of client registrations + * @param authorizedClientService the authorized client service + */ + public AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager( + ReactiveClientRegistrationRepository clientRegistrationRepository, + ReactiveOAuth2AuthorizedClientService authorizedClientService) { + Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null"); + Assert.notNull(authorizedClientService, "authorizedClientService cannot be null"); + this.clientRegistrationRepository = clientRegistrationRepository; + this.authorizedClientService = authorizedClientService; + } + + @Override + public Mono authorize(OAuth2AuthorizeRequest authorizeRequest) { + Assert.notNull(authorizeRequest, "authorizeRequest cannot be null"); + + return createAuthorizationContext(authorizeRequest) + .flatMap(this::authorizeAndSave); + } + + private Mono createAuthorizationContext(OAuth2AuthorizeRequest authorizeRequest) { + String clientRegistrationId = authorizeRequest.getClientRegistrationId(); + Authentication principal = authorizeRequest.getPrincipal(); + return Mono.justOrEmpty(authorizeRequest.getAuthorizedClient()) + .map(OAuth2AuthorizationContext::withAuthorizedClient) + .switchIfEmpty(Mono.defer(() -> this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId) + .flatMap(clientRegistration -> this.authorizedClientService.loadAuthorizedClient(clientRegistrationId, principal.getName()) + .map(OAuth2AuthorizationContext::withAuthorizedClient) + .switchIfEmpty(Mono.fromSupplier(() -> OAuth2AuthorizationContext.withClientRegistration(clientRegistration)))) + .switchIfEmpty(Mono.error(() -> new IllegalArgumentException("Could not find ClientRegistration with id '" + clientRegistrationId + "'"))))) + .flatMap(contextBuilder -> this.contextAttributesMapper.apply(authorizeRequest) + .defaultIfEmpty(Collections.emptyMap()) + .map(contextAttributes -> { + OAuth2AuthorizationContext.Builder builder = contextBuilder.principal(principal); + if (!contextAttributes.isEmpty()) { + builder = builder.attributes(attributes -> attributes.putAll(contextAttributes)); + } + return builder.build(); + })); + } + + private Mono authorizeAndSave(OAuth2AuthorizationContext authorizationContext) { + return this.authorizedClientProvider.authorize(authorizationContext) + .flatMap(authorizedClient -> this.authorizedClientService.saveAuthorizedClient( + authorizedClient, + authorizationContext.getPrincipal()) + .thenReturn(authorizedClient)) + .switchIfEmpty(Mono.defer(()-> Mono.justOrEmpty(authorizationContext.getAuthorizedClient()))); + } + + /** + * Sets the {@link ReactiveOAuth2AuthorizedClientProvider} used for authorizing (or re-authorizing) an OAuth 2.0 Client. + * + * @param authorizedClientProvider the {@link ReactiveOAuth2AuthorizedClientProvider} used for authorizing (or re-authorizing) an OAuth 2.0 Client + */ + public void setAuthorizedClientProvider(ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider) { + Assert.notNull(authorizedClientProvider, "authorizedClientProvider cannot be null"); + this.authorizedClientProvider = authorizedClientProvider; + } + + /** + * Sets the {@code Function} used for mapping attribute(s) from the {@link OAuth2AuthorizeRequest} to a {@code Map} of attributes + * to be associated to the {@link OAuth2AuthorizationContext#getAttributes() authorization context}. + * + * @param contextAttributesMapper the {@code Function} used for supplying the {@code Map} of attributes + * to the {@link OAuth2AuthorizationContext#getAttributes() authorization context} + */ + public void setContextAttributesMapper(Function>> contextAttributesMapper) { + Assert.notNull(contextAttributesMapper, "contextAttributesMapper cannot be null"); + this.contextAttributesMapper = contextAttributesMapper; + } + + /** + * The default implementation of the {@link #setContextAttributesMapper(Function) contextAttributesMapper}. + */ + public static class DefaultContextAttributesMapper implements Function>> { + + private final AuthorizedClientServiceOAuth2AuthorizedClientManager.DefaultContextAttributesMapper mapper = + new AuthorizedClientServiceOAuth2AuthorizedClientManager.DefaultContextAttributesMapper(); + + @Override + public Mono> apply(OAuth2AuthorizeRequest authorizeRequest) { + return Mono.fromCallable(() -> mapper.apply(authorizeRequest)); + } + } +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/AuthorizedClientServiceReactiveOAuth2AuthorizedClientManagerTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/AuthorizedClientServiceReactiveOAuth2AuthorizedClientManagerTests.java new file mode 100644 index 00000000000..ab9c7ff382a --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/AuthorizedClientServiceReactiveOAuth2AuthorizedClientManagerTests.java @@ -0,0 +1,317 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; +import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; +import reactor.test.publisher.PublisherProbe; + +import java.util.Map; +import java.util.function.Function; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +/** + * Tests for {@link AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager}. + * + * @author Ankur Pathak + * @author Phil Clay + */ +public class AuthorizedClientServiceReactiveOAuth2AuthorizedClientManagerTests { + private ReactiveClientRegistrationRepository clientRegistrationRepository; + private ReactiveOAuth2AuthorizedClientService authorizedClientService; + private ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider; + private Function>> contextAttributesMapper; + private AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager authorizedClientManager; + private ClientRegistration clientRegistration; + private Authentication principal; + private OAuth2AuthorizedClient authorizedClient; + private ArgumentCaptor authorizationContextCaptor; + private PublisherProbe saveAuthorizedClientProbe; + + @SuppressWarnings("unchecked") + @Before + public void setup() { + this.clientRegistrationRepository = mock(ReactiveClientRegistrationRepository.class); + this.authorizedClientService = mock(ReactiveOAuth2AuthorizedClientService.class); + this.saveAuthorizedClientProbe = PublisherProbe.empty(); + when(this.authorizedClientService.saveAuthorizedClient(any(), any())).thenReturn(this.saveAuthorizedClientProbe.mono()); + this.authorizedClientProvider = mock(ReactiveOAuth2AuthorizedClientProvider.class); + this.contextAttributesMapper = mock(Function.class); + when(this.contextAttributesMapper.apply(any())).thenReturn(Mono.empty()); + this.authorizedClientManager = new AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager( + this.clientRegistrationRepository, this.authorizedClientService); + this.authorizedClientManager.setAuthorizedClientProvider(this.authorizedClientProvider); + this.authorizedClientManager.setContextAttributesMapper(this.contextAttributesMapper); + this.clientRegistration = TestClientRegistrations.clientRegistration().build(); + this.principal = new TestingAuthenticationToken("principal", "password"); + this.authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, this.principal.getName(), + TestOAuth2AccessTokens.scopes("read", "write"), TestOAuth2RefreshTokens.refreshToken()); + this.authorizationContextCaptor = ArgumentCaptor.forClass(OAuth2AuthorizationContext.class); + } + + @Test + public void constructorWhenClientRegistrationRepositoryIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager(null, this.authorizedClientService)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("clientRegistrationRepository cannot be null"); + } + + @Test + public void constructorWhenOAuth2AuthorizedClientServiceIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager(this.clientRegistrationRepository, null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("authorizedClientService cannot be null"); + } + + @Test + public void setAuthorizedClientProviderWhenNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.authorizedClientManager.setAuthorizedClientProvider(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("authorizedClientProvider cannot be null"); + } + + @Test + public void setContextAttributesMapperWhenNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.authorizedClientManager.setContextAttributesMapper(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("contextAttributesMapper cannot be null"); + } + + @Test + public void authorizeWhenRequestIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.authorizedClientManager.authorize(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("authorizeRequest cannot be null"); + } + + @Test + public void authorizeWhenClientRegistrationNotFoundThenThrowIllegalArgumentException() { + String clientRegistrationId = "invalid-registration-id"; + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(clientRegistrationId) + .principal(this.principal) + .build(); + when(this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId)).thenReturn(Mono.empty()); + StepVerifier.create(this.authorizedClientManager.authorize(authorizeRequest)) + .verifyError(IllegalArgumentException.class); + + } + + @SuppressWarnings("unchecked") + @Test + public void authorizeWhenNotAuthorizedAndUnsupportedProviderThenNotAuthorized() { + when(this.clientRegistrationRepository.findByRegistrationId( + eq(this.clientRegistration.getRegistrationId()))).thenReturn(Mono.just(this.clientRegistration)); + when(this.authorizedClientService.loadAuthorizedClient( + any(), any())).thenReturn(Mono.empty()); + + when(authorizedClientProvider.authorize(any())).thenReturn(Mono.empty()); + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId()) + .principal(this.principal) + .build(); + Mono authorizedClient = this.authorizedClientManager.authorize(authorizeRequest); + + StepVerifier.create(authorizedClient).verifyComplete(); + + verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); + verify(this.contextAttributesMapper).apply(eq(authorizeRequest)); + + OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); + assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration); + assertThat(authorizationContext.getAuthorizedClient()).isNull(); + assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); + + verify(this.authorizedClientService, never()).saveAuthorizedClient( + any(OAuth2AuthorizedClient.class), eq(this.principal)); + } + + @SuppressWarnings("unchecked") + @Test + public void authorizeWhenNotAuthorizedAndSupportedProviderThenAuthorized() { + when(this.clientRegistrationRepository.findByRegistrationId( + eq(this.clientRegistration.getRegistrationId()))).thenReturn(Mono.just(this.clientRegistration)); + + when(this.authorizedClientService.loadAuthorizedClient( + any(), any())).thenReturn(Mono.empty()); + + when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(Mono.just(this.authorizedClient)); + + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId()) + .principal(this.principal) + .build(); + Mono authorizedClient = this.authorizedClientManager.authorize(authorizeRequest); + + StepVerifier.create(authorizedClient) + .expectNext(this.authorizedClient) + .verifyComplete(); + + verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); + verify(this.contextAttributesMapper).apply(eq(authorizeRequest)); + + OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); + assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration); + assertThat(authorizationContext.getAuthorizedClient()).isNull(); + assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); + + verify(this.authorizedClientService).saveAuthorizedClient( + eq(this.authorizedClient), eq(this.principal)); + this.saveAuthorizedClientProbe.assertWasSubscribed(); + } + + @SuppressWarnings("unchecked") + @Test + public void authorizeWhenAuthorizedAndSupportedProviderThenReauthorized() { + when(this.clientRegistrationRepository.findByRegistrationId( + eq(this.clientRegistration.getRegistrationId()))).thenReturn(Mono.just(this.clientRegistration)); + when(this.authorizedClientService.loadAuthorizedClient( + eq(this.clientRegistration.getRegistrationId()), eq(this.principal.getName()))).thenReturn(Mono.just(this.authorizedClient)); + + OAuth2AuthorizedClient reauthorizedClient = new OAuth2AuthorizedClient( + this.clientRegistration, this.principal.getName(), + TestOAuth2AccessTokens.noScopes(), TestOAuth2RefreshTokens.refreshToken()); + + when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(Mono.just(reauthorizedClient)); + + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId()) + .principal(this.principal) + .build(); + Mono authorizedClient = this.authorizedClientManager.authorize(authorizeRequest); + + StepVerifier.create(authorizedClient) + .expectNext(reauthorizedClient) + .verifyComplete(); + verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); + verify(this.contextAttributesMapper).apply(eq(authorizeRequest)); + + OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); + assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration); + assertThat(authorizationContext.getAuthorizedClient()).isSameAs(this.authorizedClient); + assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); + + verify(this.authorizedClientService).saveAuthorizedClient( + eq(reauthorizedClient), eq(this.principal)); + this.saveAuthorizedClientProbe.assertWasSubscribed(); + } + + @SuppressWarnings("unchecked") + @Test + public void reauthorizeWhenUnsupportedProviderThenNotReauthorized() { + when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(Mono.empty()); + OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest.withAuthorizedClient(this.authorizedClient) + .principal(this.principal) + .build(); + Mono authorizedClient = this.authorizedClientManager.authorize(reauthorizeRequest); + + StepVerifier.create(authorizedClient) + .expectNext(this.authorizedClient) + .verifyComplete(); + verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); + verify(this.contextAttributesMapper).apply(eq(reauthorizeRequest)); + + OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); + assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration); + assertThat(authorizationContext.getAuthorizedClient()).isSameAs(this.authorizedClient); + assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); + + verify(this.authorizedClientService, never()).saveAuthorizedClient( + any(OAuth2AuthorizedClient.class), eq(this.principal)); + } + + @SuppressWarnings("unchecked") + @Test + public void reauthorizeWhenSupportedProviderThenReauthorized() { + OAuth2AuthorizedClient reauthorizedClient = new OAuth2AuthorizedClient( + this.clientRegistration, this.principal.getName(), + TestOAuth2AccessTokens.noScopes(), TestOAuth2RefreshTokens.refreshToken()); + + when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(Mono.just(reauthorizedClient)); + + OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest.withAuthorizedClient(this.authorizedClient) + .principal(this.principal) + .build(); + Mono authorizedClient = this.authorizedClientManager.authorize(reauthorizeRequest); + + StepVerifier.create(authorizedClient) + .expectNext(reauthorizedClient) + .verifyComplete(); + + verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); + verify(this.contextAttributesMapper).apply(eq(reauthorizeRequest)); + + OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); + assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration); + assertThat(authorizationContext.getAuthorizedClient()).isSameAs(this.authorizedClient); + assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); + + verify(this.authorizedClientService).saveAuthorizedClient( + eq(reauthorizedClient), eq(this.principal)); + this.saveAuthorizedClientProbe.assertWasSubscribed(); + } + + @SuppressWarnings("unchecked") + @Test + public void reauthorizeWhenRequestAttributeScopeThenMappedToContext() { + OAuth2AuthorizedClient reauthorizedClient = new OAuth2AuthorizedClient( + this.clientRegistration, this.principal.getName(), + TestOAuth2AccessTokens.noScopes(), TestOAuth2RefreshTokens.refreshToken()); + + when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(Mono.just(reauthorizedClient)); + + OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest.withAuthorizedClient(this.authorizedClient) + .principal(this.principal) + .attribute(OAuth2ParameterNames.SCOPE, "read write") + .build(); + + this.authorizedClientManager.setContextAttributesMapper(new AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager.DefaultContextAttributesMapper()); + Mono authorizedClient = this.authorizedClientManager.authorize(reauthorizeRequest); + + StepVerifier.create(authorizedClient) + .expectNext(reauthorizedClient) + .verifyComplete(); + verify(this.authorizedClientService).saveAuthorizedClient( + eq(reauthorizedClient), eq(this.principal)); + this.saveAuthorizedClientProbe.assertWasSubscribed(); + + verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); + + OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); + assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration); + assertThat(authorizationContext.getAuthorizedClient()).isSameAs(this.authorizedClient); + assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); + assertThat(authorizationContext.getAttributes()).containsKey(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME); + String[] requestScopeAttribute = authorizationContext.getAttribute(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME); + assertThat(requestScopeAttribute).contains("read", "write"); + + } +}