From 4114e82d586f0e955eb5dee4e4fb601950ce986b Mon Sep 17 00:00:00 2001 From: Joe Grandja Date: Thu, 6 Jun 2019 12:38:30 -0400 Subject: [PATCH 01/19] Introduce OAuth2AuthorizedClientProvider --- .../client/OAuth2AuthorizationContext.java | 272 ++++++++++++++++++ .../OAuth2AuthorizedClientProvider.java | 45 +++ .../OAuth2AuthorizationContextTests.java | 103 +++++++ 3 files changed, 420 insertions(+) create mode 100644 oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizationContext.java create mode 100644 oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientProvider.java create mode 100644 oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/OAuth2AuthorizationContextTests.java diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizationContext.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizationContext.java new file mode 100644 index 00000000000..15bb1834bd0 --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizationContext.java @@ -0,0 +1,272 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client; + +import org.springframework.lang.Nullable; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.GrantedAuthority; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; + +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.Map; + +/** + * A context that holds authorization-specific state and is used by an {@link OAuth2AuthorizedClientProvider} + * when attempting to authorize (or re-authorize) an OAuth 2.0 Client. + * + * @author Joe Grandja + * @since 5.2 + * @see OAuth2AuthorizedClientProvider + */ +public final class OAuth2AuthorizationContext { + private ClientRegistration clientRegistration; + private Authentication principal; + private OAuth2AuthorizedClient authorizedClient; + private Map attributes; + + private OAuth2AuthorizationContext() { + } + + /** + * Returns the {@link ClientRegistration client} requiring authorization. + * + * @return the {@link ClientRegistration} + */ + public ClientRegistration getClientRegistration() { + return this.clientRegistration; + } + + /** + * Returns the {@code Principal} to be associated with the authorized client. + * + * @return the {@code Principal} to be associated with the authorized client + */ + public Authentication getPrincipal() { + return this.principal; + } + + /** + * Returns the {@link OAuth2AuthorizedClient authorized client} which requires re-authorization + * or {@code null} if the {@link #getClientRegistration() client} needs to be authorized. + * + * @return the {@link OAuth2AuthorizedClient} which requires re-authorization or {@code null} if the client needs to be authorized + */ + @Nullable + public OAuth2AuthorizedClient getAuthorizedClient() { + return this.authorizedClient; + } + + /** + * Returns the attributes associated to the context. + * + * @return a {@code Map} of the attributes associated to the context + */ + public Map getAttributes() { + return this.attributes; + } + + /** + * Returns the value of an attribute associated to the context, or {@code null} if not available. + * + * @param name the name of the attribute + * @param the type of the attribute + * @return the value of the attribute associated to the context + */ + @Nullable + @SuppressWarnings("unchecked") + public T getAttribute(String name) { + return (T) this.getAttributes().get(name); + } + + /** + * Returns {@code true} if the client needs to be authorized, otherwise {@code false}. + * + * @return {@code true} if the client needs to be authorized, otherwise {@code false}. + */ + public boolean authorizationRequired() { + return getAuthorizedClient() == null; + } + + /** + * Returns {@code true} if the client needs to be re-authorized, otherwise {@code false}. + * + * @return {@code true} if the client needs to be re-authorized, otherwise {@code false}. + */ + public boolean reauthorizationRequired() { + return getAuthorizedClient() != null; + } + + /** + * Returns a new {@link Builder} with the {@link ClientRegistration client} requiring authorization. + * + * @param clientRegistration the {@link ClientRegistration client} requiring authorization + * @return the {@link Builder} + */ + public static Builder authorize(ClientRegistration clientRegistration) { + return new Builder(clientRegistration); + } + + /** + * Returns a new {@link Builder} with the {@link OAuth2AuthorizedClient authorized client} requiring re-authorization. + * + * @param authorizedClient the {@link OAuth2AuthorizedClient authorized client} requiring re-authorization + * @return the {@link Builder} + */ + public static Builder reauthorize(OAuth2AuthorizedClient authorizedClient) { + return new Builder(authorizedClient); + } + + /** + * A builder for {@link OAuth2AuthorizationContext}. + */ + public static class Builder { + private ClientRegistration clientRegistration; + private Authentication principal; + private OAuth2AuthorizedClient authorizedClient; + private Map attributes; + + private Builder(ClientRegistration clientRegistration) { + Assert.notNull(clientRegistration, "clientRegistration cannot be null"); + this.clientRegistration = clientRegistration; + } + + private Builder(OAuth2AuthorizedClient authorizedClient) { + Assert.notNull(authorizedClient, "authorizedClient cannot be null"); + this.authorizedClient = authorizedClient; + } + + /** + * Sets the {@code Principal} to be associated with the authorized client + * + * @param principal the {@code Principal} to be associated with the authorized client + * @return the {@link Builder} + */ + public Builder principal(Authentication principal) { + this.principal = principal; + return this; + } + + /** + * Sets the {@code Principal}'s name to be associated with the authorized client + * + * @param principalName the {@code Principal}'s name to be associated with the authorized client + * @return the {@link Builder} + */ + public Builder principal(String principalName) { + this.principal = new PrincipalNameAuthentication(principalName); + return this; + } + + /** + * Sets the attributes associated to the context. + * + * @param attributes the attributes associated to the context + * @return the {@link Builder} + */ + public Builder attributes(Map attributes) { + this.attributes = attributes; + return this; + } + + /** + * Sets an attribute associated to the context. + * + * @param name the name of the attribute + * @param value the value of the attribute + * @return the {@link Builder} + */ + public Builder attribute(String name, Object value) { + if (this.attributes == null) { + this.attributes = new HashMap<>(); + } + this.attributes.put(name, value); + return this; + } + + /** + * Builds a new {@link OAuth2AuthorizationContext}. + * + * @return a {@link OAuth2AuthorizationContext} + */ + public OAuth2AuthorizationContext build() { + Assert.notNull(this.principal, "principal cannot be null"); + OAuth2AuthorizationContext context = new OAuth2AuthorizationContext(); + if (this.authorizedClient != null) { + context.clientRegistration = this.authorizedClient.getClientRegistration(); + context.authorizedClient = this.authorizedClient; + } else { + context.clientRegistration = this.clientRegistration; + } + context.principal = this.principal; + context.attributes = Collections.unmodifiableMap( + CollectionUtils.isEmpty(this.attributes) ? + Collections.emptyMap() : new LinkedHashMap<>(this.attributes)); + return context; + } + } + + private static class PrincipalNameAuthentication implements Authentication { + private final String principalName; + + private PrincipalNameAuthentication(String principalName) { + this.principalName = principalName; + } + + @Override + public Collection getAuthorities() { + throw unsupported(); + } + + @Override + public Object getCredentials() { + throw unsupported(); + } + + @Override + public Object getDetails() { + throw unsupported(); + } + + @Override + public Object getPrincipal() { + return getName(); + } + + @Override + public boolean isAuthenticated() { + throw unsupported(); + } + + @Override + public void setAuthenticated(boolean isAuthenticated) throws IllegalArgumentException { + throw unsupported(); + } + + @Override + public String getName() { + return this.principalName; + } + + private UnsupportedOperationException unsupported() { + return new UnsupportedOperationException("Not Supported"); + } + } +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientProvider.java new file mode 100644 index 00000000000..94beb09d4f4 --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientProvider.java @@ -0,0 +1,45 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client; + +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import reactor.util.annotation.Nullable; + +/** + * A strategy for authorizing (or re-authorizing) an OAuth 2.0 Client. + * Implementations will typically implement a specific {@link AuthorizationGrantType authorization grant} type. + * + * @author Joe Grandja + * @since 5.2 + * @see OAuth2AuthorizedClient + * @see OAuth2AuthorizationContext + * @see Section 1.3 Authorization Grant + */ +public interface OAuth2AuthorizedClientProvider { + + /** + * Attempt to authorize (or re-authorize) the {@link OAuth2AuthorizationContext#getClientRegistration() client} in the provided context. + * Implementations must return {@code null} if authorization is not supported for the specified client, + * e.g. the provider doesn't support the {@link ClientRegistration#getAuthorizationGrantType() authorization grant} type configured for the client. + * + * @param context the context that holds authorization-specific state for the client + * @return the {@link OAuth2AuthorizedClient} or {@code null} if authorization is not supported for the specified client + */ + @Nullable + OAuth2AuthorizedClient authorize(OAuth2AuthorizationContext context); + +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/OAuth2AuthorizationContextTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/OAuth2AuthorizationContextTests.java new file mode 100644 index 00000000000..4366c6a4de9 --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/OAuth2AuthorizationContextTests.java @@ -0,0 +1,103 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client; + +import org.junit.Before; +import org.junit.Test; +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; + +import static org.assertj.core.api.Assertions.*; + +/** + * Tests for {@link OAuth2AuthorizationContext}. + * + * @author Joe Grandja + */ +public class OAuth2AuthorizationContextTests { + private ClientRegistration clientRegistration; + private OAuth2AuthorizedClient authorizedClient; + private Authentication principal; + + @Before + public void setup() { + this.clientRegistration = TestClientRegistrations.clientRegistration().build(); + this.authorizedClient = new OAuth2AuthorizedClient( + this.clientRegistration, "principal", TestOAuth2AccessTokens.scopes("read", "write")); + this.principal = new TestingAuthenticationToken("principal", "password"); + } + + @Test + public void authorizeWhenClientRegistrationIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> OAuth2AuthorizationContext.authorize(null).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("clientRegistration cannot be null"); + } + + @Test + public void authorizeWhenPrincipalIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> OAuth2AuthorizationContext.authorize(this.clientRegistration).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("principal cannot be null"); + } + + @Test + public void authorizeWhenAllValuesProvidedThenAllValuesAreSet() { + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext.authorize(this.clientRegistration) + .principal(this.principal) + .attribute("attribute1", "value1") + .attribute("attribute2", "value2") + .build(); + assertThat(authorizationContext.getClientRegistration()).isSameAs(this.clientRegistration); + assertThat(authorizationContext.getPrincipal()).isSameAs(this.principal); + assertThat(authorizationContext.getAuthorizedClient()).isNull(); + assertThat(authorizationContext.getAttributes()).contains( + entry("attribute1", "value1"), entry("attribute2", "value2")); + assertThat(authorizationContext.authorizationRequired()).isTrue(); + } + + @Test + public void reauthorizeWhenAuthorizedClientIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> OAuth2AuthorizationContext.reauthorize(null).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("authorizedClient cannot be null"); + } + + @Test + public void reauthorizeWhenPrincipalIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> OAuth2AuthorizationContext.reauthorize(this.authorizedClient).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("principal cannot be null"); + } + + @Test + public void reauthorizeWhenAllValuesProvidedThenAllValuesAreSet() { + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext.reauthorize(this.authorizedClient) + .principal(this.principal) + .attribute("attribute1", "value1") + .attribute("attribute2", "value2") + .build(); + assertThat(authorizationContext.getAuthorizedClient()).isSameAs(this.authorizedClient); + assertThat(authorizationContext.getClientRegistration()).isSameAs(this.authorizedClient.getClientRegistration()); + assertThat(authorizationContext.getPrincipal()).isSameAs(this.principal); + assertThat(authorizationContext.getAttributes()).contains( + entry("attribute1", "value1"), entry("attribute2", "value2")); + assertThat(authorizationContext.reauthorizationRequired()).isTrue(); + } +} From 12d207a37cc2e725fe46bd85f8bf619de8e22f37 Mon Sep 17 00:00:00 2001 From: Joe Grandja Date: Fri, 7 Jun 2019 14:17:09 -0400 Subject: [PATCH 02/19] Add authorization_code OAuth2AuthorizedClientProvider --- ...ionCodeOAuth2AuthorizedClientProvider.java | 43 +++++++++++ ...deOAuth2AuthorizedClientProviderTests.java | 77 +++++++++++++++++++ 2 files changed, 120 insertions(+) create mode 100644 oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/AuthorizationCodeOAuth2AuthorizedClientProvider.java create mode 100644 oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/AuthorizationCodeOAuth2AuthorizedClientProviderTests.java diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/AuthorizationCodeOAuth2AuthorizedClientProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/AuthorizationCodeOAuth2AuthorizedClientProvider.java new file mode 100644 index 00000000000..fc594d36af3 --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/AuthorizationCodeOAuth2AuthorizedClientProvider.java @@ -0,0 +1,43 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client; + +import org.springframework.lang.Nullable; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.util.Assert; + +/** + * An implementation of an {@link OAuth2AuthorizedClientProvider} + * for the {@link AuthorizationGrantType#AUTHORIZATION_CODE authorization_code} grant. + * + * @author Joe Grandja + * @since 5.2 + * @see OAuth2AuthorizedClientProvider + */ +public final class AuthorizationCodeOAuth2AuthorizedClientProvider implements OAuth2AuthorizedClientProvider { + + @Override + @Nullable + public OAuth2AuthorizedClient authorize(OAuth2AuthorizationContext context) { + Assert.notNull(context, "context cannot be null"); + if (AuthorizationGrantType.AUTHORIZATION_CODE.equals(context.getClientRegistration().getAuthorizationGrantType()) && + context.authorizationRequired()) { + // ClientAuthorizationRequiredException is caught by OAuth2AuthorizationRequestRedirectFilter which initiates authorization + throw new ClientAuthorizationRequiredException(context.getClientRegistration().getRegistrationId()); + } + return null; + } +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/AuthorizationCodeOAuth2AuthorizedClientProviderTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/AuthorizationCodeOAuth2AuthorizedClientProviderTests.java new file mode 100644 index 00000000000..338c3b1d960 --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/AuthorizationCodeOAuth2AuthorizedClientProviderTests.java @@ -0,0 +1,77 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client; + +import org.junit.Before; +import org.junit.Test; +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Tests for {@link AuthorizationCodeOAuth2AuthorizedClientProvider}. + * + * @author Joe Grandja + */ +public class AuthorizationCodeOAuth2AuthorizedClientProviderTests { + private AuthorizationCodeOAuth2AuthorizedClientProvider authorizedClientProvider = + new AuthorizationCodeOAuth2AuthorizedClientProvider(); + private ClientRegistration clientRegistration; + private OAuth2AuthorizedClient authorizedClient; + private Authentication principal; + + @Before + public void setup() { + this.clientRegistration = TestClientRegistrations.clientRegistration().build(); + this.authorizedClient = new OAuth2AuthorizedClient( + this.clientRegistration, "principal", TestOAuth2AccessTokens.scopes("read", "write")); + this.principal = new TestingAuthenticationToken("principal", "password"); + } + + @Test + public void authorizeWhenContextIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.authorizedClientProvider.authorize(null)) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void authorizeWhenAuthorizationCodeAndNotAuthorizedThenAuthorize() { + OAuth2AuthorizationContext authorizationContext = + OAuth2AuthorizationContext.authorize(this.clientRegistration).principal(this.principal).build(); + assertThatThrownBy(() -> this.authorizedClientProvider.authorize(authorizationContext)) + .isInstanceOf(ClientAuthorizationRequiredException.class); + } + + @Test + public void authorizeWhenAuthorizationCodeAndAuthorizedThenUnableToAuthorize() { + OAuth2AuthorizationContext authorizationContext = + OAuth2AuthorizationContext.reauthorize(this.authorizedClient).principal(this.principal).build(); + assertThat(this.authorizedClientProvider.authorize(authorizationContext)).isNull(); + } + + @Test + public void authorizeWhenNotAuthorizationCodeThenUnableToAuthorize() { + ClientRegistration clientCredentialsClient = TestClientRegistrations.clientCredentials().build(); + OAuth2AuthorizationContext authorizationContext = + OAuth2AuthorizationContext.authorize(clientCredentialsClient).principal(this.principal).build(); + assertThat(this.authorizedClientProvider.authorize(authorizationContext)).isNull(); + } +} From c84990be80b5b4511db4db563dad4631ff6b50b4 Mon Sep 17 00:00:00 2001 From: Joe Grandja Date: Fri, 7 Jun 2019 16:45:22 -0400 Subject: [PATCH 03/19] Add client_credentials OAuth2AuthorizedClientProvider --- ...entialsOAuth2AuthorizedClientProvider.java | 110 ++++++++++++ ...lsOAuth2AuthorizedClientProviderTests.java | 170 ++++++++++++++++++ 2 files changed, 280 insertions(+) create mode 100644 oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ClientCredentialsOAuth2AuthorizedClientProvider.java create mode 100644 oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/ClientCredentialsOAuth2AuthorizedClientProviderTests.java diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ClientCredentialsOAuth2AuthorizedClientProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ClientCredentialsOAuth2AuthorizedClientProvider.java new file mode 100644 index 00000000000..c0de3587ec4 --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ClientCredentialsOAuth2AuthorizedClientProvider.java @@ -0,0 +1,110 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client; + +import org.springframework.lang.Nullable; +import org.springframework.security.oauth2.client.endpoint.DefaultClientCredentialsTokenResponseClient; +import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; +import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest; +import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; +import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.util.Assert; + +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +/** + * An implementation of an {@link OAuth2AuthorizedClientProvider} + * for the {@link AuthorizationGrantType#CLIENT_CREDENTIALS client_credentials} grant. + * + * @author Joe Grandja + * @since 5.2 + * @see OAuth2AuthorizedClientProvider + * @see DefaultClientCredentialsTokenResponseClient + */ +public final class ClientCredentialsOAuth2AuthorizedClientProvider implements OAuth2AuthorizedClientProvider { + private static final String HTTP_SERVLET_REQUEST_ATTR_NAME = HttpServletRequest.class.getName(); + private static final String HTTP_SERVLET_RESPONSE_ATTR_NAME = HttpServletResponse.class.getName(); + private final ClientRegistrationRepository clientRegistrationRepository; + private final OAuth2AuthorizedClientRepository authorizedClientRepository; + private OAuth2AccessTokenResponseClient accessTokenResponseClient = + new DefaultClientCredentialsTokenResponseClient(); + + /** + * Constructs a {@code ClientCredentialsOAuth2AuthorizedClientProvider} using the provided parameters. + * + * @param clientRegistrationRepository the repository of client registrations + * @param authorizedClientRepository the repository of authorized clients + */ + public ClientCredentialsOAuth2AuthorizedClientProvider(ClientRegistrationRepository clientRegistrationRepository, + OAuth2AuthorizedClientRepository authorizedClientRepository) { + Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null"); + Assert.notNull(authorizedClientRepository, "authorizedClientRepository cannot be null"); + this.clientRegistrationRepository = clientRegistrationRepository; + this.authorizedClientRepository = authorizedClientRepository; + } + + @Override + @Nullable + public OAuth2AuthorizedClient authorize(OAuth2AuthorizationContext context) { + Assert.notNull(context, "context cannot be null"); + if (!AuthorizationGrantType.CLIENT_CREDENTIALS.equals(context.getClientRegistration().getAuthorizationGrantType())) { + return null; + } + + HttpServletRequest request = context.getAttribute(HTTP_SERVLET_REQUEST_ATTR_NAME); + HttpServletResponse response = context.getAttribute(HTTP_SERVLET_RESPONSE_ATTR_NAME); + Assert.notNull(request, "context.HttpServletRequest cannot be null"); + Assert.notNull(response, "context.HttpServletResponse cannot be null"); + + // As per spec, in section 4.4.3 Access Token Response + // https://tools.ietf.org/html/rfc6749#section-4.4.3 + // A refresh token SHOULD NOT be included. + // + // Therefore, renewing an expired access token (re-authorization) + // is the same as acquiring a new access token (authorization). + + OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = + new OAuth2ClientCredentialsGrantRequest(context.getClientRegistration()); + OAuth2AccessTokenResponse tokenResponse = + this.accessTokenResponseClient.getTokenResponse(clientCredentialsGrantRequest); + + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( + context.getClientRegistration(), + context.getPrincipal().getName(), + tokenResponse.getAccessToken()); + + this.authorizedClientRepository.saveAuthorizedClient( + authorizedClient, + context.getPrincipal(), + request, + response); + + return authorizedClient; + } + + /** + * Sets the client used when requesting an access token credential at the Token Endpoint for the {@code client_credentials} grant. + * + * @param accessTokenResponseClient the client used when requesting an access token credential at the Token Endpoint for the {@code client_credentials} grant + */ + public void setAccessTokenResponseClient(OAuth2AccessTokenResponseClient accessTokenResponseClient) { + Assert.notNull(accessTokenResponseClient, "accessTokenResponseClient cannot be null"); + this.accessTokenResponseClient = accessTokenResponseClient; + } +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/ClientCredentialsOAuth2AuthorizedClientProviderTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/ClientCredentialsOAuth2AuthorizedClientProviderTests.java new file mode 100644 index 00000000000..a3b3b9d1ddf --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/ClientCredentialsOAuth2AuthorizedClientProviderTests.java @@ -0,0 +1,170 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client; + +import org.junit.Before; +import org.junit.Test; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; +import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; +import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses; + +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.*; + +/** + * Tests for {@link ClientCredentialsOAuth2AuthorizedClientProvider}. + * + * @author Joe Grandja + */ +public class ClientCredentialsOAuth2AuthorizedClientProviderTests { + private ClientRegistrationRepository clientRegistrationRepository; + private OAuth2AuthorizedClientRepository authorizedClientRepository; + private ClientCredentialsOAuth2AuthorizedClientProvider authorizedClientProvider; + private OAuth2AccessTokenResponseClient accessTokenResponseClient; + private ClientRegistration clientRegistration; + private Authentication principal; + + @Before + public void setup() { + this.clientRegistrationRepository = mock(ClientRegistrationRepository.class); + this.authorizedClientRepository = mock(OAuth2AuthorizedClientRepository.class); + this.authorizedClientProvider = new ClientCredentialsOAuth2AuthorizedClientProvider( + this.clientRegistrationRepository, this.authorizedClientRepository); + this.accessTokenResponseClient = mock(OAuth2AccessTokenResponseClient.class); + this.authorizedClientProvider.setAccessTokenResponseClient(this.accessTokenResponseClient); + this.clientRegistration = TestClientRegistrations.clientCredentials().build(); + this.principal = new TestingAuthenticationToken("principal", "password"); + } + + @Test + public void constructorWhenClientRegistrationRepositoryIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new ClientCredentialsOAuth2AuthorizedClientProvider(null, this.authorizedClientRepository)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("clientRegistrationRepository cannot be null"); + } + + @Test + public void constructorWhenOAuth2AuthorizedClientRepositoryIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new ClientCredentialsOAuth2AuthorizedClientProvider(this.clientRegistrationRepository, null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("authorizedClientRepository cannot be null"); + } + + @Test + public void setAccessTokenResponseClientWhenClientIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.authorizedClientProvider.setAccessTokenResponseClient(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("accessTokenResponseClient cannot be null"); + } + + @Test + public void authorizeWhenContextIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.authorizedClientProvider.authorize(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("context cannot be null"); + } + + @Test + public void authorizeWhenNotClientCredentialsThenUnableToAuthorize() { + ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); + OAuth2AuthorizationContext authorizationContext = + OAuth2AuthorizationContext.authorize(clientRegistration).principal(this.principal).build(); + assertThat(this.authorizedClientProvider.authorize(authorizationContext)).isNull(); + } + + @Test + public void authorizeWhenHttpServletRequestIsNullThenThrowIllegalArgumentException() { + OAuth2AuthorizationContext authorizationContext = + OAuth2AuthorizationContext.authorize(this.clientRegistration).principal(this.principal).build(); + assertThatThrownBy(() -> this.authorizedClientProvider.authorize(authorizationContext)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("context.HttpServletRequest cannot be null"); + } + + @Test + public void authorizeWhenHttpServletResponseIsNullThenThrowIllegalArgumentException() { + OAuth2AuthorizationContext authorizationContext = + OAuth2AuthorizationContext.authorize(this.clientRegistration) + .principal(this.principal) + .attribute(HttpServletRequest.class.getName(), new MockHttpServletRequest()) + .build(); + assertThatThrownBy(() -> this.authorizedClientProvider.authorize(authorizationContext)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("context.HttpServletResponse cannot be null"); + } + + @Test + public void authorizeWhenClientCredentialsAndNotAuthorizedThenAuthorize() { + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); + when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse); + + OAuth2AuthorizationContext authorizationContext = + OAuth2AuthorizationContext.authorize(this.clientRegistration) + .principal(this.principal) + .attribute(HttpServletRequest.class.getName(), new MockHttpServletRequest()) + .attribute(HttpServletResponse.class.getName(), new MockHttpServletResponse()) + .build(); + + OAuth2AuthorizedClient authorizedClient = this.authorizedClientProvider.authorize(authorizationContext); + + assertThat(authorizedClient.getClientRegistration()).isSameAs(this.clientRegistration); + assertThat(authorizedClient.getPrincipalName()).isEqualTo(this.principal.getName()); + assertThat(authorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken()); + verify(this.authorizedClientRepository).saveAuthorizedClient( + eq(authorizedClient), eq(this.principal), + any(HttpServletRequest.class), any(HttpServletResponse.class)); + } + + @Test + public void authorizeWhenClientCredentialsAndAuthorizedThenReauthorize() { + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); + when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse); + + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( + this.clientRegistration, this.principal.getName(), TestOAuth2AccessTokens.noScopes()); + + OAuth2AuthorizationContext authorizationContext = + OAuth2AuthorizationContext.reauthorize(authorizedClient) + .principal(this.principal) + .attribute(HttpServletRequest.class.getName(), new MockHttpServletRequest()) + .attribute(HttpServletResponse.class.getName(), new MockHttpServletResponse()) + .build(); + + authorizedClient = this.authorizedClientProvider.authorize(authorizationContext); + + assertThat(authorizedClient.getClientRegistration()).isSameAs(this.clientRegistration); + assertThat(authorizedClient.getPrincipalName()).isEqualTo(this.principal.getName()); + assertThat(authorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken()); + verify(this.authorizedClientRepository).saveAuthorizedClient( + eq(authorizedClient), eq(this.principal), + any(HttpServletRequest.class), any(HttpServletResponse.class)); + } +} From e3875ed443d7fad6b73c7e3ca65b6b9b89924379 Mon Sep 17 00:00:00 2001 From: Joe Grandja Date: Mon, 10 Jun 2019 08:08:18 -0400 Subject: [PATCH 04/19] Add refresh_token OAuth2AccessTokenResponseClient --- ...efaultRefreshTokenTokenResponseClient.java | 133 +++++++++++ .../OAuth2RefreshTokenGrantRequest.java | 82 +++++++ ...freshTokenGrantRequestEntityConverter.java | 89 +++++++ ...tRefreshTokenTokenResponseClientTests.java | 222 ++++++++++++++++++ ...TokenGrantRequestEntityConverterTests.java | 77 ++++++ .../OAuth2RefreshTokenGrantRequestTests.java | 77 ++++++ 6 files changed, 680 insertions(+) create mode 100644 oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/DefaultRefreshTokenTokenResponseClient.java create mode 100644 oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2RefreshTokenGrantRequest.java create mode 100644 oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2RefreshTokenGrantRequestEntityConverter.java create mode 100644 oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/DefaultRefreshTokenTokenResponseClientTests.java create mode 100644 oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2RefreshTokenGrantRequestEntityConverterTests.java create mode 100644 oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2RefreshTokenGrantRequestTests.java diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/DefaultRefreshTokenTokenResponseClient.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/DefaultRefreshTokenTokenResponseClient.java new file mode 100644 index 00000000000..d9c19b23f4f --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/DefaultRefreshTokenTokenResponseClient.java @@ -0,0 +1,133 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client.endpoint; + +import org.springframework.core.convert.converter.Converter; +import org.springframework.http.RequestEntity; +import org.springframework.http.ResponseEntity; +import org.springframework.http.converter.FormHttpMessageConverter; +import org.springframework.http.converter.HttpMessageConverter; +import org.springframework.security.oauth2.client.http.OAuth2ErrorResponseErrorHandler; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.OAuth2AuthorizationException; +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.security.oauth2.core.http.converter.OAuth2AccessTokenResponseHttpMessageConverter; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import org.springframework.web.client.ResponseErrorHandler; +import org.springframework.web.client.RestClientException; +import org.springframework.web.client.RestOperations; +import org.springframework.web.client.RestTemplate; + +import java.util.Arrays; + +/** + * The default implementation of an {@link OAuth2AccessTokenResponseClient} + * for the {@link AuthorizationGrantType#REFRESH_TOKEN refresh_token} grant. + * This implementation uses a {@link RestOperations} when requesting + * an access token credential at the Authorization Server's Token Endpoint. + * + * @author Joe Grandja + * @since 5.2 + * @see OAuth2AccessTokenResponseClient + * @see OAuth2RefreshTokenGrantRequest + * @see OAuth2AccessTokenResponse + * @see Section 6 Refreshing an Access Token + */ +public final class DefaultRefreshTokenTokenResponseClient implements OAuth2AccessTokenResponseClient { + private static final String INVALID_TOKEN_RESPONSE_ERROR_CODE = "invalid_token_response"; + + private Converter> requestEntityConverter = + new OAuth2RefreshTokenGrantRequestEntityConverter(); + + private RestOperations restOperations; + + public DefaultRefreshTokenTokenResponseClient() { + RestTemplate restTemplate = new RestTemplate(Arrays.asList( + new FormHttpMessageConverter(), new OAuth2AccessTokenResponseHttpMessageConverter())); + restTemplate.setErrorHandler(new OAuth2ErrorResponseErrorHandler()); + this.restOperations = restTemplate; + } + + @Override + public OAuth2AccessTokenResponse getTokenResponse(OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest) { + Assert.notNull(refreshTokenGrantRequest, "refreshTokenGrantRequest cannot be null"); + + RequestEntity request = this.requestEntityConverter.convert(refreshTokenGrantRequest); + + ResponseEntity response; + try { + response = this.restOperations.exchange(request, OAuth2AccessTokenResponse.class); + } catch (RestClientException ex) { + OAuth2Error oauth2Error = new OAuth2Error(INVALID_TOKEN_RESPONSE_ERROR_CODE, + "An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response: " + ex.getMessage(), null); + throw new OAuth2AuthorizationException(oauth2Error, ex); + } + + OAuth2AccessTokenResponse tokenResponse = response.getBody(); + + if (CollectionUtils.isEmpty(tokenResponse.getAccessToken().getScopes()) || + tokenResponse.getRefreshToken() == null) { + OAuth2AccessTokenResponse.Builder tokenResponseBuilder = OAuth2AccessTokenResponse.withResponse(tokenResponse); + + if (CollectionUtils.isEmpty(tokenResponse.getAccessToken().getScopes())) { + // As per spec, in Section 5.1 Successful Access Token Response + // https://tools.ietf.org/html/rfc6749#section-5.1 + // If AccessTokenResponse.scope is empty, then default to the scope + // originally requested by the client in the Token Request + tokenResponseBuilder.scopes(refreshTokenGrantRequest.getAuthorizedClient().getAccessToken().getScopes()); + } + + if (tokenResponse.getRefreshToken() == null) { + // Reuse existing refresh token + tokenResponseBuilder.refreshToken(refreshTokenGrantRequest.getAuthorizedClient().getRefreshToken().getTokenValue()); + } + + tokenResponse = tokenResponseBuilder.build(); + } + + return tokenResponse; + } + + /** + * Sets the {@link Converter} used for converting the {@link OAuth2RefreshTokenGrantRequest} + * to a {@link RequestEntity} representation of the OAuth 2.0 Access Token Request. + * + * @param requestEntityConverter the {@link Converter} used for converting to a {@link RequestEntity} representation of the Access Token Request + */ + public void setRequestEntityConverter(Converter> requestEntityConverter) { + Assert.notNull(requestEntityConverter, "requestEntityConverter cannot be null"); + this.requestEntityConverter = requestEntityConverter; + } + + /** + * Sets the {@link RestOperations} used when requesting the OAuth 2.0 Access Token Response. + * + *

+ * NOTE: At a minimum, the supplied {@code restOperations} must be configured with the following: + *

    + *
  1. {@link HttpMessageConverter}'s - {@link FormHttpMessageConverter} and {@link OAuth2AccessTokenResponseHttpMessageConverter}
  2. + *
  3. {@link ResponseErrorHandler} - {@link OAuth2ErrorResponseErrorHandler}
  4. + *
+ * + * @param restOperations the {@link RestOperations} used when requesting the Access Token Response + */ + public void setRestOperations(RestOperations restOperations) { + Assert.notNull(restOperations, "restOperations cannot be null"); + this.restOperations = restOperations; + } +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2RefreshTokenGrantRequest.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2RefreshTokenGrantRequest.java new file mode 100644 index 00000000000..e84c4e365f5 --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2RefreshTokenGrantRequest.java @@ -0,0 +1,82 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client.endpoint; + +import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.util.Assert; + +import java.util.Collections; +import java.util.LinkedHashSet; +import java.util.Set; + +/** + * An OAuth 2.0 Refresh Token Grant request that holds + * the {@link OAuth2AuthorizedClient authorized client}. + * + * @author Joe Grandja + * @since 5.2 + * @see AbstractOAuth2AuthorizationGrantRequest + * @see OAuth2AuthorizedClient + * @see Section 6 Refreshing an Access Token + */ +public class OAuth2RefreshTokenGrantRequest extends AbstractOAuth2AuthorizationGrantRequest { + private final OAuth2AuthorizedClient authorizedClient; + private final Set scopes; + + /** + * Constructs an {@code OAuth2RefreshTokenGrantRequest} using the provided parameters. + * + * @param authorizedClient the authorized client + */ + public OAuth2RefreshTokenGrantRequest(OAuth2AuthorizedClient authorizedClient) { + this(authorizedClient, Collections.emptySet()); + } + + /** + * Constructs an {@code OAuth2RefreshTokenGrantRequest} using the provided parameters. + * + * @param authorizedClient the authorized client + * @param scopes the scopes + */ + public OAuth2RefreshTokenGrantRequest(OAuth2AuthorizedClient authorizedClient, Set scopes) { + super(AuthorizationGrantType.REFRESH_TOKEN); + Assert.notNull(authorizedClient, "authorizedClient cannot be null"); + Assert.notNull(authorizedClient.getRefreshToken(), "authorizedClient.refreshToken cannot be null"); + this.authorizedClient = authorizedClient; + this.scopes = Collections.unmodifiableSet(scopes != null ? + new LinkedHashSet<>(scopes) : Collections.emptySet()); + + } + + /** + * Returns the {@link OAuth2AuthorizedClient authorized client}. + * + * @return the {@link OAuth2AuthorizedClient} + */ + public OAuth2AuthorizedClient getAuthorizedClient() { + return this.authorizedClient; + } + + /** + * Returns the scope(s). + * + * @return the scope(s) + */ + public Set getScopes() { + return this.scopes; + } +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2RefreshTokenGrantRequestEntityConverter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2RefreshTokenGrantRequestEntityConverter.java new file mode 100644 index 00000000000..f3bdeb71d17 --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2RefreshTokenGrantRequestEntityConverter.java @@ -0,0 +1,89 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client.endpoint; + +import org.springframework.core.convert.converter.Converter; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.RequestEntity; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.core.ClientAuthenticationMethod; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.util.CollectionUtils; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.MultiValueMap; +import org.springframework.util.StringUtils; +import org.springframework.web.util.UriComponentsBuilder; + +import java.net.URI; + +/** + * A {@link Converter} that converts the provided {@link OAuth2RefreshTokenGrantRequest} + * to a {@link RequestEntity} representation of an OAuth 2.0 Access Token Request + * for the Refresh Token Grant. + * + * @author Joe Grandja + * @since 5.2 + * @see Converter + * @see OAuth2RefreshTokenGrantRequest + * @see RequestEntity + */ +public class OAuth2RefreshTokenGrantRequestEntityConverter implements Converter> { + + /** + * Returns the {@link RequestEntity} used for the Access Token Request. + * + * @param refreshTokenGrantRequest the refresh token grant request + * @return the {@link RequestEntity} used for the Access Token Request + */ + @Override + public RequestEntity convert(OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest) { + ClientRegistration clientRegistration = refreshTokenGrantRequest.getAuthorizedClient().getClientRegistration(); + + HttpHeaders headers = OAuth2AuthorizationGrantRequestEntityUtils.getTokenRequestHeaders(clientRegistration); + MultiValueMap formParameters = buildFormParameters(refreshTokenGrantRequest); + URI uri = UriComponentsBuilder.fromUriString(clientRegistration.getProviderDetails().getTokenUri()) + .build() + .toUri(); + + return new RequestEntity<>(formParameters, headers, HttpMethod.POST, uri); + } + + /** + * Returns a {@link MultiValueMap} of the form parameters used for the Access Token Request body. + * + * @param refreshTokenGrantRequest the refresh token grant request + * @return a {@link MultiValueMap} of the form parameters used for the Access Token Request body + */ + private MultiValueMap buildFormParameters(OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest) { + ClientRegistration clientRegistration = refreshTokenGrantRequest.getAuthorizedClient().getClientRegistration(); + + MultiValueMap formParameters = new LinkedMultiValueMap<>(); + formParameters.add(OAuth2ParameterNames.GRANT_TYPE, refreshTokenGrantRequest.getGrantType().getValue()); + formParameters.add(OAuth2ParameterNames.REFRESH_TOKEN, + refreshTokenGrantRequest.getAuthorizedClient().getRefreshToken().getTokenValue()); + if (!CollectionUtils.isEmpty(refreshTokenGrantRequest.getScopes())) { + formParameters.add(OAuth2ParameterNames.SCOPE, + StringUtils.collectionToDelimitedString(refreshTokenGrantRequest.getScopes(), " ")); + } + if (ClientAuthenticationMethod.POST.equals(clientRegistration.getClientAuthenticationMethod())) { + formParameters.add(OAuth2ParameterNames.CLIENT_ID, clientRegistration.getClientId()); + formParameters.add(OAuth2ParameterNames.CLIENT_SECRET, clientRegistration.getClientSecret()); + } + + return formParameters; + } +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/DefaultRefreshTokenTokenResponseClientTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/DefaultRefreshTokenTokenResponseClientTests.java new file mode 100644 index 00000000000..1284ce16737 --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/DefaultRefreshTokenTokenResponseClientTests.java @@ -0,0 +1,222 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client.endpoint; + +import okhttp3.mockwebserver.MockResponse; +import okhttp3.mockwebserver.MockWebServer; +import okhttp3.mockwebserver.RecordedRequest; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.MediaType; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.core.ClientAuthenticationMethod; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.OAuth2AuthorizationException; +import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; +import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; + +import java.time.Instant; +import java.util.Collections; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Tests for {@link DefaultRefreshTokenTokenResponseClient}. + * + * @author Joe Grandja + */ +public class DefaultRefreshTokenTokenResponseClientTests { + private DefaultRefreshTokenTokenResponseClient tokenResponseClient = new DefaultRefreshTokenTokenResponseClient(); + private ClientRegistration.Builder clientRegistrationBuilder; + private OAuth2AuthorizedClient authorizedClient; + private MockWebServer server; + + @Before + public void setup() throws Exception { + this.server = new MockWebServer(); + this.server.start(); + String tokenUri = this.server.url("/oauth2/token").toString(); + this.clientRegistrationBuilder = TestClientRegistrations.clientRegistration().tokenUri(tokenUri); + this.authorizedClient = new OAuth2AuthorizedClient(this.clientRegistrationBuilder.build(), + "principal", TestOAuth2AccessTokens.scopes("read", "write"), TestOAuth2RefreshTokens.refreshToken()); + } + + @After + public void cleanup() throws Exception { + this.server.shutdown(); + } + + @Test + public void setRequestEntityConverterWhenConverterIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.tokenResponseClient.setRequestEntityConverter(null)) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void setRestOperationsWhenRestOperationsIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.tokenResponseClient.setRestOperations(null)) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void getTokenResponseWhenRequestIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(null)) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void getTokenResponseWhenSuccessResponseThenReturnAccessTokenResponse() throws Exception { + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + + Instant expiresAtBefore = Instant.now().plusSeconds(3600); + + OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = + new OAuth2RefreshTokenGrantRequest(this.authorizedClient); + + OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(refreshTokenGrantRequest); + + Instant expiresAtAfter = Instant.now().plusSeconds(3600); + + RecordedRequest recordedRequest = this.server.takeRequest(); + assertThat(recordedRequest.getMethod()).isEqualTo(HttpMethod.POST.toString()); + assertThat(recordedRequest.getHeader(HttpHeaders.ACCEPT)).isEqualTo(MediaType.APPLICATION_JSON_UTF8_VALUE); + assertThat(recordedRequest.getHeader(HttpHeaders.CONTENT_TYPE)).isEqualTo(MediaType.APPLICATION_FORM_URLENCODED_VALUE + ";charset=UTF-8"); + assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).startsWith("Basic "); + + String formParameters = recordedRequest.getBody().readUtf8(); + assertThat(formParameters).contains("grant_type=refresh_token"); + assertThat(formParameters).contains("refresh_token=refresh-token"); + + assertThat(accessTokenResponse.getAccessToken().getTokenValue()).isEqualTo("access-token-1234"); + assertThat(accessTokenResponse.getAccessToken().getTokenType()).isEqualTo(OAuth2AccessToken.TokenType.BEARER); + assertThat(accessTokenResponse.getAccessToken().getExpiresAt()).isBetween(expiresAtBefore, expiresAtAfter); + assertThat(accessTokenResponse.getAccessToken().getScopes()).containsExactly(this.authorizedClient.getAccessToken().getScopes().toArray(new String[0])); + assertThat(accessTokenResponse.getRefreshToken().getTokenValue()).isEqualTo(this.authorizedClient.getRefreshToken().getTokenValue()); + } + + @Test + public void getTokenResponseWhenClientAuthenticationPostThenFormParametersAreSent() throws Exception { + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + + ClientRegistration clientRegistration = this.clientRegistrationBuilder + .clientAuthenticationMethod(ClientAuthenticationMethod.POST) + .build(); + this.authorizedClient = new OAuth2AuthorizedClient(clientRegistration, + "principal", TestOAuth2AccessTokens.scopes("read", "write"), TestOAuth2RefreshTokens.refreshToken()); + + OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = + new OAuth2RefreshTokenGrantRequest(this.authorizedClient); + + this.tokenResponseClient.getTokenResponse(refreshTokenGrantRequest); + + RecordedRequest recordedRequest = this.server.takeRequest(); + assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isNull(); + + String formParameters = recordedRequest.getBody().readUtf8(); + assertThat(formParameters).contains("client_id=client-id"); + assertThat(formParameters).contains("client_secret=client-secret"); + } + + @Test + public void getTokenResponseWhenSuccessResponseAndNotBearerTokenTypeThenThrowOAuth2AuthorizationException() { + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"not-bearer\",\n" + + " \"expires_in\": \"3600\"\n" + + "}\n"; + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + + OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = + new OAuth2RefreshTokenGrantRequest(this.authorizedClient); + + assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(refreshTokenGrantRequest)) + .isInstanceOf(OAuth2AuthorizationException.class) + .hasMessageContaining("[invalid_token_response] An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response") + .hasMessageContaining("tokenType cannot be null"); + } + + @Test + public void getTokenResponseWhenSuccessResponseIncludesScopeThenAccessTokenHasResponseScope() throws Exception { + String accessTokenSuccessResponse = "{\n" + + " \"access_token\": \"access-token-1234\",\n" + + " \"token_type\": \"bearer\",\n" + + " \"expires_in\": \"3600\",\n" + + " \"scope\": \"read\"\n" + + "}\n"; + this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); + + OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = + new OAuth2RefreshTokenGrantRequest(this.authorizedClient, Collections.singleton("read")); + + OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(refreshTokenGrantRequest); + + RecordedRequest recordedRequest = this.server.takeRequest(); + String formParameters = recordedRequest.getBody().readUtf8(); + assertThat(formParameters).contains("scope=read"); + + assertThat(accessTokenResponse.getAccessToken().getScopes()).containsExactly("read"); + } + + @Test + public void getTokenResponseWhenErrorResponseThenThrowOAuth2AuthorizationException() { + String accessTokenErrorResponse = "{\n" + + " \"error\": \"unauthorized_client\"\n" + + "}\n"; + this.server.enqueue(jsonResponse(accessTokenErrorResponse).setResponseCode(400)); + + OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = + new OAuth2RefreshTokenGrantRequest(this.authorizedClient); + + assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(refreshTokenGrantRequest)) + .isInstanceOf(OAuth2AuthorizationException.class) + .hasMessageContaining("[unauthorized_client]"); + } + + @Test + public void getTokenResponseWhenServerErrorResponseThenThrowOAuth2AuthorizationException() { + this.server.enqueue(new MockResponse().setResponseCode(500)); + + OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = + new OAuth2RefreshTokenGrantRequest(this.authorizedClient); + + assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(refreshTokenGrantRequest)) + .isInstanceOf(OAuth2AuthorizationException.class) + .hasMessage("[invalid_token_response] An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response: 500 Server Error"); + } + + private MockResponse jsonResponse(String json) { + return new MockResponse() + .setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) + .setBody(json); + } +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2RefreshTokenGrantRequestEntityConverterTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2RefreshTokenGrantRequestEntityConverterTests.java new file mode 100644 index 00000000000..9925b80a62d --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2RefreshTokenGrantRequestEntityConverterTests.java @@ -0,0 +1,77 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client.endpoint; + +import org.junit.Before; +import org.junit.Test; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.MediaType; +import org.springframework.http.RequestEntity; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; +import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.util.MultiValueMap; + +import java.util.Collections; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.springframework.http.MediaType.APPLICATION_FORM_URLENCODED_VALUE; + +/** + * Tests for {@link OAuth2RefreshTokenGrantRequestEntityConverter}. + * + * @author Joe Grandja + */ +public class OAuth2RefreshTokenGrantRequestEntityConverterTests { + private OAuth2RefreshTokenGrantRequestEntityConverter converter = new OAuth2RefreshTokenGrantRequestEntityConverter(); + private OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest; + + @Before + public void setup() { + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(TestClientRegistrations.clientRegistration().build(), + "principal", TestOAuth2AccessTokens.scopes("read", "write"), TestOAuth2RefreshTokens.refreshToken()); + this.refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest(authorizedClient, Collections.singleton("read")); + } + + @SuppressWarnings("unchecked") + @Test + public void convertWhenGrantRequestValidThenConverts() { + RequestEntity requestEntity = this.converter.convert(this.refreshTokenGrantRequest); + + OAuth2AuthorizedClient authorizedClient = this.refreshTokenGrantRequest.getAuthorizedClient(); + + assertThat(requestEntity.getMethod()).isEqualTo(HttpMethod.POST); + assertThat(requestEntity.getUrl().toASCIIString()).isEqualTo( + authorizedClient.getClientRegistration().getProviderDetails().getTokenUri()); + + HttpHeaders headers = requestEntity.getHeaders(); + assertThat(headers.getAccept()).contains(MediaType.APPLICATION_JSON_UTF8); + assertThat(headers.getContentType()).isEqualTo( + MediaType.valueOf(APPLICATION_FORM_URLENCODED_VALUE + ";charset=UTF-8")); + assertThat(headers.getFirst(HttpHeaders.AUTHORIZATION)).startsWith("Basic "); + + MultiValueMap formParameters = (MultiValueMap) requestEntity.getBody(); + assertThat(formParameters.getFirst(OAuth2ParameterNames.GRANT_TYPE)).isEqualTo( + AuthorizationGrantType.REFRESH_TOKEN.getValue()); + assertThat(formParameters.getFirst(OAuth2ParameterNames.REFRESH_TOKEN)).isEqualTo( + authorizedClient.getRefreshToken().getTokenValue()); + assertThat(formParameters.getFirst(OAuth2ParameterNames.SCOPE)).isEqualTo("read"); + } +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2RefreshTokenGrantRequestTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2RefreshTokenGrantRequestTests.java new file mode 100644 index 00000000000..0a0d1fd739b --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2RefreshTokenGrantRequestTests.java @@ -0,0 +1,77 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client.endpoint; + +import org.junit.Before; +import org.junit.Test; +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; +import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens; + +import java.util.Arrays; +import java.util.HashSet; +import java.util.Set; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Tests for {@link OAuth2RefreshTokenGrantRequest}. + * + * @author Joe Grandja + */ +public class OAuth2RefreshTokenGrantRequestTests { + private ClientRegistration clientRegistration; + private Authentication principal; + private OAuth2AuthorizedClient authorizedClient; + + @Before + public void setup() { + this.clientRegistration = TestClientRegistrations.clientRegistration().build(); + this.principal = new TestingAuthenticationToken("principal", "password"); + this.authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, this.principal.getName(), + TestOAuth2AccessTokens.scopes("read", "write"), TestOAuth2RefreshTokens.refreshToken()); + } + + @Test + public void constructorWhenAuthorizedClientIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new OAuth2RefreshTokenGrantRequest(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("authorizedClient cannot be null"); + } + + @Test + public void constructorWhenRefreshTokenIsNullThenThrowIllegalArgumentException() { + this.authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, + this.principal.getName(), TestOAuth2AccessTokens.scopes("read", "write")); + assertThatThrownBy(() -> new OAuth2RefreshTokenGrantRequest(this.authorizedClient)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("authorizedClient.refreshToken cannot be null"); + } + + @Test + public void constructorWhenValidParametersProvidedThenCreated() { + Set scopes = new HashSet<>(Arrays.asList("read", "write")); + OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = + new OAuth2RefreshTokenGrantRequest(this.authorizedClient, scopes); + assertThat(refreshTokenGrantRequest.getAuthorizedClient()).isSameAs(this.authorizedClient); + assertThat(refreshTokenGrantRequest.getScopes()).isEqualTo(scopes); + } +} From c100e622490810d65f2230c45b10e9899fafbc1e Mon Sep 17 00:00:00 2001 From: Joe Grandja Date: Tue, 11 Jun 2019 21:00:47 -0400 Subject: [PATCH 05/19] Add refresh_token OAuth2AuthorizedClientProvider --- ...shTokenOAuth2AuthorizedClientProvider.java | 113 ++++++++++ ...enOAuth2AuthorizedClientProviderTests.java | 211 ++++++++++++++++++ 2 files changed, 324 insertions(+) create mode 100644 oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProvider.java create mode 100644 oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProviderTests.java diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProvider.java new file mode 100644 index 00000000000..0ac3defd432 --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProvider.java @@ -0,0 +1,113 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client; + +import org.springframework.lang.Nullable; +import org.springframework.security.oauth2.client.endpoint.DefaultRefreshTokenTokenResponseClient; +import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; +import org.springframework.security.oauth2.client.endpoint.OAuth2RefreshTokenGrantRequest; +import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; +import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.util.Assert; + +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import java.util.Set; + +/** + * An implementation of an {@link OAuth2AuthorizedClientProvider} + * for the {@link AuthorizationGrantType#REFRESH_TOKEN refresh_token} grant. + * + * @author Joe Grandja + * @since 5.2 + * @see OAuth2AuthorizedClientProvider + * @see DefaultRefreshTokenTokenResponseClient + */ +public final class RefreshTokenOAuth2AuthorizedClientProvider implements OAuth2AuthorizedClientProvider { + private static final String HTTP_SERVLET_REQUEST_ATTR_NAME = HttpServletRequest.class.getName(); + private static final String HTTP_SERVLET_RESPONSE_ATTR_NAME = HttpServletResponse.class.getName(); + private static final String SCOPE_ATTR_NAME = "SCOPE"; + private final ClientRegistrationRepository clientRegistrationRepository; + private final OAuth2AuthorizedClientRepository authorizedClientRepository; + private OAuth2AccessTokenResponseClient accessTokenResponseClient = + new DefaultRefreshTokenTokenResponseClient(); + + /** + * Constructs a {@code RefreshTokenOAuth2AuthorizedClientProvider} using the provided parameters. + * + * @param clientRegistrationRepository the repository of client registrations + * @param authorizedClientRepository the repository of authorized clients + */ + public RefreshTokenOAuth2AuthorizedClientProvider(ClientRegistrationRepository clientRegistrationRepository, + OAuth2AuthorizedClientRepository authorizedClientRepository) { + Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null"); + Assert.notNull(authorizedClientRepository, "authorizedClientRepository cannot be null"); + this.clientRegistrationRepository = clientRegistrationRepository; + this.authorizedClientRepository = authorizedClientRepository; + } + + @Override + @Nullable + public OAuth2AuthorizedClient authorize(OAuth2AuthorizationContext context) { + Assert.notNull(context, "context cannot be null"); + if (!context.reauthorizationRequired() || context.getAuthorizedClient().getRefreshToken() == null) { + return null; + } + + HttpServletRequest request = context.getAttribute(HTTP_SERVLET_REQUEST_ATTR_NAME); + HttpServletResponse response = context.getAttribute(HTTP_SERVLET_RESPONSE_ATTR_NAME); + Assert.notNull(request, "context.HttpServletRequest cannot be null"); + Assert.notNull(response, "context.HttpServletResponse cannot be null"); + + Object scopesObj = context.getAttribute(SCOPE_ATTR_NAME); + Set scopes = null; + if (scopesObj != null) { + Assert.isTrue(scopesObj instanceof Set, "The '" + SCOPE_ATTR_NAME + "' attribute must be of type " + Set.class.getName()); + scopes = (Set) scopesObj; + } + + OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = + new OAuth2RefreshTokenGrantRequest(context.getAuthorizedClient(), scopes); + OAuth2AccessTokenResponse tokenResponse = + this.accessTokenResponseClient.getTokenResponse(refreshTokenGrantRequest); + + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( + context.getClientRegistration(), + context.getPrincipal().getName(), + tokenResponse.getAccessToken(), + tokenResponse.getRefreshToken()); + + this.authorizedClientRepository.saveAuthorizedClient( + authorizedClient, + context.getPrincipal(), + request, + response); + + return authorizedClient; + } + + /** + * Sets the client used when requesting an access token credential at the Token Endpoint for the {@code refresh_token} grant. + * + * @param accessTokenResponseClient the client used when requesting an access token credential at the Token Endpoint for the {@code refresh_token} grant + */ + public void setAccessTokenResponseClient(OAuth2AccessTokenResponseClient accessTokenResponseClient) { + Assert.notNull(accessTokenResponseClient, "accessTokenResponseClient cannot be null"); + this.accessTokenResponseClient = accessTokenResponseClient; + } +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProviderTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProviderTests.java new file mode 100644 index 00000000000..344d3838f03 --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProviderTests.java @@ -0,0 +1,211 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; +import org.springframework.security.oauth2.client.endpoint.OAuth2RefreshTokenGrantRequest; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; +import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; +import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses; + +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import java.util.Collections; +import java.util.Set; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.*; + +/** + * Tests for {@link RefreshTokenOAuth2AuthorizedClientProvider}. + * + * @author Joe Grandja + */ +public class RefreshTokenOAuth2AuthorizedClientProviderTests { + private ClientRegistrationRepository clientRegistrationRepository; + private OAuth2AuthorizedClientRepository authorizedClientRepository; + private RefreshTokenOAuth2AuthorizedClientProvider authorizedClientProvider; + private OAuth2AccessTokenResponseClient accessTokenResponseClient; + private ClientRegistration clientRegistration; + private Authentication principal; + private OAuth2AuthorizedClient authorizedClient; + + @Before + public void setup() { + this.clientRegistrationRepository = mock(ClientRegistrationRepository.class); + this.authorizedClientRepository = mock(OAuth2AuthorizedClientRepository.class); + this.authorizedClientProvider = new RefreshTokenOAuth2AuthorizedClientProvider( + this.clientRegistrationRepository, this.authorizedClientRepository); + this.accessTokenResponseClient = mock(OAuth2AccessTokenResponseClient.class); + this.authorizedClientProvider.setAccessTokenResponseClient(this.accessTokenResponseClient); + this.clientRegistration = TestClientRegistrations.clientRegistration().build(); + this.principal = new TestingAuthenticationToken("principal", "password"); + this.authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, this.principal.getName(), + TestOAuth2AccessTokens.scopes("read", "write"), TestOAuth2RefreshTokens.refreshToken()); + } + + @Test + public void constructorWhenClientRegistrationRepositoryIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new RefreshTokenOAuth2AuthorizedClientProvider(null, this.authorizedClientRepository)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("clientRegistrationRepository cannot be null"); + } + + @Test + public void constructorWhenOAuth2AuthorizedClientRepositoryIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new RefreshTokenOAuth2AuthorizedClientProvider(this.clientRegistrationRepository, null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("authorizedClientRepository cannot be null"); + } + + @Test + public void setAccessTokenResponseClientWhenClientIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.authorizedClientProvider.setAccessTokenResponseClient(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("accessTokenResponseClient cannot be null"); + } + + @Test + public void authorizeWhenContextIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.authorizedClientProvider.authorize(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("context cannot be null"); + } + + @Test + public void authorizeWhenNotAuthorizedThenUnableToReauthorize() { + OAuth2AuthorizationContext authorizationContext = + OAuth2AuthorizationContext.authorize(this.clientRegistration).principal(this.principal).build(); + assertThat(this.authorizedClientProvider.authorize(authorizationContext)).isNull(); + } + + @Test + public void authorizeWhenAuthorizedAndRefreshTokenIsNullThenUnableToReauthorize() { + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( + this.clientRegistration, this.principal.getName(), this.authorizedClient.getAccessToken()); + OAuth2AuthorizationContext authorizationContext = + OAuth2AuthorizationContext.reauthorize(authorizedClient).principal(this.principal).build(); + assertThat(this.authorizedClientProvider.authorize(authorizationContext)).isNull(); + } + + @Test + public void authorizeWhenHttpServletRequestIsNullThenThrowIllegalArgumentException() { + OAuth2AuthorizationContext authorizationContext = + OAuth2AuthorizationContext.reauthorize(this.authorizedClient).principal(this.principal).build(); + assertThatThrownBy(() -> this.authorizedClientProvider.authorize(authorizationContext)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("context.HttpServletRequest cannot be null"); + } + + @Test + public void authorizeWhenHttpServletResponseIsNullThenThrowIllegalArgumentException() { + OAuth2AuthorizationContext authorizationContext = + OAuth2AuthorizationContext.reauthorize(this.authorizedClient) + .principal(this.principal) + .attribute(HttpServletRequest.class.getName(), new MockHttpServletRequest()) + .build(); + assertThatThrownBy(() -> this.authorizedClientProvider.authorize(authorizationContext)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("context.HttpServletResponse cannot be null"); + } + + @Test + public void authorizeWhenAuthorizedWithRefreshTokenThenReauthorize() { + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse() + .refreshToken("new-refresh-token") + .build(); + when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse); + + OAuth2AuthorizationContext authorizationContext = + OAuth2AuthorizationContext.reauthorize(this.authorizedClient) + .principal(this.principal) + .attribute(HttpServletRequest.class.getName(), new MockHttpServletRequest()) + .attribute(HttpServletResponse.class.getName(), new MockHttpServletResponse()) + .build(); + + OAuth2AuthorizedClient authorizedClient = this.authorizedClientProvider.authorize(authorizationContext); + + assertThat(authorizedClient.getClientRegistration()).isSameAs(this.clientRegistration); + assertThat(authorizedClient.getPrincipalName()).isEqualTo(this.principal.getName()); + assertThat(authorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken()); + assertThat(authorizedClient.getRefreshToken()).isEqualTo(accessTokenResponse.getRefreshToken()); + verify(this.authorizedClientRepository).saveAuthorizedClient( + eq(authorizedClient), eq(this.principal), + any(HttpServletRequest.class), any(HttpServletResponse.class)); + } + + @Test + public void authorizeWhenAuthorizedAndScopeProvidedThenScopeRequested() { + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse() + .refreshToken("new-refresh-token") + .build(); + when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse); + + Set scope = Collections.singleton("read"); + + OAuth2AuthorizationContext authorizationContext = + OAuth2AuthorizationContext.reauthorize(this.authorizedClient) + .principal(this.principal) + .attribute(HttpServletRequest.class.getName(), new MockHttpServletRequest()) + .attribute(HttpServletResponse.class.getName(), new MockHttpServletResponse()) + .attribute("SCOPE", scope) + .build(); + + this.authorizedClientProvider.authorize(authorizationContext); + + ArgumentCaptor refreshTokenGrantRequestArgCaptor = + ArgumentCaptor.forClass(OAuth2RefreshTokenGrantRequest.class); + verify(this.accessTokenResponseClient).getTokenResponse(refreshTokenGrantRequestArgCaptor.capture()); + assertThat(refreshTokenGrantRequestArgCaptor.getValue().getScopes()).isEqualTo(scope); + } + + @Test + public void authorizeWhenAuthorizedAndInvalidScopeProvidedThenThrowIllegalArgumentException() { + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse() + .refreshToken("new-refresh-token") + .build(); + when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse); + + String scope = "read"; + + OAuth2AuthorizationContext authorizationContext = + OAuth2AuthorizationContext.reauthorize(this.authorizedClient) + .principal(this.principal) + .attribute(HttpServletRequest.class.getName(), new MockHttpServletRequest()) + .attribute(HttpServletResponse.class.getName(), new MockHttpServletResponse()) + .attribute("SCOPE", scope) + .build(); + + assertThatThrownBy(() -> this.authorizedClientProvider.authorize(authorizationContext)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("The 'SCOPE' attribute must be of type " + Set.class.getName()); + } +} From 22d43b99bec22282899f7000ff284a2bbab9a1bb Mon Sep 17 00:00:00 2001 From: Joe Grandja Date: Wed, 12 Jun 2019 07:34:56 -0400 Subject: [PATCH 06/19] Add delegating OAuth2AuthorizedClientProvider --- ...egatingOAuth2AuthorizedClientProvider.java | 73 ++++++++++++++++ ...ngOAuth2AuthorizedClientProviderTests.java | 83 +++++++++++++++++++ 2 files changed, 156 insertions(+) create mode 100644 oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/DelegatingOAuth2AuthorizedClientProvider.java create mode 100644 oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/DelegatingOAuth2AuthorizedClientProviderTests.java diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/DelegatingOAuth2AuthorizedClientProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/DelegatingOAuth2AuthorizedClientProvider.java new file mode 100644 index 00000000000..0343b96071c --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/DelegatingOAuth2AuthorizedClientProvider.java @@ -0,0 +1,73 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client; + +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Objects; + +/** + * An implementation of an {@link OAuth2AuthorizedClientProvider} that simply delegates + * to it's internal {@code List} of {@link OAuth2AuthorizedClientProvider}(s). + *

+ * Each provider is given a chance to + * {@link OAuth2AuthorizedClientProvider#authorize(OAuth2AuthorizationContext) authorize} + * the {@link OAuth2AuthorizationContext#getClientRegistration() client} in the provided context + * with the first {@code non-null} {@link OAuth2AuthorizedClient} being returned. + * + * @author Joe Grandja + * @since 5.2 + * @see OAuth2AuthorizedClientProvider + */ +public final class DelegatingOAuth2AuthorizedClientProvider implements OAuth2AuthorizedClientProvider { + private final List authorizedClientProviders; + + /** + * Constructs a {@code DelegatingOAuth2AuthorizedClientProvider} using the provided parameters. + * + * @param authorizedClientProviders a list of {@link OAuth2AuthorizedClientProvider}(s) + */ + public DelegatingOAuth2AuthorizedClientProvider(OAuth2AuthorizedClientProvider... authorizedClientProviders) { + Assert.notEmpty(authorizedClientProviders, "authorizedClientProviders cannot be empty"); + this.authorizedClientProviders = Collections.unmodifiableList(Arrays.asList(authorizedClientProviders)); + } + + /** + * Constructs a {@code DelegatingOAuth2AuthorizedClientProvider} using the provided parameters. + * + * @param authorizedClientProviders a {@code List} of {@link OAuth2AuthorizedClientProvider}(s) + */ + public DelegatingOAuth2AuthorizedClientProvider(List authorizedClientProviders) { + Assert.notEmpty(authorizedClientProviders, "authorizedClientProviders cannot be empty"); + this.authorizedClientProviders = Collections.unmodifiableList(new ArrayList<>(authorizedClientProviders)); + } + + @Override + @Nullable + public OAuth2AuthorizedClient authorize(OAuth2AuthorizationContext context) { + Assert.notNull(context, "context cannot be null"); + return this.authorizedClientProviders.stream() + .map(authorizedClientProvider -> authorizedClientProvider.authorize(context)) + .filter(Objects::nonNull) + .findFirst() + .orElse(null); + } +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/DelegatingOAuth2AuthorizedClientProviderTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/DelegatingOAuth2AuthorizedClientProviderTests.java new file mode 100644 index 00000000000..aa833ba309d --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/DelegatingOAuth2AuthorizedClientProviderTests.java @@ -0,0 +1,83 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client; + +import org.junit.Test; +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; + +import java.util.Collections; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +/** + * Tests for {@link DelegatingOAuth2AuthorizedClientProvider}. + * + * @author Joe Grandja + */ +public class DelegatingOAuth2AuthorizedClientProviderTests { + + @Test + public void constructorWhenProvidersIsEmptyThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new DelegatingOAuth2AuthorizedClientProvider(new OAuth2AuthorizedClientProvider[0])) + .isInstanceOf(IllegalArgumentException.class); + assertThatThrownBy(() -> new DelegatingOAuth2AuthorizedClientProvider(Collections.emptyList())) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void authorizeWhenContextIsNullThenThrowIllegalArgumentException() { + DelegatingOAuth2AuthorizedClientProvider delegate = new DelegatingOAuth2AuthorizedClientProvider( + mock(OAuth2AuthorizedClientProvider.class)); + assertThatThrownBy(() -> delegate.authorize(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("context cannot be null"); + } + + @Test + public void authorizeWhenProviderCanAuthorizeThenReturnAuthorizedClient() { + Authentication principal = new TestingAuthenticationToken("principal", "password"); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( + TestClientRegistrations.clientRegistration().build(), principal.getName(), TestOAuth2AccessTokens.noScopes()); + + OAuth2AuthorizedClientProvider authorizedClientProvider = mock(OAuth2AuthorizedClientProvider.class); + when(authorizedClientProvider.authorize(any())).thenReturn(authorizedClient); + + DelegatingOAuth2AuthorizedClientProvider delegate = new DelegatingOAuth2AuthorizedClientProvider( + mock(OAuth2AuthorizedClientProvider.class), mock(OAuth2AuthorizedClientProvider.class), authorizedClientProvider); + OAuth2AuthorizationContext context = OAuth2AuthorizationContext.reauthorize(authorizedClient).principal(principal).build(); + OAuth2AuthorizedClient reauthorizedClient = delegate.authorize(context); + assertThat(reauthorizedClient).isSameAs(authorizedClient); + } + + @Test + public void authorizeWhenProviderCantAuthorizeThenReturnNull() { + OAuth2AuthorizationContext context = OAuth2AuthorizationContext + .authorize(TestClientRegistrations.clientRegistration().build()) + .principal(new TestingAuthenticationToken("principal", "password")) + .build(); + + DelegatingOAuth2AuthorizedClientProvider delegate = new DelegatingOAuth2AuthorizedClientProvider( + mock(OAuth2AuthorizedClientProvider.class), mock(OAuth2AuthorizedClientProvider.class)); + assertThat(delegate.authorize(context)).isNull(); + } +} From 7303821ae5b89faa4837f49c7f8ab6d4822afbfc Mon Sep 17 00:00:00 2001 From: Joe Grandja Date: Wed, 12 Jun 2019 15:07:18 -0400 Subject: [PATCH 07/19] Refactor and use OAuth2AuthorizedClientProvider implementations --- .../OAuth2ClientConfiguration.java | 17 +- ...Auth2AuthorizedClientArgumentResolver.java | 76 ++++--- ...uthorizedClientExchangeFilterFunction.java | 202 +++++++----------- ...AuthorizedClientArgumentResolverTests.java | 24 ++- ...izedClientExchangeFilterFunctionTests.java | 190 ++++++++-------- 5 files changed, 246 insertions(+), 263 deletions(-) diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfiguration.java b/config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfiguration.java index 4151010a846..09a9ae86b7e 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfiguration.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfiguration.java @@ -15,13 +15,15 @@ */ package org.springframework.security.config.annotation.web.configuration; -import java.util.List; -import java.util.Optional; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Import; import org.springframework.context.annotation.ImportSelector; import org.springframework.core.type.AnnotationMetadata; +import org.springframework.security.oauth2.client.AuthorizationCodeOAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.ClientCredentialsOAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.DelegatingOAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProvider; import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest; import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; @@ -31,6 +33,9 @@ import org.springframework.web.method.support.HandlerMethodArgumentResolver; import org.springframework.web.servlet.config.annotation.WebMvcConfigurer; +import java.util.List; +import java.util.Optional; + /** * {@link Configuration} for OAuth 2.0 Client support. * @@ -71,7 +76,13 @@ public void addArgumentResolvers(List argumentRes new OAuth2AuthorizedClientArgumentResolver( this.clientRegistrationRepository, this.authorizedClientRepository); if (this.accessTokenResponseClient != null) { - authorizedClientArgumentResolver.setClientCredentialsTokenResponseClient(this.accessTokenResponseClient); + ClientCredentialsOAuth2AuthorizedClientProvider clientCredentialsAuthorizedClientProvider = + new ClientCredentialsOAuth2AuthorizedClientProvider( + this.clientRegistrationRepository, this.authorizedClientRepository); + clientCredentialsAuthorizedClientProvider.setAccessTokenResponseClient(this.accessTokenResponseClient); + OAuth2AuthorizedClientProvider authorizedClientProvider = new DelegatingOAuth2AuthorizedClientProvider( + new AuthorizationCodeOAuth2AuthorizedClientProvider(), clientCredentialsAuthorizedClientProvider); + authorizedClientArgumentResolver.setAuthorizedClientProvider(authorizedClientProvider); } argumentResolvers.add(authorizedClientArgumentResolver); } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolver.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolver.java index 7a61e319c65..6e36805492c 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolver.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolver.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2019 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,8 +21,12 @@ import org.springframework.lang.Nullable; import org.springframework.security.core.Authentication; import org.springframework.security.core.context.SecurityContextHolder; -import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException; +import org.springframework.security.oauth2.client.AuthorizationCodeOAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.ClientCredentialsOAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.DelegatingOAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.OAuth2AuthorizationContext; import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProvider; import org.springframework.security.oauth2.client.annotation.RegisteredOAuth2AuthorizedClient; import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; import org.springframework.security.oauth2.client.endpoint.DefaultClientCredentialsTokenResponseClient; @@ -31,8 +35,6 @@ import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; -import org.springframework.security.oauth2.core.AuthorizationGrantType; -import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.util.Assert; import org.springframework.util.StringUtils; import org.springframework.web.bind.support.WebDataBinderFactory; @@ -66,8 +68,7 @@ public final class OAuth2AuthorizedClientArgumentResolver implements HandlerMethodArgumentResolver { private final ClientRegistrationRepository clientRegistrationRepository; private final OAuth2AuthorizedClientRepository authorizedClientRepository; - private OAuth2AccessTokenResponseClient clientCredentialsTokenResponseClient = - new DefaultClientCredentialsTokenResponseClient(); + private OAuth2AuthorizedClientProvider authorizedClientProvider; /** * Constructs an {@code OAuth2AuthorizedClientArgumentResolver} using the provided parameters. @@ -81,6 +82,7 @@ public OAuth2AuthorizedClientArgumentResolver(ClientRegistrationRepository clien Assert.notNull(authorizedClientRepository, "authorizedClientRepository cannot be null"); this.clientRegistrationRepository = clientRegistrationRepository; this.authorizedClientRepository = authorizedClientRepository; + this.authorizedClientProvider = createAuthorizedClientProvider(new DefaultClientCredentialsTokenResponseClient()); } @Override @@ -119,16 +121,20 @@ public Object resolveArgument(MethodParameter parameter, return null; } - if (AuthorizationGrantType.AUTHORIZATION_CODE.equals(clientRegistration.getAuthorizationGrantType())) { - throw new ClientAuthorizationRequiredException(clientRegistrationId); - } + HttpServletResponse servletResponse = webRequest.getNativeResponse(HttpServletResponse.class); - if (AuthorizationGrantType.CLIENT_CREDENTIALS.equals(clientRegistration.getAuthorizationGrantType())) { - HttpServletResponse servletResponse = webRequest.getNativeResponse(HttpServletResponse.class); - authorizedClient = this.authorizeClientCredentialsClient(clientRegistration, servletRequest, servletResponse); + OAuth2AuthorizationContext.Builder authorizationContextBuilder = OAuth2AuthorizationContext.authorize(clientRegistration); + if (principal == null) { + authorizationContextBuilder.principal("anonymousUser"); + } else { + authorizationContextBuilder.principal(principal); } + OAuth2AuthorizationContext authorizationContext = authorizationContextBuilder + .attribute(HttpServletRequest.class.getName(), servletRequest) + .attribute(HttpServletResponse.class.getName(), servletResponse) + .build(); - return authorizedClient; + return this.authorizedClientProvider.authorize(authorizationContext); } private String resolveClientRegistrationId(MethodParameter parameter) { @@ -149,37 +155,37 @@ private String resolveClientRegistrationId(MethodParameter parameter) { return clientRegistrationId; } - private OAuth2AuthorizedClient authorizeClientCredentialsClient(ClientRegistration clientRegistration, - HttpServletRequest request, HttpServletResponse response) { - OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = - new OAuth2ClientCredentialsGrantRequest(clientRegistration); - OAuth2AccessTokenResponse tokenResponse = - this.clientCredentialsTokenResponseClient.getTokenResponse(clientCredentialsGrantRequest); - - Authentication principal = SecurityContextHolder.getContext().getAuthentication(); - - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( - clientRegistration, - (principal != null ? principal.getName() : "anonymousUser"), - tokenResponse.getAccessToken()); - - this.authorizedClientRepository.saveAuthorizedClient( - authorizedClient, - principal, - request, - response); - - return authorizedClient; + /** + * Sets the {@link OAuth2AuthorizedClientProvider} used for authorizing (or re-authorizing) an OAuth 2.0 Client. + * + * @since 5.2 + * @param authorizedClientProvider the {@link OAuth2AuthorizedClientProvider} used for authorizing (or re-authorizing) an OAuth 2.0 Client. + */ + public void setAuthorizedClientProvider(OAuth2AuthorizedClientProvider authorizedClientProvider) { + Assert.notNull(authorizedClientProvider, "authorizedClientProvider cannot be null"); + this.authorizedClientProvider = authorizedClientProvider; } /** * Sets the client used when requesting an access token credential at the Token Endpoint for the {@code client_credentials} grant. * + * @deprecated Use {@link #setAuthorizedClientProvider(OAuth2AuthorizedClientProvider)} instead by providing it an instance of {@link ClientCredentialsOAuth2AuthorizedClientProvider} configured with a {@link ClientCredentialsOAuth2AuthorizedClientProvider#setAccessTokenResponseClient(OAuth2AccessTokenResponseClient) DefaultClientCredentialsTokenResponseClient} or a custom one. + * * @param clientCredentialsTokenResponseClient the client used when requesting an access token credential at the Token Endpoint for the {@code client_credentials} grant */ + @Deprecated public final void setClientCredentialsTokenResponseClient( OAuth2AccessTokenResponseClient clientCredentialsTokenResponseClient) { Assert.notNull(clientCredentialsTokenResponseClient, "clientCredentialsTokenResponseClient cannot be null"); - this.clientCredentialsTokenResponseClient = clientCredentialsTokenResponseClient; + this.authorizedClientProvider = createAuthorizedClientProvider(clientCredentialsTokenResponseClient); + } + + private OAuth2AuthorizedClientProvider createAuthorizedClientProvider( + OAuth2AccessTokenResponseClient clientCredentialsTokenResponseClient) { + ClientCredentialsOAuth2AuthorizedClientProvider clientCredentialsAuthorizedClientProvider = + new ClientCredentialsOAuth2AuthorizedClientProvider(this.clientRegistrationRepository, this.authorizedClientRepository); + clientCredentialsAuthorizedClientProvider.setAccessTokenResponseClient(clientCredentialsTokenResponseClient); + return new DelegatingOAuth2AuthorizedClientProvider( + new AuthorizationCodeOAuth2AuthorizedClientProvider(), clientCredentialsAuthorizedClientProvider); } } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java index 0054ca164c5..3249de24c3d 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java @@ -19,29 +19,25 @@ import org.reactivestreams.Subscription; import org.springframework.beans.factory.DisposableBean; import org.springframework.beans.factory.InitializingBean; -import org.springframework.http.HttpHeaders; -import org.springframework.http.HttpMethod; -import org.springframework.http.MediaType; import org.springframework.security.core.Authentication; import org.springframework.security.core.GrantedAuthority; -import org.springframework.security.core.context.ReactiveSecurityContextHolder; import org.springframework.security.core.context.SecurityContextHolder; -import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException; +import org.springframework.security.oauth2.client.AuthorizationCodeOAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.ClientCredentialsOAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.DelegatingOAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.OAuth2AuthorizationContext; import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.RefreshTokenOAuth2AuthorizedClientProvider; import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; -import org.springframework.security.oauth2.client.endpoint.DefaultClientCredentialsTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; -import org.springframework.security.oauth2.core.AuthorizationGrantType; -import org.springframework.security.oauth2.core.OAuth2RefreshToken; -import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.util.Assert; import org.springframework.web.context.request.RequestContextHolder; import org.springframework.web.context.request.ServletRequestAttributes; -import org.springframework.web.reactive.function.BodyInserters; import org.springframework.web.reactive.function.client.ClientRequest; import org.springframework.web.reactive.function.client.ClientResponse; import org.springframework.web.reactive.function.client.ExchangeFilterFunction; @@ -51,22 +47,17 @@ import reactor.core.publisher.Hooks; import reactor.core.publisher.Mono; import reactor.core.publisher.Operators; -import reactor.core.scheduler.Schedulers; import reactor.util.context.Context; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; -import java.net.URI; import java.time.Clock; import java.time.Duration; import java.time.Instant; import java.util.Collection; import java.util.Map; -import java.util.Optional; import java.util.function.Consumer; -import static org.springframework.security.oauth2.core.web.reactive.function.OAuth2BodyExtractors.oauth2AccessTokenResponse; - /** * Provides an easy mechanism for using an {@link OAuth2AuthorizedClient} to make OAuth2 requests by including the * token as a Bearer Token. It also provides mechanisms for looking up the {@link OAuth2AuthorizedClient}. This class is @@ -75,7 +66,7 @@ * Example usage: * *

- * OAuth2AuthorizedClientExchangeFilterFunction oauth2 = new OAuth2AuthorizedClientExchangeFilterFunction(authorizedClientService);
+ * ServletOAuth2AuthorizedClientExchangeFilterFunction oauth2 = new ServletOAuth2AuthorizedClientExchangeFilterFunction(clientRegistrationRepository, authorizedClientRepository);
  * WebClient webClient = WebClient.builder()
  *    .apply(oauth2.oauth2Configuration())
  *    .build();
@@ -92,17 +83,18 @@
  * are true:
  *
  * 
    - *
  • The ReactiveOAuth2AuthorizedClientService on the + *
  • The {@link #setAuthorizedClientProvider(OAuth2AuthorizedClientProvider) OAuth2AuthorizedClientProvider} on the * {@link ServletOAuth2AuthorizedClientExchangeFilterFunction} is not null
  • - *
  • A refresh token is present on the OAuth2AuthorizedClient
  • + *
  • A refresh token is present on the {@link OAuth2AuthorizedClient}
  • *
  • The access token will be expired in * {@link #setAccessTokenExpiresSkew(Duration)}
  • - *
  • The {@link ReactiveSecurityContextHolder} will be used to attempt to save - * the token. If it is empty, then the principal name on the OAuth2AuthorizedClient + *
  • The {@link SecurityContextHolder} will be used to attempt to save + * the token. If it is empty, then the principal name on the {@link OAuth2AuthorizedClient} * will be used to create an Authentication for saving.
  • *
* * @author Rob Winch + * @author Joe Grandja * @since 5.1 */ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction @@ -127,8 +119,7 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction private OAuth2AuthorizedClientRepository authorizedClientRepository; - private OAuth2AccessTokenResponseClient clientCredentialsTokenResponseClient = - new DefaultClientCredentialsTokenResponseClient(); + private OAuth2AuthorizedClientProvider authorizedClientProvider; private boolean defaultOAuth2AuthorizedClient; @@ -142,6 +133,15 @@ public ServletOAuth2AuthorizedClientExchangeFilterFunction( OAuth2AuthorizedClientRepository authorizedClientRepository) { this.clientRegistrationRepository = clientRegistrationRepository; this.authorizedClientRepository = authorizedClientRepository; + this.authorizedClientProvider = createDefaultAuthorizedClientProvider(clientRegistrationRepository, authorizedClientRepository); + } + + private static OAuth2AuthorizedClientProvider createDefaultAuthorizedClientProvider( + ClientRegistrationRepository clientRegistrationRepository, OAuth2AuthorizedClientRepository authorizedClientRepository) { + return new DelegatingOAuth2AuthorizedClientProvider( + new AuthorizationCodeOAuth2AuthorizedClientProvider(), + new ClientCredentialsOAuth2AuthorizedClientProvider(clientRegistrationRepository, authorizedClientRepository), + new RefreshTokenOAuth2AuthorizedClientProvider(clientRegistrationRepository, authorizedClientRepository)); } @Override @@ -155,14 +155,39 @@ public void destroy() throws Exception { } /** - * Sets the {@link OAuth2AccessTokenResponseClient} to be used for getting an {@link OAuth2AuthorizedClient} for - * client_credentials grant. + * Sets the {@link OAuth2AuthorizedClientProvider} used for authorizing (or re-authorizing) an OAuth 2.0 Client. + * + * @since 5.2 + * @param authorizedClientProvider the {@link OAuth2AuthorizedClientProvider} used for authorizing (or re-authorizing) an OAuth 2.0 Client. + */ + public void setAuthorizedClientProvider(OAuth2AuthorizedClientProvider authorizedClientProvider) { + Assert.notNull(authorizedClientProvider, "authorizedClientProvider cannot be null"); + this.authorizedClientProvider = authorizedClientProvider; + } + + /** + * Sets the {@link OAuth2AccessTokenResponseClient} used for getting an {@link OAuth2AuthorizedClient} for the client_credentials grant. + * + * @deprecated Use {@link #setAuthorizedClientProvider(OAuth2AuthorizedClientProvider)} instead by providing it an instance of {@link ClientCredentialsOAuth2AuthorizedClientProvider} configured with a {@link ClientCredentialsOAuth2AuthorizedClientProvider#setAccessTokenResponseClient(OAuth2AccessTokenResponseClient) DefaultClientCredentialsTokenResponseClient} or a custom one. + * * @param clientCredentialsTokenResponseClient the client to use */ + @Deprecated public void setClientCredentialsTokenResponseClient( OAuth2AccessTokenResponseClient clientCredentialsTokenResponseClient) { Assert.notNull(clientCredentialsTokenResponseClient, "clientCredentialsTokenResponseClient cannot be null"); - this.clientCredentialsTokenResponseClient = clientCredentialsTokenResponseClient; + this.authorizedClientProvider = createAuthorizedClientProvider(clientCredentialsTokenResponseClient); + } + + private OAuth2AuthorizedClientProvider createAuthorizedClientProvider( + OAuth2AccessTokenResponseClient clientCredentialsTokenResponseClient) { + ClientCredentialsOAuth2AuthorizedClientProvider clientCredentialsAuthorizedClientProvider = + new ClientCredentialsOAuth2AuthorizedClientProvider(this.clientRegistrationRepository, this.authorizedClientRepository); + clientCredentialsAuthorizedClientProvider.setAccessTokenResponseClient(clientCredentialsTokenResponseClient); + return new DelegatingOAuth2AuthorizedClientProvider( + new AuthorizationCodeOAuth2AuthorizedClientProvider(), + clientCredentialsAuthorizedClientProvider, + new RefreshTokenOAuth2AuthorizedClientProvider(this.clientRegistrationRepository, this.authorizedClientRepository)); } /** @@ -292,7 +317,7 @@ public Mono filter(ClientRequest request, ExchangeFunction next) .filter(req -> req.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME).isPresent()) .switchIfEmpty(mergeRequestAttributesFromContext(request)) .filter(req -> req.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME).isPresent()) - .flatMap(req -> authorizedClient(req, next, getOAuth2AuthorizedClient(req.attributes()))) + .flatMap(req -> reauthorizeClientIfNecessary(req, next, getOAuth2AuthorizedClient(req.attributes()))) .map(authorizedClient -> bearer(request, authorizedClient)) .flatMap(next::exchange) .switchIfEmpty(next.exchange(request)); @@ -319,8 +344,8 @@ private void populateRequestAttributes(Map attrs, Context ctx) { } private void populateDefaultRequestResponse(Map attrs) { - if (attrs.containsKey(HTTP_SERVLET_REQUEST_ATTR_NAME) && attrs.containsKey( - HTTP_SERVLET_RESPONSE_ATTR_NAME)) { + if (attrs.containsKey(HTTP_SERVLET_REQUEST_ATTR_NAME) && + attrs.containsKey(HTTP_SERVLET_RESPONSE_ATTR_NAME)) { return; } ServletRequestAttributes context = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes(); @@ -343,8 +368,8 @@ private void populateDefaultAuthentication(Map attrs) { } private void populateDefaultOAuth2AuthorizedClient(Map attrs) { - if (this.authorizedClientRepository == null - || attrs.containsKey(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME)) { + if (this.authorizedClientRepository == null || + attrs.containsKey(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME)) { return; } @@ -360,9 +385,8 @@ private void populateDefaultOAuth2AuthorizedClient(Map attrs) { } if (clientRegistrationId != null) { HttpServletRequest request = getRequest(attrs); - OAuth2AuthorizedClient authorizedClient = this.authorizedClientRepository - .loadAuthorizedClient(clientRegistrationId, authentication, - request); + OAuth2AuthorizedClient authorizedClient = this.authorizedClientRepository.loadAuthorizedClient( + clientRegistrationId, authentication, request); if (authorizedClient == null) { authorizedClient = getAuthorizedClient(clientRegistrationId, attrs); } @@ -375,92 +399,35 @@ private OAuth2AuthorizedClient getAuthorizedClient(String clientRegistrationId, if (clientRegistration == null) { throw new IllegalArgumentException("Could not find ClientRegistration with id " + clientRegistrationId); } - if (isClientCredentialsGrantType(clientRegistration)) { - return authorizeWithClientCredentials(clientRegistration, attrs); + Authentication authentication = getAuthentication(attrs); + if (authentication == null) { + authentication = new PrincipalNameAuthentication("anonymousUser"); } - throw new ClientAuthorizationRequiredException(clientRegistrationId); - } - - private boolean isClientCredentialsGrantType(ClientRegistration clientRegistration) { - return AuthorizationGrantType.CLIENT_CREDENTIALS.equals(clientRegistration.getAuthorizationGrantType()); + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext.authorize(clientRegistration) + .principal(authentication) + .attribute(HttpServletRequest.class.getName(), getRequest(attrs)) + .attribute(HttpServletResponse.class.getName(), getResponse(attrs)) + .build(); + return this.authorizedClientProvider.authorize(authorizationContext); } - private OAuth2AuthorizedClient authorizeWithClientCredentials( - ClientRegistration clientRegistration, Map attrs) { - HttpServletRequest request = getRequest(attrs); - HttpServletResponse response = getResponse(attrs); - OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = - new OAuth2ClientCredentialsGrantRequest(clientRegistration); - OAuth2AccessTokenResponse tokenResponse = - this.clientCredentialsTokenResponseClient.getTokenResponse(clientCredentialsGrantRequest); - - Authentication principal = getAuthentication(attrs); - - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( - clientRegistration, - (principal != null ? principal.getName() : "anonymousUser"), - tokenResponse.getAccessToken()); - - this.authorizedClientRepository.saveAuthorizedClient( - authorizedClient, - principal, - request, - response); - - return authorizedClient; - } + private Mono reauthorizeClientIfNecessary( + ClientRequest request, ExchangeFunction next, OAuth2AuthorizedClient authorizedClient) { + if (this.authorizedClientProvider == null || !hasTokenExpired(authorizedClient)) { + return Mono.just(authorizedClient); + } - private Mono authorizedClient(ClientRequest request, ExchangeFunction next, OAuth2AuthorizedClient authorizedClient) { - ClientRegistration clientRegistration = authorizedClient.getClientRegistration(); - if (isClientCredentialsGrantType(clientRegistration) && hasTokenExpired(authorizedClient)) { - // Client credentials grant do not have refresh tokens but can expire so we need to get another one - return Mono.fromSupplier(() -> authorizeWithClientCredentials(clientRegistration, request.attributes())); - } else if (shouldRefreshToken(authorizedClient)) { - return authorizeWithRefreshToken(request, next, authorizedClient); + Map attributes = request.attributes(); + Authentication authentication = getAuthentication(attributes); + if (authentication == null) { + authentication = new PrincipalNameAuthentication(authorizedClient.getPrincipalName()); } - return Mono.just(authorizedClient); - } - - private Mono authorizeWithRefreshToken(ClientRequest request, ExchangeFunction next, - OAuth2AuthorizedClient authorizedClient) { - ClientRegistration clientRegistration = authorizedClient - .getClientRegistration(); - String tokenUri = clientRegistration - .getProviderDetails().getTokenUri(); - ClientRequest refreshRequest = ClientRequest.create(HttpMethod.POST, URI.create(tokenUri)) - .header(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE) - .headers(headers -> headers.setBasicAuth(clientRegistration.getClientId(), clientRegistration.getClientSecret())) - .body(refreshTokenBody(authorizedClient.getRefreshToken().getTokenValue())) + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext.reauthorize(authorizedClient) + .principal(authentication) + .attribute(HttpServletRequest.class.getName(), getRequest(attributes)) + .attribute(HttpServletResponse.class.getName(), getResponse(attributes)) .build(); - return next.exchange(refreshRequest) - .flatMap(response -> response.body(oauth2AccessTokenResponse())) - .map(accessTokenResponse -> { - OAuth2RefreshToken refreshToken = Optional.ofNullable(accessTokenResponse.getRefreshToken()) - .orElse(authorizedClient.getRefreshToken()); - return new OAuth2AuthorizedClient(authorizedClient.getClientRegistration(), authorizedClient.getPrincipalName(), accessTokenResponse.getAccessToken(), refreshToken); - }) - .map(result -> { - Authentication principal = (Authentication) request.attribute( - AUTHENTICATION_ATTR_NAME).orElse(new PrincipalNameAuthentication(authorizedClient.getPrincipalName())); - HttpServletRequest httpRequest = (HttpServletRequest) request.attributes().get( - HTTP_SERVLET_REQUEST_ATTR_NAME); - HttpServletResponse httpResponse = (HttpServletResponse) request.attributes().get( - HTTP_SERVLET_RESPONSE_ATTR_NAME); - this.authorizedClientRepository.saveAuthorizedClient(result, principal, httpRequest, httpResponse); - return result; - }) - .publishOn(Schedulers.elastic()); - } - - private boolean shouldRefreshToken(OAuth2AuthorizedClient authorizedClient) { - if (this.authorizedClientRepository == null) { - return false; - } - OAuth2RefreshToken refreshToken = authorizedClient.getRefreshToken(); - if (refreshToken == null) { - return false; - } - return hasTokenExpired(authorizedClient); + return Mono.fromSupplier(() -> this.authorizedClientProvider.authorize(authorizationContext)); } private boolean hasTokenExpired(OAuth2AuthorizedClient authorizedClient) { @@ -491,12 +458,6 @@ private CoreSubscriber createRequestContextSubscriber(CoreSubscriber d return new RequestContextSubscriber<>(delegate, request, response, authentication); } - private static BodyInserters.FormInserter refreshTokenBody(String refreshToken) { - return BodyInserters - .fromFormData("grant_type", AuthorizationGrantType.REFRESH_TOKEN.getValue()) - .with("refresh_token", refreshToken); - } - static OAuth2AuthorizedClient getOAuth2AuthorizedClient(Map attrs) { return (OAuth2AuthorizedClient) attrs.get(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME); } @@ -550,8 +511,7 @@ public boolean isAuthenticated() { } @Override - public void setAuthenticated(boolean isAuthenticated) - throws IllegalArgumentException { + public void setAuthenticated(boolean isAuthenticated) throws IllegalArgumentException { throw unsupported(); } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolverTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolverTests.java index 508c3eda454..491cfb52daa 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolverTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolverTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2019 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,11 +20,13 @@ import org.junit.Test; import org.springframework.core.MethodParameter; import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException; +import org.springframework.security.oauth2.client.ClientCredentialsOAuth2AuthorizedClientProvider; import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; import org.springframework.security.oauth2.client.annotation.RegisteredOAuth2AuthorizedClient; import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; @@ -42,6 +44,7 @@ import org.springframework.web.context.request.ServletWebRequest; import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; import java.lang.reflect.Method; import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; @@ -67,6 +70,7 @@ public class OAuth2AuthorizedClientArgumentResolverTests { private OAuth2AuthorizedClientRepository authorizedClientRepository; private OAuth2AuthorizedClientArgumentResolver argumentResolver; private MockHttpServletRequest request; + private MockHttpServletResponse response; @Before public void setup() { @@ -109,6 +113,7 @@ public void setup() { eq(this.registration2.getRegistrationId()), any(Authentication.class), any(HttpServletRequest.class))) .thenReturn(this.authorizedClient2); this.request = new MockHttpServletRequest(); + this.response = new MockHttpServletResponse(); } @After @@ -128,6 +133,12 @@ public void constructorWhenOAuth2AuthorizedClientRepositoryIsNullThenThrowIllega .isInstanceOf(IllegalArgumentException.class); } + @Test + public void setAuthorizedClientProviderWhenProviderIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.argumentResolver.setAuthorizedClientProvider(null)) + .isInstanceOf(IllegalArgumentException.class); + } + @Test public void setClientCredentialsTokenResponseClientWhenClientIsNullThenThrowIllegalArgumentException() { assertThatThrownBy(() -> this.argumentResolver.setClientCredentialsTokenResponseClient(null)) @@ -206,7 +217,12 @@ public void resolveArgumentWhenAuthorizedClientNotFoundForAuthorizationCodeClien public void resolveArgumentWhenAuthorizedClientNotFoundForClientCredentialsClientThenResolvesFromTokenResponseClient() throws Exception { OAuth2AccessTokenResponseClient clientCredentialsTokenResponseClient = mock(OAuth2AccessTokenResponseClient.class); - this.argumentResolver.setClientCredentialsTokenResponseClient(clientCredentialsTokenResponseClient); + ClientCredentialsOAuth2AuthorizedClientProvider clientCredentialsAuthorizedClientProvider = + new ClientCredentialsOAuth2AuthorizedClientProvider( + this.clientRegistrationRepository, this.authorizedClientRepository); + clientCredentialsAuthorizedClientProvider.setAccessTokenResponseClient(clientCredentialsTokenResponseClient); + this.argumentResolver.setAuthorizedClientProvider(clientCredentialsAuthorizedClientProvider); + OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse .withToken("access-token-1234") .tokenType(OAuth2AccessToken.TokenType.BEARER) @@ -219,7 +235,7 @@ public void resolveArgumentWhenAuthorizedClientNotFoundForClientCredentialsClien MethodParameter methodParameter = this.getMethodParameter("clientCredentialsClient", OAuth2AuthorizedClient.class); OAuth2AuthorizedClient authorizedClient = (OAuth2AuthorizedClient) this.argumentResolver.resolveArgument( - methodParameter, null, new ServletWebRequest(this.request), null); + methodParameter, null, new ServletWebRequest(this.request, this.response), null); assertThat(authorizedClient).isNotNull(); assertThat(authorizedClient.getClientRegistration()).isSameAs(this.registration2); @@ -227,7 +243,7 @@ public void resolveArgumentWhenAuthorizedClientNotFoundForClientCredentialsClien assertThat(authorizedClient.getAccessToken()).isSameAs(accessTokenResponse.getAccessToken()); verify(this.authorizedClientRepository).saveAuthorizedClient( - eq(authorizedClient), eq(this.authentication), any(HttpServletRequest.class), eq(null)); + eq(authorizedClient), eq(this.authentication), any(HttpServletRequest.class), any(HttpServletResponse.class)); } private MethodParameter getMethodParameter(String methodName, Class... paramTypes) { diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java index d99de2db281..0ad69a0e13e 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java @@ -13,7 +13,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.springframework.security.oauth2.client.web.reactive.function.client; import org.junit.After; @@ -28,6 +27,9 @@ import org.springframework.core.codec.CharSequenceEncoder; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; +import org.springframework.http.HttpStatus; +import org.springframework.http.RequestEntity; +import org.springframework.http.ResponseEntity; import org.springframework.http.codec.EncoderHttpMessageWriter; import org.springframework.http.codec.FormHttpMessageWriter; import org.springframework.http.codec.HttpMessageWriter; @@ -44,10 +46,17 @@ import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.oauth2.client.AuthorizationCodeOAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.ClientCredentialsOAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.DelegatingOAuth2AuthorizedClientProvider; import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.RefreshTokenOAuth2AuthorizedClientProvider; import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; +import org.springframework.security.oauth2.client.endpoint.DefaultRefreshTokenTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest; +import org.springframework.security.oauth2.client.endpoint.OAuth2RefreshTokenGrantRequest; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; @@ -57,12 +66,12 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses; import org.springframework.security.oauth2.core.user.OAuth2User; +import org.springframework.web.client.RestOperations; import org.springframework.web.context.request.RequestContextHolder; import org.springframework.web.context.request.ServletRequestAttributes; import org.springframework.web.reactive.function.BodyInserter; import org.springframework.web.reactive.function.client.ClientRequest; import org.springframework.web.reactive.function.client.WebClient; -import reactor.core.publisher.Mono; import java.net.URI; import java.time.Duration; @@ -76,6 +85,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.*; @@ -95,6 +105,8 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests { @Mock private OAuth2AccessTokenResponseClient clientCredentialsTokenResponseClient; @Mock + private OAuth2AccessTokenResponseClient refreshTokenTokenResponseClient; + @Mock private WebClient.RequestHeadersSpec spec; @Captor private ArgumentCaptor>> attrs; @@ -106,14 +118,15 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests { */ private Map result = new HashMap<>(); - private ServletOAuth2AuthorizedClientExchangeFilterFunction function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(); + private ServletOAuth2AuthorizedClientExchangeFilterFunction function; + + private OAuth2AuthorizedClientProvider authorizedClientProvider; private MockExchangeFunction exchange = new MockExchangeFunction(); private Authentication authentication; - private ClientRegistration registration = TestClientRegistrations.clientRegistration() - .build(); + private ClientRegistration registration = TestClientRegistrations.clientRegistration().build(); private OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, "token-0", @@ -123,6 +136,19 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests { @Before public void setup() { this.authentication = new TestingAuthenticationToken("test", "this"); + this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction( + this.clientRegistrationRepository, this.authorizedClientRepository); + ClientCredentialsOAuth2AuthorizedClientProvider clientCredentialsAuthorizedClientProvider = + new ClientCredentialsOAuth2AuthorizedClientProvider(this.clientRegistrationRepository, this.authorizedClientRepository); + clientCredentialsAuthorizedClientProvider.setAccessTokenResponseClient(this.clientCredentialsTokenResponseClient); + RefreshTokenOAuth2AuthorizedClientProvider refreshTokenAuthorizedClientProvider = + new RefreshTokenOAuth2AuthorizedClientProvider(this.clientRegistrationRepository, this.authorizedClientRepository); + refreshTokenAuthorizedClientProvider.setAccessTokenResponseClient(this.refreshTokenTokenResponseClient); + this.authorizedClientProvider = new DelegatingOAuth2AuthorizedClientProvider( + new AuthorizationCodeOAuth2AuthorizedClientProvider(), + clientCredentialsAuthorizedClientProvider, + refreshTokenAuthorizedClientProvider); + this.function.setAuthorizedClientProvider(this.authorizedClientProvider); } @After @@ -131,6 +157,12 @@ public void cleanup() { RequestContextHolder.resetRequestAttributes(); } + @Test + public void setAuthorizedClientProviderWhenProviderIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.function.setAuthorizedClientProvider(null)) + .isInstanceOf(IllegalArgumentException.class); + } + @Test public void defaultRequestRequestResponseWhenNullRequestContextThenRequestAndResponseNull() { Map attrs = getDefaultRequestAttributes(); @@ -156,8 +188,6 @@ public void defaultRequestAuthenticationWhenSecurityContextEmptyThenAuthenticati @Test public void defaultRequestAuthenticationWhenAuthenticationSetThenAuthenticationSet() { - this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, - this.authorizedClientRepository); SecurityContextHolder.getContext().setAuthentication(this.authentication); Map attrs = getDefaultRequestAttributes(); assertThat(getAuthentication(attrs)).isEqualTo(this.authentication); @@ -166,8 +196,6 @@ public void defaultRequestAuthenticationWhenAuthenticationSetThenAuthenticationS @Test public void defaultRequestOAuth2AuthorizedClientWhenOAuth2AuthorizationClientAndClientIdThenNotOverride() { - this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, - this.authorizedClientRepository); OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", this.accessToken); oauth2AuthorizedClient(authorizedClient).accept(this.result); @@ -178,8 +206,6 @@ public void defaultRequestOAuth2AuthorizedClientWhenOAuth2AuthorizationClientAnd @Test public void defaultRequestOAuth2AuthorizedClientWhenAuthenticationNullAndClientRegistrationIdNullThenOAuth2AuthorizedClientNull() { - this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, - this.authorizedClientRepository); Map attrs = getDefaultRequestAttributes(); assertThat(getOAuth2AuthorizedClient(attrs)).isNull(); verifyZeroInteractions(this.authorizedClientRepository); @@ -187,8 +213,6 @@ public void defaultRequestOAuth2AuthorizedClientWhenAuthenticationNullAndClientR @Test public void defaultRequestOAuth2AuthorizedClientWhenAuthenticationWrongTypeAndClientRegistrationIdNullThenOAuth2AuthorizedClientNull() { - this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, - this.authorizedClientRepository); Map attrs = getDefaultRequestAttributes(); assertThat(getOAuth2AuthorizedClient(attrs)).isNull(); verifyZeroInteractions(this.authorizedClientRepository); @@ -208,8 +232,6 @@ public void defaultRequestOAuth2AuthorizedClientWhenRepositoryNullThenOAuth2Auth @Test public void defaultRequestOAuth2AuthorizedClientWhenDefaultTrueAndAuthenticationAndClientRegistrationIdNullThenOAuth2AuthorizedClient() { - this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, - this.authorizedClientRepository); this.function.setDefaultOAuth2AuthorizedClient(true); OAuth2User user = mock(OAuth2User.class); List authorities = AuthorityUtils.createAuthorityList("ROLE_USER"); @@ -227,8 +249,6 @@ public void defaultRequestOAuth2AuthorizedClientWhenDefaultTrueAndAuthentication @Test public void defaultRequestOAuth2AuthorizedClientWhenDefaultFalseAndAuthenticationAndClientRegistrationIdNullThenOAuth2AuthorizedClient() { - this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, - this.authorizedClientRepository); OAuth2User user = mock(OAuth2User.class); List authorities = AuthorityUtils.createAuthorityList("ROLE_USER"); OAuth2AuthenticationToken token = new OAuth2AuthenticationToken(user, authorities, "id"); @@ -241,8 +261,6 @@ public void defaultRequestOAuth2AuthorizedClientWhenDefaultFalseAndAuthenticatio @Test public void defaultRequestOAuth2AuthorizedClientWhenAuthenticationAndClientRegistrationIdThenIdIsExplicit() { - this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, - this.authorizedClientRepository); OAuth2User user = mock(OAuth2User.class); List authorities = AuthorityUtils.createAuthorityList("ROLE_USER"); OAuth2AuthenticationToken token = new OAuth2AuthenticationToken(user, authorities, "id"); @@ -260,8 +278,6 @@ public void defaultRequestOAuth2AuthorizedClientWhenAuthenticationAndClientRegis @Test public void defaultRequestOAuth2AuthorizedClientWhenAuthenticationNullAndClientRegistrationIdThenOAuth2AuthorizedClient() { - this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, - this.authorizedClientRepository); OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", this.accessToken); when(this.authorizedClientRepository.loadAuthorizedClient(any(), any(), any())).thenReturn(authorizedClient); @@ -276,17 +292,17 @@ public void defaultRequestOAuth2AuthorizedClientWhenAuthenticationNullAndClientR @Test public void defaultRequestWhenClientCredentialsThenAuthorizedClient() { this.registration = TestClientRegistrations.clientCredentials().build(); - this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, - this.authorizedClientRepository); - this.function.setClientCredentialsTokenResponseClient(this.clientCredentialsTokenResponseClient); when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(this.registration); OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses .accessTokenResponse().build(); - when(this.clientCredentialsTokenResponseClient.getTokenResponse(any())).thenReturn( - accessTokenResponse); + when(this.clientCredentialsTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse); clientRegistrationId(this.registration.getRegistrationId()).accept(this.result); + MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletResponse response = new MockHttpServletResponse(); + RequestContextHolder.setRequestAttributes(new ServletRequestAttributes(request, response)); + Map attrs = getDefaultRequestAttributes(); OAuth2AuthorizedClient authorizedClient = getOAuth2AuthorizedClient(attrs); @@ -299,15 +315,15 @@ public void defaultRequestWhenClientCredentialsThenAuthorizedClient() { @Test public void defaultRequestWhenDefaultClientRegistrationIdThenAuthorizedClient() { this.registration = TestClientRegistrations.clientCredentials().build(); - this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, - this.authorizedClientRepository); this.function.setDefaultClientRegistrationId(this.registration.getRegistrationId()); - this.function.setClientCredentialsTokenResponseClient(this.clientCredentialsTokenResponseClient); when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(this.registration); OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses .accessTokenResponse().build(); - when(this.clientCredentialsTokenResponseClient.getTokenResponse(any())).thenReturn( - accessTokenResponse); + when(this.clientCredentialsTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse); + + MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletResponse response = new MockHttpServletResponse(); + RequestContextHolder.setRequestAttributes(new ServletRequestAttributes(request, response)); Map attrs = getDefaultRequestAttributes(); OAuth2AuthorizedClient authorizedClient = getOAuth2AuthorizedClient(attrs); @@ -321,9 +337,6 @@ public void defaultRequestWhenDefaultClientRegistrationIdThenAuthorizedClient() @Test public void defaultRequestWhenClientIdNotFoundThenIllegalArgumentException() { this.registration = TestClientRegistrations.clientCredentials().build(); - this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, - this.authorizedClientRepository); - clientRegistrationId(this.registration.getRegistrationId()).accept(this.result); assertThatCode(() -> getDefaultRequestAttributes()) @@ -384,7 +397,7 @@ public void filterWhenRefreshRequiredThenRefresh() { .expiresIn(3600) .refreshToken("refresh-1") .build(); - when(this.exchange.getResponse().body(any())).thenReturn(Mono.just(response)); + when(this.refreshTokenTokenResponseClient.getTokenResponse(any())).thenReturn(response); Instant issuedAt = Instant.now().minus(Duration.ofDays(1)); Instant accessTokenExpiresAt = issuedAt.plus(Duration.ofHours(1)); @@ -392,8 +405,6 @@ public void filterWhenRefreshRequiredThenRefresh() { this.accessToken.getTokenValue(), issuedAt, accessTokenExpiresAt); - this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, - this.authorizedClientRepository); OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", issuedAt); OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, @@ -401,10 +412,13 @@ public void filterWhenRefreshRequiredThenRefresh() { ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) .attributes(oauth2AuthorizedClient(authorizedClient)) .attributes(authentication(this.authentication)) + .attributes(httpServletRequest(new MockHttpServletRequest())) + .attributes(httpServletResponse(new MockHttpServletResponse())) .build(); this.function.filter(request, this.exchange).block(); + verify(this.refreshTokenTokenResponseClient).getTokenResponse(any()); verify(this.authorizedClientRepository).saveAuthorizedClient(this.authorizedClientCaptor.capture(), eq(this.authentication), any(), any()); OAuth2AuthorizedClient newAuthorizedClient = authorizedClientCaptor.getValue(); @@ -412,19 +426,13 @@ public void filterWhenRefreshRequiredThenRefresh() { assertThat(newAuthorizedClient.getRefreshToken()).isEqualTo(response.getRefreshToken()); List requests = this.exchange.getRequests(); - assertThat(requests).hasSize(2); + assertThat(requests).hasSize(1); ClientRequest request0 = requests.get(0); - assertThat(request0.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ="); - assertThat(request0.url().toASCIIString()).isEqualTo("https://example.com/login/oauth/access_token"); - assertThat(request0.method()).isEqualTo(HttpMethod.POST); - assertThat(getBody(request0)).isEqualTo("grant_type=refresh_token&refresh_token=refresh-token"); - - ClientRequest request1 = requests.get(1); - assertThat(request1.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-1"); - assertThat(request1.url().toASCIIString()).isEqualTo("https://example.com"); - assertThat(request1.method()).isEqualTo(HttpMethod.GET); - assertThat(getBody(request1)).isEmpty(); + assertThat(request0.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-1"); + assertThat(request0.url().toASCIIString()).isEqualTo("https://example.com"); + assertThat(request0.method()).isEqualTo(HttpMethod.GET); + assertThat(getBody(request0)).isEmpty(); } @Test @@ -434,7 +442,18 @@ public void filterWhenRefreshRequiredThenRefreshAndResponseDoesNotContainRefresh .expiresIn(3600) // .refreshToken(xxx) // No refreshToken in response .build(); - when(this.exchange.getResponse().body(any())).thenReturn(Mono.just(response)); + + RestOperations refreshTokenClient = mock(RestOperations.class); + when(refreshTokenClient.exchange(any(RequestEntity.class), eq(OAuth2AccessTokenResponse.class))) + .thenReturn(new ResponseEntity(response, HttpStatus.OK)); + DefaultRefreshTokenTokenResponseClient refreshTokenTokenResponseClient = new DefaultRefreshTokenTokenResponseClient(); + refreshTokenTokenResponseClient.setRestOperations(refreshTokenClient); + + RefreshTokenOAuth2AuthorizedClientProvider authorizedClientProvider = new RefreshTokenOAuth2AuthorizedClientProvider( + this.clientRegistrationRepository, this.authorizedClientRepository); + authorizedClientProvider.setAccessTokenResponseClient(refreshTokenTokenResponseClient); + this.function.setAuthorizedClientProvider(authorizedClientProvider); + Instant issuedAt = Instant.now().minus(Duration.ofDays(1)); Instant accessTokenExpiresAt = issuedAt.plus(Duration.ofHours(1)); @@ -442,8 +461,6 @@ public void filterWhenRefreshRequiredThenRefreshAndResponseDoesNotContainRefresh this.accessToken.getTokenValue(), issuedAt, accessTokenExpiresAt); - this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, - this.authorizedClientRepository); OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", issuedAt); OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, @@ -451,40 +468,32 @@ public void filterWhenRefreshRequiredThenRefreshAndResponseDoesNotContainRefresh ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) .attributes(oauth2AuthorizedClient(authorizedClient)) .attributes(authentication(this.authentication)) + .attributes(httpServletRequest(new MockHttpServletRequest())) + .attributes(httpServletResponse(new MockHttpServletResponse())) .build(); this.function.filter(request, this.exchange).block(); + verify(refreshTokenClient).exchange(any(RequestEntity.class), eq(OAuth2AccessTokenResponse.class)); verify(this.authorizedClientRepository).saveAuthorizedClient(this.authorizedClientCaptor.capture(), eq(this.authentication), any(), any()); OAuth2AuthorizedClient newAuthorizedClient = authorizedClientCaptor.getValue(); assertThat(newAuthorizedClient.getAccessToken()).isEqualTo(response.getAccessToken()); - assertThat(newAuthorizedClient.getRefreshToken()).isEqualTo(refreshToken); + assertThat(newAuthorizedClient.getRefreshToken().getTokenValue()).isEqualTo(refreshToken.getTokenValue()); List requests = this.exchange.getRequests(); - assertThat(requests).hasSize(2); + assertThat(requests).hasSize(1); ClientRequest request0 = requests.get(0); - assertThat(request0.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ="); - assertThat(request0.url().toASCIIString()).isEqualTo("https://example.com/login/oauth/access_token"); - assertThat(request0.method()).isEqualTo(HttpMethod.POST); - assertThat(getBody(request0)).isEqualTo("grant_type=refresh_token&refresh_token=refresh-token"); - - ClientRequest request1 = requests.get(1); - assertThat(request1.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-1"); - assertThat(request1.url().toASCIIString()).isEqualTo("https://example.com"); - assertThat(request1.method()).isEqualTo(HttpMethod.GET); - assertThat(getBody(request1)).isEmpty(); + assertThat(request0.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-1"); + assertThat(request0.url().toASCIIString()).isEqualTo("https://example.com"); + assertThat(request0.method()).isEqualTo(HttpMethod.GET); + assertThat(getBody(request0)).isEmpty(); } @Test public void filterWhenClientCredentialsTokenNotExpiredThenUseCurrentToken() { this.registration = TestClientRegistrations.clientCredentials().build(); - - this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, - this.authorizedClientRepository); - this.function.setClientCredentialsTokenResponseClient(this.clientCredentialsTokenResponseClient); - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", this.accessToken, null); ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) @@ -524,22 +533,21 @@ public void filterWhenClientCredentialsTokenExpiredThenGetNewToken() { this.accessToken.getTokenValue(), issuedAt, accessTokenExpiresAt); - this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, - this.authorizedClientRepository); - this.function.setClientCredentialsTokenResponseClient(this.clientCredentialsTokenResponseClient); OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", this.accessToken, null); ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) .attributes(oauth2AuthorizedClient(authorizedClient)) .attributes(authentication(this.authentication)) + .attributes(httpServletRequest(new MockHttpServletRequest())) + .attributes(httpServletResponse(new MockHttpServletResponse())) .build(); this.function.filter(request, this.exchange).block(); verify(this.authorizedClientRepository).saveAuthorizedClient(any(), eq(this.authentication), any(), any()); - verify(clientCredentialsTokenResponseClient).getTokenResponse(any()); + verify(this.clientCredentialsTokenResponseClient).getTokenResponse(any()); List requests = this.exchange.getRequests(); assertThat(requests).hasSize(1); @@ -558,7 +566,7 @@ public void filterWhenRefreshRequiredAndEmptyReactiveSecurityContextThenSaved() .expiresIn(3600) .refreshToken("refresh-1") .build(); - when(this.exchange.getResponse().body(any())).thenReturn(Mono.just(response)); + when(this.refreshTokenTokenResponseClient.getTokenResponse(any())).thenReturn(response); Instant issuedAt = Instant.now().minus(Duration.ofDays(1)); Instant accessTokenExpiresAt = issuedAt.plus(Duration.ofHours(1)); @@ -566,42 +574,33 @@ public void filterWhenRefreshRequiredAndEmptyReactiveSecurityContextThenSaved() this.accessToken.getTokenValue(), issuedAt, accessTokenExpiresAt); - this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, - this.authorizedClientRepository); OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", issuedAt); OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", this.accessToken, refreshToken); ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) .attributes(oauth2AuthorizedClient(authorizedClient)) + .attributes(httpServletRequest(new MockHttpServletRequest())) + .attributes(httpServletResponse(new MockHttpServletResponse())) .build(); - this.function.filter(request, this.exchange) - .block(); + this.function.filter(request, this.exchange).block(); + verify(this.refreshTokenTokenResponseClient).getTokenResponse(any()); verify(this.authorizedClientRepository).saveAuthorizedClient(any(), any(), any(), any()); List requests = this.exchange.getRequests(); - assertThat(requests).hasSize(2); + assertThat(requests).hasSize(1); ClientRequest request0 = requests.get(0); - assertThat(request0.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ="); - assertThat(request0.url().toASCIIString()).isEqualTo("https://example.com/login/oauth/access_token"); - assertThat(request0.method()).isEqualTo(HttpMethod.POST); - assertThat(getBody(request0)).isEqualTo("grant_type=refresh_token&refresh_token=refresh-token"); - - ClientRequest request1 = requests.get(1); - assertThat(request1.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-1"); - assertThat(request1.url().toASCIIString()).isEqualTo("https://example.com"); - assertThat(request1.method()).isEqualTo(HttpMethod.GET); - assertThat(getBody(request1)).isEmpty(); + assertThat(request0.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-1"); + assertThat(request0.url().toASCIIString()).isEqualTo("https://example.com"); + assertThat(request0.method()).isEqualTo(HttpMethod.GET); + assertThat(getBody(request0)).isEmpty(); } @Test public void filterWhenRefreshTokenNullThenShouldRefreshFalse() { - this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, - this.authorizedClientRepository); - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", this.accessToken); ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) @@ -622,9 +621,6 @@ public void filterWhenRefreshTokenNullThenShouldRefreshFalse() { @Test public void filterWhenNotExpiredThenShouldRefreshFalse() { - this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, - this.authorizedClientRepository); - OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt()); OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", this.accessToken, refreshToken); @@ -647,8 +643,6 @@ public void filterWhenNotExpiredThenShouldRefreshFalse() { // gh-6483 @Test public void filterWhenChainedThenDefaultsStillAvailable() throws Exception { - this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction( - this.clientRegistrationRepository, this.authorizedClientRepository); this.function.afterPropertiesSet(); // Hooks.onLastOperator() initialized this.function.setDefaultOAuth2AuthorizedClient(true); @@ -698,8 +692,6 @@ public void filterWhenChainedThenDefaultsStillAvailable() throws Exception { @Test public void filterWhenRequestAttributesNotSetAndHooksNotInitThenDefaultsNotAvailable() throws Exception { - this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction( - this.clientRegistrationRepository, this.authorizedClientRepository); // this.function.afterPropertiesSet(); // Hooks.onLastOperator() NOT initialized this.function.setDefaultOAuth2AuthorizedClient(true); @@ -729,8 +721,6 @@ public void filterWhenRequestAttributesNotSetAndHooksNotInitThenDefaultsNotAvail @Test public void filterWhenRequestAttributesNotSetAndHooksInitHooksResetThenDefaultsNotAvailable() throws Exception { - this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction( - this.clientRegistrationRepository, this.authorizedClientRepository); this.function.afterPropertiesSet(); // Hooks.onLastOperator() initialized this.function.destroy(); // Hooks.onLastOperator() released this.function.setDefaultOAuth2AuthorizedClient(true); From eef15bc6980c49df499b910af4759b46fab69944 Mon Sep 17 00:00:00 2001 From: Joe Grandja Date: Fri, 14 Jun 2019 21:04:34 -0400 Subject: [PATCH 08/19] Simplify population of OAuth2AuthorizationContext --- ...entialsOAuth2AuthorizedClientProvider.java | 29 ++++-- ...shTokenOAuth2AuthorizedClientProvider.java | 50 +++++++--- ...Auth2AuthorizedClientArgumentResolver.java | 11 +-- ...uthorizedClientExchangeFilterFunction.java | 95 ++++++------------- ...lsOAuth2AuthorizedClientProviderTests.java | 4 +- ...enOAuth2AuthorizedClientProviderTests.java | 32 ++++--- 6 files changed, 117 insertions(+), 104 deletions(-) diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ClientCredentialsOAuth2AuthorizedClientProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ClientCredentialsOAuth2AuthorizedClientProvider.java index c0de3587ec4..ac1e0168f1e 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ClientCredentialsOAuth2AuthorizedClientProvider.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ClientCredentialsOAuth2AuthorizedClientProvider.java @@ -19,6 +19,7 @@ import org.springframework.security.oauth2.client.endpoint.DefaultClientCredentialsTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest; +import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.core.AuthorizationGrantType; @@ -38,8 +39,8 @@ * @see DefaultClientCredentialsTokenResponseClient */ public final class ClientCredentialsOAuth2AuthorizedClientProvider implements OAuth2AuthorizedClientProvider { - private static final String HTTP_SERVLET_REQUEST_ATTR_NAME = HttpServletRequest.class.getName(); - private static final String HTTP_SERVLET_RESPONSE_ATTR_NAME = HttpServletResponse.class.getName(); + private static final String HTTP_SERVLET_REQUEST_ATTRIBUTE_NAME = HttpServletRequest.class.getName(); + private static final String HTTP_SERVLET_RESPONSE_ATTRIBUTE_NAME = HttpServletResponse.class.getName(); private final ClientRegistrationRepository clientRegistrationRepository; private final OAuth2AuthorizedClientRepository authorizedClientRepository; private OAuth2AccessTokenResponseClient accessTokenResponseClient = @@ -59,6 +60,22 @@ public ClientCredentialsOAuth2AuthorizedClientProvider(ClientRegistrationReposit this.authorizedClientRepository = authorizedClientRepository; } + /** + * Attempt to authorize (or re-authorize) the {@link OAuth2AuthorizationContext#getClientRegistration() client} in the provided {@code context}. + * Returns {@code null} if authorization (or re-authorization) is not supported, + * e.g. the client's {@link ClientRegistration#getAuthorizationGrantType() authorization grant type} + * is not {@link AuthorizationGrantType#CLIENT_CREDENTIALS client_credentials}. + * + *

+ * The following {@link OAuth2AuthorizationContext#getAttributes() context attributes} are supported: + *

    + *
  1. {@code "javax.servlet.http.HttpServletRequest"} (required) - the {@code HttpServletRequest}
  2. + *
  3. {@code "javax.servlet.http.HttpServletResponse"} (required) - the {@code HttpServletResponse}
  4. + *
+ * + * @param context the context that holds authorization-specific state for the client + * @return the {@link OAuth2AuthorizedClient} or {@code null} if authorization (or re-authorization) is not supported + */ @Override @Nullable public OAuth2AuthorizedClient authorize(OAuth2AuthorizationContext context) { @@ -67,10 +84,10 @@ public OAuth2AuthorizedClient authorize(OAuth2AuthorizationContext context) { return null; } - HttpServletRequest request = context.getAttribute(HTTP_SERVLET_REQUEST_ATTR_NAME); - HttpServletResponse response = context.getAttribute(HTTP_SERVLET_RESPONSE_ATTR_NAME); - Assert.notNull(request, "context.HttpServletRequest cannot be null"); - Assert.notNull(response, "context.HttpServletResponse cannot be null"); + HttpServletRequest request = context.getAttribute(HTTP_SERVLET_REQUEST_ATTRIBUTE_NAME); + HttpServletResponse response = context.getAttribute(HTTP_SERVLET_RESPONSE_ATTRIBUTE_NAME); + Assert.notNull(request, "The context attribute cannot be null '" + HTTP_SERVLET_REQUEST_ATTRIBUTE_NAME + "'"); + Assert.notNull(response, "The context attribute cannot be null '" + HTTP_SERVLET_RESPONSE_ATTRIBUTE_NAME + "'"); // As per spec, in section 4.4.3 Access Token Response // https://tools.ietf.org/html/rfc6749#section-4.4.3 diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProvider.java index 0ac3defd432..11d4c021abd 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProvider.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProvider.java @@ -24,10 +24,13 @@ import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.util.Assert; +import org.springframework.util.StringUtils; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; +import java.util.Arrays; import java.util.Set; +import java.util.stream.Collectors; /** * An implementation of an {@link OAuth2AuthorizedClientProvider} @@ -39,9 +42,17 @@ * @see DefaultRefreshTokenTokenResponseClient */ public final class RefreshTokenOAuth2AuthorizedClientProvider implements OAuth2AuthorizedClientProvider { - private static final String HTTP_SERVLET_REQUEST_ATTR_NAME = HttpServletRequest.class.getName(); - private static final String HTTP_SERVLET_RESPONSE_ATTR_NAME = HttpServletResponse.class.getName(); - private static final String SCOPE_ATTR_NAME = "SCOPE"; + /** + * The name of the {@link OAuth2AuthorizationContext#getAttribute(String) attribute} + * in the {@link OAuth2AuthorizationContext context} associated to the value for the "requested scope(s)". + * The value of the attribute is a space-delimited or comma-delimited {@code String} of scope(s) + * to be requested by the {@link OAuth2AuthorizationContext#getClientRegistration() client}. + */ + public static final String REQUEST_SCOPE_ATTRIBUTE_NAME = "org.springframework.security.oauth2.client.REQUEST_SCOPE"; + + private static final String HTTP_SERVLET_REQUEST_ATTRIBUTE_NAME = HttpServletRequest.class.getName(); + private static final String HTTP_SERVLET_RESPONSE_ATTRIBUTE_NAME = HttpServletResponse.class.getName(); + private final ClientRegistrationRepository clientRegistrationRepository; private final OAuth2AuthorizedClientRepository authorizedClientRepository; private OAuth2AccessTokenResponseClient accessTokenResponseClient = @@ -61,6 +72,23 @@ public RefreshTokenOAuth2AuthorizedClientProvider(ClientRegistrationRepository c this.authorizedClientRepository = authorizedClientRepository; } + /** + * Attempt to re-authorize the {@link OAuth2AuthorizationContext#getClientRegistration() client} in the provided {@code context}. + * Returns {@code null} if re-authorization is not supported, + * e.g. the {@link OAuth2AuthorizedClient#getRefreshToken() refresh token} is not available for the + * {@link OAuth2AuthorizationContext#getAuthorizedClient() authorized client}. + * + *

+ * The following {@link OAuth2AuthorizationContext#getAttributes() context attributes} are supported: + *

    + *
  1. {@code "javax.servlet.http.HttpServletRequest"} (required) - the {@code HttpServletRequest}
  2. + *
  3. {@code "javax.servlet.http.HttpServletResponse"} (required) - the {@code HttpServletResponse}
  4. + *
  5. {@code "org.springframework.security.oauth2.client.REQUEST_SCOPE"} (optional) - a space-delimited or comma-delimited {@code String} of scope(s) to be requested by the {@link OAuth2AuthorizationContext#getClientRegistration() client}
  6. + *
+ * + * @param context the context that holds authorization-specific state for the client + * @return the {@link OAuth2AuthorizedClient} or {@code null} if re-authorization is not supported + */ @Override @Nullable public OAuth2AuthorizedClient authorize(OAuth2AuthorizationContext context) { @@ -69,16 +97,16 @@ public OAuth2AuthorizedClient authorize(OAuth2AuthorizationContext context) { return null; } - HttpServletRequest request = context.getAttribute(HTTP_SERVLET_REQUEST_ATTR_NAME); - HttpServletResponse response = context.getAttribute(HTTP_SERVLET_RESPONSE_ATTR_NAME); - Assert.notNull(request, "context.HttpServletRequest cannot be null"); - Assert.notNull(response, "context.HttpServletResponse cannot be null"); + HttpServletRequest request = context.getAttribute(HTTP_SERVLET_REQUEST_ATTRIBUTE_NAME); + HttpServletResponse response = context.getAttribute(HTTP_SERVLET_RESPONSE_ATTRIBUTE_NAME); + Assert.notNull(request, "The context attribute cannot be null '" + HTTP_SERVLET_REQUEST_ATTRIBUTE_NAME + "'"); + Assert.notNull(response, "The context attribute cannot be null '" + HTTP_SERVLET_RESPONSE_ATTRIBUTE_NAME + "'"); - Object scopesObj = context.getAttribute(SCOPE_ATTR_NAME); + String requestScope = context.getAttribute(REQUEST_SCOPE_ATTRIBUTE_NAME); Set scopes = null; - if (scopesObj != null) { - Assert.isTrue(scopesObj instanceof Set, "The '" + SCOPE_ATTR_NAME + "' attribute must be of type " + Set.class.getName()); - scopes = (Set) scopesObj; + if (!StringUtils.isEmpty(requestScope)) { + String delimiter = requestScope.indexOf(',') != -1 ? "," : " "; + scopes = Arrays.stream(StringUtils.delimitedListToStringArray(requestScope, delimiter, " ")).collect(Collectors.toSet()); } OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolver.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolver.java index 6e36805492c..b483ad43b2d 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolver.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolver.java @@ -123,17 +123,16 @@ public Object resolveArgument(MethodParameter parameter, HttpServletResponse servletResponse = webRequest.getNativeResponse(HttpServletResponse.class); - OAuth2AuthorizationContext.Builder authorizationContextBuilder = OAuth2AuthorizationContext.authorize(clientRegistration); - if (principal == null) { - authorizationContextBuilder.principal("anonymousUser"); + OAuth2AuthorizationContext.Builder contextBuilder = OAuth2AuthorizationContext.authorize(clientRegistration); + if (principal != null) { + contextBuilder.principal(principal); } else { - authorizationContextBuilder.principal(principal); + contextBuilder.principal("anonymousUser"); } - OAuth2AuthorizationContext authorizationContext = authorizationContextBuilder + OAuth2AuthorizationContext authorizationContext = contextBuilder .attribute(HttpServletRequest.class.getName(), servletRequest) .attribute(HttpServletResponse.class.getName(), servletResponse) .build(); - return this.authorizedClientProvider.authorize(authorizationContext); } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java index 3249de24c3d..99d3368cae9 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java @@ -20,7 +20,6 @@ import org.springframework.beans.factory.DisposableBean; import org.springframework.beans.factory.InitializingBean; import org.springframework.security.core.Authentication; -import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.oauth2.client.AuthorizationCodeOAuth2AuthorizedClientProvider; import org.springframework.security.oauth2.client.ClientCredentialsOAuth2AuthorizedClientProvider; @@ -54,7 +53,7 @@ import java.time.Clock; import java.time.Duration; import java.time.Instant; -import java.util.Collection; +import java.util.HashMap; import java.util.Map; import java.util.function.Consumer; @@ -317,7 +316,7 @@ public Mono filter(ClientRequest request, ExchangeFunction next) .filter(req -> req.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME).isPresent()) .switchIfEmpty(mergeRequestAttributesFromContext(request)) .filter(req -> req.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME).isPresent()) - .flatMap(req -> reauthorizeClientIfNecessary(req, next, getOAuth2AuthorizedClient(req.attributes()))) + .flatMap(req -> reauthorizeClientIfNecessary(getOAuth2AuthorizedClient(req.attributes()), req)) .map(authorizedClient -> bearer(request, authorizedClient)) .flatMap(next::exchange) .switchIfEmpty(next.exchange(request)); @@ -388,48 +387,59 @@ private void populateDefaultOAuth2AuthorizedClient(Map attrs) { OAuth2AuthorizedClient authorizedClient = this.authorizedClientRepository.loadAuthorizedClient( clientRegistrationId, authentication, request); if (authorizedClient == null) { - authorizedClient = getAuthorizedClient(clientRegistrationId, attrs); + authorizedClient = authorizeClient(clientRegistrationId, attrs); } oauth2AuthorizedClient(authorizedClient).accept(attrs); } } - private OAuth2AuthorizedClient getAuthorizedClient(String clientRegistrationId, Map attrs) { + private OAuth2AuthorizedClient authorizeClient(String clientRegistrationId, Map attributes) { ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId); if (clientRegistration == null) { throw new IllegalArgumentException("Could not find ClientRegistration with id " + clientRegistrationId); } - Authentication authentication = getAuthentication(attrs); - if (authentication == null) { - authentication = new PrincipalNameAuthentication("anonymousUser"); + + OAuth2AuthorizationContext.Builder contextBuilder = OAuth2AuthorizationContext.authorize(clientRegistration); + Authentication authentication = getAuthentication(attributes); + if (authentication != null) { + contextBuilder.principal(authentication); + } else { + contextBuilder.principal("anonymousUser"); } - OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext.authorize(clientRegistration) - .principal(authentication) - .attribute(HttpServletRequest.class.getName(), getRequest(attrs)) - .attribute(HttpServletResponse.class.getName(), getResponse(attrs)) + OAuth2AuthorizationContext authorizationContext = contextBuilder + .attributes(defaultContextAttributes(attributes)) .build(); return this.authorizedClientProvider.authorize(authorizationContext); } private Mono reauthorizeClientIfNecessary( - ClientRequest request, ExchangeFunction next, OAuth2AuthorizedClient authorizedClient) { + OAuth2AuthorizedClient authorizedClient, ClientRequest request) { if (this.authorizedClientProvider == null || !hasTokenExpired(authorizedClient)) { return Mono.just(authorizedClient); } Map attributes = request.attributes(); + + OAuth2AuthorizationContext.Builder contextBuilder = OAuth2AuthorizationContext.reauthorize(authorizedClient); Authentication authentication = getAuthentication(attributes); - if (authentication == null) { - authentication = new PrincipalNameAuthentication(authorizedClient.getPrincipalName()); + if (authentication != null) { + contextBuilder.principal(authentication); + } else { + contextBuilder.principal(authorizedClient.getPrincipalName()); } - OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext.reauthorize(authorizedClient) - .principal(authentication) - .attribute(HttpServletRequest.class.getName(), getRequest(attributes)) - .attribute(HttpServletResponse.class.getName(), getResponse(attributes)) + OAuth2AuthorizationContext authorizationContext = contextBuilder + .attributes(defaultContextAttributes(attributes)) .build(); return Mono.fromSupplier(() -> this.authorizedClientProvider.authorize(authorizationContext)); } + private Map defaultContextAttributes(Map attributes) { + Map contextAttributes = new HashMap<>(); + contextAttributes.put(HttpServletRequest.class.getName(), getRequest(attributes)); + contextAttributes.put(HttpServletResponse.class.getName(), getResponse(attributes)); + return contextAttributes; + } + private boolean hasTokenExpired(OAuth2AuthorizedClient authorizedClient) { Instant now = this.clock.instant(); Instant expiresAt = authorizedClient.getAccessToken().getExpiresAt(); @@ -478,53 +488,6 @@ static HttpServletResponse getResponse(Map attrs) { return (HttpServletResponse) attrs.get(HTTP_SERVLET_RESPONSE_ATTR_NAME); } - private static class PrincipalNameAuthentication implements Authentication { - private final String username; - - private PrincipalNameAuthentication(String username) { - this.username = username; - } - - @Override - public Collection getAuthorities() { - throw unsupported(); - } - - @Override - public Object getCredentials() { - throw unsupported(); - } - - @Override - public Object getDetails() { - throw unsupported(); - } - - @Override - public Object getPrincipal() { - throw unsupported(); - } - - @Override - public boolean isAuthenticated() { - throw unsupported(); - } - - @Override - public void setAuthenticated(boolean isAuthenticated) throws IllegalArgumentException { - throw unsupported(); - } - - @Override - public String getName() { - return this.username; - } - - private UnsupportedOperationException unsupported() { - return new UnsupportedOperationException("Not Supported"); - } - } - private static class RequestContextSubscriber implements CoreSubscriber { private static final String CONTEXT_DEFAULTED_ATTR_NAME = RequestContextSubscriber.class.getName().concat(".CONTEXT_DEFAULTED_ATTR_NAME"); private final CoreSubscriber delegate; diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/ClientCredentialsOAuth2AuthorizedClientProviderTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/ClientCredentialsOAuth2AuthorizedClientProviderTests.java index a3b3b9d1ddf..b700e94102c 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/ClientCredentialsOAuth2AuthorizedClientProviderTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/ClientCredentialsOAuth2AuthorizedClientProviderTests.java @@ -106,7 +106,7 @@ public void authorizeWhenHttpServletRequestIsNullThenThrowIllegalArgumentExcepti OAuth2AuthorizationContext.authorize(this.clientRegistration).principal(this.principal).build(); assertThatThrownBy(() -> this.authorizedClientProvider.authorize(authorizationContext)) .isInstanceOf(IllegalArgumentException.class) - .hasMessage("context.HttpServletRequest cannot be null"); + .hasMessage("The context attribute cannot be null 'javax.servlet.http.HttpServletRequest'"); } @Test @@ -118,7 +118,7 @@ public void authorizeWhenHttpServletResponseIsNullThenThrowIllegalArgumentExcept .build(); assertThatThrownBy(() -> this.authorizedClientProvider.authorize(authorizationContext)) .isInstanceOf(IllegalArgumentException.class) - .hasMessage("context.HttpServletResponse cannot be null"); + .hasMessage("The context attribute cannot be null 'javax.servlet.http.HttpServletResponse'"); } @Test diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProviderTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProviderTests.java index 344d3838f03..284eea6175a 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProviderTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProviderTests.java @@ -35,7 +35,8 @@ import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; -import java.util.Collections; +import java.util.Arrays; +import java.util.HashSet; import java.util.Set; import static org.assertj.core.api.Assertions.assertThat; @@ -122,7 +123,7 @@ public void authorizeWhenHttpServletRequestIsNullThenThrowIllegalArgumentExcepti OAuth2AuthorizationContext.reauthorize(this.authorizedClient).principal(this.principal).build(); assertThatThrownBy(() -> this.authorizedClientProvider.authorize(authorizationContext)) .isInstanceOf(IllegalArgumentException.class) - .hasMessage("context.HttpServletRequest cannot be null"); + .hasMessage("The context attribute cannot be null 'javax.servlet.http.HttpServletRequest'"); } @Test @@ -134,7 +135,7 @@ public void authorizeWhenHttpServletResponseIsNullThenThrowIllegalArgumentExcept .build(); assertThatThrownBy(() -> this.authorizedClientProvider.authorize(authorizationContext)) .isInstanceOf(IllegalArgumentException.class) - .hasMessage("context.HttpServletResponse cannot be null"); + .hasMessage("The context attribute cannot be null 'javax.servlet.http.HttpServletResponse'"); } @Test @@ -163,20 +164,21 @@ public void authorizeWhenAuthorizedWithRefreshTokenThenReauthorize() { } @Test - public void authorizeWhenAuthorizedAndScopeProvidedThenScopeRequested() { + public void authorizeWhenAuthorizedAndSpaceDelimitedScopeProvidedThenScopeRequested() { OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse() .refreshToken("new-refresh-token") .build(); when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse); - Set scope = Collections.singleton("read"); + String scope = "read write"; + Set scopes = new HashSet<>(Arrays.asList("read", "write")); OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext.reauthorize(this.authorizedClient) .principal(this.principal) .attribute(HttpServletRequest.class.getName(), new MockHttpServletRequest()) .attribute(HttpServletResponse.class.getName(), new MockHttpServletResponse()) - .attribute("SCOPE", scope) + .attribute(RefreshTokenOAuth2AuthorizedClientProvider.REQUEST_SCOPE_ATTRIBUTE_NAME, scope) .build(); this.authorizedClientProvider.authorize(authorizationContext); @@ -184,28 +186,32 @@ public void authorizeWhenAuthorizedAndScopeProvidedThenScopeRequested() { ArgumentCaptor refreshTokenGrantRequestArgCaptor = ArgumentCaptor.forClass(OAuth2RefreshTokenGrantRequest.class); verify(this.accessTokenResponseClient).getTokenResponse(refreshTokenGrantRequestArgCaptor.capture()); - assertThat(refreshTokenGrantRequestArgCaptor.getValue().getScopes()).isEqualTo(scope); + assertThat(refreshTokenGrantRequestArgCaptor.getValue().getScopes()).isEqualTo(scopes); } @Test - public void authorizeWhenAuthorizedAndInvalidScopeProvidedThenThrowIllegalArgumentException() { + public void authorizeWhenAuthorizedAndCommaDelimitedScopeProvidedThenScopeRequested() { OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse() .refreshToken("new-refresh-token") .build(); when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse); - String scope = "read"; + String scope = "read, write"; + Set scopes = new HashSet<>(Arrays.asList("read", "write")); OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext.reauthorize(this.authorizedClient) .principal(this.principal) .attribute(HttpServletRequest.class.getName(), new MockHttpServletRequest()) .attribute(HttpServletResponse.class.getName(), new MockHttpServletResponse()) - .attribute("SCOPE", scope) + .attribute(RefreshTokenOAuth2AuthorizedClientProvider.REQUEST_SCOPE_ATTRIBUTE_NAME, scope) .build(); - assertThatThrownBy(() -> this.authorizedClientProvider.authorize(authorizationContext)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("The 'SCOPE' attribute must be of type " + Set.class.getName()); + this.authorizedClientProvider.authorize(authorizationContext); + + ArgumentCaptor refreshTokenGrantRequestArgCaptor = + ArgumentCaptor.forClass(OAuth2RefreshTokenGrantRequest.class); + verify(this.accessTokenResponseClient).getTokenResponse(refreshTokenGrantRequestArgCaptor.capture()); + assertThat(refreshTokenGrantRequestArgCaptor.getValue().getScopes()).isEqualTo(scopes); } } From 2269d09d7ae43b1680db9423aed67f377d67ee1d Mon Sep 17 00:00:00 2001 From: Joe Grandja Date: Tue, 18 Jun 2019 14:26:37 -0400 Subject: [PATCH 09/19] Rename methods in OAuth2AuthorizationContext --- ...ionCodeOAuth2AuthorizedClientProvider.java | 2 +- .../client/OAuth2AuthorizationContext.java | 32 +++++++++---------- ...shTokenOAuth2AuthorizedClientProvider.java | 2 +- ...Auth2AuthorizedClientArgumentResolver.java | 2 +- ...uthorizedClientExchangeFilterFunction.java | 4 +-- ...deOAuth2AuthorizedClientProviderTests.java | 6 ++-- ...lsOAuth2AuthorizedClientProviderTests.java | 10 +++--- ...ngOAuth2AuthorizedClientProviderTests.java | 4 +-- .../OAuth2AuthorizationContextTests.java | 16 +++++----- ...enOAuth2AuthorizedClientProviderTests.java | 14 ++++---- 10 files changed, 46 insertions(+), 46 deletions(-) diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/AuthorizationCodeOAuth2AuthorizedClientProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/AuthorizationCodeOAuth2AuthorizedClientProvider.java index fc594d36af3..e0ed9e9d294 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/AuthorizationCodeOAuth2AuthorizedClientProvider.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/AuthorizationCodeOAuth2AuthorizedClientProvider.java @@ -34,7 +34,7 @@ public final class AuthorizationCodeOAuth2AuthorizedClientProvider implements OA public OAuth2AuthorizedClient authorize(OAuth2AuthorizationContext context) { Assert.notNull(context, "context cannot be null"); if (AuthorizationGrantType.AUTHORIZATION_CODE.equals(context.getClientRegistration().getAuthorizationGrantType()) && - context.authorizationRequired()) { + context.authorizationRequested()) { // ClientAuthorizationRequiredException is caught by OAuth2AuthorizationRequestRedirectFilter which initiates authorization throw new ClientAuthorizationRequiredException(context.getClientRegistration().getRegistrationId()); } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizationContext.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizationContext.java index 15bb1834bd0..122111fbc6d 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizationContext.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizationContext.java @@ -46,7 +46,7 @@ private OAuth2AuthorizationContext() { } /** - * Returns the {@link ClientRegistration client} requiring authorization. + * Returns the {@link ClientRegistration client} requesting authorization. * * @return the {@link ClientRegistration} */ @@ -64,10 +64,10 @@ public Authentication getPrincipal() { } /** - * Returns the {@link OAuth2AuthorizedClient authorized client} which requires re-authorization - * or {@code null} if the {@link #getClientRegistration() client} needs to be authorized. + * Returns the {@link OAuth2AuthorizedClient authorized client} requesting re-authorization + * or {@code null} if the {@link #getClientRegistration() client} is requesting to be authorized. * - * @return the {@link OAuth2AuthorizedClient} which requires re-authorization or {@code null} if the client needs to be authorized + * @return the {@link OAuth2AuthorizedClient} requesting re-authorization or {@code null} if the client is requesting to be authorized */ @Nullable public OAuth2AuthorizedClient getAuthorizedClient() { @@ -97,40 +97,40 @@ public T getAttribute(String name) { } /** - * Returns {@code true} if the client needs to be authorized, otherwise {@code false}. + * Returns {@code true} if the client is requesting authorization, otherwise {@code false}. * - * @return {@code true} if the client needs to be authorized, otherwise {@code false}. + * @return {@code true} if the client is requesting authorization, otherwise {@code false}. */ - public boolean authorizationRequired() { + public boolean authorizationRequested() { return getAuthorizedClient() == null; } /** - * Returns {@code true} if the client needs to be re-authorized, otherwise {@code false}. + * Returns {@code true} if the client is requesting re-authorization, otherwise {@code false}. * - * @return {@code true} if the client needs to be re-authorized, otherwise {@code false}. + * @return {@code true} if the client is requesting re-authorization, otherwise {@code false}. */ - public boolean reauthorizationRequired() { + public boolean reauthorizationRequested() { return getAuthorizedClient() != null; } /** - * Returns a new {@link Builder} with the {@link ClientRegistration client} requiring authorization. + * Returns a new {@link Builder} with the {@link ClientRegistration client} requesting authorization. * - * @param clientRegistration the {@link ClientRegistration client} requiring authorization + * @param clientRegistration the {@link ClientRegistration client} requesting authorization * @return the {@link Builder} */ - public static Builder authorize(ClientRegistration clientRegistration) { + public static Builder forAuthorization(ClientRegistration clientRegistration) { return new Builder(clientRegistration); } /** - * Returns a new {@link Builder} with the {@link OAuth2AuthorizedClient authorized client} requiring re-authorization. + * Returns a new {@link Builder} with the {@link OAuth2AuthorizedClient authorized client} requesting re-authorization. * - * @param authorizedClient the {@link OAuth2AuthorizedClient authorized client} requiring re-authorization + * @param authorizedClient the {@link OAuth2AuthorizedClient authorized client} requesting re-authorization * @return the {@link Builder} */ - public static Builder reauthorize(OAuth2AuthorizedClient authorizedClient) { + public static Builder forReauthorization(OAuth2AuthorizedClient authorizedClient) { return new Builder(authorizedClient); } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProvider.java index 11d4c021abd..d3a8d3fb9d7 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProvider.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProvider.java @@ -93,7 +93,7 @@ public RefreshTokenOAuth2AuthorizedClientProvider(ClientRegistrationRepository c @Nullable public OAuth2AuthorizedClient authorize(OAuth2AuthorizationContext context) { Assert.notNull(context, "context cannot be null"); - if (!context.reauthorizationRequired() || context.getAuthorizedClient().getRefreshToken() == null) { + if (!context.reauthorizationRequested() || context.getAuthorizedClient().getRefreshToken() == null) { return null; } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolver.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolver.java index b483ad43b2d..eb124e31bc3 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolver.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolver.java @@ -123,7 +123,7 @@ public Object resolveArgument(MethodParameter parameter, HttpServletResponse servletResponse = webRequest.getNativeResponse(HttpServletResponse.class); - OAuth2AuthorizationContext.Builder contextBuilder = OAuth2AuthorizationContext.authorize(clientRegistration); + OAuth2AuthorizationContext.Builder contextBuilder = OAuth2AuthorizationContext.forAuthorization(clientRegistration); if (principal != null) { contextBuilder.principal(principal); } else { diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java index 99d3368cae9..085980e1ea2 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java @@ -399,7 +399,7 @@ private OAuth2AuthorizedClient authorizeClient(String clientRegistrationId, Map< throw new IllegalArgumentException("Could not find ClientRegistration with id " + clientRegistrationId); } - OAuth2AuthorizationContext.Builder contextBuilder = OAuth2AuthorizationContext.authorize(clientRegistration); + OAuth2AuthorizationContext.Builder contextBuilder = OAuth2AuthorizationContext.forAuthorization(clientRegistration); Authentication authentication = getAuthentication(attributes); if (authentication != null) { contextBuilder.principal(authentication); @@ -420,7 +420,7 @@ private Mono reauthorizeClientIfNecessary( Map attributes = request.attributes(); - OAuth2AuthorizationContext.Builder contextBuilder = OAuth2AuthorizationContext.reauthorize(authorizedClient); + OAuth2AuthorizationContext.Builder contextBuilder = OAuth2AuthorizationContext.forReauthorization(authorizedClient); Authentication authentication = getAuthentication(attributes); if (authentication != null) { contextBuilder.principal(authentication); diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/AuthorizationCodeOAuth2AuthorizedClientProviderTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/AuthorizationCodeOAuth2AuthorizedClientProviderTests.java index 338c3b1d960..a9b60eb2b2a 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/AuthorizationCodeOAuth2AuthorizedClientProviderTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/AuthorizationCodeOAuth2AuthorizedClientProviderTests.java @@ -55,7 +55,7 @@ public void authorizeWhenContextIsNullThenThrowIllegalArgumentException() { @Test public void authorizeWhenAuthorizationCodeAndNotAuthorizedThenAuthorize() { OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.authorize(this.clientRegistration).principal(this.principal).build(); + OAuth2AuthorizationContext.forAuthorization(this.clientRegistration).principal(this.principal).build(); assertThatThrownBy(() -> this.authorizedClientProvider.authorize(authorizationContext)) .isInstanceOf(ClientAuthorizationRequiredException.class); } @@ -63,7 +63,7 @@ public void authorizeWhenAuthorizationCodeAndNotAuthorizedThenAuthorize() { @Test public void authorizeWhenAuthorizationCodeAndAuthorizedThenUnableToAuthorize() { OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.reauthorize(this.authorizedClient).principal(this.principal).build(); + OAuth2AuthorizationContext.forReauthorization(this.authorizedClient).principal(this.principal).build(); assertThat(this.authorizedClientProvider.authorize(authorizationContext)).isNull(); } @@ -71,7 +71,7 @@ public void authorizeWhenAuthorizationCodeAndAuthorizedThenUnableToAuthorize() { public void authorizeWhenNotAuthorizationCodeThenUnableToAuthorize() { ClientRegistration clientCredentialsClient = TestClientRegistrations.clientCredentials().build(); OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.authorize(clientCredentialsClient).principal(this.principal).build(); + OAuth2AuthorizationContext.forAuthorization(clientCredentialsClient).principal(this.principal).build(); assertThat(this.authorizedClientProvider.authorize(authorizationContext)).isNull(); } } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/ClientCredentialsOAuth2AuthorizedClientProviderTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/ClientCredentialsOAuth2AuthorizedClientProviderTests.java index b700e94102c..e0893a947fd 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/ClientCredentialsOAuth2AuthorizedClientProviderTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/ClientCredentialsOAuth2AuthorizedClientProviderTests.java @@ -96,14 +96,14 @@ public void authorizeWhenContextIsNullThenThrowIllegalArgumentException() { public void authorizeWhenNotClientCredentialsThenUnableToAuthorize() { ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.authorize(clientRegistration).principal(this.principal).build(); + OAuth2AuthorizationContext.forAuthorization(clientRegistration).principal(this.principal).build(); assertThat(this.authorizedClientProvider.authorize(authorizationContext)).isNull(); } @Test public void authorizeWhenHttpServletRequestIsNullThenThrowIllegalArgumentException() { OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.authorize(this.clientRegistration).principal(this.principal).build(); + OAuth2AuthorizationContext.forAuthorization(this.clientRegistration).principal(this.principal).build(); assertThatThrownBy(() -> this.authorizedClientProvider.authorize(authorizationContext)) .isInstanceOf(IllegalArgumentException.class) .hasMessage("The context attribute cannot be null 'javax.servlet.http.HttpServletRequest'"); @@ -112,7 +112,7 @@ public void authorizeWhenHttpServletRequestIsNullThenThrowIllegalArgumentExcepti @Test public void authorizeWhenHttpServletResponseIsNullThenThrowIllegalArgumentException() { OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.authorize(this.clientRegistration) + OAuth2AuthorizationContext.forAuthorization(this.clientRegistration) .principal(this.principal) .attribute(HttpServletRequest.class.getName(), new MockHttpServletRequest()) .build(); @@ -127,7 +127,7 @@ public void authorizeWhenClientCredentialsAndNotAuthorizedThenAuthorize() { when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse); OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.authorize(this.clientRegistration) + OAuth2AuthorizationContext.forAuthorization(this.clientRegistration) .principal(this.principal) .attribute(HttpServletRequest.class.getName(), new MockHttpServletRequest()) .attribute(HttpServletResponse.class.getName(), new MockHttpServletResponse()) @@ -152,7 +152,7 @@ public void authorizeWhenClientCredentialsAndAuthorizedThenReauthorize() { this.clientRegistration, this.principal.getName(), TestOAuth2AccessTokens.noScopes()); OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.reauthorize(authorizedClient) + OAuth2AuthorizationContext.forReauthorization(authorizedClient) .principal(this.principal) .attribute(HttpServletRequest.class.getName(), new MockHttpServletRequest()) .attribute(HttpServletResponse.class.getName(), new MockHttpServletResponse()) diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/DelegatingOAuth2AuthorizedClientProviderTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/DelegatingOAuth2AuthorizedClientProviderTests.java index aa833ba309d..efd2c51ce6a 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/DelegatingOAuth2AuthorizedClientProviderTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/DelegatingOAuth2AuthorizedClientProviderTests.java @@ -64,7 +64,7 @@ public void authorizeWhenProviderCanAuthorizeThenReturnAuthorizedClient() { DelegatingOAuth2AuthorizedClientProvider delegate = new DelegatingOAuth2AuthorizedClientProvider( mock(OAuth2AuthorizedClientProvider.class), mock(OAuth2AuthorizedClientProvider.class), authorizedClientProvider); - OAuth2AuthorizationContext context = OAuth2AuthorizationContext.reauthorize(authorizedClient).principal(principal).build(); + OAuth2AuthorizationContext context = OAuth2AuthorizationContext.forReauthorization(authorizedClient).principal(principal).build(); OAuth2AuthorizedClient reauthorizedClient = delegate.authorize(context); assertThat(reauthorizedClient).isSameAs(authorizedClient); } @@ -72,7 +72,7 @@ public void authorizeWhenProviderCanAuthorizeThenReturnAuthorizedClient() { @Test public void authorizeWhenProviderCantAuthorizeThenReturnNull() { OAuth2AuthorizationContext context = OAuth2AuthorizationContext - .authorize(TestClientRegistrations.clientRegistration().build()) + .forAuthorization(TestClientRegistrations.clientRegistration().build()) .principal(new TestingAuthenticationToken("principal", "password")) .build(); diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/OAuth2AuthorizationContextTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/OAuth2AuthorizationContextTests.java index 4366c6a4de9..2dc50c59c61 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/OAuth2AuthorizationContextTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/OAuth2AuthorizationContextTests.java @@ -45,21 +45,21 @@ public void setup() { @Test public void authorizeWhenClientRegistrationIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> OAuth2AuthorizationContext.authorize(null).build()) + assertThatThrownBy(() -> OAuth2AuthorizationContext.forAuthorization(null).build()) .isInstanceOf(IllegalArgumentException.class) .hasMessage("clientRegistration cannot be null"); } @Test public void authorizeWhenPrincipalIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> OAuth2AuthorizationContext.authorize(this.clientRegistration).build()) + assertThatThrownBy(() -> OAuth2AuthorizationContext.forAuthorization(this.clientRegistration).build()) .isInstanceOf(IllegalArgumentException.class) .hasMessage("principal cannot be null"); } @Test public void authorizeWhenAllValuesProvidedThenAllValuesAreSet() { - OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext.authorize(this.clientRegistration) + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext.forAuthorization(this.clientRegistration) .principal(this.principal) .attribute("attribute1", "value1") .attribute("attribute2", "value2") @@ -69,26 +69,26 @@ public void authorizeWhenAllValuesProvidedThenAllValuesAreSet() { assertThat(authorizationContext.getAuthorizedClient()).isNull(); assertThat(authorizationContext.getAttributes()).contains( entry("attribute1", "value1"), entry("attribute2", "value2")); - assertThat(authorizationContext.authorizationRequired()).isTrue(); + assertThat(authorizationContext.authorizationRequested()).isTrue(); } @Test public void reauthorizeWhenAuthorizedClientIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> OAuth2AuthorizationContext.reauthorize(null).build()) + assertThatThrownBy(() -> OAuth2AuthorizationContext.forReauthorization(null).build()) .isInstanceOf(IllegalArgumentException.class) .hasMessage("authorizedClient cannot be null"); } @Test public void reauthorizeWhenPrincipalIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> OAuth2AuthorizationContext.reauthorize(this.authorizedClient).build()) + assertThatThrownBy(() -> OAuth2AuthorizationContext.forReauthorization(this.authorizedClient).build()) .isInstanceOf(IllegalArgumentException.class) .hasMessage("principal cannot be null"); } @Test public void reauthorizeWhenAllValuesProvidedThenAllValuesAreSet() { - OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext.reauthorize(this.authorizedClient) + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext.forReauthorization(this.authorizedClient) .principal(this.principal) .attribute("attribute1", "value1") .attribute("attribute2", "value2") @@ -98,6 +98,6 @@ public void reauthorizeWhenAllValuesProvidedThenAllValuesAreSet() { assertThat(authorizationContext.getPrincipal()).isSameAs(this.principal); assertThat(authorizationContext.getAttributes()).contains( entry("attribute1", "value1"), entry("attribute2", "value2")); - assertThat(authorizationContext.reauthorizationRequired()).isTrue(); + assertThat(authorizationContext.reauthorizationRequested()).isTrue(); } } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProviderTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProviderTests.java index 284eea6175a..a361e3cc7f6 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProviderTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProviderTests.java @@ -104,7 +104,7 @@ public void authorizeWhenContextIsNullThenThrowIllegalArgumentException() { @Test public void authorizeWhenNotAuthorizedThenUnableToReauthorize() { OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.authorize(this.clientRegistration).principal(this.principal).build(); + OAuth2AuthorizationContext.forAuthorization(this.clientRegistration).principal(this.principal).build(); assertThat(this.authorizedClientProvider.authorize(authorizationContext)).isNull(); } @@ -113,14 +113,14 @@ public void authorizeWhenAuthorizedAndRefreshTokenIsNullThenUnableToReauthorize( OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( this.clientRegistration, this.principal.getName(), this.authorizedClient.getAccessToken()); OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.reauthorize(authorizedClient).principal(this.principal).build(); + OAuth2AuthorizationContext.forReauthorization(authorizedClient).principal(this.principal).build(); assertThat(this.authorizedClientProvider.authorize(authorizationContext)).isNull(); } @Test public void authorizeWhenHttpServletRequestIsNullThenThrowIllegalArgumentException() { OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.reauthorize(this.authorizedClient).principal(this.principal).build(); + OAuth2AuthorizationContext.forReauthorization(this.authorizedClient).principal(this.principal).build(); assertThatThrownBy(() -> this.authorizedClientProvider.authorize(authorizationContext)) .isInstanceOf(IllegalArgumentException.class) .hasMessage("The context attribute cannot be null 'javax.servlet.http.HttpServletRequest'"); @@ -129,7 +129,7 @@ public void authorizeWhenHttpServletRequestIsNullThenThrowIllegalArgumentExcepti @Test public void authorizeWhenHttpServletResponseIsNullThenThrowIllegalArgumentException() { OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.reauthorize(this.authorizedClient) + OAuth2AuthorizationContext.forReauthorization(this.authorizedClient) .principal(this.principal) .attribute(HttpServletRequest.class.getName(), new MockHttpServletRequest()) .build(); @@ -146,7 +146,7 @@ public void authorizeWhenAuthorizedWithRefreshTokenThenReauthorize() { when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse); OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.reauthorize(this.authorizedClient) + OAuth2AuthorizationContext.forReauthorization(this.authorizedClient) .principal(this.principal) .attribute(HttpServletRequest.class.getName(), new MockHttpServletRequest()) .attribute(HttpServletResponse.class.getName(), new MockHttpServletResponse()) @@ -174,7 +174,7 @@ public void authorizeWhenAuthorizedAndSpaceDelimitedScopeProvidedThenScopeReques Set scopes = new HashSet<>(Arrays.asList("read", "write")); OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.reauthorize(this.authorizedClient) + OAuth2AuthorizationContext.forReauthorization(this.authorizedClient) .principal(this.principal) .attribute(HttpServletRequest.class.getName(), new MockHttpServletRequest()) .attribute(HttpServletResponse.class.getName(), new MockHttpServletResponse()) @@ -200,7 +200,7 @@ public void authorizeWhenAuthorizedAndCommaDelimitedScopeProvidedThenScopeReques Set scopes = new HashSet<>(Arrays.asList("read", "write")); OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.reauthorize(this.authorizedClient) + OAuth2AuthorizationContext.forReauthorization(this.authorizedClient) .principal(this.principal) .attribute(HttpServletRequest.class.getName(), new MockHttpServletRequest()) .attribute(HttpServletResponse.class.getName(), new MockHttpServletResponse()) From dd6c5b0cf592cdfc8e969635f92173393d31dca7 Mon Sep 17 00:00:00 2001 From: Joe Grandja Date: Tue, 18 Jun 2019 15:22:12 -0400 Subject: [PATCH 10/19] OAuth2AuthorizedClientProvider should not save OAuth2AuthorizedClient --- .../OAuth2ClientConfiguration.java | 3 +- ...entialsOAuth2AuthorizedClientProvider.java | 41 ------------ ...shTokenOAuth2AuthorizedClientProvider.java | 42 ------------- ...Auth2AuthorizedClientArgumentResolver.java | 21 ++++--- ...uthorizedClientExchangeFilterFunction.java | 50 ++++++++------- ...lsOAuth2AuthorizedClientProviderTests.java | 62 +------------------ ...enOAuth2AuthorizedClientProviderTests.java | 58 +---------------- ...AuthorizedClientArgumentResolverTests.java | 3 +- ...izedClientExchangeFilterFunctionTests.java | 31 +--------- 9 files changed, 49 insertions(+), 262 deletions(-) diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfiguration.java b/config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfiguration.java index 09a9ae86b7e..d5e5bfbd021 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfiguration.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfiguration.java @@ -77,8 +77,7 @@ public void addArgumentResolvers(List argumentRes this.clientRegistrationRepository, this.authorizedClientRepository); if (this.accessTokenResponseClient != null) { ClientCredentialsOAuth2AuthorizedClientProvider clientCredentialsAuthorizedClientProvider = - new ClientCredentialsOAuth2AuthorizedClientProvider( - this.clientRegistrationRepository, this.authorizedClientRepository); + new ClientCredentialsOAuth2AuthorizedClientProvider(); clientCredentialsAuthorizedClientProvider.setAccessTokenResponseClient(this.accessTokenResponseClient); OAuth2AuthorizedClientProvider authorizedClientProvider = new DelegatingOAuth2AuthorizedClientProvider( new AuthorizationCodeOAuth2AuthorizedClientProvider(), clientCredentialsAuthorizedClientProvider); diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ClientCredentialsOAuth2AuthorizedClientProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ClientCredentialsOAuth2AuthorizedClientProvider.java index ac1e0168f1e..e67c69d170b 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ClientCredentialsOAuth2AuthorizedClientProvider.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ClientCredentialsOAuth2AuthorizedClientProvider.java @@ -20,15 +20,10 @@ import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest; import org.springframework.security.oauth2.client.registration.ClientRegistration; -import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; -import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.util.Assert; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; - /** * An implementation of an {@link OAuth2AuthorizedClientProvider} * for the {@link AuthorizationGrantType#CLIENT_CREDENTIALS client_credentials} grant. @@ -39,40 +34,15 @@ * @see DefaultClientCredentialsTokenResponseClient */ public final class ClientCredentialsOAuth2AuthorizedClientProvider implements OAuth2AuthorizedClientProvider { - private static final String HTTP_SERVLET_REQUEST_ATTRIBUTE_NAME = HttpServletRequest.class.getName(); - private static final String HTTP_SERVLET_RESPONSE_ATTRIBUTE_NAME = HttpServletResponse.class.getName(); - private final ClientRegistrationRepository clientRegistrationRepository; - private final OAuth2AuthorizedClientRepository authorizedClientRepository; private OAuth2AccessTokenResponseClient accessTokenResponseClient = new DefaultClientCredentialsTokenResponseClient(); - /** - * Constructs a {@code ClientCredentialsOAuth2AuthorizedClientProvider} using the provided parameters. - * - * @param clientRegistrationRepository the repository of client registrations - * @param authorizedClientRepository the repository of authorized clients - */ - public ClientCredentialsOAuth2AuthorizedClientProvider(ClientRegistrationRepository clientRegistrationRepository, - OAuth2AuthorizedClientRepository authorizedClientRepository) { - Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null"); - Assert.notNull(authorizedClientRepository, "authorizedClientRepository cannot be null"); - this.clientRegistrationRepository = clientRegistrationRepository; - this.authorizedClientRepository = authorizedClientRepository; - } - /** * Attempt to authorize (or re-authorize) the {@link OAuth2AuthorizationContext#getClientRegistration() client} in the provided {@code context}. * Returns {@code null} if authorization (or re-authorization) is not supported, * e.g. the client's {@link ClientRegistration#getAuthorizationGrantType() authorization grant type} * is not {@link AuthorizationGrantType#CLIENT_CREDENTIALS client_credentials}. * - *

- * The following {@link OAuth2AuthorizationContext#getAttributes() context attributes} are supported: - *

    - *
  1. {@code "javax.servlet.http.HttpServletRequest"} (required) - the {@code HttpServletRequest}
  2. - *
  3. {@code "javax.servlet.http.HttpServletResponse"} (required) - the {@code HttpServletResponse}
  4. - *
- * * @param context the context that holds authorization-specific state for the client * @return the {@link OAuth2AuthorizedClient} or {@code null} if authorization (or re-authorization) is not supported */ @@ -84,11 +54,6 @@ public OAuth2AuthorizedClient authorize(OAuth2AuthorizationContext context) { return null; } - HttpServletRequest request = context.getAttribute(HTTP_SERVLET_REQUEST_ATTRIBUTE_NAME); - HttpServletResponse response = context.getAttribute(HTTP_SERVLET_RESPONSE_ATTRIBUTE_NAME); - Assert.notNull(request, "The context attribute cannot be null '" + HTTP_SERVLET_REQUEST_ATTRIBUTE_NAME + "'"); - Assert.notNull(response, "The context attribute cannot be null '" + HTTP_SERVLET_RESPONSE_ATTRIBUTE_NAME + "'"); - // As per spec, in section 4.4.3 Access Token Response // https://tools.ietf.org/html/rfc6749#section-4.4.3 // A refresh token SHOULD NOT be included. @@ -106,12 +71,6 @@ public OAuth2AuthorizedClient authorize(OAuth2AuthorizationContext context) { context.getPrincipal().getName(), tokenResponse.getAccessToken()); - this.authorizedClientRepository.saveAuthorizedClient( - authorizedClient, - context.getPrincipal(), - request, - response); - return authorizedClient; } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProvider.java index d3a8d3fb9d7..5ba4e95a0bf 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProvider.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProvider.java @@ -19,15 +19,11 @@ import org.springframework.security.oauth2.client.endpoint.DefaultRefreshTokenTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.OAuth2RefreshTokenGrantRequest; -import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; -import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.util.Assert; import org.springframework.util.StringUtils; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; import java.util.Arrays; import java.util.Set; import java.util.stream.Collectors; @@ -50,42 +46,15 @@ public final class RefreshTokenOAuth2AuthorizedClientProvider implements OAuth2A */ public static final String REQUEST_SCOPE_ATTRIBUTE_NAME = "org.springframework.security.oauth2.client.REQUEST_SCOPE"; - private static final String HTTP_SERVLET_REQUEST_ATTRIBUTE_NAME = HttpServletRequest.class.getName(); - private static final String HTTP_SERVLET_RESPONSE_ATTRIBUTE_NAME = HttpServletResponse.class.getName(); - - private final ClientRegistrationRepository clientRegistrationRepository; - private final OAuth2AuthorizedClientRepository authorizedClientRepository; private OAuth2AccessTokenResponseClient accessTokenResponseClient = new DefaultRefreshTokenTokenResponseClient(); - /** - * Constructs a {@code RefreshTokenOAuth2AuthorizedClientProvider} using the provided parameters. - * - * @param clientRegistrationRepository the repository of client registrations - * @param authorizedClientRepository the repository of authorized clients - */ - public RefreshTokenOAuth2AuthorizedClientProvider(ClientRegistrationRepository clientRegistrationRepository, - OAuth2AuthorizedClientRepository authorizedClientRepository) { - Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null"); - Assert.notNull(authorizedClientRepository, "authorizedClientRepository cannot be null"); - this.clientRegistrationRepository = clientRegistrationRepository; - this.authorizedClientRepository = authorizedClientRepository; - } - /** * Attempt to re-authorize the {@link OAuth2AuthorizationContext#getClientRegistration() client} in the provided {@code context}. * Returns {@code null} if re-authorization is not supported, * e.g. the {@link OAuth2AuthorizedClient#getRefreshToken() refresh token} is not available for the * {@link OAuth2AuthorizationContext#getAuthorizedClient() authorized client}. * - *

- * The following {@link OAuth2AuthorizationContext#getAttributes() context attributes} are supported: - *

    - *
  1. {@code "javax.servlet.http.HttpServletRequest"} (required) - the {@code HttpServletRequest}
  2. - *
  3. {@code "javax.servlet.http.HttpServletResponse"} (required) - the {@code HttpServletResponse}
  4. - *
  5. {@code "org.springframework.security.oauth2.client.REQUEST_SCOPE"} (optional) - a space-delimited or comma-delimited {@code String} of scope(s) to be requested by the {@link OAuth2AuthorizationContext#getClientRegistration() client}
  6. - *
- * * @param context the context that holds authorization-specific state for the client * @return the {@link OAuth2AuthorizedClient} or {@code null} if re-authorization is not supported */ @@ -97,11 +66,6 @@ public OAuth2AuthorizedClient authorize(OAuth2AuthorizationContext context) { return null; } - HttpServletRequest request = context.getAttribute(HTTP_SERVLET_REQUEST_ATTRIBUTE_NAME); - HttpServletResponse response = context.getAttribute(HTTP_SERVLET_RESPONSE_ATTRIBUTE_NAME); - Assert.notNull(request, "The context attribute cannot be null '" + HTTP_SERVLET_REQUEST_ATTRIBUTE_NAME + "'"); - Assert.notNull(response, "The context attribute cannot be null '" + HTTP_SERVLET_RESPONSE_ATTRIBUTE_NAME + "'"); - String requestScope = context.getAttribute(REQUEST_SCOPE_ATTRIBUTE_NAME); Set scopes = null; if (!StringUtils.isEmpty(requestScope)) { @@ -120,12 +84,6 @@ public OAuth2AuthorizedClient authorize(OAuth2AuthorizationContext context) { tokenResponse.getAccessToken(), tokenResponse.getRefreshToken()); - this.authorizedClientRepository.saveAuthorizedClient( - authorizedClient, - context.getPrincipal(), - request, - response); - return authorizedClient; } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolver.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolver.java index eb124e31bc3..0b72616461b 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolver.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolver.java @@ -121,19 +121,24 @@ public Object resolveArgument(MethodParameter parameter, return null; } - HttpServletResponse servletResponse = webRequest.getNativeResponse(HttpServletResponse.class); - OAuth2AuthorizationContext.Builder contextBuilder = OAuth2AuthorizationContext.forAuthorization(clientRegistration); if (principal != null) { contextBuilder.principal(principal); } else { contextBuilder.principal("anonymousUser"); } - OAuth2AuthorizationContext authorizationContext = contextBuilder - .attribute(HttpServletRequest.class.getName(), servletRequest) - .attribute(HttpServletResponse.class.getName(), servletResponse) - .build(); - return this.authorizedClientProvider.authorize(authorizationContext); + OAuth2AuthorizationContext authorizationContext = contextBuilder.build(); + authorizedClient = this.authorizedClientProvider.authorize(authorizationContext); + + HttpServletResponse servletResponse = webRequest.getNativeResponse(HttpServletResponse.class); + + this.authorizedClientRepository.saveAuthorizedClient( + authorizedClient, + authorizationContext.getPrincipal(), + servletRequest, + servletResponse); + + return authorizedClient; } private String resolveClientRegistrationId(MethodParameter parameter) { @@ -182,7 +187,7 @@ public final void setClientCredentialsTokenResponseClient( private OAuth2AuthorizedClientProvider createAuthorizedClientProvider( OAuth2AccessTokenResponseClient clientCredentialsTokenResponseClient) { ClientCredentialsOAuth2AuthorizedClientProvider clientCredentialsAuthorizedClientProvider = - new ClientCredentialsOAuth2AuthorizedClientProvider(this.clientRegistrationRepository, this.authorizedClientRepository); + new ClientCredentialsOAuth2AuthorizedClientProvider(); clientCredentialsAuthorizedClientProvider.setAccessTokenResponseClient(clientCredentialsTokenResponseClient); return new DelegatingOAuth2AuthorizedClientProvider( new AuthorizationCodeOAuth2AuthorizedClientProvider(), clientCredentialsAuthorizedClientProvider); diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java index 085980e1ea2..476b6e32da0 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java @@ -53,7 +53,6 @@ import java.time.Clock; import java.time.Duration; import java.time.Instant; -import java.util.HashMap; import java.util.Map; import java.util.function.Consumer; @@ -132,15 +131,14 @@ public ServletOAuth2AuthorizedClientExchangeFilterFunction( OAuth2AuthorizedClientRepository authorizedClientRepository) { this.clientRegistrationRepository = clientRegistrationRepository; this.authorizedClientRepository = authorizedClientRepository; - this.authorizedClientProvider = createDefaultAuthorizedClientProvider(clientRegistrationRepository, authorizedClientRepository); + this.authorizedClientProvider = createDefaultAuthorizedClientProvider(); } - private static OAuth2AuthorizedClientProvider createDefaultAuthorizedClientProvider( - ClientRegistrationRepository clientRegistrationRepository, OAuth2AuthorizedClientRepository authorizedClientRepository) { + private static OAuth2AuthorizedClientProvider createDefaultAuthorizedClientProvider() { return new DelegatingOAuth2AuthorizedClientProvider( new AuthorizationCodeOAuth2AuthorizedClientProvider(), - new ClientCredentialsOAuth2AuthorizedClientProvider(clientRegistrationRepository, authorizedClientRepository), - new RefreshTokenOAuth2AuthorizedClientProvider(clientRegistrationRepository, authorizedClientRepository)); + new ClientCredentialsOAuth2AuthorizedClientProvider(), + new RefreshTokenOAuth2AuthorizedClientProvider()); } @Override @@ -181,12 +179,12 @@ public void setClientCredentialsTokenResponseClient( private OAuth2AuthorizedClientProvider createAuthorizedClientProvider( OAuth2AccessTokenResponseClient clientCredentialsTokenResponseClient) { ClientCredentialsOAuth2AuthorizedClientProvider clientCredentialsAuthorizedClientProvider = - new ClientCredentialsOAuth2AuthorizedClientProvider(this.clientRegistrationRepository, this.authorizedClientRepository); + new ClientCredentialsOAuth2AuthorizedClientProvider(); clientCredentialsAuthorizedClientProvider.setAccessTokenResponseClient(clientCredentialsTokenResponseClient); return new DelegatingOAuth2AuthorizedClientProvider( new AuthorizationCodeOAuth2AuthorizedClientProvider(), clientCredentialsAuthorizedClientProvider, - new RefreshTokenOAuth2AuthorizedClientProvider(this.clientRegistrationRepository, this.authorizedClientRepository)); + new RefreshTokenOAuth2AuthorizedClientProvider()); } /** @@ -406,10 +404,16 @@ private OAuth2AuthorizedClient authorizeClient(String clientRegistrationId, Map< } else { contextBuilder.principal("anonymousUser"); } - OAuth2AuthorizationContext authorizationContext = contextBuilder - .attributes(defaultContextAttributes(attributes)) - .build(); - return this.authorizedClientProvider.authorize(authorizationContext); + OAuth2AuthorizationContext authorizationContext = contextBuilder.build(); + OAuth2AuthorizedClient authorizedClient = this.authorizedClientProvider.authorize(authorizationContext); + + this.authorizedClientRepository.saveAuthorizedClient( + authorizedClient, + authorizationContext.getPrincipal(), + getRequest(attributes), + getResponse(attributes)); + + return authorizedClient; } private Mono reauthorizeClientIfNecessary( @@ -427,17 +431,17 @@ private Mono reauthorizeClientIfNecessary( } else { contextBuilder.principal(authorizedClient.getPrincipalName()); } - OAuth2AuthorizationContext authorizationContext = contextBuilder - .attributes(defaultContextAttributes(attributes)) - .build(); - return Mono.fromSupplier(() -> this.authorizedClientProvider.authorize(authorizationContext)); - } - - private Map defaultContextAttributes(Map attributes) { - Map contextAttributes = new HashMap<>(); - contextAttributes.put(HttpServletRequest.class.getName(), getRequest(attributes)); - contextAttributes.put(HttpServletResponse.class.getName(), getResponse(attributes)); - return contextAttributes; + OAuth2AuthorizationContext authorizationContext = contextBuilder.build(); + + return Mono.fromSupplier(() -> { + OAuth2AuthorizedClient reauthorizedClient = this.authorizedClientProvider.authorize(authorizationContext); + this.authorizedClientRepository.saveAuthorizedClient( + reauthorizedClient, + authorizationContext.getPrincipal(), + getRequest(attributes), + getResponse(attributes)); + return reauthorizedClient; + }); } private boolean hasTokenExpired(OAuth2AuthorizedClient authorizedClient) { diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/ClientCredentialsOAuth2AuthorizedClientProviderTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/ClientCredentialsOAuth2AuthorizedClientProviderTests.java index e0893a947fd..5f21d728a41 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/ClientCredentialsOAuth2AuthorizedClientProviderTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/ClientCredentialsOAuth2AuthorizedClientProviderTests.java @@ -17,27 +17,21 @@ import org.junit.Before; import org.junit.Test; -import org.springframework.mock.web.MockHttpServletRequest; -import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest; import org.springframework.security.oauth2.client.registration.ClientRegistration; -import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; -import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; - import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.*; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; /** * Tests for {@link ClientCredentialsOAuth2AuthorizedClientProvider}. @@ -45,8 +39,6 @@ * @author Joe Grandja */ public class ClientCredentialsOAuth2AuthorizedClientProviderTests { - private ClientRegistrationRepository clientRegistrationRepository; - private OAuth2AuthorizedClientRepository authorizedClientRepository; private ClientCredentialsOAuth2AuthorizedClientProvider authorizedClientProvider; private OAuth2AccessTokenResponseClient accessTokenResponseClient; private ClientRegistration clientRegistration; @@ -54,30 +46,13 @@ public class ClientCredentialsOAuth2AuthorizedClientProviderTests { @Before public void setup() { - this.clientRegistrationRepository = mock(ClientRegistrationRepository.class); - this.authorizedClientRepository = mock(OAuth2AuthorizedClientRepository.class); - this.authorizedClientProvider = new ClientCredentialsOAuth2AuthorizedClientProvider( - this.clientRegistrationRepository, this.authorizedClientRepository); + this.authorizedClientProvider = new ClientCredentialsOAuth2AuthorizedClientProvider(); this.accessTokenResponseClient = mock(OAuth2AccessTokenResponseClient.class); this.authorizedClientProvider.setAccessTokenResponseClient(this.accessTokenResponseClient); this.clientRegistration = TestClientRegistrations.clientCredentials().build(); this.principal = new TestingAuthenticationToken("principal", "password"); } - @Test - public void constructorWhenClientRegistrationRepositoryIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new ClientCredentialsOAuth2AuthorizedClientProvider(null, this.authorizedClientRepository)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("clientRegistrationRepository cannot be null"); - } - - @Test - public void constructorWhenOAuth2AuthorizedClientRepositoryIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new ClientCredentialsOAuth2AuthorizedClientProvider(this.clientRegistrationRepository, null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("authorizedClientRepository cannot be null"); - } - @Test public void setAccessTokenResponseClientWhenClientIsNullThenThrowIllegalArgumentException() { assertThatThrownBy(() -> this.authorizedClientProvider.setAccessTokenResponseClient(null)) @@ -100,27 +75,6 @@ public void authorizeWhenNotClientCredentialsThenUnableToAuthorize() { assertThat(this.authorizedClientProvider.authorize(authorizationContext)).isNull(); } - @Test - public void authorizeWhenHttpServletRequestIsNullThenThrowIllegalArgumentException() { - OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.forAuthorization(this.clientRegistration).principal(this.principal).build(); - assertThatThrownBy(() -> this.authorizedClientProvider.authorize(authorizationContext)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("The context attribute cannot be null 'javax.servlet.http.HttpServletRequest'"); - } - - @Test - public void authorizeWhenHttpServletResponseIsNullThenThrowIllegalArgumentException() { - OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.forAuthorization(this.clientRegistration) - .principal(this.principal) - .attribute(HttpServletRequest.class.getName(), new MockHttpServletRequest()) - .build(); - assertThatThrownBy(() -> this.authorizedClientProvider.authorize(authorizationContext)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("The context attribute cannot be null 'javax.servlet.http.HttpServletResponse'"); - } - @Test public void authorizeWhenClientCredentialsAndNotAuthorizedThenAuthorize() { OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); @@ -129,8 +83,6 @@ public void authorizeWhenClientCredentialsAndNotAuthorizedThenAuthorize() { OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext.forAuthorization(this.clientRegistration) .principal(this.principal) - .attribute(HttpServletRequest.class.getName(), new MockHttpServletRequest()) - .attribute(HttpServletResponse.class.getName(), new MockHttpServletResponse()) .build(); OAuth2AuthorizedClient authorizedClient = this.authorizedClientProvider.authorize(authorizationContext); @@ -138,9 +90,6 @@ public void authorizeWhenClientCredentialsAndNotAuthorizedThenAuthorize() { assertThat(authorizedClient.getClientRegistration()).isSameAs(this.clientRegistration); assertThat(authorizedClient.getPrincipalName()).isEqualTo(this.principal.getName()); assertThat(authorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken()); - verify(this.authorizedClientRepository).saveAuthorizedClient( - eq(authorizedClient), eq(this.principal), - any(HttpServletRequest.class), any(HttpServletResponse.class)); } @Test @@ -154,8 +103,6 @@ public void authorizeWhenClientCredentialsAndAuthorizedThenReauthorize() { OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext.forReauthorization(authorizedClient) .principal(this.principal) - .attribute(HttpServletRequest.class.getName(), new MockHttpServletRequest()) - .attribute(HttpServletResponse.class.getName(), new MockHttpServletResponse()) .build(); authorizedClient = this.authorizedClientProvider.authorize(authorizationContext); @@ -163,8 +110,5 @@ public void authorizeWhenClientCredentialsAndAuthorizedThenReauthorize() { assertThat(authorizedClient.getClientRegistration()).isSameAs(this.clientRegistration); assertThat(authorizedClient.getPrincipalName()).isEqualTo(this.principal.getName()); assertThat(authorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken()); - verify(this.authorizedClientRepository).saveAuthorizedClient( - eq(authorizedClient), eq(this.principal), - any(HttpServletRequest.class), any(HttpServletResponse.class)); } } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProviderTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProviderTests.java index a361e3cc7f6..20b44202b28 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProviderTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProviderTests.java @@ -18,23 +18,17 @@ import org.junit.Before; import org.junit.Test; import org.mockito.ArgumentCaptor; -import org.springframework.mock.web.MockHttpServletRequest; -import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.OAuth2RefreshTokenGrantRequest; import org.springframework.security.oauth2.client.registration.ClientRegistration; -import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; -import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; import java.util.Arrays; import java.util.HashSet; import java.util.Set; @@ -42,7 +36,6 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.*; /** @@ -51,8 +44,6 @@ * @author Joe Grandja */ public class RefreshTokenOAuth2AuthorizedClientProviderTests { - private ClientRegistrationRepository clientRegistrationRepository; - private OAuth2AuthorizedClientRepository authorizedClientRepository; private RefreshTokenOAuth2AuthorizedClientProvider authorizedClientProvider; private OAuth2AccessTokenResponseClient accessTokenResponseClient; private ClientRegistration clientRegistration; @@ -61,10 +52,7 @@ public class RefreshTokenOAuth2AuthorizedClientProviderTests { @Before public void setup() { - this.clientRegistrationRepository = mock(ClientRegistrationRepository.class); - this.authorizedClientRepository = mock(OAuth2AuthorizedClientRepository.class); - this.authorizedClientProvider = new RefreshTokenOAuth2AuthorizedClientProvider( - this.clientRegistrationRepository, this.authorizedClientRepository); + this.authorizedClientProvider = new RefreshTokenOAuth2AuthorizedClientProvider(); this.accessTokenResponseClient = mock(OAuth2AccessTokenResponseClient.class); this.authorizedClientProvider.setAccessTokenResponseClient(this.accessTokenResponseClient); this.clientRegistration = TestClientRegistrations.clientRegistration().build(); @@ -73,20 +61,6 @@ public void setup() { TestOAuth2AccessTokens.scopes("read", "write"), TestOAuth2RefreshTokens.refreshToken()); } - @Test - public void constructorWhenClientRegistrationRepositoryIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new RefreshTokenOAuth2AuthorizedClientProvider(null, this.authorizedClientRepository)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("clientRegistrationRepository cannot be null"); - } - - @Test - public void constructorWhenOAuth2AuthorizedClientRepositoryIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new RefreshTokenOAuth2AuthorizedClientProvider(this.clientRegistrationRepository, null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("authorizedClientRepository cannot be null"); - } - @Test public void setAccessTokenResponseClientWhenClientIsNullThenThrowIllegalArgumentException() { assertThatThrownBy(() -> this.authorizedClientProvider.setAccessTokenResponseClient(null)) @@ -117,27 +91,6 @@ public void authorizeWhenAuthorizedAndRefreshTokenIsNullThenUnableToReauthorize( assertThat(this.authorizedClientProvider.authorize(authorizationContext)).isNull(); } - @Test - public void authorizeWhenHttpServletRequestIsNullThenThrowIllegalArgumentException() { - OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.forReauthorization(this.authorizedClient).principal(this.principal).build(); - assertThatThrownBy(() -> this.authorizedClientProvider.authorize(authorizationContext)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("The context attribute cannot be null 'javax.servlet.http.HttpServletRequest'"); - } - - @Test - public void authorizeWhenHttpServletResponseIsNullThenThrowIllegalArgumentException() { - OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.forReauthorization(this.authorizedClient) - .principal(this.principal) - .attribute(HttpServletRequest.class.getName(), new MockHttpServletRequest()) - .build(); - assertThatThrownBy(() -> this.authorizedClientProvider.authorize(authorizationContext)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("The context attribute cannot be null 'javax.servlet.http.HttpServletResponse'"); - } - @Test public void authorizeWhenAuthorizedWithRefreshTokenThenReauthorize() { OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse() @@ -148,8 +101,6 @@ public void authorizeWhenAuthorizedWithRefreshTokenThenReauthorize() { OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext.forReauthorization(this.authorizedClient) .principal(this.principal) - .attribute(HttpServletRequest.class.getName(), new MockHttpServletRequest()) - .attribute(HttpServletResponse.class.getName(), new MockHttpServletResponse()) .build(); OAuth2AuthorizedClient authorizedClient = this.authorizedClientProvider.authorize(authorizationContext); @@ -158,9 +109,6 @@ public void authorizeWhenAuthorizedWithRefreshTokenThenReauthorize() { assertThat(authorizedClient.getPrincipalName()).isEqualTo(this.principal.getName()); assertThat(authorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken()); assertThat(authorizedClient.getRefreshToken()).isEqualTo(accessTokenResponse.getRefreshToken()); - verify(this.authorizedClientRepository).saveAuthorizedClient( - eq(authorizedClient), eq(this.principal), - any(HttpServletRequest.class), any(HttpServletResponse.class)); } @Test @@ -176,8 +124,6 @@ public void authorizeWhenAuthorizedAndSpaceDelimitedScopeProvidedThenScopeReques OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext.forReauthorization(this.authorizedClient) .principal(this.principal) - .attribute(HttpServletRequest.class.getName(), new MockHttpServletRequest()) - .attribute(HttpServletResponse.class.getName(), new MockHttpServletResponse()) .attribute(RefreshTokenOAuth2AuthorizedClientProvider.REQUEST_SCOPE_ATTRIBUTE_NAME, scope) .build(); @@ -202,8 +148,6 @@ public void authorizeWhenAuthorizedAndCommaDelimitedScopeProvidedThenScopeReques OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext.forReauthorization(this.authorizedClient) .principal(this.principal) - .attribute(HttpServletRequest.class.getName(), new MockHttpServletRequest()) - .attribute(HttpServletResponse.class.getName(), new MockHttpServletResponse()) .attribute(RefreshTokenOAuth2AuthorizedClientProvider.REQUEST_SCOPE_ATTRIBUTE_NAME, scope) .build(); diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolverTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolverTests.java index 491cfb52daa..c4b2562d9d2 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolverTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolverTests.java @@ -218,8 +218,7 @@ public void resolveArgumentWhenAuthorizedClientNotFoundForClientCredentialsClien OAuth2AccessTokenResponseClient clientCredentialsTokenResponseClient = mock(OAuth2AccessTokenResponseClient.class); ClientCredentialsOAuth2AuthorizedClientProvider clientCredentialsAuthorizedClientProvider = - new ClientCredentialsOAuth2AuthorizedClientProvider( - this.clientRegistrationRepository, this.authorizedClientRepository); + new ClientCredentialsOAuth2AuthorizedClientProvider(); clientCredentialsAuthorizedClientProvider.setAccessTokenResponseClient(clientCredentialsTokenResponseClient); this.argumentResolver.setAuthorizedClientProvider(clientCredentialsAuthorizedClientProvider); diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java index 0ad69a0e13e..4f26f293112 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java @@ -139,10 +139,10 @@ public void setup() { this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction( this.clientRegistrationRepository, this.authorizedClientRepository); ClientCredentialsOAuth2AuthorizedClientProvider clientCredentialsAuthorizedClientProvider = - new ClientCredentialsOAuth2AuthorizedClientProvider(this.clientRegistrationRepository, this.authorizedClientRepository); + new ClientCredentialsOAuth2AuthorizedClientProvider(); clientCredentialsAuthorizedClientProvider.setAccessTokenResponseClient(this.clientCredentialsTokenResponseClient); RefreshTokenOAuth2AuthorizedClientProvider refreshTokenAuthorizedClientProvider = - new RefreshTokenOAuth2AuthorizedClientProvider(this.clientRegistrationRepository, this.authorizedClientRepository); + new RefreshTokenOAuth2AuthorizedClientProvider(); refreshTokenAuthorizedClientProvider.setAccessTokenResponseClient(this.refreshTokenTokenResponseClient); this.authorizedClientProvider = new DelegatingOAuth2AuthorizedClientProvider( new AuthorizationCodeOAuth2AuthorizedClientProvider(), @@ -299,10 +299,6 @@ public void defaultRequestWhenClientCredentialsThenAuthorizedClient() { clientRegistrationId(this.registration.getRegistrationId()).accept(this.result); - MockHttpServletRequest request = new MockHttpServletRequest(); - MockHttpServletResponse response = new MockHttpServletResponse(); - RequestContextHolder.setRequestAttributes(new ServletRequestAttributes(request, response)); - Map attrs = getDefaultRequestAttributes(); OAuth2AuthorizedClient authorizedClient = getOAuth2AuthorizedClient(attrs); @@ -321,10 +317,6 @@ public void defaultRequestWhenDefaultClientRegistrationIdThenAuthorizedClient() .accessTokenResponse().build(); when(this.clientCredentialsTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse); - MockHttpServletRequest request = new MockHttpServletRequest(); - MockHttpServletResponse response = new MockHttpServletResponse(); - RequestContextHolder.setRequestAttributes(new ServletRequestAttributes(request, response)); - Map attrs = getDefaultRequestAttributes(); OAuth2AuthorizedClient authorizedClient = getOAuth2AuthorizedClient(attrs); @@ -412,8 +404,6 @@ public void filterWhenRefreshRequiredThenRefresh() { ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) .attributes(oauth2AuthorizedClient(authorizedClient)) .attributes(authentication(this.authentication)) - .attributes(httpServletRequest(new MockHttpServletRequest())) - .attributes(httpServletResponse(new MockHttpServletResponse())) .build(); this.function.filter(request, this.exchange).block(); @@ -449,8 +439,7 @@ public void filterWhenRefreshRequiredThenRefreshAndResponseDoesNotContainRefresh DefaultRefreshTokenTokenResponseClient refreshTokenTokenResponseClient = new DefaultRefreshTokenTokenResponseClient(); refreshTokenTokenResponseClient.setRestOperations(refreshTokenClient); - RefreshTokenOAuth2AuthorizedClientProvider authorizedClientProvider = new RefreshTokenOAuth2AuthorizedClientProvider( - this.clientRegistrationRepository, this.authorizedClientRepository); + RefreshTokenOAuth2AuthorizedClientProvider authorizedClientProvider = new RefreshTokenOAuth2AuthorizedClientProvider(); authorizedClientProvider.setAccessTokenResponseClient(refreshTokenTokenResponseClient); this.function.setAuthorizedClientProvider(authorizedClientProvider); @@ -468,8 +457,6 @@ public void filterWhenRefreshRequiredThenRefreshAndResponseDoesNotContainRefresh ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) .attributes(oauth2AuthorizedClient(authorizedClient)) .attributes(authentication(this.authentication)) - .attributes(httpServletRequest(new MockHttpServletRequest())) - .attributes(httpServletResponse(new MockHttpServletResponse())) .build(); this.function.filter(request, this.exchange).block(); @@ -539,8 +526,6 @@ public void filterWhenClientCredentialsTokenExpiredThenGetNewToken() { ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) .attributes(oauth2AuthorizedClient(authorizedClient)) .attributes(authentication(this.authentication)) - .attributes(httpServletRequest(new MockHttpServletRequest())) - .attributes(httpServletResponse(new MockHttpServletResponse())) .build(); this.function.filter(request, this.exchange).block(); @@ -580,8 +565,6 @@ public void filterWhenRefreshRequiredAndEmptyReactiveSecurityContextThenSaved() "principalName", this.accessToken, refreshToken); ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) .attributes(oauth2AuthorizedClient(authorizedClient)) - .attributes(httpServletRequest(new MockHttpServletRequest())) - .attributes(httpServletResponse(new MockHttpServletResponse())) .build(); this.function.filter(request, this.exchange).block(); @@ -695,10 +678,6 @@ public void filterWhenRequestAttributesNotSetAndHooksNotInitThenDefaultsNotAvail // this.function.afterPropertiesSet(); // Hooks.onLastOperator() NOT initialized this.function.setDefaultOAuth2AuthorizedClient(true); - MockHttpServletRequest servletRequest = new MockHttpServletRequest(); - MockHttpServletResponse servletResponse = new MockHttpServletResponse(); - RequestContextHolder.setRequestAttributes(new ServletRequestAttributes(servletRequest, servletResponse)); - OAuth2User user = mock(OAuth2User.class); List authorities = AuthorityUtils.createAuthorityList("ROLE_USER"); OAuth2AuthenticationToken authentication = new OAuth2AuthenticationToken( @@ -725,10 +704,6 @@ public void filterWhenRequestAttributesNotSetAndHooksInitHooksResetThenDefaultsN this.function.destroy(); // Hooks.onLastOperator() released this.function.setDefaultOAuth2AuthorizedClient(true); - MockHttpServletRequest servletRequest = new MockHttpServletRequest(); - MockHttpServletResponse servletResponse = new MockHttpServletResponse(); - RequestContextHolder.setRequestAttributes(new ServletRequestAttributes(servletRequest, servletResponse)); - OAuth2User user = mock(OAuth2User.class); List authorities = AuthorityUtils.createAuthorityList("ROLE_USER"); OAuth2AuthenticationToken authentication = new OAuth2AuthenticationToken( From 21d8528cfcb9a0a392a4d4f7c84384e5974a0282 Mon Sep 17 00:00:00 2001 From: Joe Grandja Date: Thu, 27 Jun 2019 12:56:43 -0400 Subject: [PATCH 11/19] Revert "OAuth2AuthorizedClientProvider should not save OAuth2AuthorizedClient" This reverts commit dd6c5b0cf592cdfc8e969635f92173393d31dca7. --- .../OAuth2ClientConfiguration.java | 3 +- ...entialsOAuth2AuthorizedClientProvider.java | 41 ++++++++++++ ...shTokenOAuth2AuthorizedClientProvider.java | 42 +++++++++++++ ...Auth2AuthorizedClientArgumentResolver.java | 21 +++---- ...uthorizedClientExchangeFilterFunction.java | 50 +++++++-------- ...lsOAuth2AuthorizedClientProviderTests.java | 62 ++++++++++++++++++- ...enOAuth2AuthorizedClientProviderTests.java | 58 ++++++++++++++++- ...AuthorizedClientArgumentResolverTests.java | 3 +- ...izedClientExchangeFilterFunctionTests.java | 31 +++++++++- 9 files changed, 262 insertions(+), 49 deletions(-) diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfiguration.java b/config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfiguration.java index d5e5bfbd021..09a9ae86b7e 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfiguration.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfiguration.java @@ -77,7 +77,8 @@ public void addArgumentResolvers(List argumentRes this.clientRegistrationRepository, this.authorizedClientRepository); if (this.accessTokenResponseClient != null) { ClientCredentialsOAuth2AuthorizedClientProvider clientCredentialsAuthorizedClientProvider = - new ClientCredentialsOAuth2AuthorizedClientProvider(); + new ClientCredentialsOAuth2AuthorizedClientProvider( + this.clientRegistrationRepository, this.authorizedClientRepository); clientCredentialsAuthorizedClientProvider.setAccessTokenResponseClient(this.accessTokenResponseClient); OAuth2AuthorizedClientProvider authorizedClientProvider = new DelegatingOAuth2AuthorizedClientProvider( new AuthorizationCodeOAuth2AuthorizedClientProvider(), clientCredentialsAuthorizedClientProvider); diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ClientCredentialsOAuth2AuthorizedClientProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ClientCredentialsOAuth2AuthorizedClientProvider.java index e67c69d170b..ac1e0168f1e 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ClientCredentialsOAuth2AuthorizedClientProvider.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ClientCredentialsOAuth2AuthorizedClientProvider.java @@ -20,10 +20,15 @@ import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest; import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; +import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.util.Assert; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + /** * An implementation of an {@link OAuth2AuthorizedClientProvider} * for the {@link AuthorizationGrantType#CLIENT_CREDENTIALS client_credentials} grant. @@ -34,15 +39,40 @@ * @see DefaultClientCredentialsTokenResponseClient */ public final class ClientCredentialsOAuth2AuthorizedClientProvider implements OAuth2AuthorizedClientProvider { + private static final String HTTP_SERVLET_REQUEST_ATTRIBUTE_NAME = HttpServletRequest.class.getName(); + private static final String HTTP_SERVLET_RESPONSE_ATTRIBUTE_NAME = HttpServletResponse.class.getName(); + private final ClientRegistrationRepository clientRegistrationRepository; + private final OAuth2AuthorizedClientRepository authorizedClientRepository; private OAuth2AccessTokenResponseClient accessTokenResponseClient = new DefaultClientCredentialsTokenResponseClient(); + /** + * Constructs a {@code ClientCredentialsOAuth2AuthorizedClientProvider} using the provided parameters. + * + * @param clientRegistrationRepository the repository of client registrations + * @param authorizedClientRepository the repository of authorized clients + */ + public ClientCredentialsOAuth2AuthorizedClientProvider(ClientRegistrationRepository clientRegistrationRepository, + OAuth2AuthorizedClientRepository authorizedClientRepository) { + Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null"); + Assert.notNull(authorizedClientRepository, "authorizedClientRepository cannot be null"); + this.clientRegistrationRepository = clientRegistrationRepository; + this.authorizedClientRepository = authorizedClientRepository; + } + /** * Attempt to authorize (or re-authorize) the {@link OAuth2AuthorizationContext#getClientRegistration() client} in the provided {@code context}. * Returns {@code null} if authorization (or re-authorization) is not supported, * e.g. the client's {@link ClientRegistration#getAuthorizationGrantType() authorization grant type} * is not {@link AuthorizationGrantType#CLIENT_CREDENTIALS client_credentials}. * + *

+ * The following {@link OAuth2AuthorizationContext#getAttributes() context attributes} are supported: + *

    + *
  1. {@code "javax.servlet.http.HttpServletRequest"} (required) - the {@code HttpServletRequest}
  2. + *
  3. {@code "javax.servlet.http.HttpServletResponse"} (required) - the {@code HttpServletResponse}
  4. + *
+ * * @param context the context that holds authorization-specific state for the client * @return the {@link OAuth2AuthorizedClient} or {@code null} if authorization (or re-authorization) is not supported */ @@ -54,6 +84,11 @@ public OAuth2AuthorizedClient authorize(OAuth2AuthorizationContext context) { return null; } + HttpServletRequest request = context.getAttribute(HTTP_SERVLET_REQUEST_ATTRIBUTE_NAME); + HttpServletResponse response = context.getAttribute(HTTP_SERVLET_RESPONSE_ATTRIBUTE_NAME); + Assert.notNull(request, "The context attribute cannot be null '" + HTTP_SERVLET_REQUEST_ATTRIBUTE_NAME + "'"); + Assert.notNull(response, "The context attribute cannot be null '" + HTTP_SERVLET_RESPONSE_ATTRIBUTE_NAME + "'"); + // As per spec, in section 4.4.3 Access Token Response // https://tools.ietf.org/html/rfc6749#section-4.4.3 // A refresh token SHOULD NOT be included. @@ -71,6 +106,12 @@ public OAuth2AuthorizedClient authorize(OAuth2AuthorizationContext context) { context.getPrincipal().getName(), tokenResponse.getAccessToken()); + this.authorizedClientRepository.saveAuthorizedClient( + authorizedClient, + context.getPrincipal(), + request, + response); + return authorizedClient; } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProvider.java index 5ba4e95a0bf..d3a8d3fb9d7 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProvider.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProvider.java @@ -19,11 +19,15 @@ import org.springframework.security.oauth2.client.endpoint.DefaultRefreshTokenTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.OAuth2RefreshTokenGrantRequest; +import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; +import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.util.Assert; import org.springframework.util.StringUtils; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; import java.util.Arrays; import java.util.Set; import java.util.stream.Collectors; @@ -46,15 +50,42 @@ public final class RefreshTokenOAuth2AuthorizedClientProvider implements OAuth2A */ public static final String REQUEST_SCOPE_ATTRIBUTE_NAME = "org.springframework.security.oauth2.client.REQUEST_SCOPE"; + private static final String HTTP_SERVLET_REQUEST_ATTRIBUTE_NAME = HttpServletRequest.class.getName(); + private static final String HTTP_SERVLET_RESPONSE_ATTRIBUTE_NAME = HttpServletResponse.class.getName(); + + private final ClientRegistrationRepository clientRegistrationRepository; + private final OAuth2AuthorizedClientRepository authorizedClientRepository; private OAuth2AccessTokenResponseClient accessTokenResponseClient = new DefaultRefreshTokenTokenResponseClient(); + /** + * Constructs a {@code RefreshTokenOAuth2AuthorizedClientProvider} using the provided parameters. + * + * @param clientRegistrationRepository the repository of client registrations + * @param authorizedClientRepository the repository of authorized clients + */ + public RefreshTokenOAuth2AuthorizedClientProvider(ClientRegistrationRepository clientRegistrationRepository, + OAuth2AuthorizedClientRepository authorizedClientRepository) { + Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null"); + Assert.notNull(authorizedClientRepository, "authorizedClientRepository cannot be null"); + this.clientRegistrationRepository = clientRegistrationRepository; + this.authorizedClientRepository = authorizedClientRepository; + } + /** * Attempt to re-authorize the {@link OAuth2AuthorizationContext#getClientRegistration() client} in the provided {@code context}. * Returns {@code null} if re-authorization is not supported, * e.g. the {@link OAuth2AuthorizedClient#getRefreshToken() refresh token} is not available for the * {@link OAuth2AuthorizationContext#getAuthorizedClient() authorized client}. * + *

+ * The following {@link OAuth2AuthorizationContext#getAttributes() context attributes} are supported: + *

    + *
  1. {@code "javax.servlet.http.HttpServletRequest"} (required) - the {@code HttpServletRequest}
  2. + *
  3. {@code "javax.servlet.http.HttpServletResponse"} (required) - the {@code HttpServletResponse}
  4. + *
  5. {@code "org.springframework.security.oauth2.client.REQUEST_SCOPE"} (optional) - a space-delimited or comma-delimited {@code String} of scope(s) to be requested by the {@link OAuth2AuthorizationContext#getClientRegistration() client}
  6. + *
+ * * @param context the context that holds authorization-specific state for the client * @return the {@link OAuth2AuthorizedClient} or {@code null} if re-authorization is not supported */ @@ -66,6 +97,11 @@ public OAuth2AuthorizedClient authorize(OAuth2AuthorizationContext context) { return null; } + HttpServletRequest request = context.getAttribute(HTTP_SERVLET_REQUEST_ATTRIBUTE_NAME); + HttpServletResponse response = context.getAttribute(HTTP_SERVLET_RESPONSE_ATTRIBUTE_NAME); + Assert.notNull(request, "The context attribute cannot be null '" + HTTP_SERVLET_REQUEST_ATTRIBUTE_NAME + "'"); + Assert.notNull(response, "The context attribute cannot be null '" + HTTP_SERVLET_RESPONSE_ATTRIBUTE_NAME + "'"); + String requestScope = context.getAttribute(REQUEST_SCOPE_ATTRIBUTE_NAME); Set scopes = null; if (!StringUtils.isEmpty(requestScope)) { @@ -84,6 +120,12 @@ public OAuth2AuthorizedClient authorize(OAuth2AuthorizationContext context) { tokenResponse.getAccessToken(), tokenResponse.getRefreshToken()); + this.authorizedClientRepository.saveAuthorizedClient( + authorizedClient, + context.getPrincipal(), + request, + response); + return authorizedClient; } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolver.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolver.java index 0b72616461b..eb124e31bc3 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolver.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolver.java @@ -121,24 +121,19 @@ public Object resolveArgument(MethodParameter parameter, return null; } + HttpServletResponse servletResponse = webRequest.getNativeResponse(HttpServletResponse.class); + OAuth2AuthorizationContext.Builder contextBuilder = OAuth2AuthorizationContext.forAuthorization(clientRegistration); if (principal != null) { contextBuilder.principal(principal); } else { contextBuilder.principal("anonymousUser"); } - OAuth2AuthorizationContext authorizationContext = contextBuilder.build(); - authorizedClient = this.authorizedClientProvider.authorize(authorizationContext); - - HttpServletResponse servletResponse = webRequest.getNativeResponse(HttpServletResponse.class); - - this.authorizedClientRepository.saveAuthorizedClient( - authorizedClient, - authorizationContext.getPrincipal(), - servletRequest, - servletResponse); - - return authorizedClient; + OAuth2AuthorizationContext authorizationContext = contextBuilder + .attribute(HttpServletRequest.class.getName(), servletRequest) + .attribute(HttpServletResponse.class.getName(), servletResponse) + .build(); + return this.authorizedClientProvider.authorize(authorizationContext); } private String resolveClientRegistrationId(MethodParameter parameter) { @@ -187,7 +182,7 @@ public final void setClientCredentialsTokenResponseClient( private OAuth2AuthorizedClientProvider createAuthorizedClientProvider( OAuth2AccessTokenResponseClient clientCredentialsTokenResponseClient) { ClientCredentialsOAuth2AuthorizedClientProvider clientCredentialsAuthorizedClientProvider = - new ClientCredentialsOAuth2AuthorizedClientProvider(); + new ClientCredentialsOAuth2AuthorizedClientProvider(this.clientRegistrationRepository, this.authorizedClientRepository); clientCredentialsAuthorizedClientProvider.setAccessTokenResponseClient(clientCredentialsTokenResponseClient); return new DelegatingOAuth2AuthorizedClientProvider( new AuthorizationCodeOAuth2AuthorizedClientProvider(), clientCredentialsAuthorizedClientProvider); diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java index 476b6e32da0..085980e1ea2 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java @@ -53,6 +53,7 @@ import java.time.Clock; import java.time.Duration; import java.time.Instant; +import java.util.HashMap; import java.util.Map; import java.util.function.Consumer; @@ -131,14 +132,15 @@ public ServletOAuth2AuthorizedClientExchangeFilterFunction( OAuth2AuthorizedClientRepository authorizedClientRepository) { this.clientRegistrationRepository = clientRegistrationRepository; this.authorizedClientRepository = authorizedClientRepository; - this.authorizedClientProvider = createDefaultAuthorizedClientProvider(); + this.authorizedClientProvider = createDefaultAuthorizedClientProvider(clientRegistrationRepository, authorizedClientRepository); } - private static OAuth2AuthorizedClientProvider createDefaultAuthorizedClientProvider() { + private static OAuth2AuthorizedClientProvider createDefaultAuthorizedClientProvider( + ClientRegistrationRepository clientRegistrationRepository, OAuth2AuthorizedClientRepository authorizedClientRepository) { return new DelegatingOAuth2AuthorizedClientProvider( new AuthorizationCodeOAuth2AuthorizedClientProvider(), - new ClientCredentialsOAuth2AuthorizedClientProvider(), - new RefreshTokenOAuth2AuthorizedClientProvider()); + new ClientCredentialsOAuth2AuthorizedClientProvider(clientRegistrationRepository, authorizedClientRepository), + new RefreshTokenOAuth2AuthorizedClientProvider(clientRegistrationRepository, authorizedClientRepository)); } @Override @@ -179,12 +181,12 @@ public void setClientCredentialsTokenResponseClient( private OAuth2AuthorizedClientProvider createAuthorizedClientProvider( OAuth2AccessTokenResponseClient clientCredentialsTokenResponseClient) { ClientCredentialsOAuth2AuthorizedClientProvider clientCredentialsAuthorizedClientProvider = - new ClientCredentialsOAuth2AuthorizedClientProvider(); + new ClientCredentialsOAuth2AuthorizedClientProvider(this.clientRegistrationRepository, this.authorizedClientRepository); clientCredentialsAuthorizedClientProvider.setAccessTokenResponseClient(clientCredentialsTokenResponseClient); return new DelegatingOAuth2AuthorizedClientProvider( new AuthorizationCodeOAuth2AuthorizedClientProvider(), clientCredentialsAuthorizedClientProvider, - new RefreshTokenOAuth2AuthorizedClientProvider()); + new RefreshTokenOAuth2AuthorizedClientProvider(this.clientRegistrationRepository, this.authorizedClientRepository)); } /** @@ -404,16 +406,10 @@ private OAuth2AuthorizedClient authorizeClient(String clientRegistrationId, Map< } else { contextBuilder.principal("anonymousUser"); } - OAuth2AuthorizationContext authorizationContext = contextBuilder.build(); - OAuth2AuthorizedClient authorizedClient = this.authorizedClientProvider.authorize(authorizationContext); - - this.authorizedClientRepository.saveAuthorizedClient( - authorizedClient, - authorizationContext.getPrincipal(), - getRequest(attributes), - getResponse(attributes)); - - return authorizedClient; + OAuth2AuthorizationContext authorizationContext = contextBuilder + .attributes(defaultContextAttributes(attributes)) + .build(); + return this.authorizedClientProvider.authorize(authorizationContext); } private Mono reauthorizeClientIfNecessary( @@ -431,17 +427,17 @@ private Mono reauthorizeClientIfNecessary( } else { contextBuilder.principal(authorizedClient.getPrincipalName()); } - OAuth2AuthorizationContext authorizationContext = contextBuilder.build(); - - return Mono.fromSupplier(() -> { - OAuth2AuthorizedClient reauthorizedClient = this.authorizedClientProvider.authorize(authorizationContext); - this.authorizedClientRepository.saveAuthorizedClient( - reauthorizedClient, - authorizationContext.getPrincipal(), - getRequest(attributes), - getResponse(attributes)); - return reauthorizedClient; - }); + OAuth2AuthorizationContext authorizationContext = contextBuilder + .attributes(defaultContextAttributes(attributes)) + .build(); + return Mono.fromSupplier(() -> this.authorizedClientProvider.authorize(authorizationContext)); + } + + private Map defaultContextAttributes(Map attributes) { + Map contextAttributes = new HashMap<>(); + contextAttributes.put(HttpServletRequest.class.getName(), getRequest(attributes)); + contextAttributes.put(HttpServletResponse.class.getName(), getResponse(attributes)); + return contextAttributes; } private boolean hasTokenExpired(OAuth2AuthorizedClient authorizedClient) { diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/ClientCredentialsOAuth2AuthorizedClientProviderTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/ClientCredentialsOAuth2AuthorizedClientProviderTests.java index 5f21d728a41..e0893a947fd 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/ClientCredentialsOAuth2AuthorizedClientProviderTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/ClientCredentialsOAuth2AuthorizedClientProviderTests.java @@ -17,21 +17,27 @@ import org.junit.Before; import org.junit.Test; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest; import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; +import static org.mockito.Mockito.*; /** * Tests for {@link ClientCredentialsOAuth2AuthorizedClientProvider}. @@ -39,6 +45,8 @@ * @author Joe Grandja */ public class ClientCredentialsOAuth2AuthorizedClientProviderTests { + private ClientRegistrationRepository clientRegistrationRepository; + private OAuth2AuthorizedClientRepository authorizedClientRepository; private ClientCredentialsOAuth2AuthorizedClientProvider authorizedClientProvider; private OAuth2AccessTokenResponseClient accessTokenResponseClient; private ClientRegistration clientRegistration; @@ -46,13 +54,30 @@ public class ClientCredentialsOAuth2AuthorizedClientProviderTests { @Before public void setup() { - this.authorizedClientProvider = new ClientCredentialsOAuth2AuthorizedClientProvider(); + this.clientRegistrationRepository = mock(ClientRegistrationRepository.class); + this.authorizedClientRepository = mock(OAuth2AuthorizedClientRepository.class); + this.authorizedClientProvider = new ClientCredentialsOAuth2AuthorizedClientProvider( + this.clientRegistrationRepository, this.authorizedClientRepository); this.accessTokenResponseClient = mock(OAuth2AccessTokenResponseClient.class); this.authorizedClientProvider.setAccessTokenResponseClient(this.accessTokenResponseClient); this.clientRegistration = TestClientRegistrations.clientCredentials().build(); this.principal = new TestingAuthenticationToken("principal", "password"); } + @Test + public void constructorWhenClientRegistrationRepositoryIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new ClientCredentialsOAuth2AuthorizedClientProvider(null, this.authorizedClientRepository)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("clientRegistrationRepository cannot be null"); + } + + @Test + public void constructorWhenOAuth2AuthorizedClientRepositoryIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new ClientCredentialsOAuth2AuthorizedClientProvider(this.clientRegistrationRepository, null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("authorizedClientRepository cannot be null"); + } + @Test public void setAccessTokenResponseClientWhenClientIsNullThenThrowIllegalArgumentException() { assertThatThrownBy(() -> this.authorizedClientProvider.setAccessTokenResponseClient(null)) @@ -75,6 +100,27 @@ public void authorizeWhenNotClientCredentialsThenUnableToAuthorize() { assertThat(this.authorizedClientProvider.authorize(authorizationContext)).isNull(); } + @Test + public void authorizeWhenHttpServletRequestIsNullThenThrowIllegalArgumentException() { + OAuth2AuthorizationContext authorizationContext = + OAuth2AuthorizationContext.forAuthorization(this.clientRegistration).principal(this.principal).build(); + assertThatThrownBy(() -> this.authorizedClientProvider.authorize(authorizationContext)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("The context attribute cannot be null 'javax.servlet.http.HttpServletRequest'"); + } + + @Test + public void authorizeWhenHttpServletResponseIsNullThenThrowIllegalArgumentException() { + OAuth2AuthorizationContext authorizationContext = + OAuth2AuthorizationContext.forAuthorization(this.clientRegistration) + .principal(this.principal) + .attribute(HttpServletRequest.class.getName(), new MockHttpServletRequest()) + .build(); + assertThatThrownBy(() -> this.authorizedClientProvider.authorize(authorizationContext)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("The context attribute cannot be null 'javax.servlet.http.HttpServletResponse'"); + } + @Test public void authorizeWhenClientCredentialsAndNotAuthorizedThenAuthorize() { OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); @@ -83,6 +129,8 @@ public void authorizeWhenClientCredentialsAndNotAuthorizedThenAuthorize() { OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext.forAuthorization(this.clientRegistration) .principal(this.principal) + .attribute(HttpServletRequest.class.getName(), new MockHttpServletRequest()) + .attribute(HttpServletResponse.class.getName(), new MockHttpServletResponse()) .build(); OAuth2AuthorizedClient authorizedClient = this.authorizedClientProvider.authorize(authorizationContext); @@ -90,6 +138,9 @@ public void authorizeWhenClientCredentialsAndNotAuthorizedThenAuthorize() { assertThat(authorizedClient.getClientRegistration()).isSameAs(this.clientRegistration); assertThat(authorizedClient.getPrincipalName()).isEqualTo(this.principal.getName()); assertThat(authorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken()); + verify(this.authorizedClientRepository).saveAuthorizedClient( + eq(authorizedClient), eq(this.principal), + any(HttpServletRequest.class), any(HttpServletResponse.class)); } @Test @@ -103,6 +154,8 @@ public void authorizeWhenClientCredentialsAndAuthorizedThenReauthorize() { OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext.forReauthorization(authorizedClient) .principal(this.principal) + .attribute(HttpServletRequest.class.getName(), new MockHttpServletRequest()) + .attribute(HttpServletResponse.class.getName(), new MockHttpServletResponse()) .build(); authorizedClient = this.authorizedClientProvider.authorize(authorizationContext); @@ -110,5 +163,8 @@ public void authorizeWhenClientCredentialsAndAuthorizedThenReauthorize() { assertThat(authorizedClient.getClientRegistration()).isSameAs(this.clientRegistration); assertThat(authorizedClient.getPrincipalName()).isEqualTo(this.principal.getName()); assertThat(authorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken()); + verify(this.authorizedClientRepository).saveAuthorizedClient( + eq(authorizedClient), eq(this.principal), + any(HttpServletRequest.class), any(HttpServletResponse.class)); } } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProviderTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProviderTests.java index 20b44202b28..a361e3cc7f6 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProviderTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProviderTests.java @@ -18,17 +18,23 @@ import org.junit.Before; import org.junit.Test; import org.mockito.ArgumentCaptor; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.OAuth2RefreshTokenGrantRequest; import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; import java.util.Arrays; import java.util.HashSet; import java.util.Set; @@ -36,6 +42,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.*; /** @@ -44,6 +51,8 @@ * @author Joe Grandja */ public class RefreshTokenOAuth2AuthorizedClientProviderTests { + private ClientRegistrationRepository clientRegistrationRepository; + private OAuth2AuthorizedClientRepository authorizedClientRepository; private RefreshTokenOAuth2AuthorizedClientProvider authorizedClientProvider; private OAuth2AccessTokenResponseClient accessTokenResponseClient; private ClientRegistration clientRegistration; @@ -52,7 +61,10 @@ public class RefreshTokenOAuth2AuthorizedClientProviderTests { @Before public void setup() { - this.authorizedClientProvider = new RefreshTokenOAuth2AuthorizedClientProvider(); + this.clientRegistrationRepository = mock(ClientRegistrationRepository.class); + this.authorizedClientRepository = mock(OAuth2AuthorizedClientRepository.class); + this.authorizedClientProvider = new RefreshTokenOAuth2AuthorizedClientProvider( + this.clientRegistrationRepository, this.authorizedClientRepository); this.accessTokenResponseClient = mock(OAuth2AccessTokenResponseClient.class); this.authorizedClientProvider.setAccessTokenResponseClient(this.accessTokenResponseClient); this.clientRegistration = TestClientRegistrations.clientRegistration().build(); @@ -61,6 +73,20 @@ public void setup() { TestOAuth2AccessTokens.scopes("read", "write"), TestOAuth2RefreshTokens.refreshToken()); } + @Test + public void constructorWhenClientRegistrationRepositoryIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new RefreshTokenOAuth2AuthorizedClientProvider(null, this.authorizedClientRepository)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("clientRegistrationRepository cannot be null"); + } + + @Test + public void constructorWhenOAuth2AuthorizedClientRepositoryIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new RefreshTokenOAuth2AuthorizedClientProvider(this.clientRegistrationRepository, null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("authorizedClientRepository cannot be null"); + } + @Test public void setAccessTokenResponseClientWhenClientIsNullThenThrowIllegalArgumentException() { assertThatThrownBy(() -> this.authorizedClientProvider.setAccessTokenResponseClient(null)) @@ -91,6 +117,27 @@ public void authorizeWhenAuthorizedAndRefreshTokenIsNullThenUnableToReauthorize( assertThat(this.authorizedClientProvider.authorize(authorizationContext)).isNull(); } + @Test + public void authorizeWhenHttpServletRequestIsNullThenThrowIllegalArgumentException() { + OAuth2AuthorizationContext authorizationContext = + OAuth2AuthorizationContext.forReauthorization(this.authorizedClient).principal(this.principal).build(); + assertThatThrownBy(() -> this.authorizedClientProvider.authorize(authorizationContext)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("The context attribute cannot be null 'javax.servlet.http.HttpServletRequest'"); + } + + @Test + public void authorizeWhenHttpServletResponseIsNullThenThrowIllegalArgumentException() { + OAuth2AuthorizationContext authorizationContext = + OAuth2AuthorizationContext.forReauthorization(this.authorizedClient) + .principal(this.principal) + .attribute(HttpServletRequest.class.getName(), new MockHttpServletRequest()) + .build(); + assertThatThrownBy(() -> this.authorizedClientProvider.authorize(authorizationContext)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("The context attribute cannot be null 'javax.servlet.http.HttpServletResponse'"); + } + @Test public void authorizeWhenAuthorizedWithRefreshTokenThenReauthorize() { OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse() @@ -101,6 +148,8 @@ public void authorizeWhenAuthorizedWithRefreshTokenThenReauthorize() { OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext.forReauthorization(this.authorizedClient) .principal(this.principal) + .attribute(HttpServletRequest.class.getName(), new MockHttpServletRequest()) + .attribute(HttpServletResponse.class.getName(), new MockHttpServletResponse()) .build(); OAuth2AuthorizedClient authorizedClient = this.authorizedClientProvider.authorize(authorizationContext); @@ -109,6 +158,9 @@ public void authorizeWhenAuthorizedWithRefreshTokenThenReauthorize() { assertThat(authorizedClient.getPrincipalName()).isEqualTo(this.principal.getName()); assertThat(authorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken()); assertThat(authorizedClient.getRefreshToken()).isEqualTo(accessTokenResponse.getRefreshToken()); + verify(this.authorizedClientRepository).saveAuthorizedClient( + eq(authorizedClient), eq(this.principal), + any(HttpServletRequest.class), any(HttpServletResponse.class)); } @Test @@ -124,6 +176,8 @@ public void authorizeWhenAuthorizedAndSpaceDelimitedScopeProvidedThenScopeReques OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext.forReauthorization(this.authorizedClient) .principal(this.principal) + .attribute(HttpServletRequest.class.getName(), new MockHttpServletRequest()) + .attribute(HttpServletResponse.class.getName(), new MockHttpServletResponse()) .attribute(RefreshTokenOAuth2AuthorizedClientProvider.REQUEST_SCOPE_ATTRIBUTE_NAME, scope) .build(); @@ -148,6 +202,8 @@ public void authorizeWhenAuthorizedAndCommaDelimitedScopeProvidedThenScopeReques OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext.forReauthorization(this.authorizedClient) .principal(this.principal) + .attribute(HttpServletRequest.class.getName(), new MockHttpServletRequest()) + .attribute(HttpServletResponse.class.getName(), new MockHttpServletResponse()) .attribute(RefreshTokenOAuth2AuthorizedClientProvider.REQUEST_SCOPE_ATTRIBUTE_NAME, scope) .build(); diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolverTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolverTests.java index c4b2562d9d2..491cfb52daa 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolverTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolverTests.java @@ -218,7 +218,8 @@ public void resolveArgumentWhenAuthorizedClientNotFoundForClientCredentialsClien OAuth2AccessTokenResponseClient clientCredentialsTokenResponseClient = mock(OAuth2AccessTokenResponseClient.class); ClientCredentialsOAuth2AuthorizedClientProvider clientCredentialsAuthorizedClientProvider = - new ClientCredentialsOAuth2AuthorizedClientProvider(); + new ClientCredentialsOAuth2AuthorizedClientProvider( + this.clientRegistrationRepository, this.authorizedClientRepository); clientCredentialsAuthorizedClientProvider.setAccessTokenResponseClient(clientCredentialsTokenResponseClient); this.argumentResolver.setAuthorizedClientProvider(clientCredentialsAuthorizedClientProvider); diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java index 4f26f293112..0ad69a0e13e 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java @@ -139,10 +139,10 @@ public void setup() { this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction( this.clientRegistrationRepository, this.authorizedClientRepository); ClientCredentialsOAuth2AuthorizedClientProvider clientCredentialsAuthorizedClientProvider = - new ClientCredentialsOAuth2AuthorizedClientProvider(); + new ClientCredentialsOAuth2AuthorizedClientProvider(this.clientRegistrationRepository, this.authorizedClientRepository); clientCredentialsAuthorizedClientProvider.setAccessTokenResponseClient(this.clientCredentialsTokenResponseClient); RefreshTokenOAuth2AuthorizedClientProvider refreshTokenAuthorizedClientProvider = - new RefreshTokenOAuth2AuthorizedClientProvider(); + new RefreshTokenOAuth2AuthorizedClientProvider(this.clientRegistrationRepository, this.authorizedClientRepository); refreshTokenAuthorizedClientProvider.setAccessTokenResponseClient(this.refreshTokenTokenResponseClient); this.authorizedClientProvider = new DelegatingOAuth2AuthorizedClientProvider( new AuthorizationCodeOAuth2AuthorizedClientProvider(), @@ -299,6 +299,10 @@ public void defaultRequestWhenClientCredentialsThenAuthorizedClient() { clientRegistrationId(this.registration.getRegistrationId()).accept(this.result); + MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletResponse response = new MockHttpServletResponse(); + RequestContextHolder.setRequestAttributes(new ServletRequestAttributes(request, response)); + Map attrs = getDefaultRequestAttributes(); OAuth2AuthorizedClient authorizedClient = getOAuth2AuthorizedClient(attrs); @@ -317,6 +321,10 @@ public void defaultRequestWhenDefaultClientRegistrationIdThenAuthorizedClient() .accessTokenResponse().build(); when(this.clientCredentialsTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse); + MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletResponse response = new MockHttpServletResponse(); + RequestContextHolder.setRequestAttributes(new ServletRequestAttributes(request, response)); + Map attrs = getDefaultRequestAttributes(); OAuth2AuthorizedClient authorizedClient = getOAuth2AuthorizedClient(attrs); @@ -404,6 +412,8 @@ public void filterWhenRefreshRequiredThenRefresh() { ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) .attributes(oauth2AuthorizedClient(authorizedClient)) .attributes(authentication(this.authentication)) + .attributes(httpServletRequest(new MockHttpServletRequest())) + .attributes(httpServletResponse(new MockHttpServletResponse())) .build(); this.function.filter(request, this.exchange).block(); @@ -439,7 +449,8 @@ public void filterWhenRefreshRequiredThenRefreshAndResponseDoesNotContainRefresh DefaultRefreshTokenTokenResponseClient refreshTokenTokenResponseClient = new DefaultRefreshTokenTokenResponseClient(); refreshTokenTokenResponseClient.setRestOperations(refreshTokenClient); - RefreshTokenOAuth2AuthorizedClientProvider authorizedClientProvider = new RefreshTokenOAuth2AuthorizedClientProvider(); + RefreshTokenOAuth2AuthorizedClientProvider authorizedClientProvider = new RefreshTokenOAuth2AuthorizedClientProvider( + this.clientRegistrationRepository, this.authorizedClientRepository); authorizedClientProvider.setAccessTokenResponseClient(refreshTokenTokenResponseClient); this.function.setAuthorizedClientProvider(authorizedClientProvider); @@ -457,6 +468,8 @@ public void filterWhenRefreshRequiredThenRefreshAndResponseDoesNotContainRefresh ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) .attributes(oauth2AuthorizedClient(authorizedClient)) .attributes(authentication(this.authentication)) + .attributes(httpServletRequest(new MockHttpServletRequest())) + .attributes(httpServletResponse(new MockHttpServletResponse())) .build(); this.function.filter(request, this.exchange).block(); @@ -526,6 +539,8 @@ public void filterWhenClientCredentialsTokenExpiredThenGetNewToken() { ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) .attributes(oauth2AuthorizedClient(authorizedClient)) .attributes(authentication(this.authentication)) + .attributes(httpServletRequest(new MockHttpServletRequest())) + .attributes(httpServletResponse(new MockHttpServletResponse())) .build(); this.function.filter(request, this.exchange).block(); @@ -565,6 +580,8 @@ public void filterWhenRefreshRequiredAndEmptyReactiveSecurityContextThenSaved() "principalName", this.accessToken, refreshToken); ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) .attributes(oauth2AuthorizedClient(authorizedClient)) + .attributes(httpServletRequest(new MockHttpServletRequest())) + .attributes(httpServletResponse(new MockHttpServletResponse())) .build(); this.function.filter(request, this.exchange).block(); @@ -678,6 +695,10 @@ public void filterWhenRequestAttributesNotSetAndHooksNotInitThenDefaultsNotAvail // this.function.afterPropertiesSet(); // Hooks.onLastOperator() NOT initialized this.function.setDefaultOAuth2AuthorizedClient(true); + MockHttpServletRequest servletRequest = new MockHttpServletRequest(); + MockHttpServletResponse servletResponse = new MockHttpServletResponse(); + RequestContextHolder.setRequestAttributes(new ServletRequestAttributes(servletRequest, servletResponse)); + OAuth2User user = mock(OAuth2User.class); List authorities = AuthorityUtils.createAuthorityList("ROLE_USER"); OAuth2AuthenticationToken authentication = new OAuth2AuthenticationToken( @@ -704,6 +725,10 @@ public void filterWhenRequestAttributesNotSetAndHooksInitHooksResetThenDefaultsN this.function.destroy(); // Hooks.onLastOperator() released this.function.setDefaultOAuth2AuthorizedClient(true); + MockHttpServletRequest servletRequest = new MockHttpServletRequest(); + MockHttpServletResponse servletResponse = new MockHttpServletResponse(); + RequestContextHolder.setRequestAttributes(new ServletRequestAttributes(servletRequest, servletResponse)); + OAuth2User user = mock(OAuth2User.class); List authorities = AuthorityUtils.createAuthorityList("ROLE_USER"); OAuth2AuthenticationToken authentication = new OAuth2AuthenticationToken( From f9fc7d3c6631dcea55c2f46d60be6b2707f91e45 Mon Sep 17 00:00:00 2001 From: Joe Grandja Date: Thu, 27 Jun 2019 15:30:15 -0400 Subject: [PATCH 12/19] OAuth2AuthorizedClientProvider implementations load/save OAuth2AuthorizedClient - #59 Redesign OAuth2AuthorizedClientProvider to load/save OAuth2AuthorizedClient - #60 ClientCredentialsOAuth2AuthorizedClientProvider should load/save OAuth2AuthorizedClient - #61 RefreshTokenOAuth2AuthorizedClientProvider should load/save OAuth2AuthorizedClient - #62 Refactor and use redesigned OAuth2AuthorizedClientProvider implementations --- .../OAuth2ClientConfiguration.java | 15 +- .../OAuth2ClientConfigurationTests.java | 31 ++-- ...ionCodeOAuth2AuthorizedClientProvider.java | 62 ++++++- ...entialsOAuth2AuthorizedClientProvider.java | 56 ++++-- ...DefaultOAuth2AuthorizedClientProvider.java | 88 ++++++++++ ...egatingOAuth2AuthorizedClientProvider.java | 18 +- .../client/OAuth2AuthorizationContext.java | 92 +++------- .../OAuth2AuthorizedClientProvider.java | 2 +- ...shTokenOAuth2AuthorizedClientProvider.java | 78 ++++++--- ...Auth2AuthorizedClientArgumentResolver.java | 30 ++-- ...uthorizedClientExchangeFilterFunction.java | 107 +++++------- ...deOAuth2AuthorizedClientProviderTests.java | 108 +++++++++++- ...lsOAuth2AuthorizedClientProviderTests.java | 105 ++++++++++-- ...ltOAuth2AuthorizedClientProviderTests.java | 137 +++++++++++++++ ...ngOAuth2AuthorizedClientProviderTests.java | 39 ++++- .../OAuth2AuthorizationContextTests.java | 48 ++---- ...enOAuth2AuthorizedClientProviderTests.java | 160 +++++++++++++----- ...AuthorizedClientArgumentResolverTests.java | 13 +- ...izedClientExchangeFilterFunctionTests.java | 110 ++++++++++-- 19 files changed, 957 insertions(+), 342 deletions(-) create mode 100644 oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/DefaultOAuth2AuthorizedClientProvider.java create mode 100644 oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/DefaultOAuth2AuthorizedClientProviderTests.java diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfiguration.java b/config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfiguration.java index 09a9ae86b7e..be486ffb545 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfiguration.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfiguration.java @@ -22,8 +22,9 @@ import org.springframework.core.type.AnnotationMetadata; import org.springframework.security.oauth2.client.AuthorizationCodeOAuth2AuthorizedClientProvider; import org.springframework.security.oauth2.client.ClientCredentialsOAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.DefaultOAuth2AuthorizedClientProvider; import org.springframework.security.oauth2.client.DelegatingOAuth2AuthorizedClientProvider; -import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.RefreshTokenOAuth2AuthorizedClientProvider; import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest; import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; @@ -80,8 +81,16 @@ public void addArgumentResolvers(List argumentRes new ClientCredentialsOAuth2AuthorizedClientProvider( this.clientRegistrationRepository, this.authorizedClientRepository); clientCredentialsAuthorizedClientProvider.setAccessTokenResponseClient(this.accessTokenResponseClient); - OAuth2AuthorizedClientProvider authorizedClientProvider = new DelegatingOAuth2AuthorizedClientProvider( - new AuthorizationCodeOAuth2AuthorizedClientProvider(), clientCredentialsAuthorizedClientProvider); + AuthorizationCodeOAuth2AuthorizedClientProvider authorizationCodeAuthorizedClientProvider = + new AuthorizationCodeOAuth2AuthorizedClientProvider( + this.clientRegistrationRepository, this.authorizedClientRepository); + RefreshTokenOAuth2AuthorizedClientProvider refreshTokenAuthorizedClientProvider = + new RefreshTokenOAuth2AuthorizedClientProvider( + this.clientRegistrationRepository, this.authorizedClientRepository); + DelegatingOAuth2AuthorizedClientProvider authorizedClientProvider = new DelegatingOAuth2AuthorizedClientProvider( + authorizationCodeAuthorizedClientProvider, refreshTokenAuthorizedClientProvider, clientCredentialsAuthorizedClientProvider); + authorizedClientProvider.setDefaultAuthorizedClientProvider( + new DefaultOAuth2AuthorizedClientProvider(this.clientRegistrationRepository, this.authorizedClientRepository)); authorizedClientArgumentResolver.setAuthorizedClientProvider(authorizedClientProvider); } argumentResolvers.add(authorizedClientArgumentResolver); diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfigurationTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfigurationTests.java index b01ece3c430..43f9523ee7a 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfigurationTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfigurationTests.java @@ -15,21 +15,6 @@ */ package org.springframework.security.config.annotation.web.configuration; -import static org.assertj.core.api.Assertions.assertThatThrownBy; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.verifyZeroInteractions; -import static org.mockito.Mockito.when; -import static org.springframework.security.oauth2.client.registration.TestClientRegistrations.clientCredentials; -import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.authentication; -import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; -import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.content; -import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; - -import javax.servlet.http.HttpServletRequest; import org.junit.Rule; import org.junit.Test; import org.springframework.beans.factory.NoSuchBeanDefinitionException; @@ -53,6 +38,19 @@ import org.springframework.web.bind.annotation.RestController; import org.springframework.web.servlet.config.annotation.EnableWebMvc; +import javax.servlet.http.HttpServletRequest; + +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.*; +import static org.springframework.security.oauth2.client.registration.TestClientRegistrations.clientCredentials; +import static org.springframework.security.oauth2.client.registration.TestClientRegistrations.clientRegistration; +import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.authentication; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.content; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; + /** * Tests for {@link OAuth2ClientConfiguration}. * @@ -72,6 +70,9 @@ public void requestWhenAuthorizedClientFoundThenMethodArgumentResolved() throws TestingAuthenticationToken authentication = new TestingAuthenticationToken(principalName, "password"); ClientRegistrationRepository clientRegistrationRepository = mock(ClientRegistrationRepository.class); + ClientRegistration clientRegistration = clientRegistration().registrationId(clientRegistrationId).build(); + when(clientRegistrationRepository.findByRegistrationId(eq(clientRegistrationId))).thenReturn(clientRegistration); + OAuth2AuthorizedClientRepository authorizedClientRepository = mock(OAuth2AuthorizedClientRepository.class); OAuth2AuthorizedClient authorizedClient = mock(OAuth2AuthorizedClient.class); when(authorizedClientRepository.loadAuthorizedClient( diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/AuthorizationCodeOAuth2AuthorizedClientProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/AuthorizationCodeOAuth2AuthorizedClientProvider.java index e0ed9e9d294..eb775e25eda 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/AuthorizationCodeOAuth2AuthorizedClientProvider.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/AuthorizationCodeOAuth2AuthorizedClientProvider.java @@ -16,9 +16,15 @@ package org.springframework.security.oauth2.client; import org.springframework.lang.Nullable; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; +import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.util.Assert; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + /** * An implementation of an {@link OAuth2AuthorizedClientProvider} * for the {@link AuthorizationGrantType#AUTHORIZATION_CODE authorization_code} grant. @@ -28,16 +34,66 @@ * @see OAuth2AuthorizedClientProvider */ public final class AuthorizationCodeOAuth2AuthorizedClientProvider implements OAuth2AuthorizedClientProvider { + private static final String HTTP_SERVLET_REQUEST_ATTRIBUTE_NAME = HttpServletRequest.class.getName(); + private static final String HTTP_SERVLET_RESPONSE_ATTRIBUTE_NAME = HttpServletResponse.class.getName(); + private final ClientRegistrationRepository clientRegistrationRepository; + private final OAuth2AuthorizedClientRepository authorizedClientRepository; + + /** + * Constructs an {@code AuthorizationCodeOAuth2AuthorizedClientProvider} using the provided parameters. + * + * @param clientRegistrationRepository the repository of client registrations + * @param authorizedClientRepository the repository of authorized clients + */ + public AuthorizationCodeOAuth2AuthorizedClientProvider(ClientRegistrationRepository clientRegistrationRepository, + OAuth2AuthorizedClientRepository authorizedClientRepository) { + Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null"); + Assert.notNull(authorizedClientRepository, "authorizedClientRepository cannot be null"); + this.clientRegistrationRepository = clientRegistrationRepository; + this.authorizedClientRepository = authorizedClientRepository; + } + /** + * Attempt to authorize the {@link OAuth2AuthorizationContext#getClientRegistrationId() client} in the provided {@code context}. + * Returns {@code null} if authorization is not supported, + * e.g. the client's {@link ClientRegistration#getAuthorizationGrantType() authorization grant type} + * is not {@link AuthorizationGrantType#AUTHORIZATION_CODE authorization_code} OR the client is already authorized. + * + *

+ * The following {@link OAuth2AuthorizationContext#getAttributes() context attributes} are supported: + *

    + *
  1. {@code "javax.servlet.http.HttpServletRequest"} (required) - the {@code HttpServletRequest}
  2. + *
  3. {@code "javax.servlet.http.HttpServletResponse"} (required) - the {@code HttpServletResponse}
  4. + *
+ * + * @param context the context that holds authorization-specific state for the client + * @return the {@link OAuth2AuthorizedClient} or {@code null} if authorization is not supported + */ @Override @Nullable public OAuth2AuthorizedClient authorize(OAuth2AuthorizationContext context) { Assert.notNull(context, "context cannot be null"); - if (AuthorizationGrantType.AUTHORIZATION_CODE.equals(context.getClientRegistration().getAuthorizationGrantType()) && - context.authorizationRequested()) { + + HttpServletRequest request = context.getAttribute(HTTP_SERVLET_REQUEST_ATTRIBUTE_NAME); + HttpServletResponse response = context.getAttribute(HTTP_SERVLET_RESPONSE_ATTRIBUTE_NAME); + Assert.notNull(request, "The context attribute cannot be null '" + HTTP_SERVLET_REQUEST_ATTRIBUTE_NAME + "'"); + Assert.notNull(response, "The context attribute cannot be null '" + HTTP_SERVLET_RESPONSE_ATTRIBUTE_NAME + "'"); + + String clientRegistrationId = context.getClientRegistrationId(); + ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId); + Assert.notNull(clientRegistration, "Could not find ClientRegistration with id '" + clientRegistrationId + "'"); + + if (!AuthorizationGrantType.AUTHORIZATION_CODE.equals(clientRegistration.getAuthorizationGrantType())) { + return null; + } + + OAuth2AuthorizedClient authorizedClient = this.authorizedClientRepository.loadAuthorizedClient( + clientRegistrationId, context.getPrincipal(), request); + if (authorizedClient == null) { // ClientAuthorizationRequiredException is caught by OAuth2AuthorizationRequestRedirectFilter which initiates authorization - throw new ClientAuthorizationRequiredException(context.getClientRegistration().getRegistrationId()); + throw new ClientAuthorizationRequiredException(clientRegistrationId); } + return null; } } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ClientCredentialsOAuth2AuthorizedClientProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ClientCredentialsOAuth2AuthorizedClientProvider.java index ac1e0168f1e..62de43008b6 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ClientCredentialsOAuth2AuthorizedClientProvider.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ClientCredentialsOAuth2AuthorizedClientProvider.java @@ -22,12 +22,15 @@ import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; +import org.springframework.security.oauth2.core.AbstractOAuth2Token; import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.util.Assert; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; +import java.time.Duration; +import java.time.Instant; /** * An implementation of an {@link OAuth2AuthorizedClientProvider} @@ -45,6 +48,7 @@ public final class ClientCredentialsOAuth2AuthorizedClientProvider implements OA private final OAuth2AuthorizedClientRepository authorizedClientRepository; private OAuth2AccessTokenResponseClient accessTokenResponseClient = new DefaultClientCredentialsTokenResponseClient(); + private Duration clockSkew = Duration.ofSeconds(60); /** * Constructs a {@code ClientCredentialsOAuth2AuthorizedClientProvider} using the provided parameters. @@ -61,10 +65,11 @@ public ClientCredentialsOAuth2AuthorizedClientProvider(ClientRegistrationReposit } /** - * Attempt to authorize (or re-authorize) the {@link OAuth2AuthorizationContext#getClientRegistration() client} in the provided {@code context}. + * Attempt to authorize (or re-authorize) the {@link OAuth2AuthorizationContext#getClientRegistrationId() client} in the provided {@code context}. * Returns {@code null} if authorization (or re-authorization) is not supported, * e.g. the client's {@link ClientRegistration#getAuthorizationGrantType() authorization grant type} - * is not {@link AuthorizationGrantType#CLIENT_CREDENTIALS client_credentials}. + * is not {@link AuthorizationGrantType#CLIENT_CREDENTIALS client_credentials} OR + * the {@link OAuth2AuthorizedClient#getAccessToken() access token} is not expired. * *

* The following {@link OAuth2AuthorizationContext#getAttributes() context attributes} are supported: @@ -80,15 +85,26 @@ public ClientCredentialsOAuth2AuthorizedClientProvider(ClientRegistrationReposit @Nullable public OAuth2AuthorizedClient authorize(OAuth2AuthorizationContext context) { Assert.notNull(context, "context cannot be null"); - if (!AuthorizationGrantType.CLIENT_CREDENTIALS.equals(context.getClientRegistration().getAuthorizationGrantType())) { - return null; - } HttpServletRequest request = context.getAttribute(HTTP_SERVLET_REQUEST_ATTRIBUTE_NAME); HttpServletResponse response = context.getAttribute(HTTP_SERVLET_RESPONSE_ATTRIBUTE_NAME); Assert.notNull(request, "The context attribute cannot be null '" + HTTP_SERVLET_REQUEST_ATTRIBUTE_NAME + "'"); Assert.notNull(response, "The context attribute cannot be null '" + HTTP_SERVLET_RESPONSE_ATTRIBUTE_NAME + "'"); + String clientRegistrationId = context.getClientRegistrationId(); + ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId); + Assert.notNull(clientRegistration, "Could not find ClientRegistration with id '" + clientRegistrationId + "'"); + + if (!AuthorizationGrantType.CLIENT_CREDENTIALS.equals(clientRegistration.getAuthorizationGrantType())) { + return null; + } + + OAuth2AuthorizedClient authorizedClient = this.authorizedClientRepository.loadAuthorizedClient( + clientRegistrationId, context.getPrincipal(), request); + if (authorizedClient != null && !hasTokenExpired(authorizedClient.getAccessToken())) { + return null; + } + // As per spec, in section 4.4.3 Access Token Response // https://tools.ietf.org/html/rfc6749#section-4.4.3 // A refresh token SHOULD NOT be included. @@ -97,24 +113,23 @@ public OAuth2AuthorizedClient authorize(OAuth2AuthorizationContext context) { // is the same as acquiring a new access token (authorization). OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = - new OAuth2ClientCredentialsGrantRequest(context.getClientRegistration()); + new OAuth2ClientCredentialsGrantRequest(clientRegistration); OAuth2AccessTokenResponse tokenResponse = this.accessTokenResponseClient.getTokenResponse(clientCredentialsGrantRequest); - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( - context.getClientRegistration(), - context.getPrincipal().getName(), - tokenResponse.getAccessToken()); + authorizedClient = new OAuth2AuthorizedClient( + clientRegistration, context.getPrincipal().getName(), tokenResponse.getAccessToken()); this.authorizedClientRepository.saveAuthorizedClient( - authorizedClient, - context.getPrincipal(), - request, - response); + authorizedClient, context.getPrincipal(), request, response); return authorizedClient; } + private boolean hasTokenExpired(AbstractOAuth2Token token) { + return token.getExpiresAt().isBefore(Instant.now().minus(this.clockSkew)); + } + /** * Sets the client used when requesting an access token credential at the Token Endpoint for the {@code client_credentials} grant. * @@ -124,4 +139,17 @@ public void setAccessTokenResponseClient(OAuth2AccessTokenResponseClient= 0, "clockSkew must be >= 0"); + this.clockSkew = clockSkew; + } } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/DefaultOAuth2AuthorizedClientProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/DefaultOAuth2AuthorizedClientProvider.java new file mode 100644 index 00000000000..494f14f96dd --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/DefaultOAuth2AuthorizedClientProvider.java @@ -0,0 +1,88 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client; + +import org.springframework.lang.Nullable; +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; +import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; +import org.springframework.util.Assert; + +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +/** + * The default implementation of an {@link OAuth2AuthorizedClientProvider} that simply + * {@link OAuth2AuthorizedClientRepository#loadAuthorizedClient(String, Authentication, HttpServletRequest) loads} + * an {@link OAuth2AuthorizedClient} from the authorized client repository. + * + * @author Joe Grandja + * @since 5.2 + * @see OAuth2AuthorizedClientProvider + */ +public final class DefaultOAuth2AuthorizedClientProvider implements OAuth2AuthorizedClientProvider { + private static final String HTTP_SERVLET_REQUEST_ATTRIBUTE_NAME = HttpServletRequest.class.getName(); + private static final String HTTP_SERVLET_RESPONSE_ATTRIBUTE_NAME = HttpServletResponse.class.getName(); + private final ClientRegistrationRepository clientRegistrationRepository; + private final OAuth2AuthorizedClientRepository authorizedClientRepository; + + /** + * Constructs an {@code DefaultOAuth2AuthorizedClientProvider} using the provided parameters. + * + * @param clientRegistrationRepository the repository of client registrations + * @param authorizedClientRepository the repository of authorized clients + */ + public DefaultOAuth2AuthorizedClientProvider(ClientRegistrationRepository clientRegistrationRepository, + OAuth2AuthorizedClientRepository authorizedClientRepository) { + Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null"); + Assert.notNull(authorizedClientRepository, "authorizedClientRepository cannot be null"); + this.clientRegistrationRepository = clientRegistrationRepository; + this.authorizedClientRepository = authorizedClientRepository; + } + + /** + * Attempts to {@link OAuth2AuthorizedClientRepository#loadAuthorizedClient(String, Authentication, HttpServletRequest) load} + * an {@link OAuth2AuthorizedClient} using the {@link OAuth2AuthorizationContext#getClientRegistrationId() client} + * in the provided {@code context}. Returns {@code null} if the client is not authorized. + * + *

+ * The following {@link OAuth2AuthorizationContext#getAttributes() context attributes} are supported: + *

    + *
  1. {@code "javax.servlet.http.HttpServletRequest"} (required) - the {@code HttpServletRequest}
  2. + *
  3. {@code "javax.servlet.http.HttpServletResponse"} (required) - the {@code HttpServletResponse}
  4. + *
+ * + * @param context the context that holds authorization-specific state for the client + * @return the {@link OAuth2AuthorizedClient} or {@code null} if the client is not authorized + */ + @Override + @Nullable + public OAuth2AuthorizedClient authorize(OAuth2AuthorizationContext context) { + Assert.notNull(context, "context cannot be null"); + + HttpServletRequest request = context.getAttribute(HTTP_SERVLET_REQUEST_ATTRIBUTE_NAME); + HttpServletResponse response = context.getAttribute(HTTP_SERVLET_RESPONSE_ATTRIBUTE_NAME); + Assert.notNull(request, "The context attribute cannot be null '" + HTTP_SERVLET_REQUEST_ATTRIBUTE_NAME + "'"); + Assert.notNull(response, "The context attribute cannot be null '" + HTTP_SERVLET_RESPONSE_ATTRIBUTE_NAME + "'"); + + String clientRegistrationId = context.getClientRegistrationId(); + ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId); + Assert.notNull(clientRegistration, "Could not find ClientRegistration with id '" + clientRegistrationId + "'"); + + return this.authorizedClientRepository.loadAuthorizedClient(clientRegistrationId, context.getPrincipal(), request); + } +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/DelegatingOAuth2AuthorizedClientProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/DelegatingOAuth2AuthorizedClientProvider.java index 0343b96071c..3167abbc9be 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/DelegatingOAuth2AuthorizedClientProvider.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/DelegatingOAuth2AuthorizedClientProvider.java @@ -30,7 +30,7 @@ *

* Each provider is given a chance to * {@link OAuth2AuthorizedClientProvider#authorize(OAuth2AuthorizationContext) authorize} - * the {@link OAuth2AuthorizationContext#getClientRegistration() client} in the provided context + * the {@link OAuth2AuthorizationContext#getClientRegistrationId() client} in the provided context * with the first {@code non-null} {@link OAuth2AuthorizedClient} being returned. * * @author Joe Grandja @@ -39,6 +39,7 @@ */ public final class DelegatingOAuth2AuthorizedClientProvider implements OAuth2AuthorizedClientProvider { private final List authorizedClientProviders; + private OAuth2AuthorizedClientProvider defaultAuthorizedClientProvider; /** * Constructs a {@code DelegatingOAuth2AuthorizedClientProvider} using the provided parameters. @@ -68,6 +69,19 @@ public OAuth2AuthorizedClient authorize(OAuth2AuthorizationContext context) { .map(authorizedClientProvider -> authorizedClientProvider.authorize(context)) .filter(Objects::nonNull) .findFirst() - .orElse(null); + .orElse(this.defaultAuthorizedClientProvider != null ? + this.defaultAuthorizedClientProvider.authorize(context) : null); + } + + /** + * Sets the default {@link OAuth2AuthorizedClientProvider} used if none of the + * {@link OAuth2AuthorizedClientProvider}(s) in the {@code List} + * are able to authorize the {@link OAuth2AuthorizationContext#getClientRegistrationId() client}. + * + * @param authorizedClientProvider the default {@link OAuth2AuthorizedClientProvider} + */ + public void setDefaultAuthorizedClientProvider(OAuth2AuthorizedClientProvider authorizedClientProvider) { + Assert.notNull(authorizedClientProvider, "authorizedClientProvider cannot be null"); + this.defaultAuthorizedClientProvider = authorizedClientProvider; } } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizationContext.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizationContext.java index 122111fbc6d..55790d962cd 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizationContext.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizationContext.java @@ -37,43 +37,31 @@ * @see OAuth2AuthorizedClientProvider */ public final class OAuth2AuthorizationContext { - private ClientRegistration clientRegistration; + private String clientRegistrationId; private Authentication principal; - private OAuth2AuthorizedClient authorizedClient; private Map attributes; private OAuth2AuthorizationContext() { } /** - * Returns the {@link ClientRegistration client} requesting authorization. + * Returns the {@link ClientRegistration client registration} identifier. * - * @return the {@link ClientRegistration} + * @return the client registration identifier */ - public ClientRegistration getClientRegistration() { - return this.clientRegistration; + public String getClientRegistrationId() { + return this.clientRegistrationId; } /** - * Returns the {@code Principal} to be associated with the authorized client. + * Returns the {@code Principal} (to be) associated to the authorized client. * - * @return the {@code Principal} to be associated with the authorized client + * @return the {@code Principal} (to be) associated to the authorized client */ public Authentication getPrincipal() { return this.principal; } - /** - * Returns the {@link OAuth2AuthorizedClient authorized client} requesting re-authorization - * or {@code null} if the {@link #getClientRegistration() client} is requesting to be authorized. - * - * @return the {@link OAuth2AuthorizedClient} requesting re-authorization or {@code null} if the client is requesting to be authorized - */ - @Nullable - public OAuth2AuthorizedClient getAuthorizedClient() { - return this.authorizedClient; - } - /** * Returns the attributes associated to the context. * @@ -97,66 +85,32 @@ public T getAttribute(String name) { } /** - * Returns {@code true} if the client is requesting authorization, otherwise {@code false}. + * Returns a new {@link Builder} initialized with the {@link ClientRegistration client registration} identifier. * - * @return {@code true} if the client is requesting authorization, otherwise {@code false}. - */ - public boolean authorizationRequested() { - return getAuthorizedClient() == null; - } - - /** - * Returns {@code true} if the client is requesting re-authorization, otherwise {@code false}. - * - * @return {@code true} if the client is requesting re-authorization, otherwise {@code false}. - */ - public boolean reauthorizationRequested() { - return getAuthorizedClient() != null; - } - - /** - * Returns a new {@link Builder} with the {@link ClientRegistration client} requesting authorization. - * - * @param clientRegistration the {@link ClientRegistration client} requesting authorization - * @return the {@link Builder} - */ - public static Builder forAuthorization(ClientRegistration clientRegistration) { - return new Builder(clientRegistration); - } - - /** - * Returns a new {@link Builder} with the {@link OAuth2AuthorizedClient authorized client} requesting re-authorization. - * - * @param authorizedClient the {@link OAuth2AuthorizedClient authorized client} requesting re-authorization + * @param clientRegistrationId the {@link ClientRegistration client registration} identifier * @return the {@link Builder} */ - public static Builder forReauthorization(OAuth2AuthorizedClient authorizedClient) { - return new Builder(authorizedClient); + public static Builder forClient(String clientRegistrationId) { + return new Builder(clientRegistrationId); } /** * A builder for {@link OAuth2AuthorizationContext}. */ public static class Builder { - private ClientRegistration clientRegistration; + private String clientRegistrationId; private Authentication principal; - private OAuth2AuthorizedClient authorizedClient; private Map attributes; - private Builder(ClientRegistration clientRegistration) { - Assert.notNull(clientRegistration, "clientRegistration cannot be null"); - this.clientRegistration = clientRegistration; - } - - private Builder(OAuth2AuthorizedClient authorizedClient) { - Assert.notNull(authorizedClient, "authorizedClient cannot be null"); - this.authorizedClient = authorizedClient; + private Builder(String clientRegistrationId) { + Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty"); + this.clientRegistrationId = clientRegistrationId; } /** - * Sets the {@code Principal} to be associated with the authorized client + * Sets the {@code Principal} (to be) associated to the authorized client. * - * @param principal the {@code Principal} to be associated with the authorized client + * @param principal the {@code Principal} (to be) associated to the authorized client * @return the {@link Builder} */ public Builder principal(Authentication principal) { @@ -165,9 +119,9 @@ public Builder principal(Authentication principal) { } /** - * Sets the {@code Principal}'s name to be associated with the authorized client + * Sets the {@code Principal}'s name (to be) associated to the authorized client. * - * @param principalName the {@code Principal}'s name to be associated with the authorized client + * @param principalName the {@code Principal}'s name (to be) associated to the authorized client * @return the {@link Builder} */ public Builder principal(String principalName) { @@ -209,12 +163,7 @@ public Builder attribute(String name, Object value) { public OAuth2AuthorizationContext build() { Assert.notNull(this.principal, "principal cannot be null"); OAuth2AuthorizationContext context = new OAuth2AuthorizationContext(); - if (this.authorizedClient != null) { - context.clientRegistration = this.authorizedClient.getClientRegistration(); - context.authorizedClient = this.authorizedClient; - } else { - context.clientRegistration = this.clientRegistration; - } + context.clientRegistrationId = this.clientRegistrationId; context.principal = this.principal; context.attributes = Collections.unmodifiableMap( CollectionUtils.isEmpty(this.attributes) ? @@ -227,6 +176,7 @@ private static class PrincipalNameAuthentication implements Authentication { private final String principalName; private PrincipalNameAuthentication(String principalName) { + Assert.hasText(principalName, "principalName cannot be empty"); this.principalName = principalName; } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientProvider.java index 94beb09d4f4..a063fe976fb 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientProvider.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientProvider.java @@ -32,7 +32,7 @@ public interface OAuth2AuthorizedClientProvider { /** - * Attempt to authorize (or re-authorize) the {@link OAuth2AuthorizationContext#getClientRegistration() client} in the provided context. + * Attempt to authorize (or re-authorize) the {@link OAuth2AuthorizationContext#getClientRegistrationId() client} in the provided context. * Implementations must return {@code null} if authorization is not supported for the specified client, * e.g. the provider doesn't support the {@link ClientRegistration#getAuthorizationGrantType() authorization grant} type configured for the client. * diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProvider.java index d3a8d3fb9d7..cc31ec46436 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProvider.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProvider.java @@ -19,18 +19,22 @@ import org.springframework.security.oauth2.client.endpoint.DefaultRefreshTokenTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.OAuth2RefreshTokenGrantRequest; +import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; +import org.springframework.security.oauth2.core.AbstractOAuth2Token; import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.util.Assert; -import org.springframework.util.StringUtils; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; +import java.time.Duration; +import java.time.Instant; import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; import java.util.Set; -import java.util.stream.Collectors; /** * An implementation of an {@link OAuth2AuthorizedClientProvider} @@ -45,8 +49,8 @@ public final class RefreshTokenOAuth2AuthorizedClientProvider implements OAuth2A /** * The name of the {@link OAuth2AuthorizationContext#getAttribute(String) attribute} * in the {@link OAuth2AuthorizationContext context} associated to the value for the "requested scope(s)". - * The value of the attribute is a space-delimited or comma-delimited {@code String} of scope(s) - * to be requested by the {@link OAuth2AuthorizationContext#getClientRegistration() client}. + * The value of the attribute is a {@code String[]} of scope(s) to be requested + * by the {@link OAuth2AuthorizationContext#getClientRegistrationId() client}. */ public static final String REQUEST_SCOPE_ATTRIBUTE_NAME = "org.springframework.security.oauth2.client.REQUEST_SCOPE"; @@ -57,6 +61,7 @@ public final class RefreshTokenOAuth2AuthorizedClientProvider implements OAuth2A private final OAuth2AuthorizedClientRepository authorizedClientRepository; private OAuth2AccessTokenResponseClient accessTokenResponseClient = new DefaultRefreshTokenTokenResponseClient(); + private Duration clockSkew = Duration.ofSeconds(60); /** * Constructs a {@code RefreshTokenOAuth2AuthorizedClientProvider} using the provided parameters. @@ -73,17 +78,17 @@ public RefreshTokenOAuth2AuthorizedClientProvider(ClientRegistrationRepository c } /** - * Attempt to re-authorize the {@link OAuth2AuthorizationContext#getClientRegistration() client} in the provided {@code context}. + * Attempt to re-authorize the {@link OAuth2AuthorizationContext#getClientRegistrationId() client} in the provided {@code context}. * Returns {@code null} if re-authorization is not supported, - * e.g. the {@link OAuth2AuthorizedClient#getRefreshToken() refresh token} is not available for the - * {@link OAuth2AuthorizationContext#getAuthorizedClient() authorized client}. + * e.g. the client is not authorized OR the {@link OAuth2AuthorizedClient#getRefreshToken() refresh token} + * is not available for the authorized client OR the {@link OAuth2AuthorizedClient#getAccessToken() access token} is not expired. * *

* The following {@link OAuth2AuthorizationContext#getAttributes() context attributes} are supported: *

    *
  1. {@code "javax.servlet.http.HttpServletRequest"} (required) - the {@code HttpServletRequest}
  2. *
  3. {@code "javax.servlet.http.HttpServletResponse"} (required) - the {@code HttpServletResponse}
  4. - *
  5. {@code "org.springframework.security.oauth2.client.REQUEST_SCOPE"} (optional) - a space-delimited or comma-delimited {@code String} of scope(s) to be requested by the {@link OAuth2AuthorizationContext#getClientRegistration() client}
  6. + *
  7. {@code "org.springframework.security.oauth2.client.REQUEST_SCOPE"} (optional) - a {@code String[]} of scope(s) to be requested by the {@link OAuth2AuthorizationContext#getClientRegistrationId() client}
  8. *
* * @param context the context that holds authorization-specific state for the client @@ -93,42 +98,50 @@ public RefreshTokenOAuth2AuthorizedClientProvider(ClientRegistrationRepository c @Nullable public OAuth2AuthorizedClient authorize(OAuth2AuthorizationContext context) { Assert.notNull(context, "context cannot be null"); - if (!context.reauthorizationRequested() || context.getAuthorizedClient().getRefreshToken() == null) { - return null; - } HttpServletRequest request = context.getAttribute(HTTP_SERVLET_REQUEST_ATTRIBUTE_NAME); HttpServletResponse response = context.getAttribute(HTTP_SERVLET_RESPONSE_ATTRIBUTE_NAME); Assert.notNull(request, "The context attribute cannot be null '" + HTTP_SERVLET_REQUEST_ATTRIBUTE_NAME + "'"); Assert.notNull(response, "The context attribute cannot be null '" + HTTP_SERVLET_RESPONSE_ATTRIBUTE_NAME + "'"); - String requestScope = context.getAttribute(REQUEST_SCOPE_ATTRIBUTE_NAME); - Set scopes = null; - if (!StringUtils.isEmpty(requestScope)) { - String delimiter = requestScope.indexOf(',') != -1 ? "," : " "; - scopes = Arrays.stream(StringUtils.delimitedListToStringArray(requestScope, delimiter, " ")).collect(Collectors.toSet()); + String clientRegistrationId = context.getClientRegistrationId(); + ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId); + Assert.notNull(clientRegistration, "Could not find ClientRegistration with id '" + clientRegistrationId + "'"); + + OAuth2AuthorizedClient authorizedClient = this.authorizedClientRepository.loadAuthorizedClient( + clientRegistrationId, context.getPrincipal(), request); + if (authorizedClient == null || + authorizedClient.getRefreshToken() == null || + !hasTokenExpired(authorizedClient.getAccessToken())) { + return null; + } + + Object requestScope = context.getAttribute(REQUEST_SCOPE_ATTRIBUTE_NAME); + Set scopes = Collections.emptySet(); + if (requestScope != null) { + Assert.isInstanceOf(String[].class, requestScope, + "The context attribute must be of type String[] '" + REQUEST_SCOPE_ATTRIBUTE_NAME + "'"); + scopes = new HashSet<>(Arrays.asList((String[]) requestScope)); } OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = - new OAuth2RefreshTokenGrantRequest(context.getAuthorizedClient(), scopes); + new OAuth2RefreshTokenGrantRequest(authorizedClient, scopes); OAuth2AccessTokenResponse tokenResponse = this.accessTokenResponseClient.getTokenResponse(refreshTokenGrantRequest); - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( - context.getClientRegistration(), - context.getPrincipal().getName(), - tokenResponse.getAccessToken(), - tokenResponse.getRefreshToken()); + authorizedClient = new OAuth2AuthorizedClient(clientRegistration, + context.getPrincipal().getName(), tokenResponse.getAccessToken(), tokenResponse.getRefreshToken()); this.authorizedClientRepository.saveAuthorizedClient( - authorizedClient, - context.getPrincipal(), - request, - response); + authorizedClient, context.getPrincipal(), request, response); return authorizedClient; } + private boolean hasTokenExpired(AbstractOAuth2Token token) { + return token.getExpiresAt().isBefore(Instant.now().minus(this.clockSkew)); + } + /** * Sets the client used when requesting an access token credential at the Token Endpoint for the {@code refresh_token} grant. * @@ -138,4 +151,17 @@ public void setAccessTokenResponseClient(OAuth2AccessTokenResponseClient= 0, "clockSkew must be >= 0"); + this.clockSkew = clockSkew; + } } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolver.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolver.java index eb124e31bc3..946b9bd4689 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolver.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolver.java @@ -23,16 +23,17 @@ import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.oauth2.client.AuthorizationCodeOAuth2AuthorizedClientProvider; import org.springframework.security.oauth2.client.ClientCredentialsOAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.DefaultOAuth2AuthorizedClientProvider; import org.springframework.security.oauth2.client.DelegatingOAuth2AuthorizedClientProvider; import org.springframework.security.oauth2.client.OAuth2AuthorizationContext; import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.RefreshTokenOAuth2AuthorizedClientProvider; import org.springframework.security.oauth2.client.annotation.RegisteredOAuth2AuthorizedClient; import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; import org.springframework.security.oauth2.client.endpoint.DefaultClientCredentialsTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest; -import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; import org.springframework.util.Assert; @@ -109,21 +110,9 @@ public Object resolveArgument(MethodParameter parameter, Authentication principal = SecurityContextHolder.getContext().getAuthentication(); HttpServletRequest servletRequest = webRequest.getNativeRequest(HttpServletRequest.class); - - OAuth2AuthorizedClient authorizedClient = this.authorizedClientRepository.loadAuthorizedClient( - clientRegistrationId, principal, servletRequest); - if (authorizedClient != null) { - return authorizedClient; - } - - ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId); - if (clientRegistration == null) { - return null; - } - HttpServletResponse servletResponse = webRequest.getNativeResponse(HttpServletResponse.class); - OAuth2AuthorizationContext.Builder contextBuilder = OAuth2AuthorizationContext.forAuthorization(clientRegistration); + OAuth2AuthorizationContext.Builder contextBuilder = OAuth2AuthorizationContext.forClient(clientRegistrationId); if (principal != null) { contextBuilder.principal(principal); } else { @@ -184,7 +173,16 @@ private OAuth2AuthorizedClientProvider createAuthorizedClientProvider( ClientCredentialsOAuth2AuthorizedClientProvider clientCredentialsAuthorizedClientProvider = new ClientCredentialsOAuth2AuthorizedClientProvider(this.clientRegistrationRepository, this.authorizedClientRepository); clientCredentialsAuthorizedClientProvider.setAccessTokenResponseClient(clientCredentialsTokenResponseClient); - return new DelegatingOAuth2AuthorizedClientProvider( - new AuthorizationCodeOAuth2AuthorizedClientProvider(), clientCredentialsAuthorizedClientProvider); + AuthorizationCodeOAuth2AuthorizedClientProvider authorizationCodeAuthorizedClientProvider = + new AuthorizationCodeOAuth2AuthorizedClientProvider(this.clientRegistrationRepository, this.authorizedClientRepository); + RefreshTokenOAuth2AuthorizedClientProvider refreshTokenAuthorizedClientProvider = + new RefreshTokenOAuth2AuthorizedClientProvider(this.clientRegistrationRepository, this.authorizedClientRepository); + + DelegatingOAuth2AuthorizedClientProvider delegate = new DelegatingOAuth2AuthorizedClientProvider( + authorizationCodeAuthorizedClientProvider, refreshTokenAuthorizedClientProvider, clientCredentialsAuthorizedClientProvider); + delegate.setDefaultAuthorizedClientProvider( + new DefaultOAuth2AuthorizedClientProvider(this.clientRegistrationRepository, this.authorizedClientRepository)); + + return delegate; } } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java index 085980e1ea2..a402d3768f0 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java @@ -23,12 +23,14 @@ import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.oauth2.client.AuthorizationCodeOAuth2AuthorizedClientProvider; import org.springframework.security.oauth2.client.ClientCredentialsOAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.DefaultOAuth2AuthorizedClientProvider; import org.springframework.security.oauth2.client.DelegatingOAuth2AuthorizedClientProvider; import org.springframework.security.oauth2.client.OAuth2AuthorizationContext; import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProvider; import org.springframework.security.oauth2.client.RefreshTokenOAuth2AuthorizedClientProvider; import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; +import org.springframework.security.oauth2.client.endpoint.DefaultClientCredentialsTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest; import org.springframework.security.oauth2.client.registration.ClientRegistration; @@ -50,10 +52,7 @@ import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; -import java.time.Clock; import java.time.Duration; -import java.time.Instant; -import java.util.HashMap; import java.util.Map; import java.util.function.Consumer; @@ -110,8 +109,6 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction private static final String REQUEST_CONTEXT_OPERATOR_KEY = RequestContextSubscriber.class.getName(); - private Clock clock = Clock.systemUTC(); - private Duration accessTokenExpiresSkew = Duration.ofMinutes(1); private ClientRegistrationRepository clientRegistrationRepository; @@ -137,10 +134,19 @@ public ServletOAuth2AuthorizedClientExchangeFilterFunction( private static OAuth2AuthorizedClientProvider createDefaultAuthorizedClientProvider( ClientRegistrationRepository clientRegistrationRepository, OAuth2AuthorizedClientRepository authorizedClientRepository) { - return new DelegatingOAuth2AuthorizedClientProvider( - new AuthorizationCodeOAuth2AuthorizedClientProvider(), - new ClientCredentialsOAuth2AuthorizedClientProvider(clientRegistrationRepository, authorizedClientRepository), - new RefreshTokenOAuth2AuthorizedClientProvider(clientRegistrationRepository, authorizedClientRepository)); + AuthorizationCodeOAuth2AuthorizedClientProvider authorizationCodeAuthorizedClientProvider = + new AuthorizationCodeOAuth2AuthorizedClientProvider(clientRegistrationRepository, authorizedClientRepository); + RefreshTokenOAuth2AuthorizedClientProvider refreshTokenAuthorizedClientProvider = + new RefreshTokenOAuth2AuthorizedClientProvider(clientRegistrationRepository, authorizedClientRepository); + ClientCredentialsOAuth2AuthorizedClientProvider clientCredentialsAuthorizedClientProvider = + new ClientCredentialsOAuth2AuthorizedClientProvider(clientRegistrationRepository, authorizedClientRepository); + + DelegatingOAuth2AuthorizedClientProvider delegate = new DelegatingOAuth2AuthorizedClientProvider( + authorizationCodeAuthorizedClientProvider, refreshTokenAuthorizedClientProvider, clientCredentialsAuthorizedClientProvider); + delegate.setDefaultAuthorizedClientProvider( + new DefaultOAuth2AuthorizedClientProvider(clientRegistrationRepository, authorizedClientRepository)); + + return delegate; } @Override @@ -178,15 +184,27 @@ public void setClientCredentialsTokenResponseClient( this.authorizedClientProvider = createAuthorizedClientProvider(clientCredentialsTokenResponseClient); } + private OAuth2AuthorizedClientProvider createAuthorizedClientProvider() { + return createAuthorizedClientProvider(new DefaultClientCredentialsTokenResponseClient()); + } + private OAuth2AuthorizedClientProvider createAuthorizedClientProvider( OAuth2AccessTokenResponseClient clientCredentialsTokenResponseClient) { ClientCredentialsOAuth2AuthorizedClientProvider clientCredentialsAuthorizedClientProvider = new ClientCredentialsOAuth2AuthorizedClientProvider(this.clientRegistrationRepository, this.authorizedClientRepository); clientCredentialsAuthorizedClientProvider.setAccessTokenResponseClient(clientCredentialsTokenResponseClient); - return new DelegatingOAuth2AuthorizedClientProvider( - new AuthorizationCodeOAuth2AuthorizedClientProvider(), - clientCredentialsAuthorizedClientProvider, - new RefreshTokenOAuth2AuthorizedClientProvider(this.clientRegistrationRepository, this.authorizedClientRepository)); + clientCredentialsAuthorizedClientProvider.setClockSkew(this.accessTokenExpiresSkew); + AuthorizationCodeOAuth2AuthorizedClientProvider authorizationCodeAuthorizedClientProvider = + new AuthorizationCodeOAuth2AuthorizedClientProvider(this.clientRegistrationRepository, this.authorizedClientRepository); + RefreshTokenOAuth2AuthorizedClientProvider refreshTokenAuthorizedClientProvider = + new RefreshTokenOAuth2AuthorizedClientProvider(this.clientRegistrationRepository, this.authorizedClientRepository); + refreshTokenAuthorizedClientProvider.setClockSkew(this.accessTokenExpiresSkew); + + DelegatingOAuth2AuthorizedClientProvider delegate = new DelegatingOAuth2AuthorizedClientProvider( + authorizationCodeAuthorizedClientProvider, refreshTokenAuthorizedClientProvider, clientCredentialsAuthorizedClientProvider); + delegate.setDefaultAuthorizedClientProvider( + new DefaultOAuth2AuthorizedClientProvider(this.clientRegistrationRepository, this.authorizedClientRepository)); + return delegate; } /** @@ -200,7 +218,6 @@ public void setDefaultOAuth2AuthorizedClient(boolean defaultOAuth2AuthorizedClie this.defaultOAuth2AuthorizedClient = defaultOAuth2AuthorizedClient; } - /** * If set, will be used as the default {@link ClientRegistration#getRegistrationId()}. It is * recommended to be cautious with this feature since all HTTP requests will receive the access token. @@ -308,6 +325,7 @@ public static Consumer> httpServletResponse(HttpServletRespo public void setAccessTokenExpiresSkew(Duration accessTokenExpiresSkew) { Assert.notNull(accessTokenExpiresSkew, "accessTokenExpiresSkew cannot be null"); this.accessTokenExpiresSkew = accessTokenExpiresSkew; + this.authorizedClientProvider = createAuthorizedClientProvider(); } @Override @@ -316,7 +334,7 @@ public Mono filter(ClientRequest request, ExchangeFunction next) .filter(req -> req.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME).isPresent()) .switchIfEmpty(mergeRequestAttributesFromContext(request)) .filter(req -> req.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME).isPresent()) - .flatMap(req -> reauthorizeClientIfNecessary(getOAuth2AuthorizedClient(req.attributes()), req)) + .flatMap(req -> authorizedClient(getOAuth2AuthorizedClient(req.attributes()), req)) .map(authorizedClient -> bearer(request, authorizedClient)) .flatMap(next::exchange) .switchIfEmpty(next.exchange(request)); @@ -387,66 +405,35 @@ private void populateDefaultOAuth2AuthorizedClient(Map attrs) { OAuth2AuthorizedClient authorizedClient = this.authorizedClientRepository.loadAuthorizedClient( clientRegistrationId, authentication, request); if (authorizedClient == null) { - authorizedClient = authorizeClient(clientRegistrationId, attrs); + authorizedClient = this.authorizedClientProvider.authorize( + createAuthorizationContext(clientRegistrationId, attrs)); } oauth2AuthorizedClient(authorizedClient).accept(attrs); } } - private OAuth2AuthorizedClient authorizeClient(String clientRegistrationId, Map attributes) { - ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId); - if (clientRegistration == null) { - throw new IllegalArgumentException("Could not find ClientRegistration with id " + clientRegistrationId); - } - - OAuth2AuthorizationContext.Builder contextBuilder = OAuth2AuthorizationContext.forAuthorization(clientRegistration); - Authentication authentication = getAuthentication(attributes); - if (authentication != null) { - contextBuilder.principal(authentication); - } else { - contextBuilder.principal("anonymousUser"); - } - OAuth2AuthorizationContext authorizationContext = contextBuilder - .attributes(defaultContextAttributes(attributes)) - .build(); - return this.authorizedClientProvider.authorize(authorizationContext); - } - - private Mono reauthorizeClientIfNecessary( + private Mono authorizedClient( OAuth2AuthorizedClient authorizedClient, ClientRequest request) { - if (this.authorizedClientProvider == null || !hasTokenExpired(authorizedClient)) { + if (this.authorizedClientProvider == null) { return Mono.just(authorizedClient); } + OAuth2AuthorizationContext authorizationContext = createAuthorizationContext( + authorizedClient.getClientRegistration().getRegistrationId(), request.attributes()); + return Mono.fromSupplier(() -> this.authorizedClientProvider.authorize(authorizationContext)); + } - Map attributes = request.attributes(); - - OAuth2AuthorizationContext.Builder contextBuilder = OAuth2AuthorizationContext.forReauthorization(authorizedClient); + private OAuth2AuthorizationContext createAuthorizationContext(String clientRegistrationId, Map attributes) { + OAuth2AuthorizationContext.Builder contextBuilder = OAuth2AuthorizationContext.forClient(clientRegistrationId); Authentication authentication = getAuthentication(attributes); if (authentication != null) { contextBuilder.principal(authentication); } else { - contextBuilder.principal(authorizedClient.getPrincipalName()); + contextBuilder.principal("anonymousUser"); } - OAuth2AuthorizationContext authorizationContext = contextBuilder - .attributes(defaultContextAttributes(attributes)) + return contextBuilder + .attribute(HttpServletRequest.class.getName(), getRequest(attributes)) + .attribute(HttpServletResponse.class.getName(), getResponse(attributes)) .build(); - return Mono.fromSupplier(() -> this.authorizedClientProvider.authorize(authorizationContext)); - } - - private Map defaultContextAttributes(Map attributes) { - Map contextAttributes = new HashMap<>(); - contextAttributes.put(HttpServletRequest.class.getName(), getRequest(attributes)); - contextAttributes.put(HttpServletResponse.class.getName(), getResponse(attributes)); - return contextAttributes; - } - - private boolean hasTokenExpired(OAuth2AuthorizedClient authorizedClient) { - Instant now = this.clock.instant(); - Instant expiresAt = authorizedClient.getAccessToken().getExpiresAt(); - if (now.isAfter(expiresAt.minus(this.accessTokenExpiresSkew))) { - return true; - } - return false; } private ClientRequest bearer(ClientRequest request, OAuth2AuthorizedClient authorizedClient) { diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/AuthorizationCodeOAuth2AuthorizedClientProviderTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/AuthorizationCodeOAuth2AuthorizedClientProviderTests.java index a9b60eb2b2a..6998636dfdf 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/AuthorizationCodeOAuth2AuthorizedClientProviderTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/AuthorizationCodeOAuth2AuthorizedClientProviderTests.java @@ -17,14 +17,25 @@ import org.junit.Before; import org.junit.Test; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; /** * Tests for {@link AuthorizationCodeOAuth2AuthorizedClientProvider}. @@ -32,20 +43,39 @@ * @author Joe Grandja */ public class AuthorizationCodeOAuth2AuthorizedClientProviderTests { - private AuthorizationCodeOAuth2AuthorizedClientProvider authorizedClientProvider = - new AuthorizationCodeOAuth2AuthorizedClientProvider(); + private ClientRegistrationRepository clientRegistrationRepository; + private OAuth2AuthorizedClientRepository authorizedClientRepository; + private AuthorizationCodeOAuth2AuthorizedClientProvider authorizedClientProvider; private ClientRegistration clientRegistration; private OAuth2AuthorizedClient authorizedClient; private Authentication principal; @Before public void setup() { + this.clientRegistrationRepository = mock(ClientRegistrationRepository.class); + this.authorizedClientRepository = mock(OAuth2AuthorizedClientRepository.class); + this.authorizedClientProvider = new AuthorizationCodeOAuth2AuthorizedClientProvider( + this.clientRegistrationRepository, this.authorizedClientRepository); this.clientRegistration = TestClientRegistrations.clientRegistration().build(); this.authorizedClient = new OAuth2AuthorizedClient( this.clientRegistration, "principal", TestOAuth2AccessTokens.scopes("read", "write")); this.principal = new TestingAuthenticationToken("principal", "password"); } + @Test + public void constructorWhenClientRegistrationRepositoryIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new AuthorizationCodeOAuth2AuthorizedClientProvider(null, this.authorizedClientRepository)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("clientRegistrationRepository cannot be null"); + } + + @Test + public void constructorWhenOAuth2AuthorizedClientRepositoryIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new AuthorizationCodeOAuth2AuthorizedClientProvider(this.clientRegistrationRepository, null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("authorizedClientRepository cannot be null"); + } + @Test public void authorizeWhenContextIsNullThenThrowIllegalArgumentException() { assertThatThrownBy(() -> this.authorizedClientProvider.authorize(null)) @@ -53,25 +83,85 @@ public void authorizeWhenContextIsNullThenThrowIllegalArgumentException() { } @Test - public void authorizeWhenAuthorizationCodeAndNotAuthorizedThenAuthorize() { + public void authorizeWhenHttpServletRequestIsNullThenThrowIllegalArgumentException() { OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.forAuthorization(this.clientRegistration).principal(this.principal).build(); + OAuth2AuthorizationContext.forClient(this.clientRegistration.getRegistrationId()) + .principal(this.principal) + .build(); assertThatThrownBy(() -> this.authorizedClientProvider.authorize(authorizationContext)) - .isInstanceOf(ClientAuthorizationRequiredException.class); + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("The context attribute cannot be null 'javax.servlet.http.HttpServletRequest'"); } @Test - public void authorizeWhenAuthorizationCodeAndAuthorizedThenUnableToAuthorize() { + public void authorizeWhenHttpServletResponseIsNullThenThrowIllegalArgumentException() { OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.forReauthorization(this.authorizedClient).principal(this.principal).build(); - assertThat(this.authorizedClientProvider.authorize(authorizationContext)).isNull(); + OAuth2AuthorizationContext.forClient(this.clientRegistration.getRegistrationId()) + .principal(this.principal) + .attribute(HttpServletRequest.class.getName(), new MockHttpServletRequest()) + .build(); + assertThatThrownBy(() -> this.authorizedClientProvider.authorize(authorizationContext)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("The context attribute cannot be null 'javax.servlet.http.HttpServletResponse'"); + } + + @Test + public void authorizeWhenClientRegistrationNotFoundThenThrowIllegalArgumentException() { + OAuth2AuthorizationContext authorizationContext = + OAuth2AuthorizationContext.forClient(this.clientRegistration.getRegistrationId()) + .principal(this.principal) + .attribute(HttpServletRequest.class.getName(), new MockHttpServletRequest()) + .attribute(HttpServletResponse.class.getName(), new MockHttpServletResponse()) + .build(); + assertThatThrownBy(() -> this.authorizedClientProvider.authorize(authorizationContext)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Could not find ClientRegistration with id '" + this.clientRegistration.getRegistrationId() + "'"); } @Test public void authorizeWhenNotAuthorizationCodeThenUnableToAuthorize() { ClientRegistration clientCredentialsClient = TestClientRegistrations.clientCredentials().build(); + when(this.clientRegistrationRepository.findByRegistrationId( + eq(clientCredentialsClient.getRegistrationId()))).thenReturn(clientCredentialsClient); + + OAuth2AuthorizationContext authorizationContext = + OAuth2AuthorizationContext.forClient(clientCredentialsClient.getRegistrationId()) + .principal(this.principal) + .attribute(HttpServletRequest.class.getName(), new MockHttpServletRequest()) + .attribute(HttpServletResponse.class.getName(), new MockHttpServletResponse()) + .build(); + assertThat(this.authorizedClientProvider.authorize(authorizationContext)).isNull(); + } + + @Test + public void authorizeWhenAuthorizationCodeAndAuthorizedThenNotAuthorize() { + when(this.clientRegistrationRepository.findByRegistrationId( + eq(this.clientRegistration.getRegistrationId()))).thenReturn(this.clientRegistration); + + when(this.authorizedClientRepository.loadAuthorizedClient(eq(this.clientRegistration.getRegistrationId()), + eq(this.principal), any(HttpServletRequest.class))).thenReturn(this.authorizedClient); + OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.forAuthorization(clientCredentialsClient).principal(this.principal).build(); + OAuth2AuthorizationContext.forClient(this.clientRegistration.getRegistrationId()) + .principal(this.principal) + .attribute(HttpServletRequest.class.getName(), new MockHttpServletRequest()) + .attribute(HttpServletResponse.class.getName(), new MockHttpServletResponse()) + .build(); assertThat(this.authorizedClientProvider.authorize(authorizationContext)).isNull(); } + + @Test + public void authorizeWhenAuthorizationCodeAndNotAuthorizedThenAuthorize() { + when(this.clientRegistrationRepository.findByRegistrationId( + eq(this.clientRegistration.getRegistrationId()))).thenReturn(this.clientRegistration); + + OAuth2AuthorizationContext authorizationContext = + OAuth2AuthorizationContext.forClient(this.clientRegistration.getRegistrationId()) + .principal(this.principal) + .attribute(HttpServletRequest.class.getName(), new MockHttpServletRequest()) + .attribute(HttpServletResponse.class.getName(), new MockHttpServletResponse()) + .build(); + assertThatThrownBy(() -> this.authorizedClientProvider.authorize(authorizationContext)) + .isInstanceOf(ClientAuthorizationRequiredException.class); + } } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/ClientCredentialsOAuth2AuthorizedClientProviderTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/ClientCredentialsOAuth2AuthorizedClientProviderTests.java index e0893a947fd..a2a94538102 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/ClientCredentialsOAuth2AuthorizedClientProviderTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/ClientCredentialsOAuth2AuthorizedClientProviderTests.java @@ -27,15 +27,18 @@ import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; +import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; +import java.time.Duration; +import java.time.Instant; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; +import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.*; @@ -86,24 +89,32 @@ public void setAccessTokenResponseClientWhenClientIsNullThenThrowIllegalArgument } @Test - public void authorizeWhenContextIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.authorizedClientProvider.authorize(null)) + public void setClockSkewWhenNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.authorizedClientProvider.setClockSkew(null)) .isInstanceOf(IllegalArgumentException.class) - .hasMessage("context cannot be null"); + .hasMessage("clockSkew cannot be null"); } @Test - public void authorizeWhenNotClientCredentialsThenUnableToAuthorize() { - ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); - OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.forAuthorization(clientRegistration).principal(this.principal).build(); - assertThat(this.authorizedClientProvider.authorize(authorizationContext)).isNull(); + public void setClockSkewWhenNegativeSecondsThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.authorizedClientProvider.setClockSkew(Duration.ofSeconds(-1))) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("clockSkew must be >= 0"); + } + + @Test + public void authorizeWhenContextIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.authorizedClientProvider.authorize(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("context cannot be null"); } @Test public void authorizeWhenHttpServletRequestIsNullThenThrowIllegalArgumentException() { OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.forAuthorization(this.clientRegistration).principal(this.principal).build(); + OAuth2AuthorizationContext.forClient(this.clientRegistration.getRegistrationId()) + .principal(this.principal) + .build(); assertThatThrownBy(() -> this.authorizedClientProvider.authorize(authorizationContext)) .isInstanceOf(IllegalArgumentException.class) .hasMessage("The context attribute cannot be null 'javax.servlet.http.HttpServletRequest'"); @@ -112,7 +123,7 @@ public void authorizeWhenHttpServletRequestIsNullThenThrowIllegalArgumentExcepti @Test public void authorizeWhenHttpServletResponseIsNullThenThrowIllegalArgumentException() { OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.forAuthorization(this.clientRegistration) + OAuth2AuthorizationContext.forClient(this.clientRegistration.getRegistrationId()) .principal(this.principal) .attribute(HttpServletRequest.class.getName(), new MockHttpServletRequest()) .build(); @@ -121,13 +132,44 @@ public void authorizeWhenHttpServletResponseIsNullThenThrowIllegalArgumentExcept .hasMessage("The context attribute cannot be null 'javax.servlet.http.HttpServletResponse'"); } + @Test + public void authorizeWhenClientRegistrationNotFoundThenThrowIllegalArgumentException() { + OAuth2AuthorizationContext authorizationContext = + OAuth2AuthorizationContext.forClient(this.clientRegistration.getRegistrationId()) + .principal(this.principal) + .attribute(HttpServletRequest.class.getName(), new MockHttpServletRequest()) + .attribute(HttpServletResponse.class.getName(), new MockHttpServletResponse()) + .build(); + assertThatThrownBy(() -> this.authorizedClientProvider.authorize(authorizationContext)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Could not find ClientRegistration with id '" + this.clientRegistration.getRegistrationId() + "'"); + } + + @Test + public void authorizeWhenNotClientCredentialsThenUnableToAuthorize() { + ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); + when(this.clientRegistrationRepository.findByRegistrationId( + eq(clientRegistration.getRegistrationId()))).thenReturn(clientRegistration); + + OAuth2AuthorizationContext authorizationContext = + OAuth2AuthorizationContext.forClient(clientRegistration.getRegistrationId()) + .principal(this.principal) + .attribute(HttpServletRequest.class.getName(), new MockHttpServletRequest()) + .attribute(HttpServletResponse.class.getName(), new MockHttpServletResponse()) + .build(); + assertThat(this.authorizedClientProvider.authorize(authorizationContext)).isNull(); + } + @Test public void authorizeWhenClientCredentialsAndNotAuthorizedThenAuthorize() { + when(this.clientRegistrationRepository.findByRegistrationId( + eq(this.clientRegistration.getRegistrationId()))).thenReturn(this.clientRegistration); + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse); OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.forAuthorization(this.clientRegistration) + OAuth2AuthorizationContext.forClient(this.clientRegistration.getRegistrationId()) .principal(this.principal) .attribute(HttpServletRequest.class.getName(), new MockHttpServletRequest()) .attribute(HttpServletResponse.class.getName(), new MockHttpServletResponse()) @@ -144,15 +186,24 @@ public void authorizeWhenClientCredentialsAndNotAuthorizedThenAuthorize() { } @Test - public void authorizeWhenClientCredentialsAndAuthorizedThenReauthorize() { + public void authorizeWhenClientCredentialsAndTokenExpiredThenReauthorize() { + when(this.clientRegistrationRepository.findByRegistrationId( + eq(this.clientRegistration.getRegistrationId()))).thenReturn(this.clientRegistration); + + Instant issuedAt = Instant.now().minus(Duration.ofDays(1)); + Instant expiresAt = issuedAt.plus(Duration.ofMinutes(60)); + OAuth2AccessToken accessToken = new OAuth2AccessToken( + OAuth2AccessToken.TokenType.BEARER, "access-token-1234", issuedAt, expiresAt); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( + this.clientRegistration, this.principal.getName(), accessToken); + when(this.authorizedClientRepository.loadAuthorizedClient(eq(this.clientRegistration.getRegistrationId()), + eq(this.principal), any(HttpServletRequest.class))).thenReturn(authorizedClient); + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse); - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( - this.clientRegistration, this.principal.getName(), TestOAuth2AccessTokens.noScopes()); - OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.forReauthorization(authorizedClient) + OAuth2AuthorizationContext.forClient(this.clientRegistration.getRegistrationId()) .principal(this.principal) .attribute(HttpServletRequest.class.getName(), new MockHttpServletRequest()) .attribute(HttpServletResponse.class.getName(), new MockHttpServletResponse()) @@ -167,4 +218,24 @@ public void authorizeWhenClientCredentialsAndAuthorizedThenReauthorize() { eq(authorizedClient), eq(this.principal), any(HttpServletRequest.class), any(HttpServletResponse.class)); } + + @Test + public void authorizeWhenClientCredentialsAndTokenNotExpiredThenNotReauthorize() { + when(this.clientRegistrationRepository.findByRegistrationId( + eq(this.clientRegistration.getRegistrationId()))).thenReturn(this.clientRegistration); + + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( + this.clientRegistration, this.principal.getName(), TestOAuth2AccessTokens.noScopes()); + when(this.authorizedClientRepository.loadAuthorizedClient(eq(this.clientRegistration.getRegistrationId()), + eq(this.principal), any(HttpServletRequest.class))).thenReturn(authorizedClient); + + OAuth2AuthorizationContext authorizationContext = + OAuth2AuthorizationContext.forClient(this.clientRegistration.getRegistrationId()) + .principal(this.principal) + .attribute(HttpServletRequest.class.getName(), new MockHttpServletRequest()) + .attribute(HttpServletResponse.class.getName(), new MockHttpServletResponse()) + .build(); + + assertThat(this.authorizedClientProvider.authorize(authorizationContext)).isNull(); + } } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/DefaultOAuth2AuthorizedClientProviderTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/DefaultOAuth2AuthorizedClientProviderTests.java new file mode 100644 index 00000000000..ff17df67d00 --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/DefaultOAuth2AuthorizedClientProviderTests.java @@ -0,0 +1,137 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client; + +import org.junit.Before; +import org.junit.Test; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; +import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; + +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +/** + * Tests for {@link DefaultOAuth2AuthorizedClientProvider}. + * + * @author Joe Grandja + */ +public class DefaultOAuth2AuthorizedClientProviderTests { + private ClientRegistrationRepository clientRegistrationRepository; + private OAuth2AuthorizedClientRepository authorizedClientRepository; + private DefaultOAuth2AuthorizedClientProvider authorizedClientProvider; + private ClientRegistration clientRegistration; + private OAuth2AuthorizedClient authorizedClient; + private Authentication principal; + + @Before + public void setup() { + this.clientRegistrationRepository = mock(ClientRegistrationRepository.class); + this.authorizedClientRepository = mock(OAuth2AuthorizedClientRepository.class); + this.authorizedClientProvider = new DefaultOAuth2AuthorizedClientProvider( + this.clientRegistrationRepository, this.authorizedClientRepository); + this.clientRegistration = TestClientRegistrations.clientRegistration().build(); + this.authorizedClient = new OAuth2AuthorizedClient( + this.clientRegistration, "principal", TestOAuth2AccessTokens.scopes("read", "write")); + this.principal = new TestingAuthenticationToken("principal", "password"); + } + + @Test + public void constructorWhenClientRegistrationRepositoryIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new DefaultOAuth2AuthorizedClientProvider(null, this.authorizedClientRepository)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("clientRegistrationRepository cannot be null"); + } + + @Test + public void constructorWhenOAuth2AuthorizedClientRepositoryIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new DefaultOAuth2AuthorizedClientProvider(this.clientRegistrationRepository, null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("authorizedClientRepository cannot be null"); + } + + @Test + public void authorizeWhenContextIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.authorizedClientProvider.authorize(null)) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void authorizeWhenHttpServletRequestIsNullThenThrowIllegalArgumentException() { + OAuth2AuthorizationContext authorizationContext = + OAuth2AuthorizationContext.forClient(this.clientRegistration.getRegistrationId()) + .principal(this.principal) + .build(); + assertThatThrownBy(() -> this.authorizedClientProvider.authorize(authorizationContext)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("The context attribute cannot be null 'javax.servlet.http.HttpServletRequest'"); + } + + @Test + public void authorizeWhenHttpServletResponseIsNullThenThrowIllegalArgumentException() { + OAuth2AuthorizationContext authorizationContext = + OAuth2AuthorizationContext.forClient(this.clientRegistration.getRegistrationId()) + .principal(this.principal) + .attribute(HttpServletRequest.class.getName(), new MockHttpServletRequest()) + .build(); + assertThatThrownBy(() -> this.authorizedClientProvider.authorize(authorizationContext)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("The context attribute cannot be null 'javax.servlet.http.HttpServletResponse'"); + } + + @Test + public void authorizeWhenClientRegistrationNotFoundThenThrowIllegalArgumentException() { + OAuth2AuthorizationContext authorizationContext = + OAuth2AuthorizationContext.forClient(this.clientRegistration.getRegistrationId()) + .principal(this.principal) + .attribute(HttpServletRequest.class.getName(), new MockHttpServletRequest()) + .attribute(HttpServletResponse.class.getName(), new MockHttpServletResponse()) + .build(); + assertThatThrownBy(() -> this.authorizedClientProvider.authorize(authorizationContext)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Could not find ClientRegistration with id '" + this.clientRegistration.getRegistrationId() + "'"); + } + + @Test + public void authorizeWhenAuthorizedThenReturnAuthorizedClient() { + when(this.clientRegistrationRepository.findByRegistrationId( + eq(this.clientRegistration.getRegistrationId()))).thenReturn(this.clientRegistration); + + when(this.authorizedClientRepository.loadAuthorizedClient(eq(this.clientRegistration.getRegistrationId()), + eq(this.principal), any(HttpServletRequest.class))).thenReturn(this.authorizedClient); + + OAuth2AuthorizationContext authorizationContext = + OAuth2AuthorizationContext.forClient(this.clientRegistration.getRegistrationId()) + .principal(this.principal) + .attribute(HttpServletRequest.class.getName(), new MockHttpServletRequest()) + .attribute(HttpServletResponse.class.getName(), new MockHttpServletResponse()) + .build(); + assertThat(this.authorizedClientProvider.authorize(authorizationContext)).isSameAs(this.authorizedClient); + } +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/DelegatingOAuth2AuthorizedClientProviderTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/DelegatingOAuth2AuthorizedClientProviderTests.java index efd2c51ce6a..ae1253c286a 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/DelegatingOAuth2AuthorizedClientProviderTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/DelegatingOAuth2AuthorizedClientProviderTests.java @@ -18,6 +18,7 @@ import org.junit.Test; import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; @@ -26,8 +27,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; +import static org.mockito.Mockito.*; /** * Tests for {@link DelegatingOAuth2AuthorizedClientProvider}. @@ -44,6 +44,15 @@ public void constructorWhenProvidersIsEmptyThenThrowIllegalArgumentException() { .isInstanceOf(IllegalArgumentException.class); } + @Test + public void setDefaultAuthorizedClientProviderWhenNullThenThrowIllegalArgumentException() { + DelegatingOAuth2AuthorizedClientProvider delegate = new DelegatingOAuth2AuthorizedClientProvider( + mock(OAuth2AuthorizedClientProvider.class)); + assertThatThrownBy(() -> delegate.setDefaultAuthorizedClientProvider(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("authorizedClientProvider cannot be null"); + } + @Test public void authorizeWhenContextIsNullThenThrowIllegalArgumentException() { DelegatingOAuth2AuthorizedClientProvider delegate = new DelegatingOAuth2AuthorizedClientProvider( @@ -56,23 +65,26 @@ public void authorizeWhenContextIsNullThenThrowIllegalArgumentException() { @Test public void authorizeWhenProviderCanAuthorizeThenReturnAuthorizedClient() { Authentication principal = new TestingAuthenticationToken("principal", "password"); + ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( - TestClientRegistrations.clientRegistration().build(), principal.getName(), TestOAuth2AccessTokens.noScopes()); + clientRegistration, principal.getName(), TestOAuth2AccessTokens.noScopes()); OAuth2AuthorizedClientProvider authorizedClientProvider = mock(OAuth2AuthorizedClientProvider.class); when(authorizedClientProvider.authorize(any())).thenReturn(authorizedClient); DelegatingOAuth2AuthorizedClientProvider delegate = new DelegatingOAuth2AuthorizedClientProvider( mock(OAuth2AuthorizedClientProvider.class), mock(OAuth2AuthorizedClientProvider.class), authorizedClientProvider); - OAuth2AuthorizationContext context = OAuth2AuthorizationContext.forReauthorization(authorizedClient).principal(principal).build(); + OAuth2AuthorizationContext context = OAuth2AuthorizationContext.forClient(clientRegistration.getRegistrationId()) + .principal(principal) + .build(); OAuth2AuthorizedClient reauthorizedClient = delegate.authorize(context); assertThat(reauthorizedClient).isSameAs(authorizedClient); } @Test public void authorizeWhenProviderCantAuthorizeThenReturnNull() { - OAuth2AuthorizationContext context = OAuth2AuthorizationContext - .forAuthorization(TestClientRegistrations.clientRegistration().build()) + ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); + OAuth2AuthorizationContext context = OAuth2AuthorizationContext.forClient(clientRegistration.getRegistrationId()) .principal(new TestingAuthenticationToken("principal", "password")) .build(); @@ -80,4 +92,19 @@ public void authorizeWhenProviderCantAuthorizeThenReturnNull() { mock(OAuth2AuthorizedClientProvider.class), mock(OAuth2AuthorizedClientProvider.class)); assertThat(delegate.authorize(context)).isNull(); } + + @Test + public void authorizeWhenProviderCantAuthorizeThenDefaultCalled() { + ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); + OAuth2AuthorizationContext context = OAuth2AuthorizationContext.forClient(clientRegistration.getRegistrationId()) + .principal(new TestingAuthenticationToken("principal", "password")) + .build(); + + DelegatingOAuth2AuthorizedClientProvider delegate = new DelegatingOAuth2AuthorizedClientProvider( + mock(OAuth2AuthorizedClientProvider.class), mock(OAuth2AuthorizedClientProvider.class)); + OAuth2AuthorizedClientProvider defaultAuthorizedClientProvider = mock(OAuth2AuthorizedClientProvider.class); + delegate.setDefaultAuthorizedClientProvider(defaultAuthorizedClientProvider); + delegate.authorize(context); + verify(defaultAuthorizedClientProvider).authorize(eq(context)); + } } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/OAuth2AuthorizationContextTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/OAuth2AuthorizationContextTests.java index 2dc50c59c61..42ff33a2cb3 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/OAuth2AuthorizationContextTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/OAuth2AuthorizationContextTests.java @@ -44,60 +44,38 @@ public void setup() { } @Test - public void authorizeWhenClientRegistrationIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> OAuth2AuthorizationContext.forAuthorization(null).build()) + public void forClientWhenClientRegistrationIdIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> OAuth2AuthorizationContext.forClient(null).build()) .isInstanceOf(IllegalArgumentException.class) - .hasMessage("clientRegistration cannot be null"); + .hasMessage("clientRegistrationId cannot be empty"); } @Test - public void authorizeWhenPrincipalIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> OAuth2AuthorizationContext.forAuthorization(this.clientRegistration).build()) + public void forClientWhenPrincipalIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> OAuth2AuthorizationContext.forClient(this.clientRegistration.getRegistrationId()).build()) .isInstanceOf(IllegalArgumentException.class) .hasMessage("principal cannot be null"); } @Test - public void authorizeWhenAllValuesProvidedThenAllValuesAreSet() { - OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext.forAuthorization(this.clientRegistration) - .principal(this.principal) - .attribute("attribute1", "value1") - .attribute("attribute2", "value2") - .build(); - assertThat(authorizationContext.getClientRegistration()).isSameAs(this.clientRegistration); - assertThat(authorizationContext.getPrincipal()).isSameAs(this.principal); - assertThat(authorizationContext.getAuthorizedClient()).isNull(); - assertThat(authorizationContext.getAttributes()).contains( - entry("attribute1", "value1"), entry("attribute2", "value2")); - assertThat(authorizationContext.authorizationRequested()).isTrue(); - } - - @Test - public void reauthorizeWhenAuthorizedClientIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> OAuth2AuthorizationContext.forReauthorization(null).build()) + public void forClientWhenPrincipalNameIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> OAuth2AuthorizationContext.forClient(this.clientRegistration.getRegistrationId()) + .principal((String) null) + .build()) .isInstanceOf(IllegalArgumentException.class) - .hasMessage("authorizedClient cannot be null"); - } - - @Test - public void reauthorizeWhenPrincipalIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> OAuth2AuthorizationContext.forReauthorization(this.authorizedClient).build()) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("principal cannot be null"); + .hasMessage("principalName cannot be empty"); } @Test - public void reauthorizeWhenAllValuesProvidedThenAllValuesAreSet() { - OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext.forReauthorization(this.authorizedClient) + public void forClientWhenAllValuesProvidedThenAllValuesAreSet() { + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext.forClient(this.clientRegistration.getRegistrationId()) .principal(this.principal) .attribute("attribute1", "value1") .attribute("attribute2", "value2") .build(); - assertThat(authorizationContext.getAuthorizedClient()).isSameAs(this.authorizedClient); - assertThat(authorizationContext.getClientRegistration()).isSameAs(this.authorizedClient.getClientRegistration()); + assertThat(authorizationContext.getClientRegistrationId()).isSameAs(this.clientRegistration.getRegistrationId()); assertThat(authorizationContext.getPrincipal()).isSameAs(this.principal); assertThat(authorizationContext.getAttributes()).contains( entry("attribute1", "value1"), entry("attribute2", "value2")); - assertThat(authorizationContext.reauthorizationRequested()).isTrue(); } } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProviderTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProviderTests.java index a361e3cc7f6..a183d84239b 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProviderTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProviderTests.java @@ -28,6 +28,7 @@ import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; +import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; @@ -35,9 +36,10 @@ import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; +import java.time.Duration; +import java.time.Instant; import java.util.Arrays; import java.util.HashSet; -import java.util.Set; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -69,8 +71,12 @@ public void setup() { this.authorizedClientProvider.setAccessTokenResponseClient(this.accessTokenResponseClient); this.clientRegistration = TestClientRegistrations.clientRegistration().build(); this.principal = new TestingAuthenticationToken("principal", "password"); + Instant issuedAt = Instant.now().minus(Duration.ofDays(1)); + Instant expiresAt = issuedAt.plus(Duration.ofMinutes(60)); + OAuth2AccessToken expiredAccessToken = new OAuth2AccessToken( + OAuth2AccessToken.TokenType.BEARER, "access-token-1234", issuedAt, expiresAt); this.authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, this.principal.getName(), - TestOAuth2AccessTokens.scopes("read", "write"), TestOAuth2RefreshTokens.refreshToken()); + expiredAccessToken, TestOAuth2RefreshTokens.refreshToken()); } @Test @@ -95,32 +101,32 @@ public void setAccessTokenResponseClientWhenClientIsNullThenThrowIllegalArgument } @Test - public void authorizeWhenContextIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.authorizedClientProvider.authorize(null)) + public void setClockSkewWhenNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.authorizedClientProvider.setClockSkew(null)) .isInstanceOf(IllegalArgumentException.class) - .hasMessage("context cannot be null"); + .hasMessage("clockSkew cannot be null"); } @Test - public void authorizeWhenNotAuthorizedThenUnableToReauthorize() { - OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.forAuthorization(this.clientRegistration).principal(this.principal).build(); - assertThat(this.authorizedClientProvider.authorize(authorizationContext)).isNull(); + public void setClockSkewWhenNegativeSecondsThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.authorizedClientProvider.setClockSkew(Duration.ofSeconds(-1))) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("clockSkew must be >= 0"); } @Test - public void authorizeWhenAuthorizedAndRefreshTokenIsNullThenUnableToReauthorize() { - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( - this.clientRegistration, this.principal.getName(), this.authorizedClient.getAccessToken()); - OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.forReauthorization(authorizedClient).principal(this.principal).build(); - assertThat(this.authorizedClientProvider.authorize(authorizationContext)).isNull(); + public void authorizeWhenContextIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.authorizedClientProvider.authorize(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("context cannot be null"); } @Test public void authorizeWhenHttpServletRequestIsNullThenThrowIllegalArgumentException() { OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.forReauthorization(this.authorizedClient).principal(this.principal).build(); + OAuth2AuthorizationContext.forClient(this.clientRegistration.getRegistrationId()) + .principal(this.principal) + .build(); assertThatThrownBy(() -> this.authorizedClientProvider.authorize(authorizationContext)) .isInstanceOf(IllegalArgumentException.class) .hasMessage("The context attribute cannot be null 'javax.servlet.http.HttpServletRequest'"); @@ -129,7 +135,7 @@ public void authorizeWhenHttpServletRequestIsNullThenThrowIllegalArgumentExcepti @Test public void authorizeWhenHttpServletResponseIsNullThenThrowIllegalArgumentException() { OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.forReauthorization(this.authorizedClient) + OAuth2AuthorizationContext.forClient(this.clientRegistration.getRegistrationId()) .principal(this.principal) .attribute(HttpServletRequest.class.getName(), new MockHttpServletRequest()) .build(); @@ -139,14 +145,85 @@ public void authorizeWhenHttpServletResponseIsNullThenThrowIllegalArgumentExcept } @Test - public void authorizeWhenAuthorizedWithRefreshTokenThenReauthorize() { + public void authorizeWhenClientRegistrationNotFoundThenThrowIllegalArgumentException() { + OAuth2AuthorizationContext authorizationContext = + OAuth2AuthorizationContext.forClient(this.clientRegistration.getRegistrationId()) + .principal(this.principal) + .attribute(HttpServletRequest.class.getName(), new MockHttpServletRequest()) + .attribute(HttpServletResponse.class.getName(), new MockHttpServletResponse()) + .build(); + assertThatThrownBy(() -> this.authorizedClientProvider.authorize(authorizationContext)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Could not find ClientRegistration with id '" + this.clientRegistration.getRegistrationId() + "'"); + } + + @Test + public void authorizeWhenNotAuthorizedThenUnableToReauthorize() { + when(this.clientRegistrationRepository.findByRegistrationId( + eq(this.clientRegistration.getRegistrationId()))).thenReturn(this.clientRegistration); + + OAuth2AuthorizationContext authorizationContext = + OAuth2AuthorizationContext.forClient(this.clientRegistration.getRegistrationId()) + .principal(this.principal) + .attribute(HttpServletRequest.class.getName(), new MockHttpServletRequest()) + .attribute(HttpServletResponse.class.getName(), new MockHttpServletResponse()) + .build(); + assertThat(this.authorizedClientProvider.authorize(authorizationContext)).isNull(); + } + + @Test + public void authorizeWhenAuthorizedAndRefreshTokenIsNullThenUnableToReauthorize() { + when(this.clientRegistrationRepository.findByRegistrationId( + eq(this.clientRegistration.getRegistrationId()))).thenReturn(this.clientRegistration); + + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( + this.clientRegistration, this.principal.getName(), this.authorizedClient.getAccessToken()); + when(this.authorizedClientRepository.loadAuthorizedClient(eq(this.clientRegistration.getRegistrationId()), + eq(this.principal), any(HttpServletRequest.class))).thenReturn(authorizedClient); + + OAuth2AuthorizationContext authorizationContext = + OAuth2AuthorizationContext.forClient(this.clientRegistration.getRegistrationId()) + .principal(this.principal) + .attribute(HttpServletRequest.class.getName(), new MockHttpServletRequest()) + .attribute(HttpServletResponse.class.getName(), new MockHttpServletResponse()) + .build(); + assertThat(this.authorizedClientProvider.authorize(authorizationContext)).isNull(); + } + + @Test + public void authorizeWhenAuthorizedAndAccessTokenNotExpiredThenNotReauthorize() { + when(this.clientRegistrationRepository.findByRegistrationId( + eq(this.clientRegistration.getRegistrationId()))).thenReturn(this.clientRegistration); + + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, + this.principal.getName(), TestOAuth2AccessTokens.noScopes(), this.authorizedClient.getRefreshToken()); + when(this.authorizedClientRepository.loadAuthorizedClient(eq(this.clientRegistration.getRegistrationId()), + eq(this.principal), any(HttpServletRequest.class))).thenReturn(authorizedClient); + + OAuth2AuthorizationContext authorizationContext = + OAuth2AuthorizationContext.forClient(this.clientRegistration.getRegistrationId()) + .principal(this.principal) + .attribute(HttpServletRequest.class.getName(), new MockHttpServletRequest()) + .attribute(HttpServletResponse.class.getName(), new MockHttpServletResponse()) + .build(); + assertThat(this.authorizedClientProvider.authorize(authorizationContext)).isNull(); + } + + @Test + public void authorizeWhenAuthorizedAndAccessTokenExpiredThenReauthorize() { + when(this.clientRegistrationRepository.findByRegistrationId( + eq(this.clientRegistration.getRegistrationId()))).thenReturn(this.clientRegistration); + + when(this.authorizedClientRepository.loadAuthorizedClient(eq(this.clientRegistration.getRegistrationId()), + eq(this.principal), any(HttpServletRequest.class))).thenReturn(this.authorizedClient); + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse() .refreshToken("new-refresh-token") .build(); when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse); OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.forReauthorization(this.authorizedClient) + OAuth2AuthorizationContext.forClient(this.clientRegistration.getRegistrationId()) .principal(this.principal) .attribute(HttpServletRequest.class.getName(), new MockHttpServletRequest()) .attribute(HttpServletResponse.class.getName(), new MockHttpServletResponse()) @@ -164,21 +241,25 @@ public void authorizeWhenAuthorizedWithRefreshTokenThenReauthorize() { } @Test - public void authorizeWhenAuthorizedAndSpaceDelimitedScopeProvidedThenScopeRequested() { + public void authorizeWhenAuthorizedAndRequestScopeProvidedThenScopeRequested() { + when(this.clientRegistrationRepository.findByRegistrationId( + eq(this.clientRegistration.getRegistrationId()))).thenReturn(this.clientRegistration); + + when(this.authorizedClientRepository.loadAuthorizedClient(eq(this.clientRegistration.getRegistrationId()), + eq(this.principal), any(HttpServletRequest.class))).thenReturn(this.authorizedClient); + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse() .refreshToken("new-refresh-token") .build(); when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse); - String scope = "read write"; - Set scopes = new HashSet<>(Arrays.asList("read", "write")); - + String[] requestScope = new String[] { "read", "write" }; OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.forReauthorization(this.authorizedClient) + OAuth2AuthorizationContext.forClient(this.clientRegistration.getRegistrationId()) .principal(this.principal) .attribute(HttpServletRequest.class.getName(), new MockHttpServletRequest()) .attribute(HttpServletResponse.class.getName(), new MockHttpServletResponse()) - .attribute(RefreshTokenOAuth2AuthorizedClientProvider.REQUEST_SCOPE_ATTRIBUTE_NAME, scope) + .attribute(RefreshTokenOAuth2AuthorizedClientProvider.REQUEST_SCOPE_ATTRIBUTE_NAME, requestScope) .build(); this.authorizedClientProvider.authorize(authorizationContext); @@ -186,32 +267,29 @@ public void authorizeWhenAuthorizedAndSpaceDelimitedScopeProvidedThenScopeReques ArgumentCaptor refreshTokenGrantRequestArgCaptor = ArgumentCaptor.forClass(OAuth2RefreshTokenGrantRequest.class); verify(this.accessTokenResponseClient).getTokenResponse(refreshTokenGrantRequestArgCaptor.capture()); - assertThat(refreshTokenGrantRequestArgCaptor.getValue().getScopes()).isEqualTo(scopes); + assertThat(refreshTokenGrantRequestArgCaptor.getValue().getScopes()).isEqualTo(new HashSet<>(Arrays.asList(requestScope))); } @Test - public void authorizeWhenAuthorizedAndCommaDelimitedScopeProvidedThenScopeRequested() { - OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse() - .refreshToken("new-refresh-token") - .build(); - when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse); + public void authorizeWhenAuthorizedAndInvalidRequestScopeProvidedThenThrowIllegalArgumentException() { + when(this.clientRegistrationRepository.findByRegistrationId( + eq(this.clientRegistration.getRegistrationId()))).thenReturn(this.clientRegistration); - String scope = "read, write"; - Set scopes = new HashSet<>(Arrays.asList("read", "write")); + when(this.authorizedClientRepository.loadAuthorizedClient(eq(this.clientRegistration.getRegistrationId()), + eq(this.principal), any(HttpServletRequest.class))).thenReturn(this.authorizedClient); + String invalidRequestScope = "read write"; OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.forReauthorization(this.authorizedClient) + OAuth2AuthorizationContext.forClient(this.clientRegistration.getRegistrationId()) .principal(this.principal) .attribute(HttpServletRequest.class.getName(), new MockHttpServletRequest()) .attribute(HttpServletResponse.class.getName(), new MockHttpServletResponse()) - .attribute(RefreshTokenOAuth2AuthorizedClientProvider.REQUEST_SCOPE_ATTRIBUTE_NAME, scope) + .attribute(RefreshTokenOAuth2AuthorizedClientProvider.REQUEST_SCOPE_ATTRIBUTE_NAME, invalidRequestScope) .build(); - this.authorizedClientProvider.authorize(authorizationContext); - - ArgumentCaptor refreshTokenGrantRequestArgCaptor = - ArgumentCaptor.forClass(OAuth2RefreshTokenGrantRequest.class); - verify(this.accessTokenResponseClient).getTokenResponse(refreshTokenGrantRequestArgCaptor.capture()); - assertThat(refreshTokenGrantRequestArgCaptor.getValue().getScopes()).isEqualTo(scopes); + assertThatThrownBy(() -> this.authorizedClientProvider.authorize(authorizationContext)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageStartingWith("The context attribute must be of type String[] '" + + RefreshTokenOAuth2AuthorizedClientProvider.REQUEST_SCOPE_ATTRIBUTE_NAME + "'"); } } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolverTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolverTests.java index 491cfb52daa..e47f4ad1502 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolverTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolverTests.java @@ -186,21 +186,22 @@ public void resolveArgumentWhenRegistrationIdEmptyAndOAuth2AuthenticationThenRes SecurityContextHolder.setContext(securityContext); MethodParameter methodParameter = this.getMethodParameter("registrationIdEmpty", OAuth2AuthorizedClient.class); assertThat(this.argumentResolver.resolveArgument( - methodParameter, null, new ServletWebRequest(this.request), null)).isSameAs(this.authorizedClient1); + methodParameter, null, new ServletWebRequest(this.request, this.response), null)).isSameAs(this.authorizedClient1); } @Test public void resolveArgumentWhenAuthorizedClientFoundThenResolves() throws Exception { MethodParameter methodParameter = this.getMethodParameter("paramTypeAuthorizedClient", OAuth2AuthorizedClient.class); assertThat(this.argumentResolver.resolveArgument( - methodParameter, null, new ServletWebRequest(this.request), null)).isSameAs(this.authorizedClient1); + methodParameter, null, new ServletWebRequest(this.request, this.response), null)).isSameAs(this.authorizedClient1); } @Test - public void resolveArgumentWhenRegistrationIdInvalidThenDoesNotResolve() throws Exception { + public void resolveArgumentWhenRegistrationIdInvalidThenThrowIllegalArgumentException() { MethodParameter methodParameter = this.getMethodParameter("registrationIdInvalid", OAuth2AuthorizedClient.class); - assertThat(this.argumentResolver.resolveArgument( - methodParameter, null, new ServletWebRequest(this.request), null)).isNull(); + assertThatThrownBy(() -> this.argumentResolver.resolveArgument(methodParameter, null, new ServletWebRequest(this.request, this.response), null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Could not find ClientRegistration with id 'invalid'"); } @Test @@ -208,7 +209,7 @@ public void resolveArgumentWhenAuthorizedClientNotFoundForAuthorizationCodeClien when(this.authorizedClientRepository.loadAuthorizedClient(anyString(), any(), any(HttpServletRequest.class))) .thenReturn(null); MethodParameter methodParameter = this.getMethodParameter("paramTypeAuthorizedClient", OAuth2AuthorizedClient.class); - assertThatThrownBy(() -> this.argumentResolver.resolveArgument(methodParameter, null, new ServletWebRequest(this.request), null)) + assertThatThrownBy(() -> this.argumentResolver.resolveArgument(methodParameter, null, new ServletWebRequest(this.request, this.response), null)) .isInstanceOf(ClientAuthorizationRequiredException.class); } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java index 0ad69a0e13e..f1f11e4f945 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java @@ -48,6 +48,7 @@ import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.oauth2.client.AuthorizationCodeOAuth2AuthorizedClientProvider; import org.springframework.security.oauth2.client.ClientCredentialsOAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.DefaultOAuth2AuthorizedClientProvider; import org.springframework.security.oauth2.client.DelegatingOAuth2AuthorizedClientProvider; import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProvider; @@ -73,6 +74,7 @@ import org.springframework.web.reactive.function.client.ClientRequest; import org.springframework.web.reactive.function.client.WebClient; +import javax.servlet.http.HttpServletRequest; import java.net.URI; import java.time.Duration; import java.time.Instant; @@ -136,18 +138,21 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests { @Before public void setup() { this.authentication = new TestingAuthenticationToken("test", "this"); - this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction( - this.clientRegistrationRepository, this.authorizedClientRepository); + AuthorizationCodeOAuth2AuthorizedClientProvider authorizationCodeAuthorizedClientProvider = + new AuthorizationCodeOAuth2AuthorizedClientProvider(this.clientRegistrationRepository, this.authorizedClientRepository); ClientCredentialsOAuth2AuthorizedClientProvider clientCredentialsAuthorizedClientProvider = new ClientCredentialsOAuth2AuthorizedClientProvider(this.clientRegistrationRepository, this.authorizedClientRepository); clientCredentialsAuthorizedClientProvider.setAccessTokenResponseClient(this.clientCredentialsTokenResponseClient); RefreshTokenOAuth2AuthorizedClientProvider refreshTokenAuthorizedClientProvider = new RefreshTokenOAuth2AuthorizedClientProvider(this.clientRegistrationRepository, this.authorizedClientRepository); refreshTokenAuthorizedClientProvider.setAccessTokenResponseClient(this.refreshTokenTokenResponseClient); - this.authorizedClientProvider = new DelegatingOAuth2AuthorizedClientProvider( - new AuthorizationCodeOAuth2AuthorizedClientProvider(), - clientCredentialsAuthorizedClientProvider, - refreshTokenAuthorizedClientProvider); + DelegatingOAuth2AuthorizedClientProvider delegate = new DelegatingOAuth2AuthorizedClientProvider( + authorizationCodeAuthorizedClientProvider, clientCredentialsAuthorizedClientProvider, refreshTokenAuthorizedClientProvider); + delegate.setDefaultAuthorizedClientProvider( + new DefaultOAuth2AuthorizedClientProvider(this.clientRegistrationRepository, this.authorizedClientRepository)); + this.authorizedClientProvider = delegate; + this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction( + this.clientRegistrationRepository, this.authorizedClientRepository); this.function.setAuthorizedClientProvider(this.authorizedClientProvider); } @@ -366,8 +371,17 @@ public void filterWhenAuthorizedClientNullThenAuthorizationHeaderNull() { public void filterWhenAuthorizedClientThenAuthorizationHeader() { OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", this.accessToken); + + when(this.clientRegistrationRepository.findByRegistrationId( + eq(this.registration.getRegistrationId()))).thenReturn(this.registration); + when(this.authorizedClientRepository.loadAuthorizedClient( + eq(this.registration.getRegistrationId()), any(Authentication.class), + any(HttpServletRequest.class))).thenReturn(authorizedClient); + ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) .attributes(oauth2AuthorizedClient(authorizedClient)) + .attributes(httpServletRequest(new MockHttpServletRequest())) + .attributes(httpServletResponse(new MockHttpServletResponse())) .build(); this.function.filter(request, this.exchange).block(); @@ -379,9 +393,18 @@ public void filterWhenAuthorizedClientThenAuthorizationHeader() { public void filterWhenExistingAuthorizationThenSingleAuthorizationHeader() { OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", this.accessToken); + + when(this.clientRegistrationRepository.findByRegistrationId( + eq(this.registration.getRegistrationId()))).thenReturn(this.registration); + when(this.authorizedClientRepository.loadAuthorizedClient( + eq(this.registration.getRegistrationId()), any(Authentication.class), + any(HttpServletRequest.class))).thenReturn(authorizedClient); + ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) .header(HttpHeaders.AUTHORIZATION, "Existing") .attributes(oauth2AuthorizedClient(authorizedClient)) + .attributes(httpServletRequest(new MockHttpServletRequest())) + .attributes(httpServletResponse(new MockHttpServletResponse())) .build(); this.function.filter(request, this.exchange).block(); @@ -398,17 +421,23 @@ public void filterWhenRefreshRequiredThenRefresh() { .refreshToken("refresh-1") .build(); when(this.refreshTokenTokenResponseClient.getTokenResponse(any())).thenReturn(response); + Instant issuedAt = Instant.now().minus(Duration.ofDays(1)); Instant accessTokenExpiresAt = issuedAt.plus(Duration.ofHours(1)); - this.accessToken = new OAuth2AccessToken(this.accessToken.getTokenType(), this.accessToken.getTokenValue(), issuedAt, accessTokenExpiresAt); - OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", issuedAt); OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", this.accessToken, refreshToken); + + when(this.clientRegistrationRepository.findByRegistrationId( + eq(this.registration.getRegistrationId()))).thenReturn(this.registration); + when(this.authorizedClientRepository.loadAuthorizedClient( + eq(this.registration.getRegistrationId()), eq(this.authentication), + any(HttpServletRequest.class))).thenReturn(authorizedClient); + ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) .attributes(oauth2AuthorizedClient(authorizedClient)) .attributes(authentication(this.authentication)) @@ -419,7 +448,8 @@ public void filterWhenRefreshRequiredThenRefresh() { this.function.filter(request, this.exchange).block(); verify(this.refreshTokenTokenResponseClient).getTokenResponse(any()); - verify(this.authorizedClientRepository).saveAuthorizedClient(this.authorizedClientCaptor.capture(), eq(this.authentication), any(), any()); + verify(this.authorizedClientRepository).saveAuthorizedClient( + this.authorizedClientCaptor.capture(), eq(this.authentication), any(), any()); OAuth2AuthorizedClient newAuthorizedClient = authorizedClientCaptor.getValue(); assertThat(newAuthorizedClient.getAccessToken()).isEqualTo(response.getAccessToken()); @@ -456,15 +486,20 @@ public void filterWhenRefreshRequiredThenRefreshAndResponseDoesNotContainRefresh Instant issuedAt = Instant.now().minus(Duration.ofDays(1)); Instant accessTokenExpiresAt = issuedAt.plus(Duration.ofHours(1)); - this.accessToken = new OAuth2AccessToken(this.accessToken.getTokenType(), this.accessToken.getTokenValue(), issuedAt, accessTokenExpiresAt); - OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", issuedAt); OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", this.accessToken, refreshToken); + + when(this.clientRegistrationRepository.findByRegistrationId( + eq(this.registration.getRegistrationId()))).thenReturn(this.registration); + when(this.authorizedClientRepository.loadAuthorizedClient( + eq(this.registration.getRegistrationId()), eq(this.authentication), + any(HttpServletRequest.class))).thenReturn(authorizedClient); + ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) .attributes(oauth2AuthorizedClient(authorizedClient)) .attributes(authentication(this.authentication)) @@ -496,9 +531,18 @@ public void filterWhenClientCredentialsTokenNotExpiredThenUseCurrentToken() { this.registration = TestClientRegistrations.clientCredentials().build(); OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", this.accessToken, null); + + when(this.clientRegistrationRepository.findByRegistrationId( + eq(this.registration.getRegistrationId()))).thenReturn(this.registration); + when(this.authorizedClientRepository.loadAuthorizedClient( + eq(this.registration.getRegistrationId()), eq(this.authentication), + any(HttpServletRequest.class))).thenReturn(authorizedClient); + ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) .attributes(oauth2AuthorizedClient(authorizedClient)) .attributes(authentication(this.authentication)) + .attributes(httpServletRequest(new MockHttpServletRequest())) + .attributes(httpServletResponse(new MockHttpServletResponse())) .build(); this.function.filter(request, this.exchange).block(); @@ -528,7 +572,6 @@ public void filterWhenClientCredentialsTokenExpiredThenGetNewToken() { Instant issuedAt = Instant.now().minus(Duration.ofDays(1)); Instant accessTokenExpiresAt = issuedAt.plus(Duration.ofHours(1)); - this.accessToken = new OAuth2AccessToken(this.accessToken.getTokenType(), this.accessToken.getTokenValue(), issuedAt, @@ -536,6 +579,13 @@ public void filterWhenClientCredentialsTokenExpiredThenGetNewToken() { OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", this.accessToken, null); + + when(this.clientRegistrationRepository.findByRegistrationId( + eq(this.registration.getRegistrationId()))).thenReturn(this.registration); + when(this.authorizedClientRepository.loadAuthorizedClient( + eq(this.registration.getRegistrationId()), eq(this.authentication), + any(HttpServletRequest.class))).thenReturn(authorizedClient); + ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) .attributes(oauth2AuthorizedClient(authorizedClient)) .attributes(authentication(this.authentication)) @@ -567,17 +617,21 @@ public void filterWhenRefreshRequiredAndEmptyReactiveSecurityContextThenSaved() .refreshToken("refresh-1") .build(); when(this.refreshTokenTokenResponseClient.getTokenResponse(any())).thenReturn(response); + Instant issuedAt = Instant.now().minus(Duration.ofDays(1)); Instant accessTokenExpiresAt = issuedAt.plus(Duration.ofHours(1)); - this.accessToken = new OAuth2AccessToken(this.accessToken.getTokenType(), - this.accessToken.getTokenValue(), - issuedAt, - accessTokenExpiresAt); - + this.accessToken.getTokenValue(), issuedAt, accessTokenExpiresAt); OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", issuedAt); OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", this.accessToken, refreshToken); + + when(this.clientRegistrationRepository.findByRegistrationId( + eq(this.registration.getRegistrationId()))).thenReturn(this.registration); + when(this.authorizedClientRepository.loadAuthorizedClient( + eq(this.registration.getRegistrationId()), any(Authentication.class), + any(HttpServletRequest.class))).thenReturn(authorizedClient); + ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) .attributes(oauth2AuthorizedClient(authorizedClient)) .attributes(httpServletRequest(new MockHttpServletRequest())) @@ -603,8 +657,17 @@ public void filterWhenRefreshRequiredAndEmptyReactiveSecurityContextThenSaved() public void filterWhenRefreshTokenNullThenShouldRefreshFalse() { OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", this.accessToken); + + when(this.clientRegistrationRepository.findByRegistrationId( + eq(this.registration.getRegistrationId()))).thenReturn(this.registration); + when(this.authorizedClientRepository.loadAuthorizedClient( + eq(this.registration.getRegistrationId()), any(Authentication.class), + any(HttpServletRequest.class))).thenReturn(authorizedClient); + ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) .attributes(oauth2AuthorizedClient(authorizedClient)) + .attributes(httpServletRequest(new MockHttpServletRequest())) + .attributes(httpServletResponse(new MockHttpServletResponse())) .build(); this.function.filter(request, this.exchange).block(); @@ -624,8 +687,17 @@ public void filterWhenNotExpiredThenShouldRefreshFalse() { OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt()); OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", this.accessToken, refreshToken); + + when(this.clientRegistrationRepository.findByRegistrationId( + eq(this.registration.getRegistrationId()))).thenReturn(this.registration); + when(this.authorizedClientRepository.loadAuthorizedClient( + eq(this.registration.getRegistrationId()), any(Authentication.class), + any(HttpServletRequest.class))).thenReturn(authorizedClient); + ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) .attributes(oauth2AuthorizedClient(authorizedClient)) + .attributes(httpServletRequest(new MockHttpServletRequest())) + .attributes(httpServletResponse(new MockHttpServletResponse())) .build(); this.function.filter(request, this.exchange).block(); @@ -656,8 +728,12 @@ public void filterWhenChainedThenDefaultsStillAvailable() throws Exception { user, authorities, this.registration.getRegistrationId()); SecurityContextHolder.getContext().setAuthentication(authentication); + when(this.clientRegistrationRepository.findByRegistrationId( + eq(this.registration.getRegistrationId()))).thenReturn(this.registration); + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( this.registration, "principalName", this.accessToken); + when(this.authorizedClientRepository.loadAuthorizedClient(eq(authentication.getAuthorizedClientRegistrationId()), eq(authentication), eq(servletRequest))).thenReturn(authorizedClient); From 8f046db6042b42fb4d41e7ba74af6aa7c02a9d9a Mon Sep 17 00:00:00 2001 From: Joe Grandja Date: Thu, 4 Jul 2019 09:42:56 -0400 Subject: [PATCH 13/19] Introduce OAuth2AuthorizedClientManager - #73 Introduce OAuth2AuthorizedClientManager - #74 Integrate OAuth2AuthorizedClientManager with OAuth2AuthorizedClientProvider(s) - #81 Add builder for OAuth2AuthorizedClientProvider --- .../OAuth2ClientConfiguration.java | 33 +-- .../OAuth2ClientConfigurationTests.java | 1 + ...ionCodeOAuth2AuthorizedClientProvider.java | 52 +--- ...entialsOAuth2AuthorizedClientProvider.java | 58 +--- ...DefaultOAuth2AuthorizedClientProvider.java | 88 ------ ...egatingOAuth2AuthorizedClientProvider.java | 18 +- .../client/OAuth2AuthorizationContext.java | 124 ++++---- .../OAuth2AuthorizedClientProvider.java | 4 +- ...OAuth2AuthorizedClientProviderBuilder.java | 267 ++++++++++++++++++ ...shTokenOAuth2AuthorizedClientProvider.java | 51 +--- .../DefaultOAuth2AuthorizedClientManager.java | 144 ++++++++++ .../web/OAuth2AuthorizedClientManager.java | 81 ++++++ ...Auth2AuthorizedClientArgumentResolver.java | 82 +++--- ...uthorizedClientExchangeFilterFunction.java | 175 +++++++----- ...deOAuth2AuthorizedClientProviderTests.java | 91 +----- ...lsOAuth2AuthorizedClientProviderTests.java | 106 +------ ...ltOAuth2AuthorizedClientProviderTests.java | 137 --------- ...ngOAuth2AuthorizedClientProviderTests.java | 31 +- .../OAuth2AuthorizationContextTests.java | 25 +- ...2AuthorizedClientProviderBuilderTests.java | 202 +++++++++++++ ...enOAuth2AuthorizedClientProviderTests.java | 132 +-------- ...ultOAuth2AuthorizedClientManagerTests.java | 265 +++++++++++++++++ ...AuthorizedClientArgumentResolverTests.java | 13 +- ...izedClientExchangeFilterFunctionTests.java | 109 ++----- 24 files changed, 1265 insertions(+), 1024 deletions(-) delete mode 100644 oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/DefaultOAuth2AuthorizedClientProvider.java create mode 100644 oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientProviderBuilder.java create mode 100644 oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizedClientManager.java create mode 100644 oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizedClientManager.java delete mode 100644 oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/DefaultOAuth2AuthorizedClientProviderTests.java create mode 100644 oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientProviderBuilderTests.java create mode 100644 oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizedClientManagerTests.java diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfiguration.java b/config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfiguration.java index be486ffb545..0efadd2824a 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfiguration.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfiguration.java @@ -20,14 +20,12 @@ import org.springframework.context.annotation.Import; import org.springframework.context.annotation.ImportSelector; import org.springframework.core.type.AnnotationMetadata; -import org.springframework.security.oauth2.client.AuthorizationCodeOAuth2AuthorizedClientProvider; -import org.springframework.security.oauth2.client.ClientCredentialsOAuth2AuthorizedClientProvider; -import org.springframework.security.oauth2.client.DefaultOAuth2AuthorizedClientProvider; -import org.springframework.security.oauth2.client.DelegatingOAuth2AuthorizedClientProvider; -import org.springframework.security.oauth2.client.RefreshTokenOAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProviderBuilder; import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest; import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; +import org.springframework.security.oauth2.client.web.DefaultOAuth2AuthorizedClientManager; import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.client.web.method.annotation.OAuth2AuthorizedClientArgumentResolver; import org.springframework.util.ClassUtils; @@ -77,21 +75,16 @@ public void addArgumentResolvers(List argumentRes new OAuth2AuthorizedClientArgumentResolver( this.clientRegistrationRepository, this.authorizedClientRepository); if (this.accessTokenResponseClient != null) { - ClientCredentialsOAuth2AuthorizedClientProvider clientCredentialsAuthorizedClientProvider = - new ClientCredentialsOAuth2AuthorizedClientProvider( - this.clientRegistrationRepository, this.authorizedClientRepository); - clientCredentialsAuthorizedClientProvider.setAccessTokenResponseClient(this.accessTokenResponseClient); - AuthorizationCodeOAuth2AuthorizedClientProvider authorizationCodeAuthorizedClientProvider = - new AuthorizationCodeOAuth2AuthorizedClientProvider( - this.clientRegistrationRepository, this.authorizedClientRepository); - RefreshTokenOAuth2AuthorizedClientProvider refreshTokenAuthorizedClientProvider = - new RefreshTokenOAuth2AuthorizedClientProvider( - this.clientRegistrationRepository, this.authorizedClientRepository); - DelegatingOAuth2AuthorizedClientProvider authorizedClientProvider = new DelegatingOAuth2AuthorizedClientProvider( - authorizationCodeAuthorizedClientProvider, refreshTokenAuthorizedClientProvider, clientCredentialsAuthorizedClientProvider); - authorizedClientProvider.setDefaultAuthorizedClientProvider( - new DefaultOAuth2AuthorizedClientProvider(this.clientRegistrationRepository, this.authorizedClientRepository)); - authorizedClientArgumentResolver.setAuthorizedClientProvider(authorizedClientProvider); + OAuth2AuthorizedClientProvider authorizedClientProvider = + OAuth2AuthorizedClientProviderBuilder.withProvider() + .authorizationCode() + .refreshToken() + .clientCredentials(configurer -> configurer.accessTokenResponseClient(this.accessTokenResponseClient)) + .build(); + DefaultOAuth2AuthorizedClientManager authorizedClientManager = new DefaultOAuth2AuthorizedClientManager( + this.clientRegistrationRepository, this.authorizedClientRepository); + authorizedClientManager.setAuthorizedClientProvider(authorizedClientProvider); + authorizedClientArgumentResolver.setAuthorizedClientManager(authorizedClientManager); } argumentResolvers.add(authorizedClientArgumentResolver); } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfigurationTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfigurationTests.java index 43f9523ee7a..d96df5495aa 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfigurationTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfigurationTests.java @@ -75,6 +75,7 @@ public void requestWhenAuthorizedClientFoundThenMethodArgumentResolved() throws OAuth2AuthorizedClientRepository authorizedClientRepository = mock(OAuth2AuthorizedClientRepository.class); OAuth2AuthorizedClient authorizedClient = mock(OAuth2AuthorizedClient.class); + when(authorizedClient.getClientRegistration()).thenReturn(clientRegistration); when(authorizedClientRepository.loadAuthorizedClient( eq(clientRegistrationId), eq(authentication), any(HttpServletRequest.class))) .thenReturn(authorizedClient); diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/AuthorizationCodeOAuth2AuthorizedClientProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/AuthorizationCodeOAuth2AuthorizedClientProvider.java index eb775e25eda..8472472d79c 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/AuthorizationCodeOAuth2AuthorizedClientProvider.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/AuthorizationCodeOAuth2AuthorizedClientProvider.java @@ -17,14 +17,9 @@ import org.springframework.lang.Nullable; import org.springframework.security.oauth2.client.registration.ClientRegistration; -import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; -import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.util.Assert; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; - /** * An implementation of an {@link OAuth2AuthorizedClientProvider} * for the {@link AuthorizationGrantType#AUTHORIZATION_CODE authorization_code} grant. @@ -34,38 +29,16 @@ * @see OAuth2AuthorizedClientProvider */ public final class AuthorizationCodeOAuth2AuthorizedClientProvider implements OAuth2AuthorizedClientProvider { - private static final String HTTP_SERVLET_REQUEST_ATTRIBUTE_NAME = HttpServletRequest.class.getName(); - private static final String HTTP_SERVLET_RESPONSE_ATTRIBUTE_NAME = HttpServletResponse.class.getName(); - private final ClientRegistrationRepository clientRegistrationRepository; - private final OAuth2AuthorizedClientRepository authorizedClientRepository; - /** - * Constructs an {@code AuthorizationCodeOAuth2AuthorizedClientProvider} using the provided parameters. - * - * @param clientRegistrationRepository the repository of client registrations - * @param authorizedClientRepository the repository of authorized clients - */ - public AuthorizationCodeOAuth2AuthorizedClientProvider(ClientRegistrationRepository clientRegistrationRepository, - OAuth2AuthorizedClientRepository authorizedClientRepository) { - Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null"); - Assert.notNull(authorizedClientRepository, "authorizedClientRepository cannot be null"); - this.clientRegistrationRepository = clientRegistrationRepository; - this.authorizedClientRepository = authorizedClientRepository; + public AuthorizationCodeOAuth2AuthorizedClientProvider() { } /** - * Attempt to authorize the {@link OAuth2AuthorizationContext#getClientRegistrationId() client} in the provided {@code context}. + * Attempt to authorize the {@link OAuth2AuthorizationContext#getClientRegistration() client} in the provided {@code context}. * Returns {@code null} if authorization is not supported, * e.g. the client's {@link ClientRegistration#getAuthorizationGrantType() authorization grant type} * is not {@link AuthorizationGrantType#AUTHORIZATION_CODE authorization_code} OR the client is already authorized. * - *

- * The following {@link OAuth2AuthorizationContext#getAttributes() context attributes} are supported: - *

    - *
  1. {@code "javax.servlet.http.HttpServletRequest"} (required) - the {@code HttpServletRequest}
  2. - *
  3. {@code "javax.servlet.http.HttpServletResponse"} (required) - the {@code HttpServletResponse}
  4. - *
- * * @param context the context that holds authorization-specific state for the client * @return the {@link OAuth2AuthorizedClient} or {@code null} if authorization is not supported */ @@ -74,26 +47,11 @@ public AuthorizationCodeOAuth2AuthorizedClientProvider(ClientRegistrationReposit public OAuth2AuthorizedClient authorize(OAuth2AuthorizationContext context) { Assert.notNull(context, "context cannot be null"); - HttpServletRequest request = context.getAttribute(HTTP_SERVLET_REQUEST_ATTRIBUTE_NAME); - HttpServletResponse response = context.getAttribute(HTTP_SERVLET_RESPONSE_ATTRIBUTE_NAME); - Assert.notNull(request, "The context attribute cannot be null '" + HTTP_SERVLET_REQUEST_ATTRIBUTE_NAME + "'"); - Assert.notNull(response, "The context attribute cannot be null '" + HTTP_SERVLET_RESPONSE_ATTRIBUTE_NAME + "'"); - - String clientRegistrationId = context.getClientRegistrationId(); - ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId); - Assert.notNull(clientRegistration, "Could not find ClientRegistration with id '" + clientRegistrationId + "'"); - - if (!AuthorizationGrantType.AUTHORIZATION_CODE.equals(clientRegistration.getAuthorizationGrantType())) { - return null; - } - - OAuth2AuthorizedClient authorizedClient = this.authorizedClientRepository.loadAuthorizedClient( - clientRegistrationId, context.getPrincipal(), request); - if (authorizedClient == null) { + if (AuthorizationGrantType.AUTHORIZATION_CODE.equals(context.getClientRegistration().getAuthorizationGrantType()) && + context.getAuthorizedClient() == null) { // ClientAuthorizationRequiredException is caught by OAuth2AuthorizationRequestRedirectFilter which initiates authorization - throw new ClientAuthorizationRequiredException(clientRegistrationId); + throw new ClientAuthorizationRequiredException(context.getClientRegistration().getRegistrationId()); } - return null; } } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ClientCredentialsOAuth2AuthorizedClientProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ClientCredentialsOAuth2AuthorizedClientProvider.java index 62de43008b6..c2087ea15e6 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ClientCredentialsOAuth2AuthorizedClientProvider.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ClientCredentialsOAuth2AuthorizedClientProvider.java @@ -20,15 +20,11 @@ import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest; import org.springframework.security.oauth2.client.registration.ClientRegistration; -import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; -import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.core.AbstractOAuth2Token; import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.util.Assert; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; import java.time.Duration; import java.time.Instant; @@ -42,42 +38,20 @@ * @see DefaultClientCredentialsTokenResponseClient */ public final class ClientCredentialsOAuth2AuthorizedClientProvider implements OAuth2AuthorizedClientProvider { - private static final String HTTP_SERVLET_REQUEST_ATTRIBUTE_NAME = HttpServletRequest.class.getName(); - private static final String HTTP_SERVLET_RESPONSE_ATTRIBUTE_NAME = HttpServletResponse.class.getName(); - private final ClientRegistrationRepository clientRegistrationRepository; - private final OAuth2AuthorizedClientRepository authorizedClientRepository; private OAuth2AccessTokenResponseClient accessTokenResponseClient = new DefaultClientCredentialsTokenResponseClient(); private Duration clockSkew = Duration.ofSeconds(60); - /** - * Constructs a {@code ClientCredentialsOAuth2AuthorizedClientProvider} using the provided parameters. - * - * @param clientRegistrationRepository the repository of client registrations - * @param authorizedClientRepository the repository of authorized clients - */ - public ClientCredentialsOAuth2AuthorizedClientProvider(ClientRegistrationRepository clientRegistrationRepository, - OAuth2AuthorizedClientRepository authorizedClientRepository) { - Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null"); - Assert.notNull(authorizedClientRepository, "authorizedClientRepository cannot be null"); - this.clientRegistrationRepository = clientRegistrationRepository; - this.authorizedClientRepository = authorizedClientRepository; + public ClientCredentialsOAuth2AuthorizedClientProvider() { } /** - * Attempt to authorize (or re-authorize) the {@link OAuth2AuthorizationContext#getClientRegistrationId() client} in the provided {@code context}. + * Attempt to authorize (or re-authorize) the {@link OAuth2AuthorizationContext#getClientRegistration() client} in the provided {@code context}. * Returns {@code null} if authorization (or re-authorization) is not supported, * e.g. the client's {@link ClientRegistration#getAuthorizationGrantType() authorization grant type} * is not {@link AuthorizationGrantType#CLIENT_CREDENTIALS client_credentials} OR * the {@link OAuth2AuthorizedClient#getAccessToken() access token} is not expired. * - *

- * The following {@link OAuth2AuthorizationContext#getAttributes() context attributes} are supported: - *

    - *
  1. {@code "javax.servlet.http.HttpServletRequest"} (required) - the {@code HttpServletRequest}
  2. - *
  3. {@code "javax.servlet.http.HttpServletResponse"} (required) - the {@code HttpServletResponse}
  4. - *
- * * @param context the context that holds authorization-specific state for the client * @return the {@link OAuth2AuthorizedClient} or {@code null} if authorization (or re-authorization) is not supported */ @@ -86,22 +60,10 @@ public ClientCredentialsOAuth2AuthorizedClientProvider(ClientRegistrationReposit public OAuth2AuthorizedClient authorize(OAuth2AuthorizationContext context) { Assert.notNull(context, "context cannot be null"); - HttpServletRequest request = context.getAttribute(HTTP_SERVLET_REQUEST_ATTRIBUTE_NAME); - HttpServletResponse response = context.getAttribute(HTTP_SERVLET_RESPONSE_ATTRIBUTE_NAME); - Assert.notNull(request, "The context attribute cannot be null '" + HTTP_SERVLET_REQUEST_ATTRIBUTE_NAME + "'"); - Assert.notNull(response, "The context attribute cannot be null '" + HTTP_SERVLET_RESPONSE_ATTRIBUTE_NAME + "'"); - - String clientRegistrationId = context.getClientRegistrationId(); - ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId); - Assert.notNull(clientRegistration, "Could not find ClientRegistration with id '" + clientRegistrationId + "'"); - - if (!AuthorizationGrantType.CLIENT_CREDENTIALS.equals(clientRegistration.getAuthorizationGrantType())) { - return null; - } - - OAuth2AuthorizedClient authorizedClient = this.authorizedClientRepository.loadAuthorizedClient( - clientRegistrationId, context.getPrincipal(), request); - if (authorizedClient != null && !hasTokenExpired(authorizedClient.getAccessToken())) { + ClientRegistration clientRegistration = context.getClientRegistration(); + OAuth2AuthorizedClient authorizedClient = context.getAuthorizedClient(); + if (!AuthorizationGrantType.CLIENT_CREDENTIALS.equals(clientRegistration.getAuthorizationGrantType()) || + (authorizedClient != null && !hasTokenExpired(authorizedClient.getAccessToken()))) { return null; } @@ -117,13 +79,7 @@ public OAuth2AuthorizedClient authorize(OAuth2AuthorizationContext context) { OAuth2AccessTokenResponse tokenResponse = this.accessTokenResponseClient.getTokenResponse(clientCredentialsGrantRequest); - authorizedClient = new OAuth2AuthorizedClient( - clientRegistration, context.getPrincipal().getName(), tokenResponse.getAccessToken()); - - this.authorizedClientRepository.saveAuthorizedClient( - authorizedClient, context.getPrincipal(), request, response); - - return authorizedClient; + return new OAuth2AuthorizedClient(clientRegistration, context.getPrincipal().getName(), tokenResponse.getAccessToken()); } private boolean hasTokenExpired(AbstractOAuth2Token token) { diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/DefaultOAuth2AuthorizedClientProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/DefaultOAuth2AuthorizedClientProvider.java deleted file mode 100644 index 494f14f96dd..00000000000 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/DefaultOAuth2AuthorizedClientProvider.java +++ /dev/null @@ -1,88 +0,0 @@ -/* - * Copyright 2002-2019 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.springframework.security.oauth2.client; - -import org.springframework.lang.Nullable; -import org.springframework.security.core.Authentication; -import org.springframework.security.oauth2.client.registration.ClientRegistration; -import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; -import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; -import org.springframework.util.Assert; - -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; - -/** - * The default implementation of an {@link OAuth2AuthorizedClientProvider} that simply - * {@link OAuth2AuthorizedClientRepository#loadAuthorizedClient(String, Authentication, HttpServletRequest) loads} - * an {@link OAuth2AuthorizedClient} from the authorized client repository. - * - * @author Joe Grandja - * @since 5.2 - * @see OAuth2AuthorizedClientProvider - */ -public final class DefaultOAuth2AuthorizedClientProvider implements OAuth2AuthorizedClientProvider { - private static final String HTTP_SERVLET_REQUEST_ATTRIBUTE_NAME = HttpServletRequest.class.getName(); - private static final String HTTP_SERVLET_RESPONSE_ATTRIBUTE_NAME = HttpServletResponse.class.getName(); - private final ClientRegistrationRepository clientRegistrationRepository; - private final OAuth2AuthorizedClientRepository authorizedClientRepository; - - /** - * Constructs an {@code DefaultOAuth2AuthorizedClientProvider} using the provided parameters. - * - * @param clientRegistrationRepository the repository of client registrations - * @param authorizedClientRepository the repository of authorized clients - */ - public DefaultOAuth2AuthorizedClientProvider(ClientRegistrationRepository clientRegistrationRepository, - OAuth2AuthorizedClientRepository authorizedClientRepository) { - Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null"); - Assert.notNull(authorizedClientRepository, "authorizedClientRepository cannot be null"); - this.clientRegistrationRepository = clientRegistrationRepository; - this.authorizedClientRepository = authorizedClientRepository; - } - - /** - * Attempts to {@link OAuth2AuthorizedClientRepository#loadAuthorizedClient(String, Authentication, HttpServletRequest) load} - * an {@link OAuth2AuthorizedClient} using the {@link OAuth2AuthorizationContext#getClientRegistrationId() client} - * in the provided {@code context}. Returns {@code null} if the client is not authorized. - * - *

- * The following {@link OAuth2AuthorizationContext#getAttributes() context attributes} are supported: - *

    - *
  1. {@code "javax.servlet.http.HttpServletRequest"} (required) - the {@code HttpServletRequest}
  2. - *
  3. {@code "javax.servlet.http.HttpServletResponse"} (required) - the {@code HttpServletResponse}
  4. - *
- * - * @param context the context that holds authorization-specific state for the client - * @return the {@link OAuth2AuthorizedClient} or {@code null} if the client is not authorized - */ - @Override - @Nullable - public OAuth2AuthorizedClient authorize(OAuth2AuthorizationContext context) { - Assert.notNull(context, "context cannot be null"); - - HttpServletRequest request = context.getAttribute(HTTP_SERVLET_REQUEST_ATTRIBUTE_NAME); - HttpServletResponse response = context.getAttribute(HTTP_SERVLET_RESPONSE_ATTRIBUTE_NAME); - Assert.notNull(request, "The context attribute cannot be null '" + HTTP_SERVLET_REQUEST_ATTRIBUTE_NAME + "'"); - Assert.notNull(response, "The context attribute cannot be null '" + HTTP_SERVLET_RESPONSE_ATTRIBUTE_NAME + "'"); - - String clientRegistrationId = context.getClientRegistrationId(); - ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId); - Assert.notNull(clientRegistration, "Could not find ClientRegistration with id '" + clientRegistrationId + "'"); - - return this.authorizedClientRepository.loadAuthorizedClient(clientRegistrationId, context.getPrincipal(), request); - } -} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/DelegatingOAuth2AuthorizedClientProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/DelegatingOAuth2AuthorizedClientProvider.java index 3167abbc9be..0343b96071c 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/DelegatingOAuth2AuthorizedClientProvider.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/DelegatingOAuth2AuthorizedClientProvider.java @@ -30,7 +30,7 @@ *

* Each provider is given a chance to * {@link OAuth2AuthorizedClientProvider#authorize(OAuth2AuthorizationContext) authorize} - * the {@link OAuth2AuthorizationContext#getClientRegistrationId() client} in the provided context + * the {@link OAuth2AuthorizationContext#getClientRegistration() client} in the provided context * with the first {@code non-null} {@link OAuth2AuthorizedClient} being returned. * * @author Joe Grandja @@ -39,7 +39,6 @@ */ public final class DelegatingOAuth2AuthorizedClientProvider implements OAuth2AuthorizedClientProvider { private final List authorizedClientProviders; - private OAuth2AuthorizedClientProvider defaultAuthorizedClientProvider; /** * Constructs a {@code DelegatingOAuth2AuthorizedClientProvider} using the provided parameters. @@ -69,19 +68,6 @@ public OAuth2AuthorizedClient authorize(OAuth2AuthorizationContext context) { .map(authorizedClientProvider -> authorizedClientProvider.authorize(context)) .filter(Objects::nonNull) .findFirst() - .orElse(this.defaultAuthorizedClientProvider != null ? - this.defaultAuthorizedClientProvider.authorize(context) : null); - } - - /** - * Sets the default {@link OAuth2AuthorizedClientProvider} used if none of the - * {@link OAuth2AuthorizedClientProvider}(s) in the {@code List} - * are able to authorize the {@link OAuth2AuthorizationContext#getClientRegistrationId() client}. - * - * @param authorizedClientProvider the default {@link OAuth2AuthorizedClientProvider} - */ - public void setDefaultAuthorizedClientProvider(OAuth2AuthorizedClientProvider authorizedClientProvider) { - Assert.notNull(authorizedClientProvider, "authorizedClientProvider cannot be null"); - this.defaultAuthorizedClientProvider = authorizedClientProvider; + .orElse(null); } } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizationContext.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizationContext.java index 55790d962cd..a2f70645ce6 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizationContext.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizationContext.java @@ -17,12 +17,10 @@ import org.springframework.lang.Nullable; import org.springframework.security.core.Authentication; -import org.springframework.security.core.GrantedAuthority; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; -import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.LinkedHashMap; @@ -37,7 +35,8 @@ * @see OAuth2AuthorizedClientProvider */ public final class OAuth2AuthorizationContext { - private String clientRegistrationId; + private ClientRegistration clientRegistration; + private OAuth2AuthorizedClient authorizedClient; private Authentication principal; private Map attributes; @@ -45,12 +44,23 @@ private OAuth2AuthorizationContext() { } /** - * Returns the {@link ClientRegistration client registration} identifier. + * Returns the {@link ClientRegistration client registration}. * - * @return the client registration identifier + * @return the {@link ClientRegistration} */ - public String getClientRegistrationId() { - return this.clientRegistrationId; + public ClientRegistration getClientRegistration() { + return this.clientRegistration; + } + + /** + * Returns the {@link OAuth2AuthorizedClient authorized client} or {@code null} + * if the {@link #forClient(ClientRegistration) client registration} was supplied. + * + * @return the {@link OAuth2AuthorizedClient} or {@code null} if the client registration was supplied + */ + @Nullable + public OAuth2AuthorizedClient getAuthorizedClient() { + return this.authorizedClient; } /** @@ -72,7 +82,7 @@ public Map getAttributes() { } /** - * Returns the value of an attribute associated to the context, or {@code null} if not available. + * Returns the value of an attribute associated to the context or {@code null} if not available. * * @param name the name of the attribute * @param the type of the attribute @@ -85,26 +95,42 @@ public T getAttribute(String name) { } /** - * Returns a new {@link Builder} initialized with the {@link ClientRegistration client registration} identifier. + * Returns a new {@link Builder} initialized with the {@link ClientRegistration}. * - * @param clientRegistrationId the {@link ClientRegistration client registration} identifier + * @param clientRegistration the {@link ClientRegistration client registration} * @return the {@link Builder} */ - public static Builder forClient(String clientRegistrationId) { - return new Builder(clientRegistrationId); + public static Builder forClient(ClientRegistration clientRegistration) { + return new Builder(clientRegistration); + } + + /** + * Returns a new {@link Builder} initialized with the {@link OAuth2AuthorizedClient}. + * + * @param authorizedClient the {@link OAuth2AuthorizedClient authorized client} + * @return the {@link Builder} + */ + public static Builder forClient(OAuth2AuthorizedClient authorizedClient) { + return new Builder(authorizedClient); } /** * A builder for {@link OAuth2AuthorizationContext}. */ public static class Builder { - private String clientRegistrationId; + private ClientRegistration clientRegistration; + private OAuth2AuthorizedClient authorizedClient; private Authentication principal; private Map attributes; - private Builder(String clientRegistrationId) { - Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty"); - this.clientRegistrationId = clientRegistrationId; + private Builder(ClientRegistration clientRegistration) { + Assert.notNull(clientRegistration, "clientRegistration cannot be null"); + this.clientRegistration = clientRegistration; + } + + private Builder(OAuth2AuthorizedClient authorizedClient) { + Assert.notNull(authorizedClient, "authorizedClient cannot be null"); + this.authorizedClient = authorizedClient; } /** @@ -118,17 +144,6 @@ public Builder principal(Authentication principal) { return this; } - /** - * Sets the {@code Principal}'s name (to be) associated to the authorized client. - * - * @param principalName the {@code Principal}'s name (to be) associated to the authorized client - * @return the {@link Builder} - */ - public Builder principal(String principalName) { - this.principal = new PrincipalNameAuthentication(principalName); - return this; - } - /** * Sets the attributes associated to the context. * @@ -163,7 +178,12 @@ public Builder attribute(String name, Object value) { public OAuth2AuthorizationContext build() { Assert.notNull(this.principal, "principal cannot be null"); OAuth2AuthorizationContext context = new OAuth2AuthorizationContext(); - context.clientRegistrationId = this.clientRegistrationId; + if (this.authorizedClient != null) { + context.clientRegistration = this.authorizedClient.getClientRegistration(); + context.authorizedClient = this.authorizedClient; + } else { + context.clientRegistration = this.clientRegistration; + } context.principal = this.principal; context.attributes = Collections.unmodifiableMap( CollectionUtils.isEmpty(this.attributes) ? @@ -171,52 +191,4 @@ public OAuth2AuthorizationContext build() { return context; } } - - private static class PrincipalNameAuthentication implements Authentication { - private final String principalName; - - private PrincipalNameAuthentication(String principalName) { - Assert.hasText(principalName, "principalName cannot be empty"); - this.principalName = principalName; - } - - @Override - public Collection getAuthorities() { - throw unsupported(); - } - - @Override - public Object getCredentials() { - throw unsupported(); - } - - @Override - public Object getDetails() { - throw unsupported(); - } - - @Override - public Object getPrincipal() { - return getName(); - } - - @Override - public boolean isAuthenticated() { - throw unsupported(); - } - - @Override - public void setAuthenticated(boolean isAuthenticated) throws IllegalArgumentException { - throw unsupported(); - } - - @Override - public String getName() { - return this.principalName; - } - - private UnsupportedOperationException unsupported() { - return new UnsupportedOperationException("Not Supported"); - } - } } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientProvider.java index a063fe976fb..b73fc8aafcb 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientProvider.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientProvider.java @@ -15,9 +15,9 @@ */ package org.springframework.security.oauth2.client; +import org.springframework.lang.Nullable; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.core.AuthorizationGrantType; -import reactor.util.annotation.Nullable; /** * A strategy for authorizing (or re-authorizing) an OAuth 2.0 Client. @@ -32,7 +32,7 @@ public interface OAuth2AuthorizedClientProvider { /** - * Attempt to authorize (or re-authorize) the {@link OAuth2AuthorizationContext#getClientRegistrationId() client} in the provided context. + * Attempt to authorize (or re-authorize) the {@link OAuth2AuthorizationContext#getClientRegistration() client} in the provided context. * Implementations must return {@code null} if authorization is not supported for the specified client, * e.g. the provider doesn't support the {@link ClientRegistration#getAuthorizationGrantType() authorization grant} type configured for the client. * diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientProviderBuilder.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientProviderBuilder.java new file mode 100644 index 00000000000..ef554d55aef --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientProviderBuilder.java @@ -0,0 +1,267 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client; + +import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; +import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest; +import org.springframework.security.oauth2.client.endpoint.OAuth2RefreshTokenGrantRequest; +import org.springframework.util.Assert; + +import java.time.Duration; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Consumer; +import java.util.stream.Collectors; + +/** + * A builder that builds a {@link DelegatingOAuth2AuthorizedClientProvider} composed of + * one or more {@link OAuth2AuthorizedClientProvider}(s) that implement specific authorization grants. + * The supported authorization grants are {@link #authorizationCode() authorization_code}, + * {@link #refreshToken() refresh_token} and {@link #clientCredentials() client_credentials}. + * In addition to the standard authorization grants, an implementation of an extension grant + * may be supplied via {@link #provider(OAuth2AuthorizedClientProvider)}. + * + * @author Joe Grandja + * @since 5.2 + * @see OAuth2AuthorizedClientProvider + * @see AuthorizationCodeOAuth2AuthorizedClientProvider + * @see RefreshTokenOAuth2AuthorizedClientProvider + * @see ClientCredentialsOAuth2AuthorizedClientProvider + * @see DelegatingOAuth2AuthorizedClientProvider + */ +public final class OAuth2AuthorizedClientProviderBuilder { + private final Map, Builder> builders = new HashMap<>(); + + private OAuth2AuthorizedClientProviderBuilder() { + } + + /** + * Returns a new {@link OAuth2AuthorizedClientProviderBuilder} for configuring the supported authorization grant(s). + * + * @return the {@link OAuth2AuthorizedClientProviderBuilder} + */ + public static OAuth2AuthorizedClientProviderBuilder withProvider() { + return new OAuth2AuthorizedClientProviderBuilder(); + } + + /** + * Configures an {@link OAuth2AuthorizedClientProvider} to be composed with the {@link DelegatingOAuth2AuthorizedClientProvider}. + * This may be used for implementations of extension authorization grants. + * + * @return the {@link OAuth2AuthorizedClientProviderBuilder} + */ + public OAuth2AuthorizedClientProviderBuilder provider(OAuth2AuthorizedClientProvider provider) { + Assert.notNull(provider, "provider cannot be null"); + this.builders.computeIfAbsent(provider.getClass(), k -> () -> provider); + return OAuth2AuthorizedClientProviderBuilder.this; + } + + /** + * Configures support for the {@code authorization_code} grant. + * + * @return the {@link OAuth2AuthorizedClientProviderBuilder} + */ + public OAuth2AuthorizedClientProviderBuilder authorizationCode() { + this.builders.computeIfAbsent(AuthorizationCodeGrantBuilder.class, k -> new AuthorizationCodeGrantBuilder()); + return OAuth2AuthorizedClientProviderBuilder.this; + } + + /** + * A builder for the {@code authorization_code} grant. + */ + public class AuthorizationCodeGrantBuilder implements Builder { + + private AuthorizationCodeGrantBuilder() { + } + + /** + * Builds an instance of {@link AuthorizationCodeOAuth2AuthorizedClientProvider}. + * + * @return the {@link AuthorizationCodeOAuth2AuthorizedClientProvider} + */ + @Override + public OAuth2AuthorizedClientProvider build() { + return new AuthorizationCodeOAuth2AuthorizedClientProvider(); + } + } + + /** + * Configures support for the {@code refresh_token} grant. + * + * @return the {@link OAuth2AuthorizedClientProviderBuilder} + */ + public OAuth2AuthorizedClientProviderBuilder refreshToken() { + this.builders.computeIfAbsent(RefreshTokenGrantBuilder.class, k -> new RefreshTokenGrantBuilder()); + return OAuth2AuthorizedClientProviderBuilder.this; + } + + /** + * Configures support for the {@code refresh_token} grant. + * + * @param builderConsumer a {@code Consumer} of {@link RefreshTokenGrantBuilder} used for further configuration + * @return the {@link OAuth2AuthorizedClientProviderBuilder} + */ + public OAuth2AuthorizedClientProviderBuilder refreshToken(Consumer builderConsumer) { + RefreshTokenGrantBuilder builder = (RefreshTokenGrantBuilder) this.builders.computeIfAbsent( + RefreshTokenGrantBuilder.class, k -> new RefreshTokenGrantBuilder()); + builderConsumer.accept(builder); + return OAuth2AuthorizedClientProviderBuilder.this; + } + + /** + * A builder for the {@code refresh_token} grant. + */ + public class RefreshTokenGrantBuilder implements Builder { + private OAuth2AccessTokenResponseClient accessTokenResponseClient; + private Duration clockSkew; + + private RefreshTokenGrantBuilder() { + } + + /** + * Sets the client used when requesting an access token credential at the Token Endpoint. + * + * @param accessTokenResponseClient the client used when requesting an access token credential at the Token Endpoint + * @return the {@link RefreshTokenGrantBuilder} + */ + public RefreshTokenGrantBuilder accessTokenResponseClient(OAuth2AccessTokenResponseClient accessTokenResponseClient) { + this.accessTokenResponseClient = accessTokenResponseClient; + return this; + } + + /** + * Sets the maximum acceptable clock skew, which is used when checking the access token expiry. + * An access token is considered expired if it's before {@code Instant.now() - clockSkew}. + * + * @param clockSkew the maximum acceptable clock skew + * @return the {@link RefreshTokenGrantBuilder} + */ + public RefreshTokenGrantBuilder clockSkew(Duration clockSkew) { + this.clockSkew = clockSkew; + return this; + } + + /** + * Builds an instance of {@link RefreshTokenOAuth2AuthorizedClientProvider}. + * + * @return the {@link RefreshTokenOAuth2AuthorizedClientProvider} + */ + @Override + public OAuth2AuthorizedClientProvider build() { + RefreshTokenOAuth2AuthorizedClientProvider authorizedClientProvider = new RefreshTokenOAuth2AuthorizedClientProvider(); + if (this.accessTokenResponseClient != null) { + authorizedClientProvider.setAccessTokenResponseClient(this.accessTokenResponseClient); + } + if (this.clockSkew != null) { + authorizedClientProvider.setClockSkew(this.clockSkew); + } + return authorizedClientProvider; + } + } + + /** + * Configures support for the {@code client_credentials} grant. + * + * @return the {@link OAuth2AuthorizedClientProviderBuilder} + */ + public OAuth2AuthorizedClientProviderBuilder clientCredentials() { + this.builders.computeIfAbsent(ClientCredentialsGrantBuilder.class, k -> new ClientCredentialsGrantBuilder()); + return OAuth2AuthorizedClientProviderBuilder.this; + } + + /** + * Configures support for the {@code client_credentials} grant. + * + * @param builderConsumer a {@code Consumer} of {@link ClientCredentialsGrantBuilder} used for further configuration + * @return the {@link OAuth2AuthorizedClientProviderBuilder} + */ + public OAuth2AuthorizedClientProviderBuilder clientCredentials(Consumer builderConsumer) { + ClientCredentialsGrantBuilder builder = (ClientCredentialsGrantBuilder) this.builders.computeIfAbsent( + ClientCredentialsGrantBuilder.class, k -> new ClientCredentialsGrantBuilder()); + builderConsumer.accept(builder); + return OAuth2AuthorizedClientProviderBuilder.this; + } + + /** + * A builder for the {@code client_credentials} grant. + */ + public class ClientCredentialsGrantBuilder implements Builder { + private OAuth2AccessTokenResponseClient accessTokenResponseClient; + private Duration clockSkew; + + private ClientCredentialsGrantBuilder() { + } + + /** + * Sets the client used when requesting an access token credential at the Token Endpoint. + * + * @param accessTokenResponseClient the client used when requesting an access token credential at the Token Endpoint + * @return the {@link ClientCredentialsGrantBuilder} + */ + public ClientCredentialsGrantBuilder accessTokenResponseClient(OAuth2AccessTokenResponseClient accessTokenResponseClient) { + this.accessTokenResponseClient = accessTokenResponseClient; + return this; + } + + /** + * Sets the maximum acceptable clock skew, which is used when checking the access token expiry. + * An access token is considered expired if it's before {@code Instant.now() - clockSkew}. + * + * @param clockSkew the maximum acceptable clock skew + * @return the {@link ClientCredentialsGrantBuilder} + */ + public ClientCredentialsGrantBuilder clockSkew(Duration clockSkew) { + this.clockSkew = clockSkew; + return this; + } + + /** + * Builds an instance of {@link ClientCredentialsOAuth2AuthorizedClientProvider}. + * + * @return the {@link ClientCredentialsOAuth2AuthorizedClientProvider} + */ + @Override + public OAuth2AuthorizedClientProvider build() { + ClientCredentialsOAuth2AuthorizedClientProvider authorizedClientProvider = new ClientCredentialsOAuth2AuthorizedClientProvider(); + if (this.accessTokenResponseClient != null) { + authorizedClientProvider.setAccessTokenResponseClient(this.accessTokenResponseClient); + } + if (this.clockSkew != null) { + authorizedClientProvider.setClockSkew(this.clockSkew); + } + return authorizedClientProvider; + } + } + + /** + * Builds an instance of {@link DelegatingOAuth2AuthorizedClientProvider} + * composed of one or more {@link OAuth2AuthorizedClientProvider}(s). + * + * @return the {@link DelegatingOAuth2AuthorizedClientProvider} + */ + public OAuth2AuthorizedClientProvider build() { + List authorizedClientProviders = + this.builders.values().stream() + .map(Builder::build) + .collect(Collectors.toList()); + return new DelegatingOAuth2AuthorizedClientProvider(authorizedClientProviders); + } + + interface Builder { + OAuth2AuthorizedClientProvider build(); + } +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProvider.java index cc31ec46436..5e6c1aa4fd3 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProvider.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProvider.java @@ -19,16 +19,11 @@ import org.springframework.security.oauth2.client.endpoint.DefaultRefreshTokenTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.OAuth2RefreshTokenGrantRequest; -import org.springframework.security.oauth2.client.registration.ClientRegistration; -import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; -import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.core.AbstractOAuth2Token; import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.util.Assert; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; import java.time.Duration; import java.time.Instant; import java.util.Arrays; @@ -50,35 +45,19 @@ public final class RefreshTokenOAuth2AuthorizedClientProvider implements OAuth2A * The name of the {@link OAuth2AuthorizationContext#getAttribute(String) attribute} * in the {@link OAuth2AuthorizationContext context} associated to the value for the "requested scope(s)". * The value of the attribute is a {@code String[]} of scope(s) to be requested - * by the {@link OAuth2AuthorizationContext#getClientRegistrationId() client}. + * by the {@link OAuth2AuthorizationContext#getClientRegistration() client}. */ public static final String REQUEST_SCOPE_ATTRIBUTE_NAME = "org.springframework.security.oauth2.client.REQUEST_SCOPE"; - private static final String HTTP_SERVLET_REQUEST_ATTRIBUTE_NAME = HttpServletRequest.class.getName(); - private static final String HTTP_SERVLET_RESPONSE_ATTRIBUTE_NAME = HttpServletResponse.class.getName(); - - private final ClientRegistrationRepository clientRegistrationRepository; - private final OAuth2AuthorizedClientRepository authorizedClientRepository; private OAuth2AccessTokenResponseClient accessTokenResponseClient = new DefaultRefreshTokenTokenResponseClient(); private Duration clockSkew = Duration.ofSeconds(60); - /** - * Constructs a {@code RefreshTokenOAuth2AuthorizedClientProvider} using the provided parameters. - * - * @param clientRegistrationRepository the repository of client registrations - * @param authorizedClientRepository the repository of authorized clients - */ - public RefreshTokenOAuth2AuthorizedClientProvider(ClientRegistrationRepository clientRegistrationRepository, - OAuth2AuthorizedClientRepository authorizedClientRepository) { - Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null"); - Assert.notNull(authorizedClientRepository, "authorizedClientRepository cannot be null"); - this.clientRegistrationRepository = clientRegistrationRepository; - this.authorizedClientRepository = authorizedClientRepository; + public RefreshTokenOAuth2AuthorizedClientProvider() { } /** - * Attempt to re-authorize the {@link OAuth2AuthorizationContext#getClientRegistrationId() client} in the provided {@code context}. + * Attempt to re-authorize the {@link OAuth2AuthorizationContext#getClientRegistration() client} in the provided {@code context}. * Returns {@code null} if re-authorization is not supported, * e.g. the client is not authorized OR the {@link OAuth2AuthorizedClient#getRefreshToken() refresh token} * is not available for the authorized client OR the {@link OAuth2AuthorizedClient#getAccessToken() access token} is not expired. @@ -86,9 +65,8 @@ public RefreshTokenOAuth2AuthorizedClientProvider(ClientRegistrationRepository c *

* The following {@link OAuth2AuthorizationContext#getAttributes() context attributes} are supported: *

    - *
  1. {@code "javax.servlet.http.HttpServletRequest"} (required) - the {@code HttpServletRequest}
  2. - *
  3. {@code "javax.servlet.http.HttpServletResponse"} (required) - the {@code HttpServletResponse}
  4. - *
  5. {@code "org.springframework.security.oauth2.client.REQUEST_SCOPE"} (optional) - a {@code String[]} of scope(s) to be requested by the {@link OAuth2AuthorizationContext#getClientRegistrationId() client}
  6. + *
  7. {@code "org.springframework.security.oauth2.client.REQUEST_SCOPE"} (optional) - a {@code String[]} of scope(s) + * to be requested by the {@link OAuth2AuthorizationContext#getClientRegistration() client}
  8. *
* * @param context the context that holds authorization-specific state for the client @@ -99,17 +77,7 @@ public RefreshTokenOAuth2AuthorizedClientProvider(ClientRegistrationRepository c public OAuth2AuthorizedClient authorize(OAuth2AuthorizationContext context) { Assert.notNull(context, "context cannot be null"); - HttpServletRequest request = context.getAttribute(HTTP_SERVLET_REQUEST_ATTRIBUTE_NAME); - HttpServletResponse response = context.getAttribute(HTTP_SERVLET_RESPONSE_ATTRIBUTE_NAME); - Assert.notNull(request, "The context attribute cannot be null '" + HTTP_SERVLET_REQUEST_ATTRIBUTE_NAME + "'"); - Assert.notNull(response, "The context attribute cannot be null '" + HTTP_SERVLET_RESPONSE_ATTRIBUTE_NAME + "'"); - - String clientRegistrationId = context.getClientRegistrationId(); - ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId); - Assert.notNull(clientRegistration, "Could not find ClientRegistration with id '" + clientRegistrationId + "'"); - - OAuth2AuthorizedClient authorizedClient = this.authorizedClientRepository.loadAuthorizedClient( - clientRegistrationId, context.getPrincipal(), request); + OAuth2AuthorizedClient authorizedClient = context.getAuthorizedClient(); if (authorizedClient == null || authorizedClient.getRefreshToken() == null || !hasTokenExpired(authorizedClient.getAccessToken())) { @@ -129,13 +97,8 @@ public OAuth2AuthorizedClient authorize(OAuth2AuthorizationContext context) { OAuth2AccessTokenResponse tokenResponse = this.accessTokenResponseClient.getTokenResponse(refreshTokenGrantRequest); - authorizedClient = new OAuth2AuthorizedClient(clientRegistration, + return new OAuth2AuthorizedClient(context.getAuthorizedClient().getClientRegistration(), context.getPrincipal().getName(), tokenResponse.getAccessToken(), tokenResponse.getRefreshToken()); - - this.authorizedClientRepository.saveAuthorizedClient( - authorizedClient, context.getPrincipal(), request, response); - - return authorizedClient; } private boolean hasTokenExpired(AbstractOAuth2Token token) { diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizedClientManager.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizedClientManager.java new file mode 100644 index 00000000000..40402896c39 --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizedClientManager.java @@ -0,0 +1,144 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client.web; + +import org.springframework.lang.Nullable; +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.client.OAuth2AuthorizationContext; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; +import org.springframework.util.Assert; + +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import java.util.Collections; +import java.util.Map; +import java.util.function.BiFunction; + +/** + * The default implementation of an {@link OAuth2AuthorizedClientManager}. + * + * @author Joe Grandja + * @since 5.2 + * @see OAuth2AuthorizedClientManager + * @see OAuth2AuthorizedClientProvider + */ +public final class DefaultOAuth2AuthorizedClientManager implements OAuth2AuthorizedClientManager { + private final ClientRegistrationRepository clientRegistrationRepository; + private final OAuth2AuthorizedClientRepository authorizedClientRepository; + private OAuth2AuthorizedClientProvider authorizedClientProvider = context -> null; + private BiFunction> contextAttributesMapper = + (clientRegistration, request) -> Collections.emptyMap(); + + /** + * Constructs a {@code DefaultOAuth2AuthorizedClientManager} using the provided parameters. + * + * @param clientRegistrationRepository the repository of client registrations + * @param authorizedClientRepository the repository of authorized clients + */ + public DefaultOAuth2AuthorizedClientManager(ClientRegistrationRepository clientRegistrationRepository, + OAuth2AuthorizedClientRepository authorizedClientRepository) { + Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null"); + Assert.notNull(authorizedClientRepository, "authorizedClientRepository cannot be null"); + this.clientRegistrationRepository = clientRegistrationRepository; + this.authorizedClientRepository = authorizedClientRepository; + } + + @Nullable + @Override + public OAuth2AuthorizedClient authorize(String clientRegistrationId, Authentication principal, + HttpServletRequest request, HttpServletResponse response) { + + Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty"); + Assert.notNull(principal, "principal cannot be null"); + Assert.notNull(request, "request cannot be null"); + Assert.notNull(response, "response cannot be null"); + + ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId); + Assert.notNull(clientRegistration, "Could not find ClientRegistration with id '" + clientRegistrationId + "'"); + + OAuth2AuthorizedClient authorizedClient = this.authorizedClientRepository.loadAuthorizedClient( + clientRegistrationId, principal, request); + if (authorizedClient != null) { + return reauthorizeIfNecessary(authorizedClient, principal, request, response); + } + + OAuth2AuthorizationContext authorizationContext = + OAuth2AuthorizationContext.forClient(clientRegistration) + .principal(principal) + .attributes(this.contextAttributesMapper.apply(clientRegistration, request)) + .build(); + authorizedClient = this.authorizedClientProvider.authorize(authorizationContext); + if (authorizedClient != null) { + this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, principal, request, response); + } + + return authorizedClient; + } + + @Override + public OAuth2AuthorizedClient reauthorize(OAuth2AuthorizedClient authorizedClient, Authentication principal, + HttpServletRequest request, HttpServletResponse response) { + + Assert.notNull(authorizedClient, "authorizedClient cannot be null"); + Assert.notNull(principal, "principal cannot be null"); + Assert.notNull(request, "request cannot be null"); + Assert.notNull(response, "response cannot be null"); + + return reauthorizeIfNecessary(authorizedClient, principal, request, response); + } + + private OAuth2AuthorizedClient reauthorizeIfNecessary(OAuth2AuthorizedClient authorizedClient, Authentication principal, + HttpServletRequest request, HttpServletResponse response) { + + OAuth2AuthorizationContext authorizationContext = + OAuth2AuthorizationContext.forClient(authorizedClient) + .principal(principal) + .attributes(this.contextAttributesMapper.apply(authorizedClient.getClientRegistration(), request)) + .build(); + OAuth2AuthorizedClient reauthorizedClient = this.authorizedClientProvider.authorize(authorizationContext); + if (reauthorizedClient != null) { + this.authorizedClientRepository.saveAuthorizedClient(reauthorizedClient, principal, request, response); + return reauthorizedClient; + } + + return authorizedClient; + } + + /** + * Sets the {@link OAuth2AuthorizedClientProvider} used for authorizing (or re-authorizing) an OAuth 2.0 Client. + * + * @param authorizedClientProvider the {@link OAuth2AuthorizedClientProvider} used for authorizing (or re-authorizing) an OAuth 2.0 Client + */ + public void setAuthorizedClientProvider(OAuth2AuthorizedClientProvider authorizedClientProvider) { + Assert.notNull(authorizedClientProvider, "authorizedClientProvider cannot be null"); + this.authorizedClientProvider = authorizedClientProvider; + } + + /** + * Sets the {@code BiFunction} used for mapping attribute(s) from the {@link ClientRegistration} and/or {@code HttpServletRequest} + * to a {@code Map} of attributes to be associated to the {@link OAuth2AuthorizationContext#getAttributes() authorization context}. + * + * @param contextAttributesMapper the {@code BiFunction} used for supplying the {@code Map} of attributes + * to the {@link OAuth2AuthorizationContext#getAttributes() authorization context} + */ + public void setContextAttributesMapper(BiFunction> contextAttributesMapper) { + Assert.notNull(contextAttributesMapper, "contextAttributesMapper cannot be null"); + this.contextAttributesMapper = contextAttributesMapper; + } +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizedClientManager.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizedClientManager.java new file mode 100644 index 00000000000..9079c021b08 --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizedClientManager.java @@ -0,0 +1,81 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client.web; + +import org.springframework.lang.Nullable; +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.registration.ClientRegistration; + +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +/** + * Implementations of this interface are responsible for the overall management + * of {@link OAuth2AuthorizedClient Authorized Client(s)}. + * + *

+ * The primary responsibilities include: + *

    + *
  1. Authorizing (or re-authorizing) an OAuth 2.0 Client + * by leveraging an {@link OAuth2AuthorizedClientProvider}(s).
  2. + *
  3. Managing the persistence of an {@link OAuth2AuthorizedClient} between requests, + * typically using an {@link OAuth2AuthorizedClientRepository}.
  4. + *
+ * + * @author Joe Grandja + * @since 5.2 + * @see OAuth2AuthorizedClient + * @see OAuth2AuthorizedClientProvider + * @see OAuth2AuthorizedClientRepository + */ +public interface OAuth2AuthorizedClientManager { + + /** + * Attempt to authorize or re-authorize (if required) the {@link ClientRegistration client} + * identified by the provided {@code clientRegistrationId}. + * Implementations must return {@code null} if authorization is not supported for the specified client, + * e.g. the associated {@link OAuth2AuthorizedClientProvider}(s) does not support + * the {@link ClientRegistration#getAuthorizationGrantType() authorization grant} type configured for the client. + * + * @param clientRegistrationId the identifier for the client's registration + * @param principal the {@code Principal} {@link Authentication} (to be) associated to the authorized client + * @param request the {@code HttpServletRequest} + * @param response the {@code HttpServletResponse} + * @return the {@link OAuth2AuthorizedClient} or {@code null} if authorization is not supported for the specified client + */ + @Nullable + OAuth2AuthorizedClient authorize(String clientRegistrationId, Authentication principal, + HttpServletRequest request, HttpServletResponse response); + + /** + * Attempt to re-authorize (if required) the provided {@link OAuth2AuthorizedClient authorized client}. + * Implementations must return the provided {@code authorizedClient} if re-authorization is not supported + * for the {@link OAuth2AuthorizedClient#getClientRegistration() client} OR is not required, + * e.g. a {@link OAuth2AuthorizedClient#getRefreshToken() refresh token} is not available OR + * the {@link OAuth2AuthorizedClient#getAccessToken() access token} is not expired. + * + * @param authorizedClient the authorized client + * @param principal the {@code Principal} {@link Authentication} associated to the authorized client + * @param request the {@code HttpServletRequest} + * @param response the {@code HttpServletResponse} + * @return the re-authorized {@link OAuth2AuthorizedClient} or the provided {@code authorizedClient} if not re-authorized + */ + OAuth2AuthorizedClient reauthorize(OAuth2AuthorizedClient authorizedClient, Authentication principal, + HttpServletRequest request, HttpServletResponse response); + +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolver.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolver.java index 946b9bd4689..a33c7cff6ab 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolver.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolver.java @@ -19,22 +19,23 @@ import org.springframework.core.annotation.AnnotatedElementUtils; import org.springframework.lang.NonNull; import org.springframework.lang.Nullable; +import org.springframework.security.authentication.AnonymousAuthenticationToken; import org.springframework.security.core.Authentication; +import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.context.SecurityContextHolder; -import org.springframework.security.oauth2.client.AuthorizationCodeOAuth2AuthorizedClientProvider; import org.springframework.security.oauth2.client.ClientCredentialsOAuth2AuthorizedClientProvider; -import org.springframework.security.oauth2.client.DefaultOAuth2AuthorizedClientProvider; -import org.springframework.security.oauth2.client.DelegatingOAuth2AuthorizedClientProvider; -import org.springframework.security.oauth2.client.OAuth2AuthorizationContext; import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProvider; -import org.springframework.security.oauth2.client.RefreshTokenOAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProviderBuilder; import org.springframework.security.oauth2.client.annotation.RegisteredOAuth2AuthorizedClient; import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; import org.springframework.security.oauth2.client.endpoint.DefaultClientCredentialsTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest; +import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; +import org.springframework.security.oauth2.client.web.DefaultOAuth2AuthorizedClientManager; +import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientManager; import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; import org.springframework.util.Assert; import org.springframework.util.StringUtils; @@ -67,9 +68,11 @@ * @see RegisteredOAuth2AuthorizedClient */ public final class OAuth2AuthorizedClientArgumentResolver implements HandlerMethodArgumentResolver { + private static final Authentication ANONYMOUS_AUTHENTICATION = new AnonymousAuthenticationToken( + "anonymous", "anonymousUser", AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS")); private final ClientRegistrationRepository clientRegistrationRepository; private final OAuth2AuthorizedClientRepository authorizedClientRepository; - private OAuth2AuthorizedClientProvider authorizedClientProvider; + private OAuth2AuthorizedClientManager authorizedClientManager; /** * Constructs an {@code OAuth2AuthorizedClientArgumentResolver} using the provided parameters. @@ -83,7 +86,7 @@ public OAuth2AuthorizedClientArgumentResolver(ClientRegistrationRepository clien Assert.notNull(authorizedClientRepository, "authorizedClientRepository cannot be null"); this.clientRegistrationRepository = clientRegistrationRepository; this.authorizedClientRepository = authorizedClientRepository; - this.authorizedClientProvider = createAuthorizedClientProvider(new DefaultClientCredentialsTokenResponseClient()); + this.authorizedClientManager = createAuthorizedClientManager(new DefaultClientCredentialsTokenResponseClient()); } @Override @@ -108,21 +111,18 @@ public Object resolveArgument(MethodParameter parameter, "@RegisteredOAuth2AuthorizedClient(registrationId = \"client1\")."); } + ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId); + Assert.notNull(clientRegistration, "Could not find ClientRegistration with id '" + clientRegistrationId + "'"); + Authentication principal = SecurityContextHolder.getContext().getAuthentication(); + if (principal == null) { + principal = ANONYMOUS_AUTHENTICATION; + } HttpServletRequest servletRequest = webRequest.getNativeRequest(HttpServletRequest.class); HttpServletResponse servletResponse = webRequest.getNativeResponse(HttpServletResponse.class); - OAuth2AuthorizationContext.Builder contextBuilder = OAuth2AuthorizationContext.forClient(clientRegistrationId); - if (principal != null) { - contextBuilder.principal(principal); - } else { - contextBuilder.principal("anonymousUser"); - } - OAuth2AuthorizationContext authorizationContext = contextBuilder - .attribute(HttpServletRequest.class.getName(), servletRequest) - .attribute(HttpServletResponse.class.getName(), servletResponse) - .build(); - return this.authorizedClientProvider.authorize(authorizationContext); + return this.authorizedClientManager.authorize(clientRegistration.getRegistrationId(), + principal, servletRequest, servletResponse); } private String resolveClientRegistrationId(MethodParameter parameter) { @@ -144,20 +144,23 @@ private String resolveClientRegistrationId(MethodParameter parameter) { } /** - * Sets the {@link OAuth2AuthorizedClientProvider} used for authorizing (or re-authorizing) an OAuth 2.0 Client. + * Sets the {@link OAuth2AuthorizedClientManager} which manages the {@link OAuth2AuthorizedClient Authorized Client(s)}. * * @since 5.2 - * @param authorizedClientProvider the {@link OAuth2AuthorizedClientProvider} used for authorizing (or re-authorizing) an OAuth 2.0 Client. + * @param authorizedClientManager the {@link OAuth2AuthorizedClientManager} which manages the authorized client(s) */ - public void setAuthorizedClientProvider(OAuth2AuthorizedClientProvider authorizedClientProvider) { - Assert.notNull(authorizedClientProvider, "authorizedClientProvider cannot be null"); - this.authorizedClientProvider = authorizedClientProvider; + public void setAuthorizedClientManager(OAuth2AuthorizedClientManager authorizedClientManager) { + Assert.notNull(authorizedClientManager, "authorizedClientManager cannot be null"); + this.authorizedClientManager = authorizedClientManager; } /** * Sets the client used when requesting an access token credential at the Token Endpoint for the {@code client_credentials} grant. * - * @deprecated Use {@link #setAuthorizedClientProvider(OAuth2AuthorizedClientProvider)} instead by providing it an instance of {@link ClientCredentialsOAuth2AuthorizedClientProvider} configured with a {@link ClientCredentialsOAuth2AuthorizedClientProvider#setAccessTokenResponseClient(OAuth2AccessTokenResponseClient) DefaultClientCredentialsTokenResponseClient} or a custom one. + * @deprecated Use {@link #setAuthorizedClientManager(OAuth2AuthorizedClientManager)} instead. + * Create an instance of {@link ClientCredentialsOAuth2AuthorizedClientProvider} configured with a + * {@link ClientCredentialsOAuth2AuthorizedClientProvider#setAccessTokenResponseClient(OAuth2AccessTokenResponseClient) DefaultClientCredentialsTokenResponseClient} + * (or a custom one) and than supply it to {@link DefaultOAuth2AuthorizedClientManager#setAuthorizedClientProvider(OAuth2AuthorizedClientProvider) DefaultOAuth2AuthorizedClientManager}. * * @param clientCredentialsTokenResponseClient the client used when requesting an access token credential at the Token Endpoint for the {@code client_credentials} grant */ @@ -165,24 +168,23 @@ public void setAuthorizedClientProvider(OAuth2AuthorizedClientProvider authorize public final void setClientCredentialsTokenResponseClient( OAuth2AccessTokenResponseClient clientCredentialsTokenResponseClient) { Assert.notNull(clientCredentialsTokenResponseClient, "clientCredentialsTokenResponseClient cannot be null"); - this.authorizedClientProvider = createAuthorizedClientProvider(clientCredentialsTokenResponseClient); + this.authorizedClientManager = createAuthorizedClientManager(clientCredentialsTokenResponseClient); } - private OAuth2AuthorizedClientProvider createAuthorizedClientProvider( + private OAuth2AuthorizedClientManager createAuthorizedClientManager( OAuth2AccessTokenResponseClient clientCredentialsTokenResponseClient) { - ClientCredentialsOAuth2AuthorizedClientProvider clientCredentialsAuthorizedClientProvider = - new ClientCredentialsOAuth2AuthorizedClientProvider(this.clientRegistrationRepository, this.authorizedClientRepository); - clientCredentialsAuthorizedClientProvider.setAccessTokenResponseClient(clientCredentialsTokenResponseClient); - AuthorizationCodeOAuth2AuthorizedClientProvider authorizationCodeAuthorizedClientProvider = - new AuthorizationCodeOAuth2AuthorizedClientProvider(this.clientRegistrationRepository, this.authorizedClientRepository); - RefreshTokenOAuth2AuthorizedClientProvider refreshTokenAuthorizedClientProvider = - new RefreshTokenOAuth2AuthorizedClientProvider(this.clientRegistrationRepository, this.authorizedClientRepository); - - DelegatingOAuth2AuthorizedClientProvider delegate = new DelegatingOAuth2AuthorizedClientProvider( - authorizationCodeAuthorizedClientProvider, refreshTokenAuthorizedClientProvider, clientCredentialsAuthorizedClientProvider); - delegate.setDefaultAuthorizedClientProvider( - new DefaultOAuth2AuthorizedClientProvider(this.clientRegistrationRepository, this.authorizedClientRepository)); - - return delegate; + + OAuth2AuthorizedClientProvider authorizedClientProvider = + OAuth2AuthorizedClientProviderBuilder.withProvider() + .authorizationCode() + .refreshToken() + .clientCredentials(configurer -> configurer.accessTokenResponseClient(clientCredentialsTokenResponseClient)) + .build(); + + DefaultOAuth2AuthorizedClientManager authorizedClientManager = new DefaultOAuth2AuthorizedClientManager( + this.clientRegistrationRepository, this.authorizedClientRepository); + authorizedClientManager.setAuthorizedClientProvider(authorizedClientProvider); + + return authorizedClientManager; } } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java index a402d3768f0..a8d6af84670 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java @@ -20,21 +20,20 @@ import org.springframework.beans.factory.DisposableBean; import org.springframework.beans.factory.InitializingBean; import org.springframework.security.core.Authentication; +import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.context.SecurityContextHolder; -import org.springframework.security.oauth2.client.AuthorizationCodeOAuth2AuthorizedClientProvider; import org.springframework.security.oauth2.client.ClientCredentialsOAuth2AuthorizedClientProvider; -import org.springframework.security.oauth2.client.DefaultOAuth2AuthorizedClientProvider; -import org.springframework.security.oauth2.client.DelegatingOAuth2AuthorizedClientProvider; -import org.springframework.security.oauth2.client.OAuth2AuthorizationContext; import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProvider; -import org.springframework.security.oauth2.client.RefreshTokenOAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProviderBuilder; import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; import org.springframework.security.oauth2.client.endpoint.DefaultClientCredentialsTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; +import org.springframework.security.oauth2.client.web.DefaultOAuth2AuthorizedClientManager; +import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientManager; import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; import org.springframework.util.Assert; import org.springframework.web.context.request.RequestContextHolder; @@ -53,6 +52,7 @@ import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import java.time.Duration; +import java.util.Collection; import java.util.Map; import java.util.function.Consumer; @@ -81,7 +81,7 @@ * are true: * *
    - *
  • The {@link #setAuthorizedClientProvider(OAuth2AuthorizedClientProvider) OAuth2AuthorizedClientProvider} on the + *
  • The {@link #setAuthorizedClientManager(OAuth2AuthorizedClientManager) OAuth2AuthorizedClientManager} on the * {@link ServletOAuth2AuthorizedClientExchangeFilterFunction} is not null
  • *
  • A refresh token is present on the {@link OAuth2AuthorizedClient}
  • *
  • The access token will be expired in @@ -115,7 +115,7 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction private OAuth2AuthorizedClientRepository authorizedClientRepository; - private OAuth2AuthorizedClientProvider authorizedClientProvider; + private OAuth2AuthorizedClientManager authorizedClientManager; private boolean defaultOAuth2AuthorizedClient; @@ -129,24 +129,24 @@ public ServletOAuth2AuthorizedClientExchangeFilterFunction( OAuth2AuthorizedClientRepository authorizedClientRepository) { this.clientRegistrationRepository = clientRegistrationRepository; this.authorizedClientRepository = authorizedClientRepository; - this.authorizedClientProvider = createDefaultAuthorizedClientProvider(clientRegistrationRepository, authorizedClientRepository); + this.authorizedClientManager = createDefaultAuthorizedClientManager(clientRegistrationRepository, authorizedClientRepository); } - private static OAuth2AuthorizedClientProvider createDefaultAuthorizedClientProvider( + private static OAuth2AuthorizedClientManager createDefaultAuthorizedClientManager( ClientRegistrationRepository clientRegistrationRepository, OAuth2AuthorizedClientRepository authorizedClientRepository) { - AuthorizationCodeOAuth2AuthorizedClientProvider authorizationCodeAuthorizedClientProvider = - new AuthorizationCodeOAuth2AuthorizedClientProvider(clientRegistrationRepository, authorizedClientRepository); - RefreshTokenOAuth2AuthorizedClientProvider refreshTokenAuthorizedClientProvider = - new RefreshTokenOAuth2AuthorizedClientProvider(clientRegistrationRepository, authorizedClientRepository); - ClientCredentialsOAuth2AuthorizedClientProvider clientCredentialsAuthorizedClientProvider = - new ClientCredentialsOAuth2AuthorizedClientProvider(clientRegistrationRepository, authorizedClientRepository); - DelegatingOAuth2AuthorizedClientProvider delegate = new DelegatingOAuth2AuthorizedClientProvider( - authorizationCodeAuthorizedClientProvider, refreshTokenAuthorizedClientProvider, clientCredentialsAuthorizedClientProvider); - delegate.setDefaultAuthorizedClientProvider( - new DefaultOAuth2AuthorizedClientProvider(clientRegistrationRepository, authorizedClientRepository)); + OAuth2AuthorizedClientProvider authorizedClientProvider = + OAuth2AuthorizedClientProviderBuilder.withProvider() + .authorizationCode() + .refreshToken() + .clientCredentials() + .build(); - return delegate; + DefaultOAuth2AuthorizedClientManager authorizedClientManager = new DefaultOAuth2AuthorizedClientManager( + clientRegistrationRepository, authorizedClientRepository); + authorizedClientManager.setAuthorizedClientProvider(authorizedClientProvider); + + return authorizedClientManager; } @Override @@ -160,20 +160,23 @@ public void destroy() throws Exception { } /** - * Sets the {@link OAuth2AuthorizedClientProvider} used for authorizing (or re-authorizing) an OAuth 2.0 Client. + * Sets the {@link OAuth2AuthorizedClientManager} which manages the {@link OAuth2AuthorizedClient Authorized Client(s)}. * * @since 5.2 - * @param authorizedClientProvider the {@link OAuth2AuthorizedClientProvider} used for authorizing (or re-authorizing) an OAuth 2.0 Client. + * @param authorizedClientManager the {@link OAuth2AuthorizedClientManager} which manages the authorized client(s) */ - public void setAuthorizedClientProvider(OAuth2AuthorizedClientProvider authorizedClientProvider) { - Assert.notNull(authorizedClientProvider, "authorizedClientProvider cannot be null"); - this.authorizedClientProvider = authorizedClientProvider; + public void setAuthorizedClientManager(OAuth2AuthorizedClientManager authorizedClientManager) { + Assert.notNull(authorizedClientManager, "authorizedClientManager cannot be null"); + this.authorizedClientManager = authorizedClientManager; } /** * Sets the {@link OAuth2AccessTokenResponseClient} used for getting an {@link OAuth2AuthorizedClient} for the client_credentials grant. * - * @deprecated Use {@link #setAuthorizedClientProvider(OAuth2AuthorizedClientProvider)} instead by providing it an instance of {@link ClientCredentialsOAuth2AuthorizedClientProvider} configured with a {@link ClientCredentialsOAuth2AuthorizedClientProvider#setAccessTokenResponseClient(OAuth2AccessTokenResponseClient) DefaultClientCredentialsTokenResponseClient} or a custom one. + * @deprecated Use {@link #setAuthorizedClientManager(OAuth2AuthorizedClientManager)} instead. + * Create an instance of {@link ClientCredentialsOAuth2AuthorizedClientProvider} configured with a + * {@link ClientCredentialsOAuth2AuthorizedClientProvider#setAccessTokenResponseClient(OAuth2AccessTokenResponseClient) DefaultClientCredentialsTokenResponseClient} + * (or a custom one) and than supply it to {@link DefaultOAuth2AuthorizedClientManager#setAuthorizedClientProvider(OAuth2AuthorizedClientProvider) DefaultOAuth2AuthorizedClientManager}. * * @param clientCredentialsTokenResponseClient the client to use */ @@ -181,30 +184,30 @@ public void setAuthorizedClientProvider(OAuth2AuthorizedClientProvider authorize public void setClientCredentialsTokenResponseClient( OAuth2AccessTokenResponseClient clientCredentialsTokenResponseClient) { Assert.notNull(clientCredentialsTokenResponseClient, "clientCredentialsTokenResponseClient cannot be null"); - this.authorizedClientProvider = createAuthorizedClientProvider(clientCredentialsTokenResponseClient); + this.authorizedClientManager = createAuthorizedClientManager(clientCredentialsTokenResponseClient); } - private OAuth2AuthorizedClientProvider createAuthorizedClientProvider() { - return createAuthorizedClientProvider(new DefaultClientCredentialsTokenResponseClient()); + private OAuth2AuthorizedClientManager createAuthorizedClientManager() { + return createAuthorizedClientManager(new DefaultClientCredentialsTokenResponseClient()); } - private OAuth2AuthorizedClientProvider createAuthorizedClientProvider( + private OAuth2AuthorizedClientManager createAuthorizedClientManager( OAuth2AccessTokenResponseClient clientCredentialsTokenResponseClient) { - ClientCredentialsOAuth2AuthorizedClientProvider clientCredentialsAuthorizedClientProvider = - new ClientCredentialsOAuth2AuthorizedClientProvider(this.clientRegistrationRepository, this.authorizedClientRepository); - clientCredentialsAuthorizedClientProvider.setAccessTokenResponseClient(clientCredentialsTokenResponseClient); - clientCredentialsAuthorizedClientProvider.setClockSkew(this.accessTokenExpiresSkew); - AuthorizationCodeOAuth2AuthorizedClientProvider authorizationCodeAuthorizedClientProvider = - new AuthorizationCodeOAuth2AuthorizedClientProvider(this.clientRegistrationRepository, this.authorizedClientRepository); - RefreshTokenOAuth2AuthorizedClientProvider refreshTokenAuthorizedClientProvider = - new RefreshTokenOAuth2AuthorizedClientProvider(this.clientRegistrationRepository, this.authorizedClientRepository); - refreshTokenAuthorizedClientProvider.setClockSkew(this.accessTokenExpiresSkew); - - DelegatingOAuth2AuthorizedClientProvider delegate = new DelegatingOAuth2AuthorizedClientProvider( - authorizationCodeAuthorizedClientProvider, refreshTokenAuthorizedClientProvider, clientCredentialsAuthorizedClientProvider); - delegate.setDefaultAuthorizedClientProvider( - new DefaultOAuth2AuthorizedClientProvider(this.clientRegistrationRepository, this.authorizedClientRepository)); - return delegate; + + OAuth2AuthorizedClientProvider authorizedClientProvider = + OAuth2AuthorizedClientProviderBuilder.withProvider() + .authorizationCode() + .refreshToken(configurer -> configurer.clockSkew(this.accessTokenExpiresSkew)) + .clientCredentials(configurer -> configurer + .accessTokenResponseClient(clientCredentialsTokenResponseClient) + .clockSkew(this.accessTokenExpiresSkew)) + .build(); + + DefaultOAuth2AuthorizedClientManager authorizedClientManager = new DefaultOAuth2AuthorizedClientManager( + this.clientRegistrationRepository, this.authorizedClientRepository); + authorizedClientManager.setAuthorizedClientProvider(authorizedClientProvider); + + return authorizedClientManager; } /** @@ -325,7 +328,7 @@ public static Consumer> httpServletResponse(HttpServletRespo public void setAccessTokenExpiresSkew(Duration accessTokenExpiresSkew) { Assert.notNull(accessTokenExpiresSkew, "accessTokenExpiresSkew cannot be null"); this.accessTokenExpiresSkew = accessTokenExpiresSkew; - this.authorizedClientProvider = createAuthorizedClientProvider(); + this.authorizedClientManager = createAuthorizedClientManager(); } @Override @@ -405,35 +408,23 @@ private void populateDefaultOAuth2AuthorizedClient(Map attrs) { OAuth2AuthorizedClient authorizedClient = this.authorizedClientRepository.loadAuthorizedClient( clientRegistrationId, authentication, request); if (authorizedClient == null) { - authorizedClient = this.authorizedClientProvider.authorize( - createAuthorizationContext(clientRegistrationId, attrs)); + authorizedClient = this.authorizedClientManager.authorize( + clientRegistrationId, authentication, request, getResponse(attrs)); } oauth2AuthorizedClient(authorizedClient).accept(attrs); } } - private Mono authorizedClient( - OAuth2AuthorizedClient authorizedClient, ClientRequest request) { - if (this.authorizedClientProvider == null) { + private Mono authorizedClient(OAuth2AuthorizedClient authorizedClient, ClientRequest request) { + if (this.authorizedClientManager == null) { return Mono.just(authorizedClient); } - OAuth2AuthorizationContext authorizationContext = createAuthorizationContext( - authorizedClient.getClientRegistration().getRegistrationId(), request.attributes()); - return Mono.fromSupplier(() -> this.authorizedClientProvider.authorize(authorizationContext)); - } - - private OAuth2AuthorizationContext createAuthorizationContext(String clientRegistrationId, Map attributes) { - OAuth2AuthorizationContext.Builder contextBuilder = OAuth2AuthorizationContext.forClient(clientRegistrationId); - Authentication authentication = getAuthentication(attributes); - if (authentication != null) { - contextBuilder.principal(authentication); - } else { - contextBuilder.principal("anonymousUser"); - } - return contextBuilder - .attribute(HttpServletRequest.class.getName(), getRequest(attributes)) - .attribute(HttpServletResponse.class.getName(), getResponse(attributes)) - .build(); + Map attrs = request.attributes(); + Authentication authentication = getAuthentication(attrs) != null ? + getAuthentication(attrs) : new PrincipalNameAuthentication(authorizedClient.getPrincipalName()); + HttpServletRequest servletRequest = getRequest(attrs); + HttpServletResponse servletResponse = getResponse(attrs); + return Mono.fromSupplier(() -> this.authorizedClientManager.reauthorize(authorizedClient, authentication, servletRequest, servletResponse)); } private ClientRequest bearer(ClientRequest request, OAuth2AuthorizedClient authorizedClient) { @@ -475,6 +466,54 @@ static HttpServletResponse getResponse(Map attrs) { return (HttpServletResponse) attrs.get(HTTP_SERVLET_RESPONSE_ATTR_NAME); } + private static class PrincipalNameAuthentication implements Authentication { + private final String principalName; + + private PrincipalNameAuthentication(String principalName) { + Assert.hasText(principalName, "principalName cannot be empty"); + this.principalName = principalName; + } + + @Override + public Collection getAuthorities() { + throw unsupported(); + } + + @Override + public Object getCredentials() { + throw unsupported(); + } + + @Override + public Object getDetails() { + throw unsupported(); + } + + @Override + public Object getPrincipal() { + return getName(); + } + + @Override + public boolean isAuthenticated() { + throw unsupported(); + } + + @Override + public void setAuthenticated(boolean isAuthenticated) throws IllegalArgumentException { + throw unsupported(); + } + + @Override + public String getName() { + return this.principalName; + } + + private UnsupportedOperationException unsupported() { + return new UnsupportedOperationException("Not Supported"); + } + } + private static class RequestContextSubscriber implements CoreSubscriber { private static final String CONTEXT_DEFAULTED_ATTR_NAME = RequestContextSubscriber.class.getName().concat(".CONTEXT_DEFAULTED_ATTR_NAME"); private final CoreSubscriber delegate; diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/AuthorizationCodeOAuth2AuthorizedClientProviderTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/AuthorizationCodeOAuth2AuthorizedClientProviderTests.java index 6998636dfdf..ad021fe8f42 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/AuthorizationCodeOAuth2AuthorizedClientProviderTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/AuthorizationCodeOAuth2AuthorizedClientProviderTests.java @@ -17,25 +17,14 @@ import org.junit.Before; import org.junit.Test; -import org.springframework.mock.web.MockHttpServletRequest; -import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.client.registration.ClientRegistration; -import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; -import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; - import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; /** * Tests for {@link AuthorizationCodeOAuth2AuthorizedClientProvider}. @@ -43,8 +32,6 @@ * @author Joe Grandja */ public class AuthorizationCodeOAuth2AuthorizedClientProviderTests { - private ClientRegistrationRepository clientRegistrationRepository; - private OAuth2AuthorizedClientRepository authorizedClientRepository; private AuthorizationCodeOAuth2AuthorizedClientProvider authorizedClientProvider; private ClientRegistration clientRegistration; private OAuth2AuthorizedClient authorizedClient; @@ -52,114 +39,44 @@ public class AuthorizationCodeOAuth2AuthorizedClientProviderTests { @Before public void setup() { - this.clientRegistrationRepository = mock(ClientRegistrationRepository.class); - this.authorizedClientRepository = mock(OAuth2AuthorizedClientRepository.class); - this.authorizedClientProvider = new AuthorizationCodeOAuth2AuthorizedClientProvider( - this.clientRegistrationRepository, this.authorizedClientRepository); + this.authorizedClientProvider = new AuthorizationCodeOAuth2AuthorizedClientProvider(); this.clientRegistration = TestClientRegistrations.clientRegistration().build(); this.authorizedClient = new OAuth2AuthorizedClient( this.clientRegistration, "principal", TestOAuth2AccessTokens.scopes("read", "write")); this.principal = new TestingAuthenticationToken("principal", "password"); } - @Test - public void constructorWhenClientRegistrationRepositoryIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new AuthorizationCodeOAuth2AuthorizedClientProvider(null, this.authorizedClientRepository)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("clientRegistrationRepository cannot be null"); - } - - @Test - public void constructorWhenOAuth2AuthorizedClientRepositoryIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new AuthorizationCodeOAuth2AuthorizedClientProvider(this.clientRegistrationRepository, null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("authorizedClientRepository cannot be null"); - } - @Test public void authorizeWhenContextIsNullThenThrowIllegalArgumentException() { assertThatThrownBy(() -> this.authorizedClientProvider.authorize(null)) .isInstanceOf(IllegalArgumentException.class); } - @Test - public void authorizeWhenHttpServletRequestIsNullThenThrowIllegalArgumentException() { - OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.forClient(this.clientRegistration.getRegistrationId()) - .principal(this.principal) - .build(); - assertThatThrownBy(() -> this.authorizedClientProvider.authorize(authorizationContext)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("The context attribute cannot be null 'javax.servlet.http.HttpServletRequest'"); - } - - @Test - public void authorizeWhenHttpServletResponseIsNullThenThrowIllegalArgumentException() { - OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.forClient(this.clientRegistration.getRegistrationId()) - .principal(this.principal) - .attribute(HttpServletRequest.class.getName(), new MockHttpServletRequest()) - .build(); - assertThatThrownBy(() -> this.authorizedClientProvider.authorize(authorizationContext)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("The context attribute cannot be null 'javax.servlet.http.HttpServletResponse'"); - } - - @Test - public void authorizeWhenClientRegistrationNotFoundThenThrowIllegalArgumentException() { - OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.forClient(this.clientRegistration.getRegistrationId()) - .principal(this.principal) - .attribute(HttpServletRequest.class.getName(), new MockHttpServletRequest()) - .attribute(HttpServletResponse.class.getName(), new MockHttpServletResponse()) - .build(); - assertThatThrownBy(() -> this.authorizedClientProvider.authorize(authorizationContext)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("Could not find ClientRegistration with id '" + this.clientRegistration.getRegistrationId() + "'"); - } - @Test public void authorizeWhenNotAuthorizationCodeThenUnableToAuthorize() { ClientRegistration clientCredentialsClient = TestClientRegistrations.clientCredentials().build(); - when(this.clientRegistrationRepository.findByRegistrationId( - eq(clientCredentialsClient.getRegistrationId()))).thenReturn(clientCredentialsClient); OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.forClient(clientCredentialsClient.getRegistrationId()) + OAuth2AuthorizationContext.forClient(clientCredentialsClient) .principal(this.principal) - .attribute(HttpServletRequest.class.getName(), new MockHttpServletRequest()) - .attribute(HttpServletResponse.class.getName(), new MockHttpServletResponse()) .build(); assertThat(this.authorizedClientProvider.authorize(authorizationContext)).isNull(); } @Test public void authorizeWhenAuthorizationCodeAndAuthorizedThenNotAuthorize() { - when(this.clientRegistrationRepository.findByRegistrationId( - eq(this.clientRegistration.getRegistrationId()))).thenReturn(this.clientRegistration); - - when(this.authorizedClientRepository.loadAuthorizedClient(eq(this.clientRegistration.getRegistrationId()), - eq(this.principal), any(HttpServletRequest.class))).thenReturn(this.authorizedClient); - OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.forClient(this.clientRegistration.getRegistrationId()) + OAuth2AuthorizationContext.forClient(this.authorizedClient) .principal(this.principal) - .attribute(HttpServletRequest.class.getName(), new MockHttpServletRequest()) - .attribute(HttpServletResponse.class.getName(), new MockHttpServletResponse()) .build(); assertThat(this.authorizedClientProvider.authorize(authorizationContext)).isNull(); } @Test public void authorizeWhenAuthorizationCodeAndNotAuthorizedThenAuthorize() { - when(this.clientRegistrationRepository.findByRegistrationId( - eq(this.clientRegistration.getRegistrationId()))).thenReturn(this.clientRegistration); - OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.forClient(this.clientRegistration.getRegistrationId()) + OAuth2AuthorizationContext.forClient(this.clientRegistration) .principal(this.principal) - .attribute(HttpServletRequest.class.getName(), new MockHttpServletRequest()) - .attribute(HttpServletResponse.class.getName(), new MockHttpServletResponse()) .build(); assertThatThrownBy(() -> this.authorizedClientProvider.authorize(authorizationContext)) .isInstanceOf(ClientAuthorizationRequiredException.class); diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/ClientCredentialsOAuth2AuthorizedClientProviderTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/ClientCredentialsOAuth2AuthorizedClientProviderTests.java index a2a94538102..8e589583843 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/ClientCredentialsOAuth2AuthorizedClientProviderTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/ClientCredentialsOAuth2AuthorizedClientProviderTests.java @@ -17,30 +17,25 @@ import org.junit.Before; import org.junit.Test; -import org.springframework.mock.web.MockHttpServletRequest; -import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest; import org.springframework.security.oauth2.client.registration.ClientRegistration; -import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; -import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; import java.time.Duration; import java.time.Instant; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.*; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; /** * Tests for {@link ClientCredentialsOAuth2AuthorizedClientProvider}. @@ -48,8 +43,6 @@ * @author Joe Grandja */ public class ClientCredentialsOAuth2AuthorizedClientProviderTests { - private ClientRegistrationRepository clientRegistrationRepository; - private OAuth2AuthorizedClientRepository authorizedClientRepository; private ClientCredentialsOAuth2AuthorizedClientProvider authorizedClientProvider; private OAuth2AccessTokenResponseClient accessTokenResponseClient; private ClientRegistration clientRegistration; @@ -57,30 +50,13 @@ public class ClientCredentialsOAuth2AuthorizedClientProviderTests { @Before public void setup() { - this.clientRegistrationRepository = mock(ClientRegistrationRepository.class); - this.authorizedClientRepository = mock(OAuth2AuthorizedClientRepository.class); - this.authorizedClientProvider = new ClientCredentialsOAuth2AuthorizedClientProvider( - this.clientRegistrationRepository, this.authorizedClientRepository); + this.authorizedClientProvider = new ClientCredentialsOAuth2AuthorizedClientProvider(); this.accessTokenResponseClient = mock(OAuth2AccessTokenResponseClient.class); this.authorizedClientProvider.setAccessTokenResponseClient(this.accessTokenResponseClient); this.clientRegistration = TestClientRegistrations.clientCredentials().build(); this.principal = new TestingAuthenticationToken("principal", "password"); } - @Test - public void constructorWhenClientRegistrationRepositoryIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new ClientCredentialsOAuth2AuthorizedClientProvider(null, this.authorizedClientRepository)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("clientRegistrationRepository cannot be null"); - } - - @Test - public void constructorWhenOAuth2AuthorizedClientRepositoryIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new ClientCredentialsOAuth2AuthorizedClientProvider(this.clientRegistrationRepository, null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("authorizedClientRepository cannot be null"); - } - @Test public void setAccessTokenResponseClientWhenClientIsNullThenThrowIllegalArgumentException() { assertThatThrownBy(() -> this.authorizedClientProvider.setAccessTokenResponseClient(null)) @@ -109,133 +85,65 @@ public void authorizeWhenContextIsNullThenThrowIllegalArgumentException() { .hasMessage("context cannot be null"); } - @Test - public void authorizeWhenHttpServletRequestIsNullThenThrowIllegalArgumentException() { - OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.forClient(this.clientRegistration.getRegistrationId()) - .principal(this.principal) - .build(); - assertThatThrownBy(() -> this.authorizedClientProvider.authorize(authorizationContext)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("The context attribute cannot be null 'javax.servlet.http.HttpServletRequest'"); - } - - @Test - public void authorizeWhenHttpServletResponseIsNullThenThrowIllegalArgumentException() { - OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.forClient(this.clientRegistration.getRegistrationId()) - .principal(this.principal) - .attribute(HttpServletRequest.class.getName(), new MockHttpServletRequest()) - .build(); - assertThatThrownBy(() -> this.authorizedClientProvider.authorize(authorizationContext)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("The context attribute cannot be null 'javax.servlet.http.HttpServletResponse'"); - } - - @Test - public void authorizeWhenClientRegistrationNotFoundThenThrowIllegalArgumentException() { - OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.forClient(this.clientRegistration.getRegistrationId()) - .principal(this.principal) - .attribute(HttpServletRequest.class.getName(), new MockHttpServletRequest()) - .attribute(HttpServletResponse.class.getName(), new MockHttpServletResponse()) - .build(); - assertThatThrownBy(() -> this.authorizedClientProvider.authorize(authorizationContext)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("Could not find ClientRegistration with id '" + this.clientRegistration.getRegistrationId() + "'"); - } - @Test public void authorizeWhenNotClientCredentialsThenUnableToAuthorize() { ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); - when(this.clientRegistrationRepository.findByRegistrationId( - eq(clientRegistration.getRegistrationId()))).thenReturn(clientRegistration); OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.forClient(clientRegistration.getRegistrationId()) + OAuth2AuthorizationContext.forClient(clientRegistration) .principal(this.principal) - .attribute(HttpServletRequest.class.getName(), new MockHttpServletRequest()) - .attribute(HttpServletResponse.class.getName(), new MockHttpServletResponse()) .build(); assertThat(this.authorizedClientProvider.authorize(authorizationContext)).isNull(); } @Test public void authorizeWhenClientCredentialsAndNotAuthorizedThenAuthorize() { - when(this.clientRegistrationRepository.findByRegistrationId( - eq(this.clientRegistration.getRegistrationId()))).thenReturn(this.clientRegistration); - OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse); OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.forClient(this.clientRegistration.getRegistrationId()) + OAuth2AuthorizationContext.forClient(this.clientRegistration) .principal(this.principal) - .attribute(HttpServletRequest.class.getName(), new MockHttpServletRequest()) - .attribute(HttpServletResponse.class.getName(), new MockHttpServletResponse()) .build(); - OAuth2AuthorizedClient authorizedClient = this.authorizedClientProvider.authorize(authorizationContext); assertThat(authorizedClient.getClientRegistration()).isSameAs(this.clientRegistration); assertThat(authorizedClient.getPrincipalName()).isEqualTo(this.principal.getName()); assertThat(authorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken()); - verify(this.authorizedClientRepository).saveAuthorizedClient( - eq(authorizedClient), eq(this.principal), - any(HttpServletRequest.class), any(HttpServletResponse.class)); } @Test public void authorizeWhenClientCredentialsAndTokenExpiredThenReauthorize() { - when(this.clientRegistrationRepository.findByRegistrationId( - eq(this.clientRegistration.getRegistrationId()))).thenReturn(this.clientRegistration); - Instant issuedAt = Instant.now().minus(Duration.ofDays(1)); Instant expiresAt = issuedAt.plus(Duration.ofMinutes(60)); OAuth2AccessToken accessToken = new OAuth2AccessToken( OAuth2AccessToken.TokenType.BEARER, "access-token-1234", issuedAt, expiresAt); OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( this.clientRegistration, this.principal.getName(), accessToken); - when(this.authorizedClientRepository.loadAuthorizedClient(eq(this.clientRegistration.getRegistrationId()), - eq(this.principal), any(HttpServletRequest.class))).thenReturn(authorizedClient); OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse); OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.forClient(this.clientRegistration.getRegistrationId()) + OAuth2AuthorizationContext.forClient(authorizedClient) .principal(this.principal) - .attribute(HttpServletRequest.class.getName(), new MockHttpServletRequest()) - .attribute(HttpServletResponse.class.getName(), new MockHttpServletResponse()) .build(); - authorizedClient = this.authorizedClientProvider.authorize(authorizationContext); assertThat(authorizedClient.getClientRegistration()).isSameAs(this.clientRegistration); assertThat(authorizedClient.getPrincipalName()).isEqualTo(this.principal.getName()); assertThat(authorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken()); - verify(this.authorizedClientRepository).saveAuthorizedClient( - eq(authorizedClient), eq(this.principal), - any(HttpServletRequest.class), any(HttpServletResponse.class)); } @Test public void authorizeWhenClientCredentialsAndTokenNotExpiredThenNotReauthorize() { - when(this.clientRegistrationRepository.findByRegistrationId( - eq(this.clientRegistration.getRegistrationId()))).thenReturn(this.clientRegistration); - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( this.clientRegistration, this.principal.getName(), TestOAuth2AccessTokens.noScopes()); - when(this.authorizedClientRepository.loadAuthorizedClient(eq(this.clientRegistration.getRegistrationId()), - eq(this.principal), any(HttpServletRequest.class))).thenReturn(authorizedClient); OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.forClient(this.clientRegistration.getRegistrationId()) + OAuth2AuthorizationContext.forClient(authorizedClient) .principal(this.principal) - .attribute(HttpServletRequest.class.getName(), new MockHttpServletRequest()) - .attribute(HttpServletResponse.class.getName(), new MockHttpServletResponse()) .build(); - assertThat(this.authorizedClientProvider.authorize(authorizationContext)).isNull(); } } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/DefaultOAuth2AuthorizedClientProviderTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/DefaultOAuth2AuthorizedClientProviderTests.java deleted file mode 100644 index ff17df67d00..00000000000 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/DefaultOAuth2AuthorizedClientProviderTests.java +++ /dev/null @@ -1,137 +0,0 @@ -/* - * Copyright 2002-2019 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.springframework.security.oauth2.client; - -import org.junit.Before; -import org.junit.Test; -import org.springframework.mock.web.MockHttpServletRequest; -import org.springframework.mock.web.MockHttpServletResponse; -import org.springframework.security.authentication.TestingAuthenticationToken; -import org.springframework.security.core.Authentication; -import org.springframework.security.oauth2.client.registration.ClientRegistration; -import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; -import org.springframework.security.oauth2.client.registration.TestClientRegistrations; -import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; -import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; - -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; - -/** - * Tests for {@link DefaultOAuth2AuthorizedClientProvider}. - * - * @author Joe Grandja - */ -public class DefaultOAuth2AuthorizedClientProviderTests { - private ClientRegistrationRepository clientRegistrationRepository; - private OAuth2AuthorizedClientRepository authorizedClientRepository; - private DefaultOAuth2AuthorizedClientProvider authorizedClientProvider; - private ClientRegistration clientRegistration; - private OAuth2AuthorizedClient authorizedClient; - private Authentication principal; - - @Before - public void setup() { - this.clientRegistrationRepository = mock(ClientRegistrationRepository.class); - this.authorizedClientRepository = mock(OAuth2AuthorizedClientRepository.class); - this.authorizedClientProvider = new DefaultOAuth2AuthorizedClientProvider( - this.clientRegistrationRepository, this.authorizedClientRepository); - this.clientRegistration = TestClientRegistrations.clientRegistration().build(); - this.authorizedClient = new OAuth2AuthorizedClient( - this.clientRegistration, "principal", TestOAuth2AccessTokens.scopes("read", "write")); - this.principal = new TestingAuthenticationToken("principal", "password"); - } - - @Test - public void constructorWhenClientRegistrationRepositoryIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new DefaultOAuth2AuthorizedClientProvider(null, this.authorizedClientRepository)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("clientRegistrationRepository cannot be null"); - } - - @Test - public void constructorWhenOAuth2AuthorizedClientRepositoryIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new DefaultOAuth2AuthorizedClientProvider(this.clientRegistrationRepository, null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("authorizedClientRepository cannot be null"); - } - - @Test - public void authorizeWhenContextIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.authorizedClientProvider.authorize(null)) - .isInstanceOf(IllegalArgumentException.class); - } - - @Test - public void authorizeWhenHttpServletRequestIsNullThenThrowIllegalArgumentException() { - OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.forClient(this.clientRegistration.getRegistrationId()) - .principal(this.principal) - .build(); - assertThatThrownBy(() -> this.authorizedClientProvider.authorize(authorizationContext)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("The context attribute cannot be null 'javax.servlet.http.HttpServletRequest'"); - } - - @Test - public void authorizeWhenHttpServletResponseIsNullThenThrowIllegalArgumentException() { - OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.forClient(this.clientRegistration.getRegistrationId()) - .principal(this.principal) - .attribute(HttpServletRequest.class.getName(), new MockHttpServletRequest()) - .build(); - assertThatThrownBy(() -> this.authorizedClientProvider.authorize(authorizationContext)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("The context attribute cannot be null 'javax.servlet.http.HttpServletResponse'"); - } - - @Test - public void authorizeWhenClientRegistrationNotFoundThenThrowIllegalArgumentException() { - OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.forClient(this.clientRegistration.getRegistrationId()) - .principal(this.principal) - .attribute(HttpServletRequest.class.getName(), new MockHttpServletRequest()) - .attribute(HttpServletResponse.class.getName(), new MockHttpServletResponse()) - .build(); - assertThatThrownBy(() -> this.authorizedClientProvider.authorize(authorizationContext)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("Could not find ClientRegistration with id '" + this.clientRegistration.getRegistrationId() + "'"); - } - - @Test - public void authorizeWhenAuthorizedThenReturnAuthorizedClient() { - when(this.clientRegistrationRepository.findByRegistrationId( - eq(this.clientRegistration.getRegistrationId()))).thenReturn(this.clientRegistration); - - when(this.authorizedClientRepository.loadAuthorizedClient(eq(this.clientRegistration.getRegistrationId()), - eq(this.principal), any(HttpServletRequest.class))).thenReturn(this.authorizedClient); - - OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.forClient(this.clientRegistration.getRegistrationId()) - .principal(this.principal) - .attribute(HttpServletRequest.class.getName(), new MockHttpServletRequest()) - .attribute(HttpServletResponse.class.getName(), new MockHttpServletResponse()) - .build(); - assertThat(this.authorizedClientProvider.authorize(authorizationContext)).isSameAs(this.authorizedClient); - } -} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/DelegatingOAuth2AuthorizedClientProviderTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/DelegatingOAuth2AuthorizedClientProviderTests.java index ae1253c286a..d4e5d20d7e3 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/DelegatingOAuth2AuthorizedClientProviderTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/DelegatingOAuth2AuthorizedClientProviderTests.java @@ -27,7 +27,8 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.*; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; /** * Tests for {@link DelegatingOAuth2AuthorizedClientProvider}. @@ -44,15 +45,6 @@ public void constructorWhenProvidersIsEmptyThenThrowIllegalArgumentException() { .isInstanceOf(IllegalArgumentException.class); } - @Test - public void setDefaultAuthorizedClientProviderWhenNullThenThrowIllegalArgumentException() { - DelegatingOAuth2AuthorizedClientProvider delegate = new DelegatingOAuth2AuthorizedClientProvider( - mock(OAuth2AuthorizedClientProvider.class)); - assertThatThrownBy(() -> delegate.setDefaultAuthorizedClientProvider(null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("authorizedClientProvider cannot be null"); - } - @Test public void authorizeWhenContextIsNullThenThrowIllegalArgumentException() { DelegatingOAuth2AuthorizedClientProvider delegate = new DelegatingOAuth2AuthorizedClientProvider( @@ -74,7 +66,7 @@ public void authorizeWhenProviderCanAuthorizeThenReturnAuthorizedClient() { DelegatingOAuth2AuthorizedClientProvider delegate = new DelegatingOAuth2AuthorizedClientProvider( mock(OAuth2AuthorizedClientProvider.class), mock(OAuth2AuthorizedClientProvider.class), authorizedClientProvider); - OAuth2AuthorizationContext context = OAuth2AuthorizationContext.forClient(clientRegistration.getRegistrationId()) + OAuth2AuthorizationContext context = OAuth2AuthorizationContext.forClient(clientRegistration) .principal(principal) .build(); OAuth2AuthorizedClient reauthorizedClient = delegate.authorize(context); @@ -84,7 +76,7 @@ public void authorizeWhenProviderCanAuthorizeThenReturnAuthorizedClient() { @Test public void authorizeWhenProviderCantAuthorizeThenReturnNull() { ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); - OAuth2AuthorizationContext context = OAuth2AuthorizationContext.forClient(clientRegistration.getRegistrationId()) + OAuth2AuthorizationContext context = OAuth2AuthorizationContext.forClient(clientRegistration) .principal(new TestingAuthenticationToken("principal", "password")) .build(); @@ -92,19 +84,4 @@ public void authorizeWhenProviderCantAuthorizeThenReturnNull() { mock(OAuth2AuthorizedClientProvider.class), mock(OAuth2AuthorizedClientProvider.class)); assertThat(delegate.authorize(context)).isNull(); } - - @Test - public void authorizeWhenProviderCantAuthorizeThenDefaultCalled() { - ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); - OAuth2AuthorizationContext context = OAuth2AuthorizationContext.forClient(clientRegistration.getRegistrationId()) - .principal(new TestingAuthenticationToken("principal", "password")) - .build(); - - DelegatingOAuth2AuthorizedClientProvider delegate = new DelegatingOAuth2AuthorizedClientProvider( - mock(OAuth2AuthorizedClientProvider.class), mock(OAuth2AuthorizedClientProvider.class)); - OAuth2AuthorizedClientProvider defaultAuthorizedClientProvider = mock(OAuth2AuthorizedClientProvider.class); - delegate.setDefaultAuthorizedClientProvider(defaultAuthorizedClientProvider); - delegate.authorize(context); - verify(defaultAuthorizedClientProvider).authorize(eq(context)); - } } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/OAuth2AuthorizationContextTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/OAuth2AuthorizationContextTests.java index 42ff33a2cb3..a1351d4c476 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/OAuth2AuthorizationContextTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/OAuth2AuthorizationContextTests.java @@ -44,36 +44,35 @@ public void setup() { } @Test - public void forClientWhenClientRegistrationIdIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> OAuth2AuthorizationContext.forClient(null).build()) + public void forClientWhenClientRegistrationIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> OAuth2AuthorizationContext.forClient((ClientRegistration) null).build()) .isInstanceOf(IllegalArgumentException.class) - .hasMessage("clientRegistrationId cannot be empty"); + .hasMessage("clientRegistration cannot be null"); } @Test - public void forClientWhenPrincipalIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> OAuth2AuthorizationContext.forClient(this.clientRegistration.getRegistrationId()).build()) + public void forClientWhenAuthorizedClientIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> OAuth2AuthorizationContext.forClient((OAuth2AuthorizedClient) null).build()) .isInstanceOf(IllegalArgumentException.class) - .hasMessage("principal cannot be null"); + .hasMessage("authorizedClient cannot be null"); } @Test - public void forClientWhenPrincipalNameIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> OAuth2AuthorizationContext.forClient(this.clientRegistration.getRegistrationId()) - .principal((String) null) - .build()) + public void forClientWhenPrincipalIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> OAuth2AuthorizationContext.forClient(this.clientRegistration).build()) .isInstanceOf(IllegalArgumentException.class) - .hasMessage("principalName cannot be empty"); + .hasMessage("principal cannot be null"); } @Test public void forClientWhenAllValuesProvidedThenAllValuesAreSet() { - OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext.forClient(this.clientRegistration.getRegistrationId()) + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext.forClient(this.authorizedClient) .principal(this.principal) .attribute("attribute1", "value1") .attribute("attribute2", "value2") .build(); - assertThat(authorizationContext.getClientRegistrationId()).isSameAs(this.clientRegistration.getRegistrationId()); + assertThat(authorizationContext.getClientRegistration()).isSameAs(this.clientRegistration); + assertThat(authorizationContext.getAuthorizedClient()).isSameAs(this.authorizedClient); assertThat(authorizationContext.getPrincipal()).isSameAs(this.principal); assertThat(authorizationContext.getAttributes()).contains( entry("attribute1", "value1"), entry("attribute2", "value2")); diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientProviderBuilderTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientProviderBuilderTests.java new file mode 100644 index 00000000000..a9c400731c0 --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientProviderBuilderTests.java @@ -0,0 +1,202 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client; + +import org.junit.Before; +import org.junit.Test; +import org.springframework.http.HttpStatus; +import org.springframework.http.RequestEntity; +import org.springframework.http.ResponseEntity; +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.client.endpoint.DefaultClientCredentialsTokenResponseClient; +import org.springframework.security.oauth2.client.endpoint.DefaultRefreshTokenTokenResponseClient; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses; +import org.springframework.web.client.RestOperations; + +import java.time.Duration; +import java.time.Instant; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.*; + +/** + * Tests for {@link OAuth2AuthorizedClientProviderBuilder}. + * + * @author Joe Grandja + */ +public class OAuth2AuthorizedClientProviderBuilderTests { + private RestOperations accessTokenClient; + private DefaultClientCredentialsTokenResponseClient clientCredentialsTokenResponseClient; + private DefaultRefreshTokenTokenResponseClient refreshTokenTokenResponseClient; + private Authentication principal; + + @SuppressWarnings("unchecked") + @Before + public void setup() { + OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); + this.accessTokenClient = mock(RestOperations.class); + when(this.accessTokenClient.exchange(any(RequestEntity.class), eq(OAuth2AccessTokenResponse.class))) + .thenReturn(new ResponseEntity(accessTokenResponse, HttpStatus.OK)); + this.refreshTokenTokenResponseClient = new DefaultRefreshTokenTokenResponseClient(); + this.refreshTokenTokenResponseClient.setRestOperations(this.accessTokenClient); + this.clientCredentialsTokenResponseClient = new DefaultClientCredentialsTokenResponseClient(); + this.clientCredentialsTokenResponseClient.setRestOperations(this.accessTokenClient); + this.principal = new TestingAuthenticationToken("principal", "password"); + } + + @Test + public void providerWhenNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> OAuth2AuthorizedClientProviderBuilder.withProvider().provider(null)) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void buildWhenAuthorizationCodeProviderThenProviderAuthorizes() { + OAuth2AuthorizedClientProvider authorizedClientProvider = + OAuth2AuthorizedClientProviderBuilder.withProvider() + .authorizationCode() + .build(); + + OAuth2AuthorizationContext authorizationContext = + OAuth2AuthorizationContext.forClient(TestClientRegistrations.clientRegistration().build()) + .principal(this.principal) + .build(); + assertThatThrownBy(() -> authorizedClientProvider.authorize(authorizationContext)) + .isInstanceOf(ClientAuthorizationRequiredException.class); + } + + @Test + public void buildWhenRefreshTokenProviderThenProviderReauthorizes() { + OAuth2AuthorizedClientProvider authorizedClientProvider = + OAuth2AuthorizedClientProviderBuilder.withProvider() + .refreshToken(configurer -> configurer.accessTokenResponseClient(this.refreshTokenTokenResponseClient)) + .build(); + + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( + TestClientRegistrations.clientRegistration().build(), + this.principal.getName(), + expiredAccessToken(), + TestOAuth2RefreshTokens.refreshToken()); + + OAuth2AuthorizationContext authorizationContext = + OAuth2AuthorizationContext.forClient(authorizedClient) + .principal(this.principal) + .build(); + OAuth2AuthorizedClient reauthorizedClient = authorizedClientProvider.authorize(authorizationContext); + + assertThat(reauthorizedClient).isNotNull(); + verify(this.accessTokenClient).exchange(any(RequestEntity.class), eq(OAuth2AccessTokenResponse.class)); + } + + @Test + public void buildWhenClientCredentialsProviderThenProviderAuthorizes() { + OAuth2AuthorizedClientProvider authorizedClientProvider = + OAuth2AuthorizedClientProviderBuilder.withProvider() + .clientCredentials(configurer -> configurer.accessTokenResponseClient(this.clientCredentialsTokenResponseClient)) + .build(); + + OAuth2AuthorizationContext authorizationContext = + OAuth2AuthorizationContext.forClient(TestClientRegistrations.clientCredentials().build()) + .principal(this.principal) + .build(); + OAuth2AuthorizedClient authorizedClient = authorizedClientProvider.authorize(authorizationContext); + + assertThat(authorizedClient).isNotNull(); + verify(this.accessTokenClient).exchange(any(RequestEntity.class), eq(OAuth2AccessTokenResponse.class)); + } + + @Test + public void buildWhenAllProvidersThenProvidersAuthorize() { + OAuth2AuthorizedClientProvider authorizedClientProvider = + OAuth2AuthorizedClientProviderBuilder.withProvider() + .authorizationCode() + .refreshToken(configurer -> configurer.accessTokenResponseClient(this.refreshTokenTokenResponseClient)) + .clientCredentials(configurer -> configurer.accessTokenResponseClient(this.clientCredentialsTokenResponseClient)) + .build(); + + ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); + + + // authorization_code + OAuth2AuthorizationContext authorizationCodeContext = + OAuth2AuthorizationContext.forClient(clientRegistration) + .principal(this.principal) + .build(); + assertThatThrownBy(() -> authorizedClientProvider.authorize(authorizationCodeContext)) + .isInstanceOf(ClientAuthorizationRequiredException.class); + + + // refresh_token + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( + clientRegistration, + this.principal.getName(), + expiredAccessToken(), + TestOAuth2RefreshTokens.refreshToken()); + + OAuth2AuthorizationContext refreshTokenContext = + OAuth2AuthorizationContext.forClient(authorizedClient) + .principal(this.principal) + .build(); + OAuth2AuthorizedClient reauthorizedClient = authorizedClientProvider.authorize(refreshTokenContext); + + assertThat(reauthorizedClient).isNotNull(); + verify(this.accessTokenClient, times(1)).exchange(any(RequestEntity.class), eq(OAuth2AccessTokenResponse.class)); + + + // client_credentials + OAuth2AuthorizationContext clientCredentialsContext = + OAuth2AuthorizationContext.forClient(TestClientRegistrations.clientCredentials().build()) + .principal(this.principal) + .build(); + authorizedClient = authorizedClientProvider.authorize(clientCredentialsContext); + + assertThat(authorizedClient).isNotNull(); + verify(this.accessTokenClient, times(2)).exchange(any(RequestEntity.class), eq(OAuth2AccessTokenResponse.class)); + } + + @Test + public void buildWhenCustomProviderThenProviderCalled() { + OAuth2AuthorizedClientProvider customProvider = mock(OAuth2AuthorizedClientProvider.class); + + OAuth2AuthorizedClientProvider authorizedClientProvider = + OAuth2AuthorizedClientProviderBuilder.withProvider() + .provider(customProvider) + .build(); + + OAuth2AuthorizationContext authorizationContext = + OAuth2AuthorizationContext.forClient(TestClientRegistrations.clientRegistration().build()) + .principal(this.principal) + .build(); + authorizedClientProvider.authorize(authorizationContext); + + verify(customProvider).authorize(any(OAuth2AuthorizationContext.class)); + } + + private OAuth2AccessToken expiredAccessToken() { + Instant issuedAt = Instant.now().minus(Duration.ofDays(1)); + Instant expiresAt = issuedAt.plus(Duration.ofMinutes(60)); + return new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, "access-token-1234", issuedAt, expiresAt); + } +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProviderTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProviderTests.java index a183d84239b..cd4b6725e14 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProviderTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProviderTests.java @@ -18,24 +18,18 @@ import org.junit.Before; import org.junit.Test; import org.mockito.ArgumentCaptor; -import org.springframework.mock.web.MockHttpServletRequest; -import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.OAuth2RefreshTokenGrantRequest; import org.springframework.security.oauth2.client.registration.ClientRegistration; -import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; -import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; import java.time.Duration; import java.time.Instant; import java.util.Arrays; @@ -44,7 +38,6 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.*; /** @@ -53,8 +46,6 @@ * @author Joe Grandja */ public class RefreshTokenOAuth2AuthorizedClientProviderTests { - private ClientRegistrationRepository clientRegistrationRepository; - private OAuth2AuthorizedClientRepository authorizedClientRepository; private RefreshTokenOAuth2AuthorizedClientProvider authorizedClientProvider; private OAuth2AccessTokenResponseClient accessTokenResponseClient; private ClientRegistration clientRegistration; @@ -63,10 +54,7 @@ public class RefreshTokenOAuth2AuthorizedClientProviderTests { @Before public void setup() { - this.clientRegistrationRepository = mock(ClientRegistrationRepository.class); - this.authorizedClientRepository = mock(OAuth2AuthorizedClientRepository.class); - this.authorizedClientProvider = new RefreshTokenOAuth2AuthorizedClientProvider( - this.clientRegistrationRepository, this.authorizedClientRepository); + this.authorizedClientProvider = new RefreshTokenOAuth2AuthorizedClientProvider(); this.accessTokenResponseClient = mock(OAuth2AccessTokenResponseClient.class); this.authorizedClientProvider.setAccessTokenResponseClient(this.accessTokenResponseClient); this.clientRegistration = TestClientRegistrations.clientRegistration().build(); @@ -79,20 +67,6 @@ public void setup() { expiredAccessToken, TestOAuth2RefreshTokens.refreshToken()); } - @Test - public void constructorWhenClientRegistrationRepositoryIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new RefreshTokenOAuth2AuthorizedClientProvider(null, this.authorizedClientRepository)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("clientRegistrationRepository cannot be null"); - } - - @Test - public void constructorWhenOAuth2AuthorizedClientRepositoryIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new RefreshTokenOAuth2AuthorizedClientProvider(this.clientRegistrationRepository, null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("authorizedClientRepository cannot be null"); - } - @Test public void setAccessTokenResponseClientWhenClientIsNullThenThrowIllegalArgumentException() { assertThatThrownBy(() -> this.authorizedClientProvider.setAccessTokenResponseClient(null)) @@ -121,133 +95,61 @@ public void authorizeWhenContextIsNullThenThrowIllegalArgumentException() { .hasMessage("context cannot be null"); } - @Test - public void authorizeWhenHttpServletRequestIsNullThenThrowIllegalArgumentException() { - OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.forClient(this.clientRegistration.getRegistrationId()) - .principal(this.principal) - .build(); - assertThatThrownBy(() -> this.authorizedClientProvider.authorize(authorizationContext)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("The context attribute cannot be null 'javax.servlet.http.HttpServletRequest'"); - } - - @Test - public void authorizeWhenHttpServletResponseIsNullThenThrowIllegalArgumentException() { - OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.forClient(this.clientRegistration.getRegistrationId()) - .principal(this.principal) - .attribute(HttpServletRequest.class.getName(), new MockHttpServletRequest()) - .build(); - assertThatThrownBy(() -> this.authorizedClientProvider.authorize(authorizationContext)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("The context attribute cannot be null 'javax.servlet.http.HttpServletResponse'"); - } - - @Test - public void authorizeWhenClientRegistrationNotFoundThenThrowIllegalArgumentException() { - OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.forClient(this.clientRegistration.getRegistrationId()) - .principal(this.principal) - .attribute(HttpServletRequest.class.getName(), new MockHttpServletRequest()) - .attribute(HttpServletResponse.class.getName(), new MockHttpServletResponse()) - .build(); - assertThatThrownBy(() -> this.authorizedClientProvider.authorize(authorizationContext)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("Could not find ClientRegistration with id '" + this.clientRegistration.getRegistrationId() + "'"); - } - @Test public void authorizeWhenNotAuthorizedThenUnableToReauthorize() { - when(this.clientRegistrationRepository.findByRegistrationId( - eq(this.clientRegistration.getRegistrationId()))).thenReturn(this.clientRegistration); - OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.forClient(this.clientRegistration.getRegistrationId()) + OAuth2AuthorizationContext.forClient(this.clientRegistration) .principal(this.principal) - .attribute(HttpServletRequest.class.getName(), new MockHttpServletRequest()) - .attribute(HttpServletResponse.class.getName(), new MockHttpServletResponse()) .build(); assertThat(this.authorizedClientProvider.authorize(authorizationContext)).isNull(); } @Test public void authorizeWhenAuthorizedAndRefreshTokenIsNullThenUnableToReauthorize() { - when(this.clientRegistrationRepository.findByRegistrationId( - eq(this.clientRegistration.getRegistrationId()))).thenReturn(this.clientRegistration); - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( this.clientRegistration, this.principal.getName(), this.authorizedClient.getAccessToken()); - when(this.authorizedClientRepository.loadAuthorizedClient(eq(this.clientRegistration.getRegistrationId()), - eq(this.principal), any(HttpServletRequest.class))).thenReturn(authorizedClient); OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.forClient(this.clientRegistration.getRegistrationId()) + OAuth2AuthorizationContext.forClient(authorizedClient) .principal(this.principal) - .attribute(HttpServletRequest.class.getName(), new MockHttpServletRequest()) - .attribute(HttpServletResponse.class.getName(), new MockHttpServletResponse()) .build(); assertThat(this.authorizedClientProvider.authorize(authorizationContext)).isNull(); } @Test public void authorizeWhenAuthorizedAndAccessTokenNotExpiredThenNotReauthorize() { - when(this.clientRegistrationRepository.findByRegistrationId( - eq(this.clientRegistration.getRegistrationId()))).thenReturn(this.clientRegistration); - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, this.principal.getName(), TestOAuth2AccessTokens.noScopes(), this.authorizedClient.getRefreshToken()); - when(this.authorizedClientRepository.loadAuthorizedClient(eq(this.clientRegistration.getRegistrationId()), - eq(this.principal), any(HttpServletRequest.class))).thenReturn(authorizedClient); OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.forClient(this.clientRegistration.getRegistrationId()) + OAuth2AuthorizationContext.forClient(authorizedClient) .principal(this.principal) - .attribute(HttpServletRequest.class.getName(), new MockHttpServletRequest()) - .attribute(HttpServletResponse.class.getName(), new MockHttpServletResponse()) .build(); assertThat(this.authorizedClientProvider.authorize(authorizationContext)).isNull(); } @Test public void authorizeWhenAuthorizedAndAccessTokenExpiredThenReauthorize() { - when(this.clientRegistrationRepository.findByRegistrationId( - eq(this.clientRegistration.getRegistrationId()))).thenReturn(this.clientRegistration); - - when(this.authorizedClientRepository.loadAuthorizedClient(eq(this.clientRegistration.getRegistrationId()), - eq(this.principal), any(HttpServletRequest.class))).thenReturn(this.authorizedClient); - OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse() .refreshToken("new-refresh-token") .build(); when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse); OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.forClient(this.clientRegistration.getRegistrationId()) + OAuth2AuthorizationContext.forClient(this.authorizedClient) .principal(this.principal) - .attribute(HttpServletRequest.class.getName(), new MockHttpServletRequest()) - .attribute(HttpServletResponse.class.getName(), new MockHttpServletResponse()) .build(); - OAuth2AuthorizedClient authorizedClient = this.authorizedClientProvider.authorize(authorizationContext); + OAuth2AuthorizedClient reauthorizedClient = this.authorizedClientProvider.authorize(authorizationContext); - assertThat(authorizedClient.getClientRegistration()).isSameAs(this.clientRegistration); - assertThat(authorizedClient.getPrincipalName()).isEqualTo(this.principal.getName()); - assertThat(authorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken()); - assertThat(authorizedClient.getRefreshToken()).isEqualTo(accessTokenResponse.getRefreshToken()); - verify(this.authorizedClientRepository).saveAuthorizedClient( - eq(authorizedClient), eq(this.principal), - any(HttpServletRequest.class), any(HttpServletResponse.class)); + assertThat(reauthorizedClient.getClientRegistration()).isSameAs(this.clientRegistration); + assertThat(reauthorizedClient.getPrincipalName()).isEqualTo(this.principal.getName()); + assertThat(reauthorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken()); + assertThat(reauthorizedClient.getRefreshToken()).isEqualTo(accessTokenResponse.getRefreshToken()); } @Test public void authorizeWhenAuthorizedAndRequestScopeProvidedThenScopeRequested() { - when(this.clientRegistrationRepository.findByRegistrationId( - eq(this.clientRegistration.getRegistrationId()))).thenReturn(this.clientRegistration); - - when(this.authorizedClientRepository.loadAuthorizedClient(eq(this.clientRegistration.getRegistrationId()), - eq(this.principal), any(HttpServletRequest.class))).thenReturn(this.authorizedClient); - OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse() .refreshToken("new-refresh-token") .build(); @@ -255,10 +157,8 @@ public void authorizeWhenAuthorizedAndRequestScopeProvidedThenScopeRequested() { String[] requestScope = new String[] { "read", "write" }; OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.forClient(this.clientRegistration.getRegistrationId()) + OAuth2AuthorizationContext.forClient(this.authorizedClient) .principal(this.principal) - .attribute(HttpServletRequest.class.getName(), new MockHttpServletRequest()) - .attribute(HttpServletResponse.class.getName(), new MockHttpServletResponse()) .attribute(RefreshTokenOAuth2AuthorizedClientProvider.REQUEST_SCOPE_ATTRIBUTE_NAME, requestScope) .build(); @@ -272,18 +172,10 @@ public void authorizeWhenAuthorizedAndRequestScopeProvidedThenScopeRequested() { @Test public void authorizeWhenAuthorizedAndInvalidRequestScopeProvidedThenThrowIllegalArgumentException() { - when(this.clientRegistrationRepository.findByRegistrationId( - eq(this.clientRegistration.getRegistrationId()))).thenReturn(this.clientRegistration); - - when(this.authorizedClientRepository.loadAuthorizedClient(eq(this.clientRegistration.getRegistrationId()), - eq(this.principal), any(HttpServletRequest.class))).thenReturn(this.authorizedClient); - String invalidRequestScope = "read write"; OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.forClient(this.clientRegistration.getRegistrationId()) + OAuth2AuthorizationContext.forClient(this.authorizedClient) .principal(this.principal) - .attribute(HttpServletRequest.class.getName(), new MockHttpServletRequest()) - .attribute(HttpServletResponse.class.getName(), new MockHttpServletResponse()) .attribute(RefreshTokenOAuth2AuthorizedClientProvider.REQUEST_SCOPE_ATTRIBUTE_NAME, invalidRequestScope) .build(); diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizedClientManagerTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizedClientManagerTests.java new file mode 100644 index 00000000000..91b9b59abbc --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizedClientManagerTests.java @@ -0,0 +1,265 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client.web; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.client.OAuth2AuthorizationContext; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; +import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens; + +import java.util.function.BiFunction; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.*; + +/** + * Tests for {@link DefaultOAuth2AuthorizedClientManager}. + * + * @author Joe Grandja + */ +public class DefaultOAuth2AuthorizedClientManagerTests { + private ClientRegistrationRepository clientRegistrationRepository; + private OAuth2AuthorizedClientRepository authorizedClientRepository; + private OAuth2AuthorizedClientProvider authorizedClientProvider; + private BiFunction contextAttributesMapper; + private DefaultOAuth2AuthorizedClientManager authorizedClientManager; + private ClientRegistration clientRegistration; + private Authentication principal; + private OAuth2AuthorizedClient authorizedClient; + private MockHttpServletRequest request; + private MockHttpServletResponse response; + private ArgumentCaptor authorizationContextCaptor; + + @SuppressWarnings("unchecked") + @Before + public void setup() { + this.clientRegistrationRepository = mock(ClientRegistrationRepository.class); + this.authorizedClientRepository = mock(OAuth2AuthorizedClientRepository.class); + this.authorizedClientProvider = mock(OAuth2AuthorizedClientProvider.class); + this.contextAttributesMapper = mock(BiFunction.class); + this.authorizedClientManager = new DefaultOAuth2AuthorizedClientManager( + this.clientRegistrationRepository, this.authorizedClientRepository); + this.authorizedClientManager.setAuthorizedClientProvider(this.authorizedClientProvider); + this.authorizedClientManager.setContextAttributesMapper(this.contextAttributesMapper); + this.clientRegistration = TestClientRegistrations.clientRegistration().build(); + this.principal = new TestingAuthenticationToken("principal", "password"); + this.authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, this.principal.getName(), + TestOAuth2AccessTokens.scopes("read", "write"), TestOAuth2RefreshTokens.refreshToken()); + this.request = new MockHttpServletRequest(); + this.response = new MockHttpServletResponse(); + this.authorizationContextCaptor = ArgumentCaptor.forClass(OAuth2AuthorizationContext.class); + } + + @Test + public void constructorWhenClientRegistrationRepositoryIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new DefaultOAuth2AuthorizedClientManager(null, this.authorizedClientRepository)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("clientRegistrationRepository cannot be null"); + } + + @Test + public void constructorWhenOAuth2AuthorizedClientRepositoryIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new DefaultOAuth2AuthorizedClientManager(this.clientRegistrationRepository, null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("authorizedClientRepository cannot be null"); + } + + @Test + public void setAuthorizedClientProviderWhenNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.authorizedClientManager.setAuthorizedClientProvider(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("authorizedClientProvider cannot be null"); + } + + @Test + public void setContextAttributesMapperWhenNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.authorizedClientManager.setContextAttributesMapper(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("contextAttributesMapper cannot be null"); + } + + @Test + public void authorizeWhenArgumentsInvalidThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.authorizedClientManager.authorize(null, this.principal, this.request, this.response)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("clientRegistrationId cannot be empty"); + assertThatThrownBy(() -> this.authorizedClientManager.authorize(this.clientRegistration.getRegistrationId(), null, this.request, this.response)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("principal cannot be null"); + assertThatThrownBy(() -> this.authorizedClientManager.authorize(this.clientRegistration.getRegistrationId(), this.principal, null, this.response)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("request cannot be null"); + assertThatThrownBy(() -> this.authorizedClientManager.authorize(this.clientRegistration.getRegistrationId(), this.principal, this.request, null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("response cannot be null"); + } + + @Test + public void authorizeWhenClientRegistrationNotFoundThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.authorizedClientManager.authorize("invalid-registration-id", this.principal, this.request, this.response)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Could not find ClientRegistration with id 'invalid-registration-id'"); + } + + @SuppressWarnings("unchecked") + @Test + public void authorizeWhenNotAuthorizedAndUnsupportedProviderThenNotAuthorized() { + when(this.clientRegistrationRepository.findByRegistrationId( + eq(this.clientRegistration.getRegistrationId()))).thenReturn(this.clientRegistration); + + OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize( + this.clientRegistration.getRegistrationId(), this.principal, this.request, this.response); + + verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); + verify(this.contextAttributesMapper).apply(eq(this.clientRegistration), eq(this.request)); + + OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); + assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration); + assertThat(authorizationContext.getAuthorizedClient()).isNull(); + assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); + + assertThat(authorizedClient).isNull(); + verify(this.authorizedClientRepository, never()).saveAuthorizedClient( + any(OAuth2AuthorizedClient.class), eq(this.principal), eq(this.request), eq(this.response)); + } + + @SuppressWarnings("unchecked") + @Test + public void authorizeWhenNotAuthorizedAndSupportedProviderThenAuthorized() { + when(this.clientRegistrationRepository.findByRegistrationId( + eq(this.clientRegistration.getRegistrationId()))).thenReturn(this.clientRegistration); + + when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(this.authorizedClient); + + OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize( + this.clientRegistration.getRegistrationId(), this.principal, this.request, this.response); + + verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); + verify(this.contextAttributesMapper).apply(eq(this.clientRegistration), eq(this.request)); + + OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); + assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration); + assertThat(authorizationContext.getAuthorizedClient()).isNull(); + assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); + + assertThat(authorizedClient).isSameAs(this.authorizedClient); + verify(this.authorizedClientRepository).saveAuthorizedClient( + eq(this.authorizedClient), eq(this.principal), eq(this.request), eq(this.response)); + } + + @SuppressWarnings("unchecked") + @Test + public void authorizeWhenAuthorizedAndSupportedProviderThenReauthorized() { + when(this.clientRegistrationRepository.findByRegistrationId( + eq(this.clientRegistration.getRegistrationId()))).thenReturn(this.clientRegistration); + when(this.authorizedClientRepository.loadAuthorizedClient( + eq(this.clientRegistration.getRegistrationId()), eq(this.principal), eq(this.request))).thenReturn(this.authorizedClient); + + OAuth2AuthorizedClient reauthorizedClient = new OAuth2AuthorizedClient( + this.clientRegistration, this.principal.getName(), + TestOAuth2AccessTokens.noScopes(), TestOAuth2RefreshTokens.refreshToken()); + + when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(reauthorizedClient); + + OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize( + this.clientRegistration.getRegistrationId(), this.principal, this.request, this.response); + + verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); + verify(this.contextAttributesMapper).apply(eq(this.clientRegistration), eq(this.request)); + + OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); + assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration); + assertThat(authorizationContext.getAuthorizedClient()).isSameAs(this.authorizedClient); + assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); + + assertThat(authorizedClient).isSameAs(reauthorizedClient); + verify(this.authorizedClientRepository).saveAuthorizedClient( + eq(reauthorizedClient), eq(this.principal), eq(this.request), eq(this.response)); + } + + @Test + public void reauthorizeWhenArgumentsInvalidThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.authorizedClientManager.reauthorize(null, this.principal, this.request, this.response)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("authorizedClient cannot be null"); + assertThatThrownBy(() -> this.authorizedClientManager.reauthorize(this.authorizedClient, null, this.request, this.response)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("principal cannot be null"); + assertThatThrownBy(() -> this.authorizedClientManager.reauthorize(this.authorizedClient, this.principal, null, this.response)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("request cannot be null"); + assertThatThrownBy(() -> this.authorizedClientManager.reauthorize(this.authorizedClient, this.principal, this.request, null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("response cannot be null"); + } + + @SuppressWarnings("unchecked") + @Test + public void reauthorizeWhenUnsupportedProviderThenNotReauthorized() { + OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.reauthorize( + this.authorizedClient, this.principal, this.request, this.response); + + verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); + verify(this.contextAttributesMapper).apply(eq(this.clientRegistration), eq(this.request)); + + OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); + assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration); + assertThat(authorizationContext.getAuthorizedClient()).isSameAs(this.authorizedClient); + assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); + + assertThat(authorizedClient).isSameAs(this.authorizedClient); + verify(this.authorizedClientRepository, never()).saveAuthorizedClient( + any(OAuth2AuthorizedClient.class), eq(this.principal), eq(this.request), eq(this.response)); + } + + @SuppressWarnings("unchecked") + @Test + public void reauthorizeWhenSupportedProviderThenReauthorized() { + OAuth2AuthorizedClient reauthorizedClient = new OAuth2AuthorizedClient( + this.clientRegistration, this.principal.getName(), + TestOAuth2AccessTokens.noScopes(), TestOAuth2RefreshTokens.refreshToken()); + + when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(reauthorizedClient); + + OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.reauthorize( + this.authorizedClient, this.principal, this.request, this.response); + + verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); + verify(this.contextAttributesMapper).apply(eq(this.clientRegistration), eq(this.request)); + + OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); + assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration); + assertThat(authorizationContext.getAuthorizedClient()).isSameAs(this.authorizedClient); + assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); + + assertThat(authorizedClient).isSameAs(reauthorizedClient); + verify(this.authorizedClientRepository).saveAuthorizedClient( + eq(reauthorizedClient), eq(this.principal), eq(this.request), eq(this.response)); + } +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolverTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolverTests.java index e47f4ad1502..652233674c8 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolverTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolverTests.java @@ -35,6 +35,7 @@ import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; import org.springframework.security.oauth2.client.registration.InMemoryClientRegistrationRepository; +import org.springframework.security.oauth2.client.web.DefaultOAuth2AuthorizedClientManager; import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.ClientAuthenticationMethod; @@ -134,8 +135,8 @@ public void constructorWhenOAuth2AuthorizedClientRepositoryIsNullThenThrowIllega } @Test - public void setAuthorizedClientProviderWhenProviderIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.argumentResolver.setAuthorizedClientProvider(null)) + public void setAuthorizedClientManagerWhenProviderIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.argumentResolver.setAuthorizedClientManager(null)) .isInstanceOf(IllegalArgumentException.class); } @@ -219,10 +220,12 @@ public void resolveArgumentWhenAuthorizedClientNotFoundForClientCredentialsClien OAuth2AccessTokenResponseClient clientCredentialsTokenResponseClient = mock(OAuth2AccessTokenResponseClient.class); ClientCredentialsOAuth2AuthorizedClientProvider clientCredentialsAuthorizedClientProvider = - new ClientCredentialsOAuth2AuthorizedClientProvider( - this.clientRegistrationRepository, this.authorizedClientRepository); + new ClientCredentialsOAuth2AuthorizedClientProvider(); clientCredentialsAuthorizedClientProvider.setAccessTokenResponseClient(clientCredentialsTokenResponseClient); - this.argumentResolver.setAuthorizedClientProvider(clientCredentialsAuthorizedClientProvider); + DefaultOAuth2AuthorizedClientManager authorizedClientManager = new DefaultOAuth2AuthorizedClientManager( + this.clientRegistrationRepository, this.authorizedClientRepository); + authorizedClientManager.setAuthorizedClientProvider(clientCredentialsAuthorizedClientProvider); + this.argumentResolver.setAuthorizedClientManager(authorizedClientManager); OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse .withToken("access-token-1234") diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java index f1f11e4f945..6e016b746b3 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java @@ -46,12 +46,9 @@ import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.context.SecurityContextHolder; -import org.springframework.security.oauth2.client.AuthorizationCodeOAuth2AuthorizedClientProvider; -import org.springframework.security.oauth2.client.ClientCredentialsOAuth2AuthorizedClientProvider; -import org.springframework.security.oauth2.client.DefaultOAuth2AuthorizedClientProvider; -import org.springframework.security.oauth2.client.DelegatingOAuth2AuthorizedClientProvider; import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProviderBuilder; import org.springframework.security.oauth2.client.RefreshTokenOAuth2AuthorizedClientProvider; import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; import org.springframework.security.oauth2.client.endpoint.DefaultRefreshTokenTokenResponseClient; @@ -61,6 +58,8 @@ import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.client.web.DefaultOAuth2AuthorizedClientManager; +import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientManager; import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2RefreshToken; @@ -74,7 +73,6 @@ import org.springframework.web.reactive.function.client.ClientRequest; import org.springframework.web.reactive.function.client.WebClient; -import javax.servlet.http.HttpServletRequest; import java.net.URI; import java.time.Duration; import java.time.Instant; @@ -122,7 +120,7 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests { private ServletOAuth2AuthorizedClientExchangeFilterFunction function; - private OAuth2AuthorizedClientProvider authorizedClientProvider; + private OAuth2AuthorizedClientManager authorizedClientManager; private MockExchangeFunction exchange = new MockExchangeFunction(); @@ -138,22 +136,18 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests { @Before public void setup() { this.authentication = new TestingAuthenticationToken("test", "this"); - AuthorizationCodeOAuth2AuthorizedClientProvider authorizationCodeAuthorizedClientProvider = - new AuthorizationCodeOAuth2AuthorizedClientProvider(this.clientRegistrationRepository, this.authorizedClientRepository); - ClientCredentialsOAuth2AuthorizedClientProvider clientCredentialsAuthorizedClientProvider = - new ClientCredentialsOAuth2AuthorizedClientProvider(this.clientRegistrationRepository, this.authorizedClientRepository); - clientCredentialsAuthorizedClientProvider.setAccessTokenResponseClient(this.clientCredentialsTokenResponseClient); - RefreshTokenOAuth2AuthorizedClientProvider refreshTokenAuthorizedClientProvider = - new RefreshTokenOAuth2AuthorizedClientProvider(this.clientRegistrationRepository, this.authorizedClientRepository); - refreshTokenAuthorizedClientProvider.setAccessTokenResponseClient(this.refreshTokenTokenResponseClient); - DelegatingOAuth2AuthorizedClientProvider delegate = new DelegatingOAuth2AuthorizedClientProvider( - authorizationCodeAuthorizedClientProvider, clientCredentialsAuthorizedClientProvider, refreshTokenAuthorizedClientProvider); - delegate.setDefaultAuthorizedClientProvider( - new DefaultOAuth2AuthorizedClientProvider(this.clientRegistrationRepository, this.authorizedClientRepository)); - this.authorizedClientProvider = delegate; + OAuth2AuthorizedClientProvider authorizedClientProvider = + OAuth2AuthorizedClientProviderBuilder.withProvider() + .authorizationCode() + .refreshToken(configurer -> configurer.accessTokenResponseClient(this.refreshTokenTokenResponseClient)) + .clientCredentials(configurer -> configurer.accessTokenResponseClient(this.clientCredentialsTokenResponseClient)) + .build(); + DefaultOAuth2AuthorizedClientManager authorizedClientManager = new DefaultOAuth2AuthorizedClientManager( + this.clientRegistrationRepository, this.authorizedClientRepository); + authorizedClientManager.setAuthorizedClientProvider(authorizedClientProvider); this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction( this.clientRegistrationRepository, this.authorizedClientRepository); - this.function.setAuthorizedClientProvider(this.authorizedClientProvider); + this.function.setAuthorizedClientManager(authorizedClientManager); } @After @@ -163,8 +157,8 @@ public void cleanup() { } @Test - public void setAuthorizedClientProviderWhenProviderIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.function.setAuthorizedClientProvider(null)) + public void setAuthorizedClientManagerWhenManagerIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.function.setAuthorizedClientManager(null)) .isInstanceOf(IllegalArgumentException.class); } @@ -307,13 +301,14 @@ public void defaultRequestWhenClientCredentialsThenAuthorizedClient() { MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletResponse response = new MockHttpServletResponse(); RequestContextHolder.setRequestAttributes(new ServletRequestAttributes(request, response)); + SecurityContextHolder.getContext().setAuthentication(this.authentication); Map attrs = getDefaultRequestAttributes(); OAuth2AuthorizedClient authorizedClient = getOAuth2AuthorizedClient(attrs); assertThat(authorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken()); assertThat(authorizedClient.getClientRegistration()).isEqualTo(this.registration); - assertThat(authorizedClient.getPrincipalName()).isEqualTo("anonymousUser"); + assertThat(authorizedClient.getPrincipalName()).isEqualTo("test"); assertThat(authorizedClient.getRefreshToken()).isEqualTo(accessTokenResponse.getRefreshToken()); } @@ -329,13 +324,14 @@ public void defaultRequestWhenDefaultClientRegistrationIdThenAuthorizedClient() MockHttpServletRequest request = new MockHttpServletRequest(); MockHttpServletResponse response = new MockHttpServletResponse(); RequestContextHolder.setRequestAttributes(new ServletRequestAttributes(request, response)); + SecurityContextHolder.getContext().setAuthentication(this.authentication); Map attrs = getDefaultRequestAttributes(); OAuth2AuthorizedClient authorizedClient = getOAuth2AuthorizedClient(attrs); assertThat(authorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken()); assertThat(authorizedClient.getClientRegistration()).isEqualTo(this.registration); - assertThat(authorizedClient.getPrincipalName()).isEqualTo("anonymousUser"); + assertThat(authorizedClient.getPrincipalName()).isEqualTo("test"); assertThat(authorizedClient.getRefreshToken()).isEqualTo(accessTokenResponse.getRefreshToken()); } @@ -372,12 +368,6 @@ public void filterWhenAuthorizedClientThenAuthorizationHeader() { OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", this.accessToken); - when(this.clientRegistrationRepository.findByRegistrationId( - eq(this.registration.getRegistrationId()))).thenReturn(this.registration); - when(this.authorizedClientRepository.loadAuthorizedClient( - eq(this.registration.getRegistrationId()), any(Authentication.class), - any(HttpServletRequest.class))).thenReturn(authorizedClient); - ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) .attributes(oauth2AuthorizedClient(authorizedClient)) .attributes(httpServletRequest(new MockHttpServletRequest())) @@ -394,12 +384,6 @@ public void filterWhenExistingAuthorizationThenSingleAuthorizationHeader() { OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", this.accessToken); - when(this.clientRegistrationRepository.findByRegistrationId( - eq(this.registration.getRegistrationId()))).thenReturn(this.registration); - when(this.authorizedClientRepository.loadAuthorizedClient( - eq(this.registration.getRegistrationId()), any(Authentication.class), - any(HttpServletRequest.class))).thenReturn(authorizedClient); - ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) .header(HttpHeaders.AUTHORIZATION, "Existing") .attributes(oauth2AuthorizedClient(authorizedClient)) @@ -432,12 +416,6 @@ public void filterWhenRefreshRequiredThenRefresh() { OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", this.accessToken, refreshToken); - when(this.clientRegistrationRepository.findByRegistrationId( - eq(this.registration.getRegistrationId()))).thenReturn(this.registration); - when(this.authorizedClientRepository.loadAuthorizedClient( - eq(this.registration.getRegistrationId()), eq(this.authentication), - any(HttpServletRequest.class))).thenReturn(authorizedClient); - ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) .attributes(oauth2AuthorizedClient(authorizedClient)) .attributes(authentication(this.authentication)) @@ -479,10 +457,12 @@ public void filterWhenRefreshRequiredThenRefreshAndResponseDoesNotContainRefresh DefaultRefreshTokenTokenResponseClient refreshTokenTokenResponseClient = new DefaultRefreshTokenTokenResponseClient(); refreshTokenTokenResponseClient.setRestOperations(refreshTokenClient); - RefreshTokenOAuth2AuthorizedClientProvider authorizedClientProvider = new RefreshTokenOAuth2AuthorizedClientProvider( - this.clientRegistrationRepository, this.authorizedClientRepository); + RefreshTokenOAuth2AuthorizedClientProvider authorizedClientProvider = new RefreshTokenOAuth2AuthorizedClientProvider(); authorizedClientProvider.setAccessTokenResponseClient(refreshTokenTokenResponseClient); - this.function.setAuthorizedClientProvider(authorizedClientProvider); + DefaultOAuth2AuthorizedClientManager authorizedClientManager = new DefaultOAuth2AuthorizedClientManager( + this.clientRegistrationRepository, this.authorizedClientRepository); + authorizedClientManager.setAuthorizedClientProvider(authorizedClientProvider); + this.function.setAuthorizedClientManager(authorizedClientManager); Instant issuedAt = Instant.now().minus(Duration.ofDays(1)); Instant accessTokenExpiresAt = issuedAt.plus(Duration.ofHours(1)); @@ -494,12 +474,6 @@ public void filterWhenRefreshRequiredThenRefreshAndResponseDoesNotContainRefresh OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", this.accessToken, refreshToken); - when(this.clientRegistrationRepository.findByRegistrationId( - eq(this.registration.getRegistrationId()))).thenReturn(this.registration); - when(this.authorizedClientRepository.loadAuthorizedClient( - eq(this.registration.getRegistrationId()), eq(this.authentication), - any(HttpServletRequest.class))).thenReturn(authorizedClient); - ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) .attributes(oauth2AuthorizedClient(authorizedClient)) .attributes(authentication(this.authentication)) @@ -532,12 +506,6 @@ public void filterWhenClientCredentialsTokenNotExpiredThenUseCurrentToken() { OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", this.accessToken, null); - when(this.clientRegistrationRepository.findByRegistrationId( - eq(this.registration.getRegistrationId()))).thenReturn(this.registration); - when(this.authorizedClientRepository.loadAuthorizedClient( - eq(this.registration.getRegistrationId()), eq(this.authentication), - any(HttpServletRequest.class))).thenReturn(authorizedClient); - ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) .attributes(oauth2AuthorizedClient(authorizedClient)) .attributes(authentication(this.authentication)) @@ -580,12 +548,6 @@ public void filterWhenClientCredentialsTokenExpiredThenGetNewToken() { OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", this.accessToken, null); - when(this.clientRegistrationRepository.findByRegistrationId( - eq(this.registration.getRegistrationId()))).thenReturn(this.registration); - when(this.authorizedClientRepository.loadAuthorizedClient( - eq(this.registration.getRegistrationId()), eq(this.authentication), - any(HttpServletRequest.class))).thenReturn(authorizedClient); - ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) .attributes(oauth2AuthorizedClient(authorizedClient)) .attributes(authentication(this.authentication)) @@ -626,12 +588,6 @@ public void filterWhenRefreshRequiredAndEmptyReactiveSecurityContextThenSaved() OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", this.accessToken, refreshToken); - when(this.clientRegistrationRepository.findByRegistrationId( - eq(this.registration.getRegistrationId()))).thenReturn(this.registration); - when(this.authorizedClientRepository.loadAuthorizedClient( - eq(this.registration.getRegistrationId()), any(Authentication.class), - any(HttpServletRequest.class))).thenReturn(authorizedClient); - ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) .attributes(oauth2AuthorizedClient(authorizedClient)) .attributes(httpServletRequest(new MockHttpServletRequest())) @@ -658,12 +614,6 @@ public void filterWhenRefreshTokenNullThenShouldRefreshFalse() { OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", this.accessToken); - when(this.clientRegistrationRepository.findByRegistrationId( - eq(this.registration.getRegistrationId()))).thenReturn(this.registration); - when(this.authorizedClientRepository.loadAuthorizedClient( - eq(this.registration.getRegistrationId()), any(Authentication.class), - any(HttpServletRequest.class))).thenReturn(authorizedClient); - ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) .attributes(oauth2AuthorizedClient(authorizedClient)) .attributes(httpServletRequest(new MockHttpServletRequest())) @@ -688,12 +638,6 @@ public void filterWhenNotExpiredThenShouldRefreshFalse() { OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", this.accessToken, refreshToken); - when(this.clientRegistrationRepository.findByRegistrationId( - eq(this.registration.getRegistrationId()))).thenReturn(this.registration); - when(this.authorizedClientRepository.loadAuthorizedClient( - eq(this.registration.getRegistrationId()), any(Authentication.class), - any(HttpServletRequest.class))).thenReturn(authorizedClient); - ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) .attributes(oauth2AuthorizedClient(authorizedClient)) .attributes(httpServletRequest(new MockHttpServletRequest())) @@ -728,9 +672,6 @@ public void filterWhenChainedThenDefaultsStillAvailable() throws Exception { user, authorities, this.registration.getRegistrationId()); SecurityContextHolder.getContext().setAuthentication(authentication); - when(this.clientRegistrationRepository.findByRegistrationId( - eq(this.registration.getRegistrationId()))).thenReturn(this.registration); - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( this.registration, "principalName", this.accessToken); From 3eaa30f4be82bcbbeb2ab1e8955a631a6176c36a Mon Sep 17 00:00:00 2001 From: Joe Grandja Date: Wed, 10 Jul 2019 11:45:40 -0400 Subject: [PATCH 14/19] Update OAuth2AuthorizedClientManager from review --- .../OAuth2ClientConfiguration.java | 27 ++--- .../DefaultOAuth2AuthorizedClientManager.java | 57 +++++---- .../client/web/OAuth2AuthorizeRequest.java | 87 ++++++++++++++ .../web/OAuth2AuthorizedClientManager.java | 28 ++--- .../client/web/OAuth2ReauthorizeRequest.java | 57 +++++++++ ...Auth2AuthorizedClientArgumentResolver.java | 87 ++++++++------ ...uthorizedClientExchangeFilterFunction.java | 109 ++++++++++-------- ...ultOAuth2AuthorizedClientManagerTests.java | 65 +++++------ .../web/OAuth2AuthorizeRequestTests.java | 78 +++++++++++++ .../web/OAuth2ReauthorizeRequestTests.java | 83 +++++++++++++ ...AuthorizedClientArgumentResolverTests.java | 18 ++- ...izedClientExchangeFilterFunctionTests.java | 24 ++-- .../java/sample/config/WebClientConfig.java | 17 ++- 13 files changed, 531 insertions(+), 206 deletions(-) create mode 100644 oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizeRequest.java create mode 100644 oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2ReauthorizeRequest.java create mode 100644 oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizeRequestTests.java create mode 100644 oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2ReauthorizeRequestTests.java diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfiguration.java b/config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfiguration.java index 0efadd2824a..d3f41f889bb 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfiguration.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfiguration.java @@ -71,22 +71,17 @@ static class OAuth2ClientWebMvcSecurityConfiguration implements WebMvcConfigurer @Override public void addArgumentResolvers(List argumentResolvers) { if (this.clientRegistrationRepository != null && this.authorizedClientRepository != null) { - OAuth2AuthorizedClientArgumentResolver authorizedClientArgumentResolver = - new OAuth2AuthorizedClientArgumentResolver( - this.clientRegistrationRepository, this.authorizedClientRepository); - if (this.accessTokenResponseClient != null) { - OAuth2AuthorizedClientProvider authorizedClientProvider = - OAuth2AuthorizedClientProviderBuilder.withProvider() - .authorizationCode() - .refreshToken() - .clientCredentials(configurer -> configurer.accessTokenResponseClient(this.accessTokenResponseClient)) - .build(); - DefaultOAuth2AuthorizedClientManager authorizedClientManager = new DefaultOAuth2AuthorizedClientManager( - this.clientRegistrationRepository, this.authorizedClientRepository); - authorizedClientManager.setAuthorizedClientProvider(authorizedClientProvider); - authorizedClientArgumentResolver.setAuthorizedClientManager(authorizedClientManager); - } - argumentResolvers.add(authorizedClientArgumentResolver); + OAuth2AuthorizedClientProvider authorizedClientProvider = + OAuth2AuthorizedClientProviderBuilder.withProvider() + .authorizationCode() + .refreshToken() + .clientCredentials(configurer -> + Optional.ofNullable(this.accessTokenResponseClient).ifPresent(configurer::accessTokenResponseClient)) + .build(); + DefaultOAuth2AuthorizedClientManager authorizedClientManager = new DefaultOAuth2AuthorizedClientManager( + this.clientRegistrationRepository, this.authorizedClientRepository); + authorizedClientManager.setAuthorizedClientProvider(authorizedClientProvider); + argumentResolvers.add(new OAuth2AuthorizedClientArgumentResolver(authorizedClientManager)); } } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizedClientManager.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizedClientManager.java index 40402896c39..45272a3c4e3 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizedClientManager.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizedClientManager.java @@ -28,7 +28,7 @@ import javax.servlet.http.HttpServletResponse; import java.util.Collections; import java.util.Map; -import java.util.function.BiFunction; +import java.util.function.Function; /** * The default implementation of an {@link OAuth2AuthorizedClientManager}. @@ -42,8 +42,7 @@ public final class DefaultOAuth2AuthorizedClientManager implements OAuth2Authori private final ClientRegistrationRepository clientRegistrationRepository; private final OAuth2AuthorizedClientRepository authorizedClientRepository; private OAuth2AuthorizedClientProvider authorizedClientProvider = context -> null; - private BiFunction> contextAttributesMapper = - (clientRegistration, request) -> Collections.emptyMap(); + private Function> contextAttributesMapper = authorizeRequest -> Collections.emptyMap(); /** * Constructs a {@code DefaultOAuth2AuthorizedClientManager} using the provided parameters. @@ -61,59 +60,55 @@ public DefaultOAuth2AuthorizedClientManager(ClientRegistrationRepository clientR @Nullable @Override - public OAuth2AuthorizedClient authorize(String clientRegistrationId, Authentication principal, - HttpServletRequest request, HttpServletResponse response) { + public OAuth2AuthorizedClient authorize(OAuth2AuthorizeRequest authorizeRequest) { + Assert.notNull(authorizeRequest, "authorizeRequest cannot be null"); - Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty"); - Assert.notNull(principal, "principal cannot be null"); - Assert.notNull(request, "request cannot be null"); - Assert.notNull(response, "response cannot be null"); + String clientRegistrationId = authorizeRequest.getClientRegistrationId(); + Authentication principal = authorizeRequest.getPrincipal(); + HttpServletRequest servletRequest = authorizeRequest.getServletRequest(); + HttpServletResponse servletResponse = authorizeRequest.getServletResponse(); ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId); Assert.notNull(clientRegistration, "Could not find ClientRegistration with id '" + clientRegistrationId + "'"); OAuth2AuthorizedClient authorizedClient = this.authorizedClientRepository.loadAuthorizedClient( - clientRegistrationId, principal, request); + clientRegistrationId, principal, servletRequest); if (authorizedClient != null) { - return reauthorizeIfNecessary(authorizedClient, principal, request, response); + OAuth2ReauthorizeRequest reauthorizeRequest = new OAuth2ReauthorizeRequest( + authorizedClient, principal, servletRequest, servletResponse); + return reauthorize(reauthorizeRequest); } OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext.forClient(clientRegistration) .principal(principal) - .attributes(this.contextAttributesMapper.apply(clientRegistration, request)) + .attributes(this.contextAttributesMapper.apply(authorizeRequest)) .build(); authorizedClient = this.authorizedClientProvider.authorize(authorizationContext); if (authorizedClient != null) { - this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, principal, request, response); + this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, principal, servletRequest, servletResponse); } return authorizedClient; } @Override - public OAuth2AuthorizedClient reauthorize(OAuth2AuthorizedClient authorizedClient, Authentication principal, - HttpServletRequest request, HttpServletResponse response) { + public OAuth2AuthorizedClient reauthorize(OAuth2ReauthorizeRequest reauthorizeRequest) { + Assert.notNull(reauthorizeRequest, "reauthorizeRequest cannot be null"); - Assert.notNull(authorizedClient, "authorizedClient cannot be null"); - Assert.notNull(principal, "principal cannot be null"); - Assert.notNull(request, "request cannot be null"); - Assert.notNull(response, "response cannot be null"); - - return reauthorizeIfNecessary(authorizedClient, principal, request, response); - } - - private OAuth2AuthorizedClient reauthorizeIfNecessary(OAuth2AuthorizedClient authorizedClient, Authentication principal, - HttpServletRequest request, HttpServletResponse response) { + OAuth2AuthorizedClient authorizedClient = reauthorizeRequest.getAuthorizedClient(); + Authentication principal = reauthorizeRequest.getPrincipal(); + HttpServletRequest servletRequest = reauthorizeRequest.getServletRequest(); + HttpServletResponse servletResponse = reauthorizeRequest.getServletResponse(); OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext.forClient(authorizedClient) .principal(principal) - .attributes(this.contextAttributesMapper.apply(authorizedClient.getClientRegistration(), request)) + .attributes(this.contextAttributesMapper.apply(reauthorizeRequest)) .build(); OAuth2AuthorizedClient reauthorizedClient = this.authorizedClientProvider.authorize(authorizationContext); if (reauthorizedClient != null) { - this.authorizedClientRepository.saveAuthorizedClient(reauthorizedClient, principal, request, response); + this.authorizedClientRepository.saveAuthorizedClient(reauthorizedClient, principal, servletRequest, servletResponse); return reauthorizedClient; } @@ -131,13 +126,13 @@ public void setAuthorizedClientProvider(OAuth2AuthorizedClientProvider authorize } /** - * Sets the {@code BiFunction} used for mapping attribute(s) from the {@link ClientRegistration} and/or {@code HttpServletRequest} - * to a {@code Map} of attributes to be associated to the {@link OAuth2AuthorizationContext#getAttributes() authorization context}. + * Sets the {@code Function} used for mapping attribute(s) from the {@link OAuth2AuthorizeRequest} to a {@code Map} of attributes + * to be associated to the {@link OAuth2AuthorizationContext#getAttributes() authorization context}. * - * @param contextAttributesMapper the {@code BiFunction} used for supplying the {@code Map} of attributes + * @param contextAttributesMapper the {@code Function} used for supplying the {@code Map} of attributes * to the {@link OAuth2AuthorizationContext#getAttributes() authorization context} */ - public void setContextAttributesMapper(BiFunction> contextAttributesMapper) { + public void setContextAttributesMapper(Function> contextAttributesMapper) { Assert.notNull(contextAttributesMapper, "contextAttributesMapper cannot be null"); this.contextAttributesMapper = contextAttributesMapper; } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizeRequest.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizeRequest.java new file mode 100644 index 00000000000..6bf4eae846d --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizeRequest.java @@ -0,0 +1,87 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client.web; + +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.util.Assert; + +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +/** + * Represents a request the {@link OAuth2AuthorizedClientManager} uses to + * {@link OAuth2AuthorizedClientManager#authorize(OAuth2AuthorizeRequest) authorize} (or re-authorize) + * the {@link ClientRegistration client} identified by the provided {@link #getClientRegistrationId() clientRegistrationId}. + * + * @author Joe Grandja + * @since 5.2 + * @see OAuth2AuthorizedClientManager + */ +public class OAuth2AuthorizeRequest { + private final String clientRegistrationId; + private final Authentication principal; + private final HttpServletRequest servletRequest; + private final HttpServletResponse servletResponse; + + public OAuth2AuthorizeRequest(String clientRegistrationId, Authentication principal, + HttpServletRequest servletRequest, HttpServletResponse servletResponse) { + Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty"); + Assert.notNull(principal, "principal cannot be null"); + Assert.notNull(servletRequest, "servletRequest cannot be null"); + Assert.notNull(servletResponse, "servletResponse cannot be null"); + this.clientRegistrationId = clientRegistrationId; + this.principal = principal; + this.servletRequest = servletRequest; + this.servletResponse = servletResponse; + } + + /** + * Returns the identifier for the {@link ClientRegistration client registration}. + * + * @return the identifier for the client registration + */ + public String getClientRegistrationId() { + return this.clientRegistrationId; + } + + /** + * Returns the {@code Principal} (to be) associated to the authorized client. + * + * @return the {@code Principal} (to be) associated to the authorized client + */ + public Authentication getPrincipal() { + return this.principal; + } + + /** + * Returns the {@code HttpServletRequest}. + * + * @return the {@code HttpServletRequest} + */ + public HttpServletRequest getServletRequest() { + return this.servletRequest; + } + + /** + * Returns the {@code HttpServletResponse}. + * + * @return the {@code HttpServletResponse} + */ + public HttpServletResponse getServletResponse() { + return this.servletResponse; + } +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizedClientManager.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizedClientManager.java index 9079c021b08..1b3638cb4d2 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizedClientManager.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizedClientManager.java @@ -16,14 +16,10 @@ package org.springframework.security.oauth2.client.web; import org.springframework.lang.Nullable; -import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProvider; import org.springframework.security.oauth2.client.registration.ClientRegistration; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; - /** * Implementations of this interface are responsible for the overall management * of {@link OAuth2AuthorizedClient Authorized Client(s)}. @@ -47,35 +43,27 @@ public interface OAuth2AuthorizedClientManager { /** * Attempt to authorize or re-authorize (if required) the {@link ClientRegistration client} - * identified by the provided {@code clientRegistrationId}. + * identified by the provided {@link OAuth2AuthorizeRequest#getClientRegistrationId() clientRegistrationId}. * Implementations must return {@code null} if authorization is not supported for the specified client, * e.g. the associated {@link OAuth2AuthorizedClientProvider}(s) does not support * the {@link ClientRegistration#getAuthorizationGrantType() authorization grant} type configured for the client. * - * @param clientRegistrationId the identifier for the client's registration - * @param principal the {@code Principal} {@link Authentication} (to be) associated to the authorized client - * @param request the {@code HttpServletRequest} - * @param response the {@code HttpServletResponse} + * @param authorizeRequest the authorize request * @return the {@link OAuth2AuthorizedClient} or {@code null} if authorization is not supported for the specified client */ @Nullable - OAuth2AuthorizedClient authorize(String clientRegistrationId, Authentication principal, - HttpServletRequest request, HttpServletResponse response); + OAuth2AuthorizedClient authorize(OAuth2AuthorizeRequest authorizeRequest); /** - * Attempt to re-authorize (if required) the provided {@link OAuth2AuthorizedClient authorized client}. - * Implementations must return the provided {@code authorizedClient} if re-authorization is not supported + * Attempt to re-authorize (if required) the provided {@link OAuth2ReauthorizeRequest#getAuthorizedClient() authorized client}. + * Implementations must return the provided authorized client if re-authorization is not supported * for the {@link OAuth2AuthorizedClient#getClientRegistration() client} OR is not required, * e.g. a {@link OAuth2AuthorizedClient#getRefreshToken() refresh token} is not available OR * the {@link OAuth2AuthorizedClient#getAccessToken() access token} is not expired. * - * @param authorizedClient the authorized client - * @param principal the {@code Principal} {@link Authentication} associated to the authorized client - * @param request the {@code HttpServletRequest} - * @param response the {@code HttpServletResponse} - * @return the re-authorized {@link OAuth2AuthorizedClient} or the provided {@code authorizedClient} if not re-authorized + * @param reauthorizeRequest the re-authorize request + * @return the re-authorized {@link OAuth2AuthorizedClient} or the provided {@link OAuth2ReauthorizeRequest#getAuthorizedClient() authorized client} if not re-authorized */ - OAuth2AuthorizedClient reauthorize(OAuth2AuthorizedClient authorizedClient, Authentication principal, - HttpServletRequest request, HttpServletResponse response); + OAuth2AuthorizedClient reauthorize(OAuth2ReauthorizeRequest reauthorizeRequest); } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2ReauthorizeRequest.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2ReauthorizeRequest.java new file mode 100644 index 00000000000..80beafe115a --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2ReauthorizeRequest.java @@ -0,0 +1,57 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client.web; + +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; +import org.springframework.util.Assert; + +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +/** + * Represents a request the {@link OAuth2AuthorizedClientManager} uses to + * {@link OAuth2AuthorizedClientManager#reauthorize(OAuth2ReauthorizeRequest) re-authorize} + * the provided {@link OAuth2AuthorizedClient#getClientRegistration() client}. + * + * @author Joe Grandja + * @since 5.2 + * @see OAuth2AuthorizeRequest + * @see OAuth2AuthorizedClientManager + */ +public class OAuth2ReauthorizeRequest extends OAuth2AuthorizeRequest { + private OAuth2AuthorizedClient authorizedClient; + + public OAuth2ReauthorizeRequest(OAuth2AuthorizedClient authorizedClient, Authentication principal, + HttpServletRequest servletRequest, HttpServletResponse servletResponse) { + super(getClientRegistrationId(authorizedClient), principal, servletRequest, servletResponse); + this.authorizedClient = authorizedClient; + } + + private static String getClientRegistrationId(OAuth2AuthorizedClient authorizedClient) { + Assert.notNull(authorizedClient, "authorizedClient cannot be null"); + return authorizedClient.getClientRegistration().getRegistrationId(); + } + + /** + * Returns the {@link OAuth2AuthorizedClient authorized client}. + * + * @return the {@link OAuth2AuthorizedClient} + */ + public OAuth2AuthorizedClient getAuthorizedClient() { + return this.authorizedClient; + } +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolver.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolver.java index a33c7cff6ab..9f8d1b744aa 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolver.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolver.java @@ -29,12 +29,11 @@ import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProviderBuilder; import org.springframework.security.oauth2.client.annotation.RegisteredOAuth2AuthorizedClient; import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; -import org.springframework.security.oauth2.client.endpoint.DefaultClientCredentialsTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest; -import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; import org.springframework.security.oauth2.client.web.DefaultOAuth2AuthorizedClientManager; +import org.springframework.security.oauth2.client.web.OAuth2AuthorizeRequest; import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientManager; import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; import org.springframework.util.Assert; @@ -70,23 +69,50 @@ public final class OAuth2AuthorizedClientArgumentResolver implements HandlerMethodArgumentResolver { private static final Authentication ANONYMOUS_AUTHENTICATION = new AnonymousAuthenticationToken( "anonymous", "anonymousUser", AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS")); - private final ClientRegistrationRepository clientRegistrationRepository; - private final OAuth2AuthorizedClientRepository authorizedClientRepository; private OAuth2AuthorizedClientManager authorizedClientManager; /** * Constructs an {@code OAuth2AuthorizedClientArgumentResolver} using the provided parameters. * + * @since 5.2 + * @param authorizedClientManager the {@link OAuth2AuthorizedClientManager} which manages the authorized client(s) + */ + public OAuth2AuthorizedClientArgumentResolver(OAuth2AuthorizedClientManager authorizedClientManager) { + Assert.notNull(authorizedClientManager, "authorizedClientManager cannot be null"); + this.authorizedClientManager = authorizedClientManager; + } + + /** + * Constructs an {@code OAuth2AuthorizedClientArgumentResolver} using the provided parameters. + * + * @deprecated Use {@link #OAuth2AuthorizedClientArgumentResolver(OAuth2AuthorizedClientManager)} instead. + * See {@link DefaultOAuth2AuthorizedClientManager} and {@link OAuth2AuthorizedClientProviderBuilder}. + * * @param clientRegistrationRepository the repository of client registrations * @param authorizedClientRepository the repository of authorized clients */ + @Deprecated public OAuth2AuthorizedClientArgumentResolver(ClientRegistrationRepository clientRegistrationRepository, OAuth2AuthorizedClientRepository authorizedClientRepository) { Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null"); Assert.notNull(authorizedClientRepository, "authorizedClientRepository cannot be null"); - this.clientRegistrationRepository = clientRegistrationRepository; - this.authorizedClientRepository = authorizedClientRepository; - this.authorizedClientManager = createAuthorizedClientManager(new DefaultClientCredentialsTokenResponseClient()); + this.authorizedClientManager = createDefaultAuthorizedClientManager(clientRegistrationRepository, authorizedClientRepository); + } + + private static OAuth2AuthorizedClientManager createDefaultAuthorizedClientManager( + ClientRegistrationRepository clientRegistrationRepository, OAuth2AuthorizedClientRepository authorizedClientRepository) { + + OAuth2AuthorizedClientProvider authorizedClientProvider = + OAuth2AuthorizedClientProviderBuilder.withProvider() + .authorizationCode() + .refreshToken() + .clientCredentials() + .build(); + DefaultOAuth2AuthorizedClientManager authorizedClientManager = new DefaultOAuth2AuthorizedClientManager( + clientRegistrationRepository, authorizedClientRepository); + authorizedClientManager.setAuthorizedClientProvider(authorizedClientProvider); + + return authorizedClientManager; } @Override @@ -111,9 +137,6 @@ public Object resolveArgument(MethodParameter parameter, "@RegisteredOAuth2AuthorizedClient(registrationId = \"client1\")."); } - ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId); - Assert.notNull(clientRegistration, "Could not find ClientRegistration with id '" + clientRegistrationId + "'"); - Authentication principal = SecurityContextHolder.getContext().getAuthentication(); if (principal == null) { principal = ANONYMOUS_AUTHENTICATION; @@ -121,8 +144,10 @@ public Object resolveArgument(MethodParameter parameter, HttpServletRequest servletRequest = webRequest.getNativeRequest(HttpServletRequest.class); HttpServletResponse servletResponse = webRequest.getNativeResponse(HttpServletResponse.class); - return this.authorizedClientManager.authorize(clientRegistration.getRegistrationId(), - principal, servletRequest, servletResponse); + OAuth2AuthorizeRequest authorizeRequest = new OAuth2AuthorizeRequest( + clientRegistrationId, principal, servletRequest, servletResponse); + + return this.authorizedClientManager.authorize(authorizeRequest); } private String resolveClientRegistrationId(MethodParameter parameter) { @@ -143,21 +168,10 @@ private String resolveClientRegistrationId(MethodParameter parameter) { return clientRegistrationId; } - /** - * Sets the {@link OAuth2AuthorizedClientManager} which manages the {@link OAuth2AuthorizedClient Authorized Client(s)}. - * - * @since 5.2 - * @param authorizedClientManager the {@link OAuth2AuthorizedClientManager} which manages the authorized client(s) - */ - public void setAuthorizedClientManager(OAuth2AuthorizedClientManager authorizedClientManager) { - Assert.notNull(authorizedClientManager, "authorizedClientManager cannot be null"); - this.authorizedClientManager = authorizedClientManager; - } - /** * Sets the client used when requesting an access token credential at the Token Endpoint for the {@code client_credentials} grant. * - * @deprecated Use {@link #setAuthorizedClientManager(OAuth2AuthorizedClientManager)} instead. + * @deprecated Use {@link #OAuth2AuthorizedClientArgumentResolver(OAuth2AuthorizedClientManager)} instead. * Create an instance of {@link ClientCredentialsOAuth2AuthorizedClientProvider} configured with a * {@link ClientCredentialsOAuth2AuthorizedClientProvider#setAccessTokenResponseClient(OAuth2AccessTokenResponseClient) DefaultClientCredentialsTokenResponseClient} * (or a custom one) and than supply it to {@link DefaultOAuth2AuthorizedClientManager#setAuthorizedClientProvider(OAuth2AuthorizedClientProvider) DefaultOAuth2AuthorizedClientManager}. @@ -168,23 +182,20 @@ public void setAuthorizedClientManager(OAuth2AuthorizedClientManager authorizedC public final void setClientCredentialsTokenResponseClient( OAuth2AccessTokenResponseClient clientCredentialsTokenResponseClient) { Assert.notNull(clientCredentialsTokenResponseClient, "clientCredentialsTokenResponseClient cannot be null"); - this.authorizedClientManager = createAuthorizedClientManager(clientCredentialsTokenResponseClient); + updateAuthorizedClientManager(clientCredentialsTokenResponseClient); } - private OAuth2AuthorizedClientManager createAuthorizedClientManager( + private void updateAuthorizedClientManager( OAuth2AccessTokenResponseClient clientCredentialsTokenResponseClient) { - OAuth2AuthorizedClientProvider authorizedClientProvider = - OAuth2AuthorizedClientProviderBuilder.withProvider() - .authorizationCode() - .refreshToken() - .clientCredentials(configurer -> configurer.accessTokenResponseClient(clientCredentialsTokenResponseClient)) - .build(); - - DefaultOAuth2AuthorizedClientManager authorizedClientManager = new DefaultOAuth2AuthorizedClientManager( - this.clientRegistrationRepository, this.authorizedClientRepository); - authorizedClientManager.setAuthorizedClientProvider(authorizedClientProvider); - - return authorizedClientManager; + if (this.authorizedClientManager instanceof DefaultOAuth2AuthorizedClientManager) { + OAuth2AuthorizedClientProvider authorizedClientProvider = + OAuth2AuthorizedClientProviderBuilder.withProvider() + .authorizationCode() + .refreshToken() + .clientCredentials(configurer -> configurer.accessTokenResponseClient(clientCredentialsTokenResponseClient)) + .build(); + ((DefaultOAuth2AuthorizedClientManager) this.authorizedClientManager).setAuthorizedClientProvider(authorizedClientProvider); + } } } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java index a8d6af84670..198afee50d0 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java @@ -19,8 +19,10 @@ import org.reactivestreams.Subscription; import org.springframework.beans.factory.DisposableBean; import org.springframework.beans.factory.InitializingBean; +import org.springframework.security.authentication.AnonymousAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.core.GrantedAuthority; +import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.oauth2.client.ClientCredentialsOAuth2AuthorizedClientProvider; import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; @@ -33,8 +35,10 @@ import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; import org.springframework.security.oauth2.client.web.DefaultOAuth2AuthorizedClientManager; +import org.springframework.security.oauth2.client.web.OAuth2AuthorizeRequest; import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientManager; import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; +import org.springframework.security.oauth2.client.web.OAuth2ReauthorizeRequest; import org.springframework.util.Assert; import org.springframework.web.context.request.RequestContextHolder; import org.springframework.web.context.request.ServletRequestAttributes; @@ -81,8 +85,7 @@ * are true: * *
      - *
    • The {@link #setAuthorizedClientManager(OAuth2AuthorizedClientManager) OAuth2AuthorizedClientManager} on the - * {@link ServletOAuth2AuthorizedClientExchangeFilterFunction} is not null
    • + *
    • The {@link OAuth2AuthorizedClientManager} is not null
    • *
    • A refresh token is present on the {@link OAuth2AuthorizedClient}
    • *
    • The access token will be expired in * {@link #setAccessTokenExpiresSkew(Duration)}
    • @@ -94,6 +97,7 @@ * @author Rob Winch * @author Joe Grandja * @since 5.1 + * @see OAuth2AuthorizedClientManager */ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implements ExchangeFilterFunction, InitializingBean, DisposableBean { @@ -109,11 +113,10 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction private static final String REQUEST_CONTEXT_OPERATOR_KEY = RequestContextSubscriber.class.getName(); - private Duration accessTokenExpiresSkew = Duration.ofMinutes(1); - - private ClientRegistrationRepository clientRegistrationRepository; + private static final Authentication ANONYMOUS_AUTHENTICATION = new AnonymousAuthenticationToken( + "anonymous", "anonymousUser", AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS")); - private OAuth2AuthorizedClientRepository authorizedClientRepository; + private Duration accessTokenExpiresSkew = Duration.ofMinutes(1); private OAuth2AuthorizedClientManager authorizedClientManager; @@ -124,11 +127,30 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction public ServletOAuth2AuthorizedClientExchangeFilterFunction() { } + /** + * Constructs a {@code ServletOAuth2AuthorizedClientExchangeFilterFunction} using the provided parameters. + * + * @since 5.2 + * @param authorizedClientManager the {@link OAuth2AuthorizedClientManager} which manages the authorized client(s) + */ + public ServletOAuth2AuthorizedClientExchangeFilterFunction(OAuth2AuthorizedClientManager authorizedClientManager) { + Assert.notNull(authorizedClientManager, "authorizedClientManager cannot be null"); + this.authorizedClientManager = authorizedClientManager; + } + + /** + * Constructs a {@code ServletOAuth2AuthorizedClientExchangeFilterFunction} using the provided parameters. + * + * @deprecated Use {@link #ServletOAuth2AuthorizedClientExchangeFilterFunction(OAuth2AuthorizedClientManager)} instead. + * See {@link DefaultOAuth2AuthorizedClientManager} and {@link OAuth2AuthorizedClientProviderBuilder}. + * + * @param clientRegistrationRepository the repository of client registrations + * @param authorizedClientRepository the repository of authorized clients + */ + @Deprecated public ServletOAuth2AuthorizedClientExchangeFilterFunction( ClientRegistrationRepository clientRegistrationRepository, OAuth2AuthorizedClientRepository authorizedClientRepository) { - this.clientRegistrationRepository = clientRegistrationRepository; - this.authorizedClientRepository = authorizedClientRepository; this.authorizedClientManager = createDefaultAuthorizedClientManager(clientRegistrationRepository, authorizedClientRepository); } @@ -141,7 +163,6 @@ private static OAuth2AuthorizedClientManager createDefaultAuthorizedClientManage .refreshToken() .clientCredentials() .build(); - DefaultOAuth2AuthorizedClientManager authorizedClientManager = new DefaultOAuth2AuthorizedClientManager( clientRegistrationRepository, authorizedClientRepository); authorizedClientManager.setAuthorizedClientProvider(authorizedClientProvider); @@ -159,21 +180,10 @@ public void destroy() throws Exception { Hooks.resetOnLastOperator(REQUEST_CONTEXT_OPERATOR_KEY); } - /** - * Sets the {@link OAuth2AuthorizedClientManager} which manages the {@link OAuth2AuthorizedClient Authorized Client(s)}. - * - * @since 5.2 - * @param authorizedClientManager the {@link OAuth2AuthorizedClientManager} which manages the authorized client(s) - */ - public void setAuthorizedClientManager(OAuth2AuthorizedClientManager authorizedClientManager) { - Assert.notNull(authorizedClientManager, "authorizedClientManager cannot be null"); - this.authorizedClientManager = authorizedClientManager; - } - /** * Sets the {@link OAuth2AccessTokenResponseClient} used for getting an {@link OAuth2AuthorizedClient} for the client_credentials grant. * - * @deprecated Use {@link #setAuthorizedClientManager(OAuth2AuthorizedClientManager)} instead. + * @deprecated Use {@link #ServletOAuth2AuthorizedClientExchangeFilterFunction(OAuth2AuthorizedClientManager)} instead. * Create an instance of {@link ClientCredentialsOAuth2AuthorizedClientProvider} configured with a * {@link ClientCredentialsOAuth2AuthorizedClientProvider#setAccessTokenResponseClient(OAuth2AccessTokenResponseClient) DefaultClientCredentialsTokenResponseClient} * (or a custom one) and than supply it to {@link DefaultOAuth2AuthorizedClientManager#setAuthorizedClientProvider(OAuth2AuthorizedClientProvider) DefaultOAuth2AuthorizedClientManager}. @@ -184,30 +194,27 @@ public void setAuthorizedClientManager(OAuth2AuthorizedClientManager authorizedC public void setClientCredentialsTokenResponseClient( OAuth2AccessTokenResponseClient clientCredentialsTokenResponseClient) { Assert.notNull(clientCredentialsTokenResponseClient, "clientCredentialsTokenResponseClient cannot be null"); - this.authorizedClientManager = createAuthorizedClientManager(clientCredentialsTokenResponseClient); + updateAuthorizedClientManager(clientCredentialsTokenResponseClient); } - private OAuth2AuthorizedClientManager createAuthorizedClientManager() { - return createAuthorizedClientManager(new DefaultClientCredentialsTokenResponseClient()); + private void updateAuthorizedClientManager() { + updateAuthorizedClientManager(new DefaultClientCredentialsTokenResponseClient()); } - private OAuth2AuthorizedClientManager createAuthorizedClientManager( + private void updateAuthorizedClientManager( OAuth2AccessTokenResponseClient clientCredentialsTokenResponseClient) { - OAuth2AuthorizedClientProvider authorizedClientProvider = - OAuth2AuthorizedClientProviderBuilder.withProvider() - .authorizationCode() - .refreshToken(configurer -> configurer.clockSkew(this.accessTokenExpiresSkew)) - .clientCredentials(configurer -> configurer - .accessTokenResponseClient(clientCredentialsTokenResponseClient) - .clockSkew(this.accessTokenExpiresSkew)) - .build(); - - DefaultOAuth2AuthorizedClientManager authorizedClientManager = new DefaultOAuth2AuthorizedClientManager( - this.clientRegistrationRepository, this.authorizedClientRepository); - authorizedClientManager.setAuthorizedClientProvider(authorizedClientProvider); - - return authorizedClientManager; + if (this.authorizedClientManager instanceof DefaultOAuth2AuthorizedClientManager) { + OAuth2AuthorizedClientProvider authorizedClientProvider = + OAuth2AuthorizedClientProviderBuilder.withProvider() + .authorizationCode() + .refreshToken(configurer -> configurer.clockSkew(this.accessTokenExpiresSkew)) + .clientCredentials(configurer -> configurer + .accessTokenResponseClient(clientCredentialsTokenResponseClient) + .clockSkew(this.accessTokenExpiresSkew)) + .build(); + ((DefaultOAuth2AuthorizedClientManager) this.authorizedClientManager).setAuthorizedClientProvider(authorizedClientProvider); + } } /** @@ -328,7 +335,7 @@ public static Consumer> httpServletResponse(HttpServletRespo public void setAccessTokenExpiresSkew(Duration accessTokenExpiresSkew) { Assert.notNull(accessTokenExpiresSkew, "accessTokenExpiresSkew cannot be null"); this.accessTokenExpiresSkew = accessTokenExpiresSkew; - this.authorizedClientManager = createAuthorizedClientManager(); + updateAuthorizedClientManager(); } @Override @@ -388,7 +395,7 @@ private void populateDefaultAuthentication(Map attrs) { } private void populateDefaultOAuth2AuthorizedClient(Map attrs) { - if (this.authorizedClientRepository == null || + if (this.authorizedClientManager == null || attrs.containsKey(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME)) { return; } @@ -405,12 +412,12 @@ private void populateDefaultOAuth2AuthorizedClient(Map attrs) { } if (clientRegistrationId != null) { HttpServletRequest request = getRequest(attrs); - OAuth2AuthorizedClient authorizedClient = this.authorizedClientRepository.loadAuthorizedClient( - clientRegistrationId, authentication, request); - if (authorizedClient == null) { - authorizedClient = this.authorizedClientManager.authorize( - clientRegistrationId, authentication, request, getResponse(attrs)); + if (authentication == null) { + authentication = ANONYMOUS_AUTHENTICATION; } + OAuth2AuthorizeRequest authorizeRequest = new OAuth2AuthorizeRequest( + clientRegistrationId, authentication, request, getResponse(attrs)); + OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest); oauth2AuthorizedClient(authorizedClient).accept(attrs); } } @@ -420,11 +427,15 @@ private Mono authorizedClient(OAuth2AuthorizedClient aut return Mono.just(authorizedClient); } Map attrs = request.attributes(); - Authentication authentication = getAuthentication(attrs) != null ? - getAuthentication(attrs) : new PrincipalNameAuthentication(authorizedClient.getPrincipalName()); + Authentication authentication = getAuthentication(attrs); + if (authentication == null) { + authentication = new PrincipalNameAuthentication(authorizedClient.getPrincipalName()); + } HttpServletRequest servletRequest = getRequest(attrs); HttpServletResponse servletResponse = getResponse(attrs); - return Mono.fromSupplier(() -> this.authorizedClientManager.reauthorize(authorizedClient, authentication, servletRequest, servletResponse)); + OAuth2ReauthorizeRequest reauthorizeRequest = new OAuth2ReauthorizeRequest( + authorizedClient, authentication, servletRequest, servletResponse); + return Mono.fromSupplier(() -> this.authorizedClientManager.reauthorize(reauthorizeRequest)); } private ClientRequest bearer(ClientRequest request, OAuth2AuthorizedClient authorizedClient) { diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizedClientManagerTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizedClientManagerTests.java index 91b9b59abbc..a6f96e71a7f 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizedClientManagerTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizedClientManagerTests.java @@ -31,7 +31,7 @@ import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens; -import java.util.function.BiFunction; +import java.util.function.Function; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -47,7 +47,7 @@ public class DefaultOAuth2AuthorizedClientManagerTests { private ClientRegistrationRepository clientRegistrationRepository; private OAuth2AuthorizedClientRepository authorizedClientRepository; private OAuth2AuthorizedClientProvider authorizedClientProvider; - private BiFunction contextAttributesMapper; + private Function contextAttributesMapper; private DefaultOAuth2AuthorizedClientManager authorizedClientManager; private ClientRegistration clientRegistration; private Authentication principal; @@ -62,7 +62,7 @@ public void setup() { this.clientRegistrationRepository = mock(ClientRegistrationRepository.class); this.authorizedClientRepository = mock(OAuth2AuthorizedClientRepository.class); this.authorizedClientProvider = mock(OAuth2AuthorizedClientProvider.class); - this.contextAttributesMapper = mock(BiFunction.class); + this.contextAttributesMapper = mock(Function.class); this.authorizedClientManager = new DefaultOAuth2AuthorizedClientManager( this.clientRegistrationRepository, this.authorizedClientRepository); this.authorizedClientManager.setAuthorizedClientProvider(this.authorizedClientProvider); @@ -105,24 +105,17 @@ public void setContextAttributesMapperWhenNullThenThrowIllegalArgumentException( } @Test - public void authorizeWhenArgumentsInvalidThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.authorizedClientManager.authorize(null, this.principal, this.request, this.response)) + public void authorizeWhenRequestIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.authorizedClientManager.authorize(null)) .isInstanceOf(IllegalArgumentException.class) - .hasMessage("clientRegistrationId cannot be empty"); - assertThatThrownBy(() -> this.authorizedClientManager.authorize(this.clientRegistration.getRegistrationId(), null, this.request, this.response)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("principal cannot be null"); - assertThatThrownBy(() -> this.authorizedClientManager.authorize(this.clientRegistration.getRegistrationId(), this.principal, null, this.response)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("request cannot be null"); - assertThatThrownBy(() -> this.authorizedClientManager.authorize(this.clientRegistration.getRegistrationId(), this.principal, this.request, null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("response cannot be null"); + .hasMessage("authorizeRequest cannot be null"); } @Test public void authorizeWhenClientRegistrationNotFoundThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.authorizedClientManager.authorize("invalid-registration-id", this.principal, this.request, this.response)) + OAuth2AuthorizeRequest authorizeRequest = new OAuth2AuthorizeRequest( + "invalid-registration-id", this.principal, this.request, this.response); + assertThatThrownBy(() -> this.authorizedClientManager.authorize(authorizeRequest)) .isInstanceOf(IllegalArgumentException.class) .hasMessage("Could not find ClientRegistration with id 'invalid-registration-id'"); } @@ -133,11 +126,12 @@ public void authorizeWhenNotAuthorizedAndUnsupportedProviderThenNotAuthorized() when(this.clientRegistrationRepository.findByRegistrationId( eq(this.clientRegistration.getRegistrationId()))).thenReturn(this.clientRegistration); - OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize( + OAuth2AuthorizeRequest authorizeRequest = new OAuth2AuthorizeRequest( this.clientRegistration.getRegistrationId(), this.principal, this.request, this.response); + OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest); verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); - verify(this.contextAttributesMapper).apply(eq(this.clientRegistration), eq(this.request)); + verify(this.contextAttributesMapper).apply(eq(authorizeRequest)); OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration); @@ -157,11 +151,12 @@ public void authorizeWhenNotAuthorizedAndSupportedProviderThenAuthorized() { when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(this.authorizedClient); - OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize( + OAuth2AuthorizeRequest authorizeRequest = new OAuth2AuthorizeRequest( this.clientRegistration.getRegistrationId(), this.principal, this.request, this.response); + OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest); verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); - verify(this.contextAttributesMapper).apply(eq(this.clientRegistration), eq(this.request)); + verify(this.contextAttributesMapper).apply(eq(authorizeRequest)); OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration); @@ -187,11 +182,12 @@ public void authorizeWhenAuthorizedAndSupportedProviderThenReauthorized() { when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(reauthorizedClient); - OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize( + OAuth2AuthorizeRequest authorizeRequest = new OAuth2AuthorizeRequest( this.clientRegistration.getRegistrationId(), this.principal, this.request, this.response); + OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest); verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); - verify(this.contextAttributesMapper).apply(eq(this.clientRegistration), eq(this.request)); + verify(this.contextAttributesMapper).apply(any()); OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration); @@ -204,29 +200,21 @@ public void authorizeWhenAuthorizedAndSupportedProviderThenReauthorized() { } @Test - public void reauthorizeWhenArgumentsInvalidThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.authorizedClientManager.reauthorize(null, this.principal, this.request, this.response)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("authorizedClient cannot be null"); - assertThatThrownBy(() -> this.authorizedClientManager.reauthorize(this.authorizedClient, null, this.request, this.response)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("principal cannot be null"); - assertThatThrownBy(() -> this.authorizedClientManager.reauthorize(this.authorizedClient, this.principal, null, this.response)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("request cannot be null"); - assertThatThrownBy(() -> this.authorizedClientManager.reauthorize(this.authorizedClient, this.principal, this.request, null)) + public void reauthorizeWhenRequestIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.authorizedClientManager.reauthorize(null)) .isInstanceOf(IllegalArgumentException.class) - .hasMessage("response cannot be null"); + .hasMessage("reauthorizeRequest cannot be null"); } @SuppressWarnings("unchecked") @Test public void reauthorizeWhenUnsupportedProviderThenNotReauthorized() { - OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.reauthorize( + OAuth2ReauthorizeRequest reauthorizeRequest = new OAuth2ReauthorizeRequest( this.authorizedClient, this.principal, this.request, this.response); + OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.reauthorize(reauthorizeRequest); verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); - verify(this.contextAttributesMapper).apply(eq(this.clientRegistration), eq(this.request)); + verify(this.contextAttributesMapper).apply(eq(reauthorizeRequest)); OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration); @@ -247,11 +235,12 @@ public void reauthorizeWhenSupportedProviderThenReauthorized() { when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(reauthorizedClient); - OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.reauthorize( + OAuth2ReauthorizeRequest reauthorizeRequest = new OAuth2ReauthorizeRequest( this.authorizedClient, this.principal, this.request, this.response); + OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.reauthorize(reauthorizeRequest); verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); - verify(this.contextAttributesMapper).apply(eq(this.clientRegistration), eq(this.request)); + verify(this.contextAttributesMapper).apply(eq(reauthorizeRequest)); OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration); diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizeRequestTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizeRequestTests.java new file mode 100644 index 00000000000..1a2c114df5c --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizeRequestTests.java @@ -0,0 +1,78 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client.web; + +import org.junit.Test; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Tests for {@link OAuth2AuthorizeRequest}. + * + * @author Joe Grandja + */ +public class OAuth2AuthorizeRequestTests { + private ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); + private Authentication principal = new TestingAuthenticationToken("principal", "password"); + private MockHttpServletRequest servletRequest = new MockHttpServletRequest(); + private MockHttpServletResponse servletResponse = new MockHttpServletResponse(); + + @Test + public void constructorWhenClientRegistrationIdIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new OAuth2AuthorizeRequest(null, this.principal, this.servletRequest, this.servletResponse)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("clientRegistrationId cannot be empty"); + } + + @Test + public void constructorWhenPrincipalIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new OAuth2AuthorizeRequest(this.clientRegistration.getRegistrationId(), null, this.servletRequest, this.servletResponse)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("principal cannot be null"); + } + + @Test + public void constructorWhenServletRequestIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new OAuth2AuthorizeRequest(this.clientRegistration.getRegistrationId(), this.principal, null, this.servletResponse)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("servletRequest cannot be null"); + } + + @Test + public void constructorWhenServletResponseIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new OAuth2AuthorizeRequest(this.clientRegistration.getRegistrationId(), this.principal, this.servletRequest, null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("servletResponse cannot be null"); + } + + @Test + public void constructorWhenAllValuesProvidedThenAllValuesAreSet() { + OAuth2AuthorizeRequest authorizeRequest = new OAuth2AuthorizeRequest( + this.clientRegistration.getRegistrationId(), this.principal, this.servletRequest, this.servletResponse); + + assertThat(authorizeRequest.getClientRegistrationId()).isEqualTo(this.clientRegistration.getRegistrationId()); + assertThat(authorizeRequest.getPrincipal()).isEqualTo(this.principal); + assertThat(authorizeRequest.getServletRequest()).isEqualTo(this.servletRequest); + assertThat(authorizeRequest.getServletResponse()).isEqualTo(this.servletResponse); + } +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2ReauthorizeRequestTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2ReauthorizeRequestTests.java new file mode 100644 index 00000000000..a82d9170ba5 --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2ReauthorizeRequestTests.java @@ -0,0 +1,83 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client.web; + +import org.junit.Test; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Tests for {@link OAuth2ReauthorizeRequest}. + * + * @author Joe Grandja + */ +public class OAuth2ReauthorizeRequestTests { + private ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); + private Authentication principal = new TestingAuthenticationToken("principal", "password"); + private OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( + this.clientRegistration, this.principal.getName(), TestOAuth2AccessTokens.noScopes()); + private MockHttpServletRequest servletRequest = new MockHttpServletRequest(); + private MockHttpServletResponse servletResponse = new MockHttpServletResponse(); + + @Test + public void constructorWhenAuthorizedClientIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new OAuth2ReauthorizeRequest(null, this.principal, this.servletRequest, this.servletResponse)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("authorizedClient cannot be null"); + } + + @Test + public void constructorWhenPrincipalIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new OAuth2ReauthorizeRequest(this.authorizedClient, null, this.servletRequest, this.servletResponse)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("principal cannot be null"); + } + + @Test + public void constructorWhenServletRequestIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new OAuth2ReauthorizeRequest(this.authorizedClient, this.principal, null, this.servletResponse)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("servletRequest cannot be null"); + } + + @Test + public void constructorWhenServletResponseIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new OAuth2ReauthorizeRequest(this.authorizedClient, this.principal, this.servletRequest, null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("servletResponse cannot be null"); + } + + @Test + public void constructorWhenAllValuesProvidedThenAllValuesAreSet() { + OAuth2ReauthorizeRequest reauthorizeRequest = new OAuth2ReauthorizeRequest( + this.authorizedClient, this.principal, this.servletRequest, this.servletResponse); + + assertThat(reauthorizeRequest.getAuthorizedClient()).isEqualTo(this.authorizedClient); + assertThat(reauthorizeRequest.getClientRegistrationId()).isEqualTo(this.clientRegistration.getRegistrationId()); + assertThat(reauthorizeRequest.getPrincipal()).isEqualTo(this.principal); + assertThat(reauthorizeRequest.getServletRequest()).isEqualTo(this.servletRequest); + assertThat(reauthorizeRequest.getServletResponse()).isEqualTo(this.servletResponse); + } +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolverTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolverTests.java index 652233674c8..07548e86bf2 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolverTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolverTests.java @@ -28,6 +28,8 @@ import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException; import org.springframework.security.oauth2.client.ClientCredentialsOAuth2AuthorizedClientProvider; import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProviderBuilder; import org.springframework.security.oauth2.client.annotation.RegisteredOAuth2AuthorizedClient; import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; @@ -103,8 +105,16 @@ public void setup() { .build(); this.clientRegistrationRepository = new InMemoryClientRegistrationRepository(this.registration1, this.registration2); this.authorizedClientRepository = mock(OAuth2AuthorizedClientRepository.class); - this.argumentResolver = new OAuth2AuthorizedClientArgumentResolver( + OAuth2AuthorizedClientProvider authorizedClientProvider = + OAuth2AuthorizedClientProviderBuilder.withProvider() + .authorizationCode() + .refreshToken() + .clientCredentials() + .build(); + DefaultOAuth2AuthorizedClientManager authorizedClientManager = new DefaultOAuth2AuthorizedClientManager( this.clientRegistrationRepository, this.authorizedClientRepository); + authorizedClientManager.setAuthorizedClientProvider(authorizedClientProvider); + this.argumentResolver = new OAuth2AuthorizedClientArgumentResolver(authorizedClientManager); this.authorizedClient1 = new OAuth2AuthorizedClient(this.registration1, this.principalName, mock(OAuth2AccessToken.class)); when(this.authorizedClientRepository.loadAuthorizedClient( eq(this.registration1.getRegistrationId()), any(Authentication.class), any(HttpServletRequest.class))) @@ -135,8 +145,8 @@ public void constructorWhenOAuth2AuthorizedClientRepositoryIsNullThenThrowIllega } @Test - public void setAuthorizedClientManagerWhenProviderIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.argumentResolver.setAuthorizedClientManager(null)) + public void constructorWhenAuthorizedClientManagerIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new OAuth2AuthorizedClientArgumentResolver(null)) .isInstanceOf(IllegalArgumentException.class); } @@ -225,7 +235,7 @@ public void resolveArgumentWhenAuthorizedClientNotFoundForClientCredentialsClien DefaultOAuth2AuthorizedClientManager authorizedClientManager = new DefaultOAuth2AuthorizedClientManager( this.clientRegistrationRepository, this.authorizedClientRepository); authorizedClientManager.setAuthorizedClientProvider(clientCredentialsAuthorizedClientProvider); - this.argumentResolver.setAuthorizedClientManager(authorizedClientManager); + this.argumentResolver = new OAuth2AuthorizedClientArgumentResolver(authorizedClientManager); OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse .withToken("access-token-1234") diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java index 6e016b746b3..e27f37170d6 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java @@ -59,7 +59,6 @@ import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.client.web.DefaultOAuth2AuthorizedClientManager; -import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientManager; import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2RefreshToken; @@ -120,8 +119,6 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests { private ServletOAuth2AuthorizedClientExchangeFilterFunction function; - private OAuth2AuthorizedClientManager authorizedClientManager; - private MockExchangeFunction exchange = new MockExchangeFunction(); private Authentication authentication; @@ -145,9 +142,7 @@ public void setup() { DefaultOAuth2AuthorizedClientManager authorizedClientManager = new DefaultOAuth2AuthorizedClientManager( this.clientRegistrationRepository, this.authorizedClientRepository); authorizedClientManager.setAuthorizedClientProvider(authorizedClientProvider); - this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction( - this.clientRegistrationRepository, this.authorizedClientRepository); - this.function.setAuthorizedClientManager(authorizedClientManager); + this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(authorizedClientManager); } @After @@ -157,8 +152,8 @@ public void cleanup() { } @Test - public void setAuthorizedClientManagerWhenManagerIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.function.setAuthorizedClientManager(null)) + public void constructorWhenAuthorizedClientManagerIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new ServletOAuth2AuthorizedClientExchangeFilterFunction(null)) .isInstanceOf(IllegalArgumentException.class); } @@ -238,7 +233,10 @@ public void defaultRequestOAuth2AuthorizedClientWhenDefaultTrueAndAuthentication OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", this.accessToken); when(this.authorizedClientRepository.loadAuthorizedClient(any(), any(), any())).thenReturn(authorizedClient); + when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(this.registration); authentication(token).accept(this.result); + httpServletRequest(new MockHttpServletRequest()).accept(this.result); + httpServletResponse(new MockHttpServletResponse()).accept(this.result); Map attrs = getDefaultRequestAttributes(); @@ -266,8 +264,11 @@ public void defaultRequestOAuth2AuthorizedClientWhenAuthenticationAndClientRegis OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", this.accessToken); when(this.authorizedClientRepository.loadAuthorizedClient(any(), any(), any())).thenReturn(authorizedClient); + when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(this.registration); authentication(token).accept(this.result); clientRegistrationId("explicit").accept(this.result); + httpServletRequest(new MockHttpServletRequest()).accept(this.result); + httpServletResponse(new MockHttpServletResponse()).accept(this.result); Map attrs = getDefaultRequestAttributes(); @@ -277,10 +278,13 @@ public void defaultRequestOAuth2AuthorizedClientWhenAuthenticationAndClientRegis @Test public void defaultRequestOAuth2AuthorizedClientWhenAuthenticationNullAndClientRegistrationIdThenOAuth2AuthorizedClient() { + when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(this.registration); OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", this.accessToken); when(this.authorizedClientRepository.loadAuthorizedClient(any(), any(), any())).thenReturn(authorizedClient); clientRegistrationId("id").accept(this.result); + httpServletRequest(new MockHttpServletRequest()).accept(this.result); + httpServletResponse(new MockHttpServletResponse()).accept(this.result); Map attrs = getDefaultRequestAttributes(); @@ -462,7 +466,7 @@ public void filterWhenRefreshRequiredThenRefreshAndResponseDoesNotContainRefresh DefaultOAuth2AuthorizedClientManager authorizedClientManager = new DefaultOAuth2AuthorizedClientManager( this.clientRegistrationRepository, this.authorizedClientRepository); authorizedClientManager.setAuthorizedClientProvider(authorizedClientProvider); - this.function.setAuthorizedClientManager(authorizedClientManager); + this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(authorizedClientManager); Instant issuedAt = Instant.now().minus(Duration.ofDays(1)); Instant accessTokenExpiresAt = issuedAt.plus(Duration.ofHours(1)); @@ -678,6 +682,8 @@ public void filterWhenChainedThenDefaultsStillAvailable() throws Exception { when(this.authorizedClientRepository.loadAuthorizedClient(eq(authentication.getAuthorizedClientRegistrationId()), eq(authentication), eq(servletRequest))).thenReturn(authorizedClient); + when(this.clientRegistrationRepository.findByRegistrationId(eq(authentication.getAuthorizedClientRegistrationId()))).thenReturn(this.registration); + // Default request attributes set final ClientRequest request1 = ClientRequest.create(GET, URI.create("https://example1.com")) .attributes(attrs -> attrs.putAll(getDefaultRequestAttributes())).build(); diff --git a/samples/boot/oauth2webclient/src/main/java/sample/config/WebClientConfig.java b/samples/boot/oauth2webclient/src/main/java/sample/config/WebClientConfig.java index b995ffb61db..7874a270425 100644 --- a/samples/boot/oauth2webclient/src/main/java/sample/config/WebClientConfig.java +++ b/samples/boot/oauth2webclient/src/main/java/sample/config/WebClientConfig.java @@ -18,7 +18,10 @@ import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProviderBuilder; import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; +import org.springframework.security.oauth2.client.web.DefaultOAuth2AuthorizedClientManager; import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction; import org.springframework.web.reactive.function.client.WebClient; @@ -32,8 +35,20 @@ public class WebClientConfig { @Bean WebClient webClient(ClientRegistrationRepository clientRegistrationRepository, OAuth2AuthorizedClientRepository authorizedClientRepository) { - ServletOAuth2AuthorizedClientExchangeFilterFunction oauth2 = new ServletOAuth2AuthorizedClientExchangeFilterFunction(clientRegistrationRepository, authorizedClientRepository); + OAuth2AuthorizedClientProvider authorizedClientProvider = + OAuth2AuthorizedClientProviderBuilder.withProvider() + .authorizationCode() + .refreshToken() + .clientCredentials() + .build(); + DefaultOAuth2AuthorizedClientManager authorizedClientManager = new DefaultOAuth2AuthorizedClientManager( + clientRegistrationRepository, authorizedClientRepository); + authorizedClientManager.setAuthorizedClientProvider(authorizedClientProvider); + + ServletOAuth2AuthorizedClientExchangeFilterFunction oauth2 = + new ServletOAuth2AuthorizedClientExchangeFilterFunction(authorizedClientManager); oauth2.setDefaultOAuth2AuthorizedClient(true); + return WebClient.builder() .apply(oauth2.oauth2Configuration()) .build(); From 468b9291007efa77606885df4de194991e7c5ec5 Mon Sep 17 00:00:00 2001 From: Joe Grandja Date: Fri, 12 Jul 2019 15:31:06 -0400 Subject: [PATCH 15/19] Polish --- ...ionCodeOAuth2AuthorizedClientProvider.java | 3 -- ...entialsOAuth2AuthorizedClientProvider.java | 3 -- ...shTokenOAuth2AuthorizedClientProvider.java | 3 -- ...Auth2AuthorizedClientArgumentResolver.java | 24 +++++---- ...uthorizedClientExchangeFilterFunction.java | 54 ++++++++++++------- ...AuthorizedClientArgumentResolverTests.java | 12 ++++- ...izedClientExchangeFilterFunctionTests.java | 24 +++++++++ 7 files changed, 82 insertions(+), 41 deletions(-) diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/AuthorizationCodeOAuth2AuthorizedClientProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/AuthorizationCodeOAuth2AuthorizedClientProvider.java index 8472472d79c..7ff23c3ceb5 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/AuthorizationCodeOAuth2AuthorizedClientProvider.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/AuthorizationCodeOAuth2AuthorizedClientProvider.java @@ -30,9 +30,6 @@ */ public final class AuthorizationCodeOAuth2AuthorizedClientProvider implements OAuth2AuthorizedClientProvider { - public AuthorizationCodeOAuth2AuthorizedClientProvider() { - } - /** * Attempt to authorize the {@link OAuth2AuthorizationContext#getClientRegistration() client} in the provided {@code context}. * Returns {@code null} if authorization is not supported, diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ClientCredentialsOAuth2AuthorizedClientProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ClientCredentialsOAuth2AuthorizedClientProvider.java index c2087ea15e6..c3b686cd08b 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ClientCredentialsOAuth2AuthorizedClientProvider.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ClientCredentialsOAuth2AuthorizedClientProvider.java @@ -42,9 +42,6 @@ public final class ClientCredentialsOAuth2AuthorizedClientProvider implements OA new DefaultClientCredentialsTokenResponseClient(); private Duration clockSkew = Duration.ofSeconds(60); - public ClientCredentialsOAuth2AuthorizedClientProvider() { - } - /** * Attempt to authorize (or re-authorize) the {@link OAuth2AuthorizationContext#getClientRegistration() client} in the provided {@code context}. * Returns {@code null} if authorization (or re-authorization) is not supported, diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProvider.java index 5e6c1aa4fd3..af760431ddb 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProvider.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProvider.java @@ -53,9 +53,6 @@ public final class RefreshTokenOAuth2AuthorizedClientProvider implements OAuth2A new DefaultRefreshTokenTokenResponseClient(); private Duration clockSkew = Duration.ofSeconds(60); - public RefreshTokenOAuth2AuthorizedClientProvider() { - } - /** * Attempt to re-authorize the {@link OAuth2AuthorizationContext#getClientRegistration() client} in the provided {@code context}. * Returns {@code null} if re-authorization is not supported, diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolver.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolver.java index 9f8d1b744aa..3eb03412e08 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolver.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolver.java @@ -70,6 +70,7 @@ public final class OAuth2AuthorizedClientArgumentResolver implements HandlerMeth private static final Authentication ANONYMOUS_AUTHENTICATION = new AnonymousAuthenticationToken( "anonymous", "anonymousUser", AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS")); private OAuth2AuthorizedClientManager authorizedClientManager; + private boolean defaultAuthorizedClientManager; /** * Constructs an {@code OAuth2AuthorizedClientArgumentResolver} using the provided parameters. @@ -97,6 +98,7 @@ public OAuth2AuthorizedClientArgumentResolver(ClientRegistrationRepository clien Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null"); Assert.notNull(authorizedClientRepository, "authorizedClientRepository cannot be null"); this.authorizedClientManager = createDefaultAuthorizedClientManager(clientRegistrationRepository, authorizedClientRepository); + this.defaultAuthorizedClientManager = true; } private static OAuth2AuthorizedClientManager createDefaultAuthorizedClientManager( @@ -182,20 +184,20 @@ private String resolveClientRegistrationId(MethodParameter parameter) { public final void setClientCredentialsTokenResponseClient( OAuth2AccessTokenResponseClient clientCredentialsTokenResponseClient) { Assert.notNull(clientCredentialsTokenResponseClient, "clientCredentialsTokenResponseClient cannot be null"); - updateAuthorizedClientManager(clientCredentialsTokenResponseClient); + Assert.state(this.defaultAuthorizedClientManager, "The client cannot be set when the constructor used is \"OAuth2AuthorizedClientArgumentResolver(OAuth2AuthorizedClientManager)\". " + + "Instead, use the constructor \"OAuth2AuthorizedClientArgumentResolver(ClientRegistrationRepository, OAuth2AuthorizedClientRepository)\"."); + updateDefaultAuthorizedClientManager(clientCredentialsTokenResponseClient); } - private void updateAuthorizedClientManager( + private void updateDefaultAuthorizedClientManager( OAuth2AccessTokenResponseClient clientCredentialsTokenResponseClient) { - if (this.authorizedClientManager instanceof DefaultOAuth2AuthorizedClientManager) { - OAuth2AuthorizedClientProvider authorizedClientProvider = - OAuth2AuthorizedClientProviderBuilder.withProvider() - .authorizationCode() - .refreshToken() - .clientCredentials(configurer -> configurer.accessTokenResponseClient(clientCredentialsTokenResponseClient)) - .build(); - ((DefaultOAuth2AuthorizedClientManager) this.authorizedClientManager).setAuthorizedClientProvider(authorizedClientProvider); - } + OAuth2AuthorizedClientProvider authorizedClientProvider = + OAuth2AuthorizedClientProviderBuilder.withProvider() + .authorizationCode() + .refreshToken() + .clientCredentials(configurer -> configurer.accessTokenResponseClient(clientCredentialsTokenResponseClient)) + .build(); + ((DefaultOAuth2AuthorizedClientManager) this.authorizedClientManager).setAuthorizedClientProvider(authorizedClientProvider); } } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java index 198afee50d0..624e02f6f97 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java @@ -28,8 +28,8 @@ import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProvider; import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProviderBuilder; +import org.springframework.security.oauth2.client.RefreshTokenOAuth2AuthorizedClientProvider; import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; -import org.springframework.security.oauth2.client.endpoint.DefaultClientCredentialsTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest; import org.springframework.security.oauth2.client.registration.ClientRegistration; @@ -87,8 +87,7 @@ *
        *
      • The {@link OAuth2AuthorizedClientManager} is not null
      • *
      • A refresh token is present on the {@link OAuth2AuthorizedClient}
      • - *
      • The access token will be expired in - * {@link #setAccessTokenExpiresSkew(Duration)}
      • + *
      • The access token is expired
      • *
      • The {@link SecurityContextHolder} will be used to attempt to save * the token. If it is empty, then the principal name on the {@link OAuth2AuthorizedClient} * will be used to create an Authentication for saving.
      • @@ -116,10 +115,16 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction private static final Authentication ANONYMOUS_AUTHENTICATION = new AnonymousAuthenticationToken( "anonymous", "anonymousUser", AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS")); + @Deprecated private Duration accessTokenExpiresSkew = Duration.ofMinutes(1); + @Deprecated + private OAuth2AccessTokenResponseClient clientCredentialsTokenResponseClient; + private OAuth2AuthorizedClientManager authorizedClientManager; + private boolean defaultAuthorizedClientManager; + private boolean defaultOAuth2AuthorizedClient; private String defaultClientRegistrationId; @@ -152,6 +157,7 @@ public ServletOAuth2AuthorizedClientExchangeFilterFunction( ClientRegistrationRepository clientRegistrationRepository, OAuth2AuthorizedClientRepository authorizedClientRepository) { this.authorizedClientManager = createDefaultAuthorizedClientManager(clientRegistrationRepository, authorizedClientRepository); + this.defaultAuthorizedClientManager = true; } private static OAuth2AuthorizedClientManager createDefaultAuthorizedClientManager( @@ -194,27 +200,27 @@ public void destroy() throws Exception { public void setClientCredentialsTokenResponseClient( OAuth2AccessTokenResponseClient clientCredentialsTokenResponseClient) { Assert.notNull(clientCredentialsTokenResponseClient, "clientCredentialsTokenResponseClient cannot be null"); - updateAuthorizedClientManager(clientCredentialsTokenResponseClient); + Assert.state(this.defaultAuthorizedClientManager, "The client cannot be set when the constructor used is \"ServletOAuth2AuthorizedClientExchangeFilterFunction(OAuth2AuthorizedClientManager)\". " + + "Instead, use the constructor \"ServletOAuth2AuthorizedClientExchangeFilterFunction(ClientRegistrationRepository, OAuth2AuthorizedClientRepository)\"."); + this.clientCredentialsTokenResponseClient = clientCredentialsTokenResponseClient; + updateDefaultAuthorizedClientManager(); } - private void updateAuthorizedClientManager() { - updateAuthorizedClientManager(new DefaultClientCredentialsTokenResponseClient()); + private void updateDefaultAuthorizedClientManager() { + OAuth2AuthorizedClientProvider authorizedClientProvider = + OAuth2AuthorizedClientProviderBuilder.withProvider() + .authorizationCode() + .refreshToken(configurer -> configurer.clockSkew(this.accessTokenExpiresSkew)) + .clientCredentials(this::updateClientCredentialsProvider) + .build(); + ((DefaultOAuth2AuthorizedClientManager) this.authorizedClientManager).setAuthorizedClientProvider(authorizedClientProvider); } - private void updateAuthorizedClientManager( - OAuth2AccessTokenResponseClient clientCredentialsTokenResponseClient) { - - if (this.authorizedClientManager instanceof DefaultOAuth2AuthorizedClientManager) { - OAuth2AuthorizedClientProvider authorizedClientProvider = - OAuth2AuthorizedClientProviderBuilder.withProvider() - .authorizationCode() - .refreshToken(configurer -> configurer.clockSkew(this.accessTokenExpiresSkew)) - .clientCredentials(configurer -> configurer - .accessTokenResponseClient(clientCredentialsTokenResponseClient) - .clockSkew(this.accessTokenExpiresSkew)) - .build(); - ((DefaultOAuth2AuthorizedClientManager) this.authorizedClientManager).setAuthorizedClientProvider(authorizedClientProvider); + private void updateClientCredentialsProvider(OAuth2AuthorizedClientProviderBuilder.ClientCredentialsGrantBuilder builder) { + if (this.clientCredentialsTokenResponseClient != null) { + builder.accessTokenResponseClient(this.clientCredentialsTokenResponseClient); } + builder.clockSkew(this.accessTokenExpiresSkew); } /** @@ -330,12 +336,20 @@ public static Consumer> httpServletResponse(HttpServletRespo /** * An access token will be considered expired by comparing its expiration to now + * this skewed Duration. The default is 1 minute. + * + * @deprecated The {@code accessTokenExpiresSkew} should be configured with the specific {@link OAuth2AuthorizedClientProvider} implementation, + * e.g. {@link ClientCredentialsOAuth2AuthorizedClientProvider#setClockSkew(Duration) ClientCredentialsOAuth2AuthorizedClientProvider} or + * {@link RefreshTokenOAuth2AuthorizedClientProvider#setClockSkew(Duration) RefreshTokenOAuth2AuthorizedClientProvider}. + * * @param accessTokenExpiresSkew the Duration to use. */ + @Deprecated public void setAccessTokenExpiresSkew(Duration accessTokenExpiresSkew) { Assert.notNull(accessTokenExpiresSkew, "accessTokenExpiresSkew cannot be null"); + Assert.state(this.defaultAuthorizedClientManager, "The accessTokenExpiresSkew cannot be set when the constructor used is \"ServletOAuth2AuthorizedClientExchangeFilterFunction(OAuth2AuthorizedClientManager)\". " + + "Instead, use the constructor \"ServletOAuth2AuthorizedClientExchangeFilterFunction(ClientRegistrationRepository, OAuth2AuthorizedClientRepository)\"."); this.accessTokenExpiresSkew = accessTokenExpiresSkew; - updateAuthorizedClientManager(); + updateDefaultAuthorizedClientManager(); } @Override diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolverTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolverTests.java index 07548e86bf2..b0ee98b411e 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolverTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolverTests.java @@ -32,6 +32,7 @@ import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProviderBuilder; import org.springframework.security.oauth2.client.annotation.RegisteredOAuth2AuthorizedClient; import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; +import org.springframework.security.oauth2.client.endpoint.DefaultClientCredentialsTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest; import org.springframework.security.oauth2.client.registration.ClientRegistration; @@ -153,7 +154,16 @@ public void constructorWhenAuthorizedClientManagerIsNullThenThrowIllegalArgument @Test public void setClientCredentialsTokenResponseClientWhenClientIsNullThenThrowIllegalArgumentException() { assertThatThrownBy(() -> this.argumentResolver.setClientCredentialsTokenResponseClient(null)) - .isInstanceOf(IllegalArgumentException.class); + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("clientCredentialsTokenResponseClient cannot be null"); + } + + @Test + public void setClientCredentialsTokenResponseClientWhenNotDefaultAuthorizedClientManagerThenThrowIllegalStateException() { + assertThatThrownBy(() -> this.argumentResolver.setClientCredentialsTokenResponseClient(new DefaultClientCredentialsTokenResponseClient())) + .isInstanceOf(IllegalStateException.class) + .hasMessage("The client cannot be set when the constructor used is \"OAuth2AuthorizedClientArgumentResolver(OAuth2AuthorizedClientManager)\". " + + "Instead, use the constructor \"OAuth2AuthorizedClientArgumentResolver(ClientRegistrationRepository, OAuth2AuthorizedClientRepository)\"."); } @Test diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java index e27f37170d6..a23f1c79ee9 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java @@ -51,6 +51,7 @@ import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProviderBuilder; import org.springframework.security.oauth2.client.RefreshTokenOAuth2AuthorizedClientProvider; import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; +import org.springframework.security.oauth2.client.endpoint.DefaultClientCredentialsTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.DefaultRefreshTokenTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest; @@ -157,6 +158,29 @@ public void constructorWhenAuthorizedClientManagerIsNullThenThrowIllegalArgument .isInstanceOf(IllegalArgumentException.class); } + @Test + public void setClientCredentialsTokenResponseClientWhenClientIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.function.setClientCredentialsTokenResponseClient(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("clientCredentialsTokenResponseClient cannot be null"); + } + + @Test + public void setClientCredentialsTokenResponseClientWhenNotDefaultAuthorizedClientManagerThenThrowIllegalStateException() { + assertThatThrownBy(() -> this.function.setClientCredentialsTokenResponseClient(new DefaultClientCredentialsTokenResponseClient())) + .isInstanceOf(IllegalStateException.class) + .hasMessage("The client cannot be set when the constructor used is \"ServletOAuth2AuthorizedClientExchangeFilterFunction(OAuth2AuthorizedClientManager)\". " + + "Instead, use the constructor \"ServletOAuth2AuthorizedClientExchangeFilterFunction(ClientRegistrationRepository, OAuth2AuthorizedClientRepository)\"."); + } + + @Test + public void setAccessTokenExpiresSkewWhenNotDefaultAuthorizedClientManagerThenThrowIllegalStateException() { + assertThatThrownBy(() -> this.function.setAccessTokenExpiresSkew(Duration.ofSeconds(30))) + .isInstanceOf(IllegalStateException.class) + .hasMessage("The accessTokenExpiresSkew cannot be set when the constructor used is \"ServletOAuth2AuthorizedClientExchangeFilterFunction(OAuth2AuthorizedClientManager)\". " + + "Instead, use the constructor \"ServletOAuth2AuthorizedClientExchangeFilterFunction(ClientRegistrationRepository, OAuth2AuthorizedClientRepository)\"."); + } + @Test public void defaultRequestRequestResponseWhenNullRequestContextThenRequestAndResponseNull() { Map attrs = getDefaultRequestAttributes(); From 0fcbc6cf3d9691fd882ae0659620b10172a05b65 Mon Sep 17 00:00:00 2001 From: Joe Grandja Date: Mon, 15 Jul 2019 22:13:32 -0400 Subject: [PATCH 16/19] Fix package tangles --- ...shTokenOAuth2AuthorizedClientProvider.java | 5 +- ...efaultRefreshTokenTokenResponseClient.java | 4 +- .../OAuth2RefreshTokenGrantRequest.java | 70 +++++++++++++------ ...freshTokenGrantRequestEntityConverter.java | 6 +- ...tRefreshTokenTokenResponseClientTests.java | 37 +++++----- ...TokenGrantRequestEntityConverterTests.java | 18 +++-- .../OAuth2RefreshTokenGrantRequestTests.java | 41 ++++++----- 7 files changed, 110 insertions(+), 71 deletions(-) diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProvider.java index af760431ddb..112ce629ca4 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProvider.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProvider.java @@ -89,8 +89,9 @@ public OAuth2AuthorizedClient authorize(OAuth2AuthorizationContext context) { scopes = new HashSet<>(Arrays.asList((String[]) requestScope)); } - OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = - new OAuth2RefreshTokenGrantRequest(authorizedClient, scopes); + OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest( + authorizedClient.getClientRegistration(), authorizedClient.getAccessToken(), + authorizedClient.getRefreshToken(), scopes); OAuth2AccessTokenResponse tokenResponse = this.accessTokenResponseClient.getTokenResponse(refreshTokenGrantRequest); diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/DefaultRefreshTokenTokenResponseClient.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/DefaultRefreshTokenTokenResponseClient.java index d9c19b23f4f..0efd37d8ebd 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/DefaultRefreshTokenTokenResponseClient.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/DefaultRefreshTokenTokenResponseClient.java @@ -89,12 +89,12 @@ public OAuth2AccessTokenResponse getTokenResponse(OAuth2RefreshTokenGrantRequest // https://tools.ietf.org/html/rfc6749#section-5.1 // If AccessTokenResponse.scope is empty, then default to the scope // originally requested by the client in the Token Request - tokenResponseBuilder.scopes(refreshTokenGrantRequest.getAuthorizedClient().getAccessToken().getScopes()); + tokenResponseBuilder.scopes(refreshTokenGrantRequest.getAccessToken().getScopes()); } if (tokenResponse.getRefreshToken() == null) { // Reuse existing refresh token - tokenResponseBuilder.refreshToken(refreshTokenGrantRequest.getAuthorizedClient().getRefreshToken().getTokenValue()); + tokenResponseBuilder.refreshToken(refreshTokenGrantRequest.getRefreshToken().getTokenValue()); } tokenResponse = tokenResponseBuilder.build(); diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2RefreshTokenGrantRequest.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2RefreshTokenGrantRequest.java index e84c4e365f5..a93b76fdd41 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2RefreshTokenGrantRequest.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2RefreshTokenGrantRequest.java @@ -15,8 +15,10 @@ */ package org.springframework.security.oauth2.client.endpoint; -import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; +import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.OAuth2RefreshToken; import org.springframework.util.Assert; import java.util.Collections; @@ -24,57 +26,85 @@ import java.util.Set; /** - * An OAuth 2.0 Refresh Token Grant request that holds - * the {@link OAuth2AuthorizedClient authorized client}. + * An OAuth 2.0 Refresh Token Grant request that holds the {@link OAuth2RefreshToken refresh token} credential + * granted to the {@link #getClientRegistration() client}. * * @author Joe Grandja * @since 5.2 * @see AbstractOAuth2AuthorizationGrantRequest - * @see OAuth2AuthorizedClient + * @see OAuth2RefreshToken * @see Section 6 Refreshing an Access Token */ public class OAuth2RefreshTokenGrantRequest extends AbstractOAuth2AuthorizationGrantRequest { - private final OAuth2AuthorizedClient authorizedClient; + private final ClientRegistration clientRegistration; + private final OAuth2AccessToken accessToken; + private final OAuth2RefreshToken refreshToken; private final Set scopes; /** * Constructs an {@code OAuth2RefreshTokenGrantRequest} using the provided parameters. * - * @param authorizedClient the authorized client + * @param clientRegistration the authorized client's registration + * @param accessToken the access token credential granted + * @param refreshToken the refresh token credential granted */ - public OAuth2RefreshTokenGrantRequest(OAuth2AuthorizedClient authorizedClient) { - this(authorizedClient, Collections.emptySet()); + public OAuth2RefreshTokenGrantRequest(ClientRegistration clientRegistration, OAuth2AccessToken accessToken, + OAuth2RefreshToken refreshToken) { + this(clientRegistration, accessToken, refreshToken, Collections.emptySet()); } /** * Constructs an {@code OAuth2RefreshTokenGrantRequest} using the provided parameters. * - * @param authorizedClient the authorized client - * @param scopes the scopes + * @param clientRegistration the authorized client's registration + * @param accessToken the access token credential granted + * @param refreshToken the refresh token credential granted + * @param scopes the scopes to request */ - public OAuth2RefreshTokenGrantRequest(OAuth2AuthorizedClient authorizedClient, Set scopes) { + public OAuth2RefreshTokenGrantRequest(ClientRegistration clientRegistration, OAuth2AccessToken accessToken, + OAuth2RefreshToken refreshToken, Set scopes) { super(AuthorizationGrantType.REFRESH_TOKEN); - Assert.notNull(authorizedClient, "authorizedClient cannot be null"); - Assert.notNull(authorizedClient.getRefreshToken(), "authorizedClient.refreshToken cannot be null"); - this.authorizedClient = authorizedClient; + Assert.notNull(clientRegistration, "clientRegistration cannot be null"); + Assert.notNull(accessToken, "accessToken cannot be null"); + Assert.notNull(refreshToken, "refreshToken cannot be null"); + this.clientRegistration = clientRegistration; + this.accessToken = accessToken; + this.refreshToken = refreshToken; this.scopes = Collections.unmodifiableSet(scopes != null ? new LinkedHashSet<>(scopes) : Collections.emptySet()); + } + /** + * Returns the authorized client's {@link ClientRegistration registration}. + * + * @return the {@link ClientRegistration} + */ + public ClientRegistration getClientRegistration() { + return this.clientRegistration; + } + + /** + * Returns the {@link OAuth2AccessToken access token} credential granted. + * + * @return the {@link OAuth2AccessToken} + */ + public OAuth2AccessToken getAccessToken() { + return this.accessToken; } /** - * Returns the {@link OAuth2AuthorizedClient authorized client}. + * Returns the {@link OAuth2RefreshToken refresh token} credential granted. * - * @return the {@link OAuth2AuthorizedClient} + * @return the {@link OAuth2RefreshToken} */ - public OAuth2AuthorizedClient getAuthorizedClient() { - return this.authorizedClient; + public OAuth2RefreshToken getRefreshToken() { + return this.refreshToken; } /** - * Returns the scope(s). + * Returns the scope(s) to request. * - * @return the scope(s) + * @return the scope(s) to request */ public Set getScopes() { return this.scopes; diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2RefreshTokenGrantRequestEntityConverter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2RefreshTokenGrantRequestEntityConverter.java index f3bdeb71d17..00cac8beed0 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2RefreshTokenGrantRequestEntityConverter.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2RefreshTokenGrantRequestEntityConverter.java @@ -51,7 +51,7 @@ public class OAuth2RefreshTokenGrantRequestEntityConverter implements Converter< */ @Override public RequestEntity convert(OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest) { - ClientRegistration clientRegistration = refreshTokenGrantRequest.getAuthorizedClient().getClientRegistration(); + ClientRegistration clientRegistration = refreshTokenGrantRequest.getClientRegistration(); HttpHeaders headers = OAuth2AuthorizationGrantRequestEntityUtils.getTokenRequestHeaders(clientRegistration); MultiValueMap formParameters = buildFormParameters(refreshTokenGrantRequest); @@ -69,12 +69,12 @@ public RequestEntity convert(OAuth2RefreshTokenGrantRequest refreshTokenGrant * @return a {@link MultiValueMap} of the form parameters used for the Access Token Request body */ private MultiValueMap buildFormParameters(OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest) { - ClientRegistration clientRegistration = refreshTokenGrantRequest.getAuthorizedClient().getClientRegistration(); + ClientRegistration clientRegistration = refreshTokenGrantRequest.getClientRegistration(); MultiValueMap formParameters = new LinkedMultiValueMap<>(); formParameters.add(OAuth2ParameterNames.GRANT_TYPE, refreshTokenGrantRequest.getGrantType().getValue()); formParameters.add(OAuth2ParameterNames.REFRESH_TOKEN, - refreshTokenGrantRequest.getAuthorizedClient().getRefreshToken().getTokenValue()); + refreshTokenGrantRequest.getRefreshToken().getTokenValue()); if (!CollectionUtils.isEmpty(refreshTokenGrantRequest.getScopes())) { formParameters.add(OAuth2ParameterNames.SCOPE, StringUtils.collectionToDelimitedString(refreshTokenGrantRequest.getScopes(), " ")); diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/DefaultRefreshTokenTokenResponseClientTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/DefaultRefreshTokenTokenResponseClientTests.java index 1284ce16737..5902eb4543f 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/DefaultRefreshTokenTokenResponseClientTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/DefaultRefreshTokenTokenResponseClientTests.java @@ -24,12 +24,12 @@ import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; import org.springframework.http.MediaType; -import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.core.ClientAuthenticationMethod; import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AuthorizationException; +import org.springframework.security.oauth2.core.OAuth2RefreshToken; import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; @@ -48,7 +48,8 @@ public class DefaultRefreshTokenTokenResponseClientTests { private DefaultRefreshTokenTokenResponseClient tokenResponseClient = new DefaultRefreshTokenTokenResponseClient(); private ClientRegistration.Builder clientRegistrationBuilder; - private OAuth2AuthorizedClient authorizedClient; + private OAuth2AccessToken accessToken; + private OAuth2RefreshToken refreshToken; private MockWebServer server; @Before @@ -57,8 +58,8 @@ public void setup() throws Exception { this.server.start(); String tokenUri = this.server.url("/oauth2/token").toString(); this.clientRegistrationBuilder = TestClientRegistrations.clientRegistration().tokenUri(tokenUri); - this.authorizedClient = new OAuth2AuthorizedClient(this.clientRegistrationBuilder.build(), - "principal", TestOAuth2AccessTokens.scopes("read", "write"), TestOAuth2RefreshTokens.refreshToken()); + this.accessToken = TestOAuth2AccessTokens.scopes("read", "write"); + this.refreshToken = TestOAuth2RefreshTokens.refreshToken(); } @After @@ -95,8 +96,8 @@ public void getTokenResponseWhenSuccessResponseThenReturnAccessTokenResponse() t Instant expiresAtBefore = Instant.now().plusSeconds(3600); - OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = - new OAuth2RefreshTokenGrantRequest(this.authorizedClient); + OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest( + this.clientRegistrationBuilder.build(), this.accessToken, this.refreshToken); OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(refreshTokenGrantRequest); @@ -115,8 +116,8 @@ public void getTokenResponseWhenSuccessResponseThenReturnAccessTokenResponse() t assertThat(accessTokenResponse.getAccessToken().getTokenValue()).isEqualTo("access-token-1234"); assertThat(accessTokenResponse.getAccessToken().getTokenType()).isEqualTo(OAuth2AccessToken.TokenType.BEARER); assertThat(accessTokenResponse.getAccessToken().getExpiresAt()).isBetween(expiresAtBefore, expiresAtAfter); - assertThat(accessTokenResponse.getAccessToken().getScopes()).containsExactly(this.authorizedClient.getAccessToken().getScopes().toArray(new String[0])); - assertThat(accessTokenResponse.getRefreshToken().getTokenValue()).isEqualTo(this.authorizedClient.getRefreshToken().getTokenValue()); + assertThat(accessTokenResponse.getAccessToken().getScopes()).containsExactly(this.accessToken.getScopes().toArray(new String[0])); + assertThat(accessTokenResponse.getRefreshToken().getTokenValue()).isEqualTo(this.refreshToken.getTokenValue()); } @Test @@ -131,11 +132,9 @@ public void getTokenResponseWhenClientAuthenticationPostThenFormParametersAreSen ClientRegistration clientRegistration = this.clientRegistrationBuilder .clientAuthenticationMethod(ClientAuthenticationMethod.POST) .build(); - this.authorizedClient = new OAuth2AuthorizedClient(clientRegistration, - "principal", TestOAuth2AccessTokens.scopes("read", "write"), TestOAuth2RefreshTokens.refreshToken()); OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = - new OAuth2RefreshTokenGrantRequest(this.authorizedClient); + new OAuth2RefreshTokenGrantRequest(clientRegistration, this.accessToken, this.refreshToken); this.tokenResponseClient.getTokenResponse(refreshTokenGrantRequest); @@ -156,8 +155,8 @@ public void getTokenResponseWhenSuccessResponseAndNotBearerTokenTypeThenThrowOAu "}\n"; this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); - OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = - new OAuth2RefreshTokenGrantRequest(this.authorizedClient); + OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest( + this.clientRegistrationBuilder.build(), this.accessToken, this.refreshToken); assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(refreshTokenGrantRequest)) .isInstanceOf(OAuth2AuthorizationException.class) @@ -175,8 +174,8 @@ public void getTokenResponseWhenSuccessResponseIncludesScopeThenAccessTokenHasRe "}\n"; this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); - OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = - new OAuth2RefreshTokenGrantRequest(this.authorizedClient, Collections.singleton("read")); + OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest( + this.clientRegistrationBuilder.build(), this.accessToken, this.refreshToken, Collections.singleton("read")); OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(refreshTokenGrantRequest); @@ -194,8 +193,8 @@ public void getTokenResponseWhenErrorResponseThenThrowOAuth2AuthorizationExcepti "}\n"; this.server.enqueue(jsonResponse(accessTokenErrorResponse).setResponseCode(400)); - OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = - new OAuth2RefreshTokenGrantRequest(this.authorizedClient); + OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest( + this.clientRegistrationBuilder.build(), this.accessToken, this.refreshToken); assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(refreshTokenGrantRequest)) .isInstanceOf(OAuth2AuthorizationException.class) @@ -206,8 +205,8 @@ public void getTokenResponseWhenErrorResponseThenThrowOAuth2AuthorizationExcepti public void getTokenResponseWhenServerErrorResponseThenThrowOAuth2AuthorizationException() { this.server.enqueue(new MockResponse().setResponseCode(500)); - OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = - new OAuth2RefreshTokenGrantRequest(this.authorizedClient); + OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest( + this.clientRegistrationBuilder.build(), this.accessToken, this.refreshToken); assertThatThrownBy(() -> this.tokenResponseClient.getTokenResponse(refreshTokenGrantRequest)) .isInstanceOf(OAuth2AuthorizationException.class) diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2RefreshTokenGrantRequestEntityConverterTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2RefreshTokenGrantRequestEntityConverterTests.java index 9925b80a62d..2f73174039f 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2RefreshTokenGrantRequestEntityConverterTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2RefreshTokenGrantRequestEntityConverterTests.java @@ -21,9 +21,10 @@ import org.springframework.http.HttpMethod; import org.springframework.http.MediaType; import org.springframework.http.RequestEntity; -import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; +import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.OAuth2RefreshToken; import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; @@ -45,9 +46,11 @@ public class OAuth2RefreshTokenGrantRequestEntityConverterTests { @Before public void setup() { - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(TestClientRegistrations.clientRegistration().build(), - "principal", TestOAuth2AccessTokens.scopes("read", "write"), TestOAuth2RefreshTokens.refreshToken()); - this.refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest(authorizedClient, Collections.singleton("read")); + this.refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest( + TestClientRegistrations.clientRegistration().build(), + TestOAuth2AccessTokens.scopes("read", "write"), + TestOAuth2RefreshTokens.refreshToken(), + Collections.singleton("read")); } @SuppressWarnings("unchecked") @@ -55,11 +58,12 @@ public void setup() { public void convertWhenGrantRequestValidThenConverts() { RequestEntity requestEntity = this.converter.convert(this.refreshTokenGrantRequest); - OAuth2AuthorizedClient authorizedClient = this.refreshTokenGrantRequest.getAuthorizedClient(); + ClientRegistration clientRegistration = this.refreshTokenGrantRequest.getClientRegistration(); + OAuth2RefreshToken refreshToken = this.refreshTokenGrantRequest.getRefreshToken(); assertThat(requestEntity.getMethod()).isEqualTo(HttpMethod.POST); assertThat(requestEntity.getUrl().toASCIIString()).isEqualTo( - authorizedClient.getClientRegistration().getProviderDetails().getTokenUri()); + clientRegistration.getProviderDetails().getTokenUri()); HttpHeaders headers = requestEntity.getHeaders(); assertThat(headers.getAccept()).contains(MediaType.APPLICATION_JSON_UTF8); @@ -71,7 +75,7 @@ public void convertWhenGrantRequestValidThenConverts() { assertThat(formParameters.getFirst(OAuth2ParameterNames.GRANT_TYPE)).isEqualTo( AuthorizationGrantType.REFRESH_TOKEN.getValue()); assertThat(formParameters.getFirst(OAuth2ParameterNames.REFRESH_TOKEN)).isEqualTo( - authorizedClient.getRefreshToken().getTokenValue()); + refreshToken.getTokenValue()); assertThat(formParameters.getFirst(OAuth2ParameterNames.SCOPE)).isEqualTo("read"); } } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2RefreshTokenGrantRequestTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2RefreshTokenGrantRequestTests.java index 0a0d1fd739b..dc90a388a5c 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2RefreshTokenGrantRequestTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/OAuth2RefreshTokenGrantRequestTests.java @@ -17,11 +17,10 @@ import org.junit.Before; import org.junit.Test; -import org.springframework.security.authentication.TestingAuthenticationToken; -import org.springframework.security.core.Authentication; -import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.OAuth2RefreshToken; import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens; @@ -39,39 +38,45 @@ */ public class OAuth2RefreshTokenGrantRequestTests { private ClientRegistration clientRegistration; - private Authentication principal; - private OAuth2AuthorizedClient authorizedClient; + private OAuth2AccessToken accessToken; + private OAuth2RefreshToken refreshToken; @Before public void setup() { this.clientRegistration = TestClientRegistrations.clientRegistration().build(); - this.principal = new TestingAuthenticationToken("principal", "password"); - this.authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, this.principal.getName(), - TestOAuth2AccessTokens.scopes("read", "write"), TestOAuth2RefreshTokens.refreshToken()); + this.accessToken = TestOAuth2AccessTokens.scopes("read", "write"); + this.refreshToken = TestOAuth2RefreshTokens.refreshToken(); } @Test - public void constructorWhenAuthorizedClientIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new OAuth2RefreshTokenGrantRequest(null)) + public void constructorWhenClientRegistrationIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new OAuth2RefreshTokenGrantRequest(null, this.accessToken, this.refreshToken)) .isInstanceOf(IllegalArgumentException.class) - .hasMessage("authorizedClient cannot be null"); + .hasMessage("clientRegistration cannot be null"); + } + + @Test + public void constructorWhenAccessTokenIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new OAuth2RefreshTokenGrantRequest(this.clientRegistration, null, this.refreshToken)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("accessToken cannot be null"); } @Test public void constructorWhenRefreshTokenIsNullThenThrowIllegalArgumentException() { - this.authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, - this.principal.getName(), TestOAuth2AccessTokens.scopes("read", "write")); - assertThatThrownBy(() -> new OAuth2RefreshTokenGrantRequest(this.authorizedClient)) + assertThatThrownBy(() -> new OAuth2RefreshTokenGrantRequest(this.clientRegistration, this.accessToken, null)) .isInstanceOf(IllegalArgumentException.class) - .hasMessage("authorizedClient.refreshToken cannot be null"); + .hasMessage("refreshToken cannot be null"); } @Test public void constructorWhenValidParametersProvidedThenCreated() { Set scopes = new HashSet<>(Arrays.asList("read", "write")); - OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = - new OAuth2RefreshTokenGrantRequest(this.authorizedClient, scopes); - assertThat(refreshTokenGrantRequest.getAuthorizedClient()).isSameAs(this.authorizedClient); + OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest( + this.clientRegistration, this.accessToken, this.refreshToken, scopes); + assertThat(refreshTokenGrantRequest.getClientRegistration()).isSameAs(this.clientRegistration); + assertThat(refreshTokenGrantRequest.getAccessToken()).isSameAs(this.accessToken); + assertThat(refreshTokenGrantRequest.getRefreshToken()).isSameAs(this.refreshToken); assertThat(refreshTokenGrantRequest.getScopes()).isEqualTo(scopes); } } From b0eff46dbffa6e2f96f81cea09c0de6c1ed54b6c Mon Sep 17 00:00:00 2001 From: Joe Grandja Date: Wed, 17 Jul 2019 08:54:18 -0400 Subject: [PATCH 17/19] Updates from review --- .../OAuth2ClientConfiguration.java | 2 +- ...entialsOAuth2AuthorizedClientProvider.java | 8 +++-- .../client/OAuth2AuthorizationContext.java | 15 ++++++-- ...OAuth2AuthorizedClientProviderBuilder.java | 16 ++++----- ...shTokenOAuth2AuthorizedClientProvider.java | 14 ++------ .../DefaultOAuth2AuthorizedClientManager.java | 27 ++++++++++++-- ...Auth2AuthorizedClientArgumentResolver.java | 4 +-- ...uthorizedClientExchangeFilterFunction.java | 4 +-- ...deOAuth2AuthorizedClientProviderTests.java | 6 ++-- ...lsOAuth2AuthorizedClientProviderTests.java | 8 ++--- ...ngOAuth2AuthorizedClientProviderTests.java | 4 +-- .../OAuth2AuthorizationContextTests.java | 8 ++--- ...2AuthorizedClientProviderBuilderTests.java | 26 +++++++------- ...enOAuth2AuthorizedClientProviderTests.java | 18 +++++----- ...ultOAuth2AuthorizedClientManagerTests.java | 35 +++++++++++++++++++ ...AuthorizedClientArgumentResolverTests.java | 2 +- ...izedClientExchangeFilterFunctionTests.java | 2 +- .../java/sample/config/WebClientConfig.java | 26 ++++++++------ 18 files changed, 146 insertions(+), 79 deletions(-) diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfiguration.java b/config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfiguration.java index d3f41f889bb..532d9078b55 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfiguration.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfiguration.java @@ -72,7 +72,7 @@ static class OAuth2ClientWebMvcSecurityConfiguration implements WebMvcConfigurer public void addArgumentResolvers(List argumentResolvers) { if (this.clientRegistrationRepository != null && this.authorizedClientRepository != null) { OAuth2AuthorizedClientProvider authorizedClientProvider = - OAuth2AuthorizedClientProviderBuilder.withProvider() + OAuth2AuthorizedClientProviderBuilder.builder() .authorizationCode() .refreshToken() .clientCredentials(configurer -> diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ClientCredentialsOAuth2AuthorizedClientProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ClientCredentialsOAuth2AuthorizedClientProvider.java index c3b686cd08b..36dcb6a4848 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ClientCredentialsOAuth2AuthorizedClientProvider.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ClientCredentialsOAuth2AuthorizedClientProvider.java @@ -58,9 +58,13 @@ public OAuth2AuthorizedClient authorize(OAuth2AuthorizationContext context) { Assert.notNull(context, "context cannot be null"); ClientRegistration clientRegistration = context.getClientRegistration(); + if (!AuthorizationGrantType.CLIENT_CREDENTIALS.equals(clientRegistration.getAuthorizationGrantType())) { + return null; + } + OAuth2AuthorizedClient authorizedClient = context.getAuthorizedClient(); - if (!AuthorizationGrantType.CLIENT_CREDENTIALS.equals(clientRegistration.getAuthorizationGrantType()) || - (authorizedClient != null && !hasTokenExpired(authorizedClient.getAccessToken()))) { + if (authorizedClient != null && !hasTokenExpired(authorizedClient.getAccessToken())) { + // If client is already authorized but access token is NOT expired than no need for re-authorization return null; } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizationContext.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizationContext.java index a2f70645ce6..d7aa1aec171 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizationContext.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizationContext.java @@ -35,6 +35,15 @@ * @see OAuth2AuthorizedClientProvider */ public final class OAuth2AuthorizationContext { + /** + * The name of the {@link #getAttribute(String) attribute} + * in the {@link OAuth2AuthorizationContext context} + * associated to the value for the "request scope(s)". + * The value of the attribute is a {@code String[]} of scope(s) to be requested + * by the {@link OAuth2AuthorizationContext#getClientRegistration() client}. + */ + public static final String REQUEST_SCOPE_ATTRIBUTE_NAME = OAuth2AuthorizationContext.class.getName().concat(".REQUEST_SCOPE"); + private ClientRegistration clientRegistration; private OAuth2AuthorizedClient authorizedClient; private Authentication principal; @@ -54,7 +63,7 @@ public ClientRegistration getClientRegistration() { /** * Returns the {@link OAuth2AuthorizedClient authorized client} or {@code null} - * if the {@link #forClient(ClientRegistration) client registration} was supplied. + * if the {@link #withClientRegistration(ClientRegistration) client registration} was supplied. * * @return the {@link OAuth2AuthorizedClient} or {@code null} if the client registration was supplied */ @@ -100,7 +109,7 @@ public T getAttribute(String name) { * @param clientRegistration the {@link ClientRegistration client registration} * @return the {@link Builder} */ - public static Builder forClient(ClientRegistration clientRegistration) { + public static Builder withClientRegistration(ClientRegistration clientRegistration) { return new Builder(clientRegistration); } @@ -110,7 +119,7 @@ public static Builder forClient(ClientRegistration clientRegistration) { * @param authorizedClient the {@link OAuth2AuthorizedClient authorized client} * @return the {@link Builder} */ - public static Builder forClient(OAuth2AuthorizedClient authorizedClient) { + public static Builder withAuthorizedClient(OAuth2AuthorizedClient authorizedClient) { return new Builder(authorizedClient); } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientProviderBuilder.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientProviderBuilder.java index ef554d55aef..6405e3bc13e 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientProviderBuilder.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientProviderBuilder.java @@ -21,7 +21,7 @@ import org.springframework.util.Assert; import java.time.Duration; -import java.util.HashMap; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.function.Consumer; @@ -44,7 +44,7 @@ * @see DelegatingOAuth2AuthorizedClientProvider */ public final class OAuth2AuthorizedClientProviderBuilder { - private final Map, Builder> builders = new HashMap<>(); + private final Map, Builder> builders = new LinkedHashMap<>(); private OAuth2AuthorizedClientProviderBuilder() { } @@ -54,7 +54,7 @@ private OAuth2AuthorizedClientProviderBuilder() { * * @return the {@link OAuth2AuthorizedClientProviderBuilder} */ - public static OAuth2AuthorizedClientProviderBuilder withProvider() { + public static OAuth2AuthorizedClientProviderBuilder builder() { return new OAuth2AuthorizedClientProviderBuilder(); } @@ -76,7 +76,7 @@ public OAuth2AuthorizedClientProviderBuilder provider(OAuth2AuthorizedClientProv * @return the {@link OAuth2AuthorizedClientProviderBuilder} */ public OAuth2AuthorizedClientProviderBuilder authorizationCode() { - this.builders.computeIfAbsent(AuthorizationCodeGrantBuilder.class, k -> new AuthorizationCodeGrantBuilder()); + this.builders.computeIfAbsent(AuthorizationCodeOAuth2AuthorizedClientProvider.class, k -> new AuthorizationCodeGrantBuilder()); return OAuth2AuthorizedClientProviderBuilder.this; } @@ -105,7 +105,7 @@ public OAuth2AuthorizedClientProvider build() { * @return the {@link OAuth2AuthorizedClientProviderBuilder} */ public OAuth2AuthorizedClientProviderBuilder refreshToken() { - this.builders.computeIfAbsent(RefreshTokenGrantBuilder.class, k -> new RefreshTokenGrantBuilder()); + this.builders.computeIfAbsent(RefreshTokenOAuth2AuthorizedClientProvider.class, k -> new RefreshTokenGrantBuilder()); return OAuth2AuthorizedClientProviderBuilder.this; } @@ -117,7 +117,7 @@ public OAuth2AuthorizedClientProviderBuilder refreshToken() { */ public OAuth2AuthorizedClientProviderBuilder refreshToken(Consumer builderConsumer) { RefreshTokenGrantBuilder builder = (RefreshTokenGrantBuilder) this.builders.computeIfAbsent( - RefreshTokenGrantBuilder.class, k -> new RefreshTokenGrantBuilder()); + RefreshTokenOAuth2AuthorizedClientProvider.class, k -> new RefreshTokenGrantBuilder()); builderConsumer.accept(builder); return OAuth2AuthorizedClientProviderBuilder.this; } @@ -179,7 +179,7 @@ public OAuth2AuthorizedClientProvider build() { * @return the {@link OAuth2AuthorizedClientProviderBuilder} */ public OAuth2AuthorizedClientProviderBuilder clientCredentials() { - this.builders.computeIfAbsent(ClientCredentialsGrantBuilder.class, k -> new ClientCredentialsGrantBuilder()); + this.builders.computeIfAbsent(ClientCredentialsOAuth2AuthorizedClientProvider.class, k -> new ClientCredentialsGrantBuilder()); return OAuth2AuthorizedClientProviderBuilder.this; } @@ -191,7 +191,7 @@ public OAuth2AuthorizedClientProviderBuilder clientCredentials() { */ public OAuth2AuthorizedClientProviderBuilder clientCredentials(Consumer builderConsumer) { ClientCredentialsGrantBuilder builder = (ClientCredentialsGrantBuilder) this.builders.computeIfAbsent( - ClientCredentialsGrantBuilder.class, k -> new ClientCredentialsGrantBuilder()); + ClientCredentialsOAuth2AuthorizedClientProvider.class, k -> new ClientCredentialsGrantBuilder()); builderConsumer.accept(builder); return OAuth2AuthorizedClientProviderBuilder.this; } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProvider.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProvider.java index 112ce629ca4..36046118948 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProvider.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProvider.java @@ -41,14 +41,6 @@ * @see DefaultRefreshTokenTokenResponseClient */ public final class RefreshTokenOAuth2AuthorizedClientProvider implements OAuth2AuthorizedClientProvider { - /** - * The name of the {@link OAuth2AuthorizationContext#getAttribute(String) attribute} - * in the {@link OAuth2AuthorizationContext context} associated to the value for the "requested scope(s)". - * The value of the attribute is a {@code String[]} of scope(s) to be requested - * by the {@link OAuth2AuthorizationContext#getClientRegistration() client}. - */ - public static final String REQUEST_SCOPE_ATTRIBUTE_NAME = "org.springframework.security.oauth2.client.REQUEST_SCOPE"; - private OAuth2AccessTokenResponseClient accessTokenResponseClient = new DefaultRefreshTokenTokenResponseClient(); private Duration clockSkew = Duration.ofSeconds(60); @@ -62,7 +54,7 @@ public final class RefreshTokenOAuth2AuthorizedClientProvider implements OAuth2A *

        * The following {@link OAuth2AuthorizationContext#getAttributes() context attributes} are supported: *

          - *
        1. {@code "org.springframework.security.oauth2.client.REQUEST_SCOPE"} (optional) - a {@code String[]} of scope(s) + *
        2. {@link OAuth2AuthorizationContext#REQUEST_SCOPE_ATTRIBUTE_NAME} (optional) - a {@code String[]} of scope(s) * to be requested by the {@link OAuth2AuthorizationContext#getClientRegistration() client}
        3. *
        * @@ -81,11 +73,11 @@ public OAuth2AuthorizedClient authorize(OAuth2AuthorizationContext context) { return null; } - Object requestScope = context.getAttribute(REQUEST_SCOPE_ATTRIBUTE_NAME); + Object requestScope = context.getAttribute(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME); Set scopes = Collections.emptySet(); if (requestScope != null) { Assert.isInstanceOf(String[].class, requestScope, - "The context attribute must be of type String[] '" + REQUEST_SCOPE_ATTRIBUTE_NAME + "'"); + "The context attribute must be of type String[] '" + OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME + "'"); scopes = new HashSet<>(Arrays.asList((String[]) requestScope)); } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizedClientManager.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizedClientManager.java index 45272a3c4e3..3b9275b4def 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizedClientManager.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizedClientManager.java @@ -22,11 +22,14 @@ import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProvider; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.util.Assert; +import org.springframework.util.StringUtils; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import java.util.Collections; +import java.util.HashMap; import java.util.Map; import java.util.function.Function; @@ -42,7 +45,7 @@ public final class DefaultOAuth2AuthorizedClientManager implements OAuth2Authori private final ClientRegistrationRepository clientRegistrationRepository; private final OAuth2AuthorizedClientRepository authorizedClientRepository; private OAuth2AuthorizedClientProvider authorizedClientProvider = context -> null; - private Function> contextAttributesMapper = authorizeRequest -> Collections.emptyMap(); + private Function> contextAttributesMapper = new DefaultContextAttributesMapper(); /** * Constructs a {@code DefaultOAuth2AuthorizedClientManager} using the provided parameters. @@ -80,7 +83,7 @@ public OAuth2AuthorizedClient authorize(OAuth2AuthorizeRequest authorizeRequest) } OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.forClient(clientRegistration) + OAuth2AuthorizationContext.withClientRegistration(clientRegistration) .principal(principal) .attributes(this.contextAttributesMapper.apply(authorizeRequest)) .build(); @@ -102,7 +105,7 @@ public OAuth2AuthorizedClient reauthorize(OAuth2ReauthorizeRequest reauthorizeRe HttpServletResponse servletResponse = reauthorizeRequest.getServletResponse(); OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.forClient(authorizedClient) + OAuth2AuthorizationContext.withAuthorizedClient(authorizedClient) .principal(principal) .attributes(this.contextAttributesMapper.apply(reauthorizeRequest)) .build(); @@ -136,4 +139,22 @@ public void setContextAttributesMapper(Function> { + + @Override + public Map apply(OAuth2AuthorizeRequest authorizeRequest) { + Map contextAttributes = Collections.emptyMap(); + String scope = authorizeRequest.getServletRequest().getParameter(OAuth2ParameterNames.SCOPE); + if (StringUtils.hasText(scope)) { + contextAttributes = new HashMap<>(); + contextAttributes.put(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME, + StringUtils.delimitedListToStringArray(scope, " ")); + } + return contextAttributes; + } + } } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolver.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolver.java index 3eb03412e08..4f9f1d77880 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolver.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolver.java @@ -105,7 +105,7 @@ private static OAuth2AuthorizedClientManager createDefaultAuthorizedClientManage ClientRegistrationRepository clientRegistrationRepository, OAuth2AuthorizedClientRepository authorizedClientRepository) { OAuth2AuthorizedClientProvider authorizedClientProvider = - OAuth2AuthorizedClientProviderBuilder.withProvider() + OAuth2AuthorizedClientProviderBuilder.builder() .authorizationCode() .refreshToken() .clientCredentials() @@ -193,7 +193,7 @@ private void updateDefaultAuthorizedClientManager( OAuth2AccessTokenResponseClient clientCredentialsTokenResponseClient) { OAuth2AuthorizedClientProvider authorizedClientProvider = - OAuth2AuthorizedClientProviderBuilder.withProvider() + OAuth2AuthorizedClientProviderBuilder.builder() .authorizationCode() .refreshToken() .clientCredentials(configurer -> configurer.accessTokenResponseClient(clientCredentialsTokenResponseClient)) diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java index 624e02f6f97..f25eaea23cc 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java @@ -164,7 +164,7 @@ private static OAuth2AuthorizedClientManager createDefaultAuthorizedClientManage ClientRegistrationRepository clientRegistrationRepository, OAuth2AuthorizedClientRepository authorizedClientRepository) { OAuth2AuthorizedClientProvider authorizedClientProvider = - OAuth2AuthorizedClientProviderBuilder.withProvider() + OAuth2AuthorizedClientProviderBuilder.builder() .authorizationCode() .refreshToken() .clientCredentials() @@ -208,7 +208,7 @@ public void setClientCredentialsTokenResponseClient( private void updateDefaultAuthorizedClientManager() { OAuth2AuthorizedClientProvider authorizedClientProvider = - OAuth2AuthorizedClientProviderBuilder.withProvider() + OAuth2AuthorizedClientProviderBuilder.builder() .authorizationCode() .refreshToken(configurer -> configurer.clockSkew(this.accessTokenExpiresSkew)) .clientCredentials(this::updateClientCredentialsProvider) diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/AuthorizationCodeOAuth2AuthorizedClientProviderTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/AuthorizationCodeOAuth2AuthorizedClientProviderTests.java index ad021fe8f42..c393b9f3235 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/AuthorizationCodeOAuth2AuthorizedClientProviderTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/AuthorizationCodeOAuth2AuthorizedClientProviderTests.java @@ -57,7 +57,7 @@ public void authorizeWhenNotAuthorizationCodeThenUnableToAuthorize() { ClientRegistration clientCredentialsClient = TestClientRegistrations.clientCredentials().build(); OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.forClient(clientCredentialsClient) + OAuth2AuthorizationContext.withClientRegistration(clientCredentialsClient) .principal(this.principal) .build(); assertThat(this.authorizedClientProvider.authorize(authorizationContext)).isNull(); @@ -66,7 +66,7 @@ public void authorizeWhenNotAuthorizationCodeThenUnableToAuthorize() { @Test public void authorizeWhenAuthorizationCodeAndAuthorizedThenNotAuthorize() { OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.forClient(this.authorizedClient) + OAuth2AuthorizationContext.withAuthorizedClient(this.authorizedClient) .principal(this.principal) .build(); assertThat(this.authorizedClientProvider.authorize(authorizationContext)).isNull(); @@ -75,7 +75,7 @@ public void authorizeWhenAuthorizationCodeAndAuthorizedThenNotAuthorize() { @Test public void authorizeWhenAuthorizationCodeAndNotAuthorizedThenAuthorize() { OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.forClient(this.clientRegistration) + OAuth2AuthorizationContext.withClientRegistration(this.clientRegistration) .principal(this.principal) .build(); assertThatThrownBy(() -> this.authorizedClientProvider.authorize(authorizationContext)) diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/ClientCredentialsOAuth2AuthorizedClientProviderTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/ClientCredentialsOAuth2AuthorizedClientProviderTests.java index 8e589583843..10acb7cdac8 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/ClientCredentialsOAuth2AuthorizedClientProviderTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/ClientCredentialsOAuth2AuthorizedClientProviderTests.java @@ -90,7 +90,7 @@ public void authorizeWhenNotClientCredentialsThenUnableToAuthorize() { ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.forClient(clientRegistration) + OAuth2AuthorizationContext.withClientRegistration(clientRegistration) .principal(this.principal) .build(); assertThat(this.authorizedClientProvider.authorize(authorizationContext)).isNull(); @@ -102,7 +102,7 @@ public void authorizeWhenClientCredentialsAndNotAuthorizedThenAuthorize() { when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse); OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.forClient(this.clientRegistration) + OAuth2AuthorizationContext.withClientRegistration(this.clientRegistration) .principal(this.principal) .build(); OAuth2AuthorizedClient authorizedClient = this.authorizedClientProvider.authorize(authorizationContext); @@ -125,7 +125,7 @@ public void authorizeWhenClientCredentialsAndTokenExpiredThenReauthorize() { when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse); OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.forClient(authorizedClient) + OAuth2AuthorizationContext.withAuthorizedClient(authorizedClient) .principal(this.principal) .build(); authorizedClient = this.authorizedClientProvider.authorize(authorizationContext); @@ -141,7 +141,7 @@ public void authorizeWhenClientCredentialsAndTokenNotExpiredThenNotReauthorize() this.clientRegistration, this.principal.getName(), TestOAuth2AccessTokens.noScopes()); OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.forClient(authorizedClient) + OAuth2AuthorizationContext.withAuthorizedClient(authorizedClient) .principal(this.principal) .build(); assertThat(this.authorizedClientProvider.authorize(authorizationContext)).isNull(); diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/DelegatingOAuth2AuthorizedClientProviderTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/DelegatingOAuth2AuthorizedClientProviderTests.java index d4e5d20d7e3..f930233aa83 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/DelegatingOAuth2AuthorizedClientProviderTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/DelegatingOAuth2AuthorizedClientProviderTests.java @@ -66,7 +66,7 @@ public void authorizeWhenProviderCanAuthorizeThenReturnAuthorizedClient() { DelegatingOAuth2AuthorizedClientProvider delegate = new DelegatingOAuth2AuthorizedClientProvider( mock(OAuth2AuthorizedClientProvider.class), mock(OAuth2AuthorizedClientProvider.class), authorizedClientProvider); - OAuth2AuthorizationContext context = OAuth2AuthorizationContext.forClient(clientRegistration) + OAuth2AuthorizationContext context = OAuth2AuthorizationContext.withClientRegistration(clientRegistration) .principal(principal) .build(); OAuth2AuthorizedClient reauthorizedClient = delegate.authorize(context); @@ -76,7 +76,7 @@ public void authorizeWhenProviderCanAuthorizeThenReturnAuthorizedClient() { @Test public void authorizeWhenProviderCantAuthorizeThenReturnNull() { ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); - OAuth2AuthorizationContext context = OAuth2AuthorizationContext.forClient(clientRegistration) + OAuth2AuthorizationContext context = OAuth2AuthorizationContext.withClientRegistration(clientRegistration) .principal(new TestingAuthenticationToken("principal", "password")) .build(); diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/OAuth2AuthorizationContextTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/OAuth2AuthorizationContextTests.java index a1351d4c476..89236d4c4ff 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/OAuth2AuthorizationContextTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/OAuth2AuthorizationContextTests.java @@ -45,28 +45,28 @@ public void setup() { @Test public void forClientWhenClientRegistrationIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> OAuth2AuthorizationContext.forClient((ClientRegistration) null).build()) + assertThatThrownBy(() -> OAuth2AuthorizationContext.withClientRegistration((ClientRegistration) null).build()) .isInstanceOf(IllegalArgumentException.class) .hasMessage("clientRegistration cannot be null"); } @Test public void forClientWhenAuthorizedClientIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> OAuth2AuthorizationContext.forClient((OAuth2AuthorizedClient) null).build()) + assertThatThrownBy(() -> OAuth2AuthorizationContext.withAuthorizedClient((OAuth2AuthorizedClient) null).build()) .isInstanceOf(IllegalArgumentException.class) .hasMessage("authorizedClient cannot be null"); } @Test public void forClientWhenPrincipalIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> OAuth2AuthorizationContext.forClient(this.clientRegistration).build()) + assertThatThrownBy(() -> OAuth2AuthorizationContext.withClientRegistration(this.clientRegistration).build()) .isInstanceOf(IllegalArgumentException.class) .hasMessage("principal cannot be null"); } @Test public void forClientWhenAllValuesProvidedThenAllValuesAreSet() { - OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext.forClient(this.authorizedClient) + OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext.withAuthorizedClient(this.authorizedClient) .principal(this.principal) .attribute("attribute1", "value1") .attribute("attribute2", "value2") diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientProviderBuilderTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientProviderBuilderTests.java index a9c400731c0..cbc1880877c 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientProviderBuilderTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientProviderBuilderTests.java @@ -68,19 +68,19 @@ public void setup() { @Test public void providerWhenNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> OAuth2AuthorizedClientProviderBuilder.withProvider().provider(null)) + assertThatThrownBy(() -> OAuth2AuthorizedClientProviderBuilder.builder().provider(null)) .isInstanceOf(IllegalArgumentException.class); } @Test public void buildWhenAuthorizationCodeProviderThenProviderAuthorizes() { OAuth2AuthorizedClientProvider authorizedClientProvider = - OAuth2AuthorizedClientProviderBuilder.withProvider() + OAuth2AuthorizedClientProviderBuilder.builder() .authorizationCode() .build(); OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.forClient(TestClientRegistrations.clientRegistration().build()) + OAuth2AuthorizationContext.withClientRegistration(TestClientRegistrations.clientRegistration().build()) .principal(this.principal) .build(); assertThatThrownBy(() -> authorizedClientProvider.authorize(authorizationContext)) @@ -90,7 +90,7 @@ public void buildWhenAuthorizationCodeProviderThenProviderAuthorizes() { @Test public void buildWhenRefreshTokenProviderThenProviderReauthorizes() { OAuth2AuthorizedClientProvider authorizedClientProvider = - OAuth2AuthorizedClientProviderBuilder.withProvider() + OAuth2AuthorizedClientProviderBuilder.builder() .refreshToken(configurer -> configurer.accessTokenResponseClient(this.refreshTokenTokenResponseClient)) .build(); @@ -101,7 +101,7 @@ public void buildWhenRefreshTokenProviderThenProviderReauthorizes() { TestOAuth2RefreshTokens.refreshToken()); OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.forClient(authorizedClient) + OAuth2AuthorizationContext.withAuthorizedClient(authorizedClient) .principal(this.principal) .build(); OAuth2AuthorizedClient reauthorizedClient = authorizedClientProvider.authorize(authorizationContext); @@ -113,12 +113,12 @@ public void buildWhenRefreshTokenProviderThenProviderReauthorizes() { @Test public void buildWhenClientCredentialsProviderThenProviderAuthorizes() { OAuth2AuthorizedClientProvider authorizedClientProvider = - OAuth2AuthorizedClientProviderBuilder.withProvider() + OAuth2AuthorizedClientProviderBuilder.builder() .clientCredentials(configurer -> configurer.accessTokenResponseClient(this.clientCredentialsTokenResponseClient)) .build(); OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.forClient(TestClientRegistrations.clientCredentials().build()) + OAuth2AuthorizationContext.withClientRegistration(TestClientRegistrations.clientCredentials().build()) .principal(this.principal) .build(); OAuth2AuthorizedClient authorizedClient = authorizedClientProvider.authorize(authorizationContext); @@ -130,7 +130,7 @@ public void buildWhenClientCredentialsProviderThenProviderAuthorizes() { @Test public void buildWhenAllProvidersThenProvidersAuthorize() { OAuth2AuthorizedClientProvider authorizedClientProvider = - OAuth2AuthorizedClientProviderBuilder.withProvider() + OAuth2AuthorizedClientProviderBuilder.builder() .authorizationCode() .refreshToken(configurer -> configurer.accessTokenResponseClient(this.refreshTokenTokenResponseClient)) .clientCredentials(configurer -> configurer.accessTokenResponseClient(this.clientCredentialsTokenResponseClient)) @@ -141,7 +141,7 @@ public void buildWhenAllProvidersThenProvidersAuthorize() { // authorization_code OAuth2AuthorizationContext authorizationCodeContext = - OAuth2AuthorizationContext.forClient(clientRegistration) + OAuth2AuthorizationContext.withClientRegistration(clientRegistration) .principal(this.principal) .build(); assertThatThrownBy(() -> authorizedClientProvider.authorize(authorizationCodeContext)) @@ -156,7 +156,7 @@ public void buildWhenAllProvidersThenProvidersAuthorize() { TestOAuth2RefreshTokens.refreshToken()); OAuth2AuthorizationContext refreshTokenContext = - OAuth2AuthorizationContext.forClient(authorizedClient) + OAuth2AuthorizationContext.withAuthorizedClient(authorizedClient) .principal(this.principal) .build(); OAuth2AuthorizedClient reauthorizedClient = authorizedClientProvider.authorize(refreshTokenContext); @@ -167,7 +167,7 @@ public void buildWhenAllProvidersThenProvidersAuthorize() { // client_credentials OAuth2AuthorizationContext clientCredentialsContext = - OAuth2AuthorizationContext.forClient(TestClientRegistrations.clientCredentials().build()) + OAuth2AuthorizationContext.withClientRegistration(TestClientRegistrations.clientCredentials().build()) .principal(this.principal) .build(); authorizedClient = authorizedClientProvider.authorize(clientCredentialsContext); @@ -181,12 +181,12 @@ public void buildWhenCustomProviderThenProviderCalled() { OAuth2AuthorizedClientProvider customProvider = mock(OAuth2AuthorizedClientProvider.class); OAuth2AuthorizedClientProvider authorizedClientProvider = - OAuth2AuthorizedClientProviderBuilder.withProvider() + OAuth2AuthorizedClientProviderBuilder.builder() .provider(customProvider) .build(); OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.forClient(TestClientRegistrations.clientRegistration().build()) + OAuth2AuthorizationContext.withClientRegistration(TestClientRegistrations.clientRegistration().build()) .principal(this.principal) .build(); authorizedClientProvider.authorize(authorizationContext); diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProviderTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProviderTests.java index cd4b6725e14..06124d06b08 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProviderTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/RefreshTokenOAuth2AuthorizedClientProviderTests.java @@ -98,7 +98,7 @@ public void authorizeWhenContextIsNullThenThrowIllegalArgumentException() { @Test public void authorizeWhenNotAuthorizedThenUnableToReauthorize() { OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.forClient(this.clientRegistration) + OAuth2AuthorizationContext.withClientRegistration(this.clientRegistration) .principal(this.principal) .build(); assertThat(this.authorizedClientProvider.authorize(authorizationContext)).isNull(); @@ -110,7 +110,7 @@ public void authorizeWhenAuthorizedAndRefreshTokenIsNullThenUnableToReauthorize( this.clientRegistration, this.principal.getName(), this.authorizedClient.getAccessToken()); OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.forClient(authorizedClient) + OAuth2AuthorizationContext.withAuthorizedClient(authorizedClient) .principal(this.principal) .build(); assertThat(this.authorizedClientProvider.authorize(authorizationContext)).isNull(); @@ -122,7 +122,7 @@ public void authorizeWhenAuthorizedAndAccessTokenNotExpiredThenNotReauthorize() this.principal.getName(), TestOAuth2AccessTokens.noScopes(), this.authorizedClient.getRefreshToken()); OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.forClient(authorizedClient) + OAuth2AuthorizationContext.withAuthorizedClient(authorizedClient) .principal(this.principal) .build(); assertThat(this.authorizedClientProvider.authorize(authorizationContext)).isNull(); @@ -136,7 +136,7 @@ public void authorizeWhenAuthorizedAndAccessTokenExpiredThenReauthorize() { when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse); OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.forClient(this.authorizedClient) + OAuth2AuthorizationContext.withAuthorizedClient(this.authorizedClient) .principal(this.principal) .build(); @@ -157,9 +157,9 @@ public void authorizeWhenAuthorizedAndRequestScopeProvidedThenScopeRequested() { String[] requestScope = new String[] { "read", "write" }; OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.forClient(this.authorizedClient) + OAuth2AuthorizationContext.withAuthorizedClient(this.authorizedClient) .principal(this.principal) - .attribute(RefreshTokenOAuth2AuthorizedClientProvider.REQUEST_SCOPE_ATTRIBUTE_NAME, requestScope) + .attribute(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME, requestScope) .build(); this.authorizedClientProvider.authorize(authorizationContext); @@ -174,14 +174,14 @@ public void authorizeWhenAuthorizedAndRequestScopeProvidedThenScopeRequested() { public void authorizeWhenAuthorizedAndInvalidRequestScopeProvidedThenThrowIllegalArgumentException() { String invalidRequestScope = "read write"; OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.forClient(this.authorizedClient) + OAuth2AuthorizationContext.withAuthorizedClient(this.authorizedClient) .principal(this.principal) - .attribute(RefreshTokenOAuth2AuthorizedClientProvider.REQUEST_SCOPE_ATTRIBUTE_NAME, invalidRequestScope) + .attribute(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME, invalidRequestScope) .build(); assertThatThrownBy(() -> this.authorizedClientProvider.authorize(authorizationContext)) .isInstanceOf(IllegalArgumentException.class) .hasMessageStartingWith("The context attribute must be of type String[] '" + - RefreshTokenOAuth2AuthorizedClientProvider.REQUEST_SCOPE_ATTRIBUTE_NAME + "'"); + OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME + "'"); } } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizedClientManagerTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizedClientManagerTests.java index a6f96e71a7f..6abfa8057da 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizedClientManagerTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizedClientManagerTests.java @@ -30,6 +30,7 @@ import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import java.util.function.Function; @@ -251,4 +252,38 @@ public void reauthorizeWhenSupportedProviderThenReauthorized() { verify(this.authorizedClientRepository).saveAuthorizedClient( eq(reauthorizedClient), eq(this.principal), eq(this.request), eq(this.response)); } + + @SuppressWarnings("unchecked") + @Test + public void reauthorizeWhenRequestScopeParameterThenMappedToContext() { + OAuth2AuthorizedClient reauthorizedClient = new OAuth2AuthorizedClient( + this.clientRegistration, this.principal.getName(), + TestOAuth2AccessTokens.noScopes(), TestOAuth2RefreshTokens.refreshToken()); + + when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(reauthorizedClient); + + // Override the mock with the default + this.authorizedClientManager.setContextAttributesMapper( + new DefaultOAuth2AuthorizedClientManager.DefaultContextAttributesMapper()); + + this.request.addParameter(OAuth2ParameterNames.SCOPE, "read write"); + + OAuth2ReauthorizeRequest reauthorizeRequest = new OAuth2ReauthorizeRequest( + this.authorizedClient, this.principal, this.request, this.response); + OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.reauthorize(reauthorizeRequest); + + verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); + + OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); + assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration); + assertThat(authorizationContext.getAuthorizedClient()).isSameAs(this.authorizedClient); + assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); + assertThat(authorizationContext.getAttributes()).containsKey(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME); + String[] requestScopeAttribute = authorizationContext.getAttribute(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME); + assertThat(requestScopeAttribute).contains("read", "write"); + + assertThat(authorizedClient).isSameAs(reauthorizedClient); + verify(this.authorizedClientRepository).saveAuthorizedClient( + eq(reauthorizedClient), eq(this.principal), eq(this.request), eq(this.response)); + } } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolverTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolverTests.java index b0ee98b411e..d5f7094e31e 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolverTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolverTests.java @@ -107,7 +107,7 @@ public void setup() { this.clientRegistrationRepository = new InMemoryClientRegistrationRepository(this.registration1, this.registration2); this.authorizedClientRepository = mock(OAuth2AuthorizedClientRepository.class); OAuth2AuthorizedClientProvider authorizedClientProvider = - OAuth2AuthorizedClientProviderBuilder.withProvider() + OAuth2AuthorizedClientProviderBuilder.builder() .authorizationCode() .refreshToken() .clientCredentials() diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java index a23f1c79ee9..86df1cec94f 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java @@ -135,7 +135,7 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests { public void setup() { this.authentication = new TestingAuthenticationToken("test", "this"); OAuth2AuthorizedClientProvider authorizedClientProvider = - OAuth2AuthorizedClientProviderBuilder.withProvider() + OAuth2AuthorizedClientProviderBuilder.builder() .authorizationCode() .refreshToken(configurer -> configurer.accessTokenResponseClient(this.refreshTokenTokenResponseClient)) .clientCredentials(configurer -> configurer.accessTokenResponseClient(this.clientCredentialsTokenResponseClient)) diff --git a/samples/boot/oauth2webclient/src/main/java/sample/config/WebClientConfig.java b/samples/boot/oauth2webclient/src/main/java/sample/config/WebClientConfig.java index 7874a270425..636bc53fd6f 100644 --- a/samples/boot/oauth2webclient/src/main/java/sample/config/WebClientConfig.java +++ b/samples/boot/oauth2webclient/src/main/java/sample/config/WebClientConfig.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2019 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,6 +22,7 @@ import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProviderBuilder; import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; import org.springframework.security.oauth2.client.web.DefaultOAuth2AuthorizedClientManager; +import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientManager; import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction; import org.springframework.web.reactive.function.client.WebClient; @@ -34,9 +35,20 @@ public class WebClientConfig { @Bean - WebClient webClient(ClientRegistrationRepository clientRegistrationRepository, OAuth2AuthorizedClientRepository authorizedClientRepository) { + WebClient webClient(OAuth2AuthorizedClientManager authorizedClientManager) { + ServletOAuth2AuthorizedClientExchangeFilterFunction oauth2 = + new ServletOAuth2AuthorizedClientExchangeFilterFunction(authorizedClientManager); + oauth2.setDefaultOAuth2AuthorizedClient(true); + return WebClient.builder() + .apply(oauth2.oauth2Configuration()) + .build(); + } + + @Bean + OAuth2AuthorizedClientManager authorizedClientManager(ClientRegistrationRepository clientRegistrationRepository, + OAuth2AuthorizedClientRepository authorizedClientRepository) { OAuth2AuthorizedClientProvider authorizedClientProvider = - OAuth2AuthorizedClientProviderBuilder.withProvider() + OAuth2AuthorizedClientProviderBuilder.builder() .authorizationCode() .refreshToken() .clientCredentials() @@ -45,12 +57,6 @@ WebClient webClient(ClientRegistrationRepository clientRegistrationRepository, O clientRegistrationRepository, authorizedClientRepository); authorizedClientManager.setAuthorizedClientProvider(authorizedClientProvider); - ServletOAuth2AuthorizedClientExchangeFilterFunction oauth2 = - new ServletOAuth2AuthorizedClientExchangeFilterFunction(authorizedClientManager); - oauth2.setDefaultOAuth2AuthorizedClient(true); - - return WebClient.builder() - .apply(oauth2.oauth2Configuration()) - .build(); + return authorizedClientManager; } } From 18285e55935fd23c3cf6a930e8902ffa20ba5646 Mon Sep 17 00:00:00 2001 From: Joe Grandja Date: Tue, 23 Jul 2019 15:19:10 -0400 Subject: [PATCH 18/19] Remove deprecation in ServletOAuth2AuthorizedClientExchangeFilterFunction --- .../annotation/OAuth2AuthorizedClientArgumentResolver.java | 4 ---- .../ServletOAuth2AuthorizedClientExchangeFilterFunction.java | 4 ---- 2 files changed, 8 deletions(-) diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolver.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolver.java index 4f9f1d77880..de931ba5218 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolver.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolver.java @@ -86,13 +86,9 @@ public OAuth2AuthorizedClientArgumentResolver(OAuth2AuthorizedClientManager auth /** * Constructs an {@code OAuth2AuthorizedClientArgumentResolver} using the provided parameters. * - * @deprecated Use {@link #OAuth2AuthorizedClientArgumentResolver(OAuth2AuthorizedClientManager)} instead. - * See {@link DefaultOAuth2AuthorizedClientManager} and {@link OAuth2AuthorizedClientProviderBuilder}. - * * @param clientRegistrationRepository the repository of client registrations * @param authorizedClientRepository the repository of authorized clients */ - @Deprecated public OAuth2AuthorizedClientArgumentResolver(ClientRegistrationRepository clientRegistrationRepository, OAuth2AuthorizedClientRepository authorizedClientRepository) { Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null"); diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java index f25eaea23cc..20f8991997c 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java @@ -146,13 +146,9 @@ public ServletOAuth2AuthorizedClientExchangeFilterFunction(OAuth2AuthorizedClien /** * Constructs a {@code ServletOAuth2AuthorizedClientExchangeFilterFunction} using the provided parameters. * - * @deprecated Use {@link #ServletOAuth2AuthorizedClientExchangeFilterFunction(OAuth2AuthorizedClientManager)} instead. - * See {@link DefaultOAuth2AuthorizedClientManager} and {@link OAuth2AuthorizedClientProviderBuilder}. - * * @param clientRegistrationRepository the repository of client registrations * @param authorizedClientRepository the repository of authorized clients */ - @Deprecated public ServletOAuth2AuthorizedClientExchangeFilterFunction( ClientRegistrationRepository clientRegistrationRepository, OAuth2AuthorizedClientRepository authorizedClientRepository) { From ac82955aac26f4f9cdb450cb15ad05eff3e74591 Mon Sep 17 00:00:00 2001 From: Joe Grandja Date: Tue, 23 Jul 2019 15:53:14 -0400 Subject: [PATCH 19/19] Remove OAuth2AuthorizedClientManager.reauthorize() --- .../DefaultOAuth2AuthorizedClientManager.java | 59 +++++-------- .../client/web/OAuth2AuthorizeRequest.java | 43 ++++++++++ .../web/OAuth2AuthorizedClientManager.java | 18 ++-- .../client/web/OAuth2ReauthorizeRequest.java | 57 ------------- ...uthorizedClientExchangeFilterFunction.java | 5 +- ...ultOAuth2AuthorizedClientManagerTests.java | 19 ++--- .../web/OAuth2AuthorizeRequestTests.java | 30 ++++++- .../web/OAuth2ReauthorizeRequestTests.java | 83 ------------------- 8 files changed, 108 insertions(+), 206 deletions(-) delete mode 100644 oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2ReauthorizeRequest.java delete mode 100644 oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2ReauthorizeRequestTests.java diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizedClientManager.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizedClientManager.java index 3b9275b4def..d3644186136 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizedClientManager.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizedClientManager.java @@ -67,52 +67,39 @@ public OAuth2AuthorizedClient authorize(OAuth2AuthorizeRequest authorizeRequest) Assert.notNull(authorizeRequest, "authorizeRequest cannot be null"); String clientRegistrationId = authorizeRequest.getClientRegistrationId(); + OAuth2AuthorizedClient authorizedClient = authorizeRequest.getAuthorizedClient(); Authentication principal = authorizeRequest.getPrincipal(); HttpServletRequest servletRequest = authorizeRequest.getServletRequest(); HttpServletResponse servletResponse = authorizeRequest.getServletResponse(); - ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId); - Assert.notNull(clientRegistration, "Could not find ClientRegistration with id '" + clientRegistrationId + "'"); - - OAuth2AuthorizedClient authorizedClient = this.authorizedClientRepository.loadAuthorizedClient( - clientRegistrationId, principal, servletRequest); + OAuth2AuthorizationContext.Builder contextBuilder; if (authorizedClient != null) { - OAuth2ReauthorizeRequest reauthorizeRequest = new OAuth2ReauthorizeRequest( - authorizedClient, principal, servletRequest, servletResponse); - return reauthorize(reauthorizeRequest); + contextBuilder = OAuth2AuthorizationContext.withAuthorizedClient(authorizedClient); + } else { + ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId); + Assert.notNull(clientRegistration, "Could not find ClientRegistration with id '" + clientRegistrationId + "'"); + authorizedClient = this.authorizedClientRepository.loadAuthorizedClient( + clientRegistrationId, principal, servletRequest); + if (authorizedClient != null) { + contextBuilder = OAuth2AuthorizationContext.withAuthorizedClient(authorizedClient); + } else { + contextBuilder = OAuth2AuthorizationContext.withClientRegistration(clientRegistration); + } } + OAuth2AuthorizationContext authorizationContext = contextBuilder + .principal(principal) + .attributes(this.contextAttributesMapper.apply(authorizeRequest)) + .build(); - OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.withClientRegistration(clientRegistration) - .principal(principal) - .attributes(this.contextAttributesMapper.apply(authorizeRequest)) - .build(); authorizedClient = this.authorizedClientProvider.authorize(authorizationContext); if (authorizedClient != null) { this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, principal, servletRequest, servletResponse); - } - - return authorizedClient; - } - - @Override - public OAuth2AuthorizedClient reauthorize(OAuth2ReauthorizeRequest reauthorizeRequest) { - Assert.notNull(reauthorizeRequest, "reauthorizeRequest cannot be null"); - - OAuth2AuthorizedClient authorizedClient = reauthorizeRequest.getAuthorizedClient(); - Authentication principal = reauthorizeRequest.getPrincipal(); - HttpServletRequest servletRequest = reauthorizeRequest.getServletRequest(); - HttpServletResponse servletResponse = reauthorizeRequest.getServletResponse(); - - OAuth2AuthorizationContext authorizationContext = - OAuth2AuthorizationContext.withAuthorizedClient(authorizedClient) - .principal(principal) - .attributes(this.contextAttributesMapper.apply(reauthorizeRequest)) - .build(); - OAuth2AuthorizedClient reauthorizedClient = this.authorizedClientProvider.authorize(authorizationContext); - if (reauthorizedClient != null) { - this.authorizedClientRepository.saveAuthorizedClient(reauthorizedClient, principal, servletRequest, servletResponse); - return reauthorizedClient; + } else { + // In the case of re-authorization, the returned `authorizedClient` may be null if re-authorization is not supported. + // For these cases, return the provided `authorizationContext.authorizedClient`. + if (authorizationContext.getAuthorizedClient() != null) { + return authorizationContext.getAuthorizedClient(); + } } return authorizedClient; diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizeRequest.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizeRequest.java index 6bf4eae846d..7f221183855 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizeRequest.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizeRequest.java @@ -15,7 +15,9 @@ */ package org.springframework.security.oauth2.client.web; +import org.springframework.lang.Nullable; import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.util.Assert; @@ -33,10 +35,19 @@ */ public class OAuth2AuthorizeRequest { private final String clientRegistrationId; + private final OAuth2AuthorizedClient authorizedClient; private final Authentication principal; private final HttpServletRequest servletRequest; private final HttpServletResponse servletResponse; + /** + * Constructs an {@code OAuth2AuthorizeRequest} using the provided parameters. + * + * @param clientRegistrationId the identifier for the {@link ClientRegistration client registration} + * @param principal the {@code Principal} (to be) associated to the authorized client + * @param servletRequest the {@code HttpServletRequest} + * @param servletResponse the {@code HttpServletResponse} + */ public OAuth2AuthorizeRequest(String clientRegistrationId, Authentication principal, HttpServletRequest servletRequest, HttpServletResponse servletResponse) { Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty"); @@ -44,6 +55,28 @@ public OAuth2AuthorizeRequest(String clientRegistrationId, Authentication princi Assert.notNull(servletRequest, "servletRequest cannot be null"); Assert.notNull(servletResponse, "servletResponse cannot be null"); this.clientRegistrationId = clientRegistrationId; + this.authorizedClient = null; + this.principal = principal; + this.servletRequest = servletRequest; + this.servletResponse = servletResponse; + } + + /** + * Constructs an {@code OAuth2AuthorizeRequest} using the provided parameters. + * + * @param authorizedClient the {@link OAuth2AuthorizedClient authorized client} + * @param principal the {@code Principal} associated to the authorized client + * @param servletRequest the {@code HttpServletRequest} + * @param servletResponse the {@code HttpServletResponse} + */ + public OAuth2AuthorizeRequest(OAuth2AuthorizedClient authorizedClient, Authentication principal, + HttpServletRequest servletRequest, HttpServletResponse servletResponse) { + Assert.notNull(authorizedClient, "authorizedClient cannot be null"); + Assert.notNull(principal, "principal cannot be null"); + Assert.notNull(servletRequest, "servletRequest cannot be null"); + Assert.notNull(servletResponse, "servletResponse cannot be null"); + this.clientRegistrationId = authorizedClient.getClientRegistration().getRegistrationId(); + this.authorizedClient = authorizedClient; this.principal = principal; this.servletRequest = servletRequest; this.servletResponse = servletResponse; @@ -58,6 +91,16 @@ public String getClientRegistrationId() { return this.clientRegistrationId; } + /** + * Returns the {@link OAuth2AuthorizedClient authorized client} or {@code null} if it was not provided. + * + * @return the {@link OAuth2AuthorizedClient} or {@code null} if it was not provided + */ + @Nullable + public OAuth2AuthorizedClient getAuthorizedClient() { + return this.authorizedClient; + } + /** * Returns the {@code Principal} (to be) associated to the authorized client. * diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizedClientManager.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizedClientManager.java index 1b3638cb4d2..af90c1600e0 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizedClientManager.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizedClientManager.java @@ -48,22 +48,16 @@ public interface OAuth2AuthorizedClientManager { * e.g. the associated {@link OAuth2AuthorizedClientProvider}(s) does not support * the {@link ClientRegistration#getAuthorizationGrantType() authorization grant} type configured for the client. * + *

        + * In the case of re-authorization, implementations must return the provided {@link OAuth2AuthorizeRequest#getAuthorizedClient() authorized client} + * if re-authorization is not supported for the client OR is not required, + * e.g. a {@link OAuth2AuthorizedClient#getRefreshToken() refresh token} is not available OR + * the {@link OAuth2AuthorizedClient#getAccessToken() access token} is not expired. + * * @param authorizeRequest the authorize request * @return the {@link OAuth2AuthorizedClient} or {@code null} if authorization is not supported for the specified client */ @Nullable OAuth2AuthorizedClient authorize(OAuth2AuthorizeRequest authorizeRequest); - /** - * Attempt to re-authorize (if required) the provided {@link OAuth2ReauthorizeRequest#getAuthorizedClient() authorized client}. - * Implementations must return the provided authorized client if re-authorization is not supported - * for the {@link OAuth2AuthorizedClient#getClientRegistration() client} OR is not required, - * e.g. a {@link OAuth2AuthorizedClient#getRefreshToken() refresh token} is not available OR - * the {@link OAuth2AuthorizedClient#getAccessToken() access token} is not expired. - * - * @param reauthorizeRequest the re-authorize request - * @return the re-authorized {@link OAuth2AuthorizedClient} or the provided {@link OAuth2ReauthorizeRequest#getAuthorizedClient() authorized client} if not re-authorized - */ - OAuth2AuthorizedClient reauthorize(OAuth2ReauthorizeRequest reauthorizeRequest); - } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2ReauthorizeRequest.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2ReauthorizeRequest.java deleted file mode 100644 index 80beafe115a..00000000000 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2ReauthorizeRequest.java +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Copyright 2002-2019 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.springframework.security.oauth2.client.web; - -import org.springframework.security.core.Authentication; -import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; -import org.springframework.util.Assert; - -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; - -/** - * Represents a request the {@link OAuth2AuthorizedClientManager} uses to - * {@link OAuth2AuthorizedClientManager#reauthorize(OAuth2ReauthorizeRequest) re-authorize} - * the provided {@link OAuth2AuthorizedClient#getClientRegistration() client}. - * - * @author Joe Grandja - * @since 5.2 - * @see OAuth2AuthorizeRequest - * @see OAuth2AuthorizedClientManager - */ -public class OAuth2ReauthorizeRequest extends OAuth2AuthorizeRequest { - private OAuth2AuthorizedClient authorizedClient; - - public OAuth2ReauthorizeRequest(OAuth2AuthorizedClient authorizedClient, Authentication principal, - HttpServletRequest servletRequest, HttpServletResponse servletResponse) { - super(getClientRegistrationId(authorizedClient), principal, servletRequest, servletResponse); - this.authorizedClient = authorizedClient; - } - - private static String getClientRegistrationId(OAuth2AuthorizedClient authorizedClient) { - Assert.notNull(authorizedClient, "authorizedClient cannot be null"); - return authorizedClient.getClientRegistration().getRegistrationId(); - } - - /** - * Returns the {@link OAuth2AuthorizedClient authorized client}. - * - * @return the {@link OAuth2AuthorizedClient} - */ - public OAuth2AuthorizedClient getAuthorizedClient() { - return this.authorizedClient; - } -} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java index 20f8991997c..9919bf859f7 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java @@ -38,7 +38,6 @@ import org.springframework.security.oauth2.client.web.OAuth2AuthorizeRequest; import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientManager; import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; -import org.springframework.security.oauth2.client.web.OAuth2ReauthorizeRequest; import org.springframework.util.Assert; import org.springframework.web.context.request.RequestContextHolder; import org.springframework.web.context.request.ServletRequestAttributes; @@ -443,9 +442,9 @@ private Mono authorizedClient(OAuth2AuthorizedClient aut } HttpServletRequest servletRequest = getRequest(attrs); HttpServletResponse servletResponse = getResponse(attrs); - OAuth2ReauthorizeRequest reauthorizeRequest = new OAuth2ReauthorizeRequest( + OAuth2AuthorizeRequest reauthorizeRequest = new OAuth2AuthorizeRequest( authorizedClient, authentication, servletRequest, servletResponse); - return Mono.fromSupplier(() -> this.authorizedClientManager.reauthorize(reauthorizeRequest)); + return Mono.fromSupplier(() -> this.authorizedClientManager.authorize(reauthorizeRequest)); } private ClientRequest bearer(ClientRequest request, OAuth2AuthorizedClient authorizedClient) { diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizedClientManagerTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizedClientManagerTests.java index 6abfa8057da..1d200fc323d 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizedClientManagerTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizedClientManagerTests.java @@ -200,19 +200,12 @@ public void authorizeWhenAuthorizedAndSupportedProviderThenReauthorized() { eq(reauthorizedClient), eq(this.principal), eq(this.request), eq(this.response)); } - @Test - public void reauthorizeWhenRequestIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.authorizedClientManager.reauthorize(null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("reauthorizeRequest cannot be null"); - } - @SuppressWarnings("unchecked") @Test public void reauthorizeWhenUnsupportedProviderThenNotReauthorized() { - OAuth2ReauthorizeRequest reauthorizeRequest = new OAuth2ReauthorizeRequest( + OAuth2AuthorizeRequest reauthorizeRequest = new OAuth2AuthorizeRequest( this.authorizedClient, this.principal, this.request, this.response); - OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.reauthorize(reauthorizeRequest); + OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(reauthorizeRequest); verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); verify(this.contextAttributesMapper).apply(eq(reauthorizeRequest)); @@ -236,9 +229,9 @@ public void reauthorizeWhenSupportedProviderThenReauthorized() { when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(reauthorizedClient); - OAuth2ReauthorizeRequest reauthorizeRequest = new OAuth2ReauthorizeRequest( + OAuth2AuthorizeRequest reauthorizeRequest = new OAuth2AuthorizeRequest( this.authorizedClient, this.principal, this.request, this.response); - OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.reauthorize(reauthorizeRequest); + OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(reauthorizeRequest); verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); verify(this.contextAttributesMapper).apply(eq(reauthorizeRequest)); @@ -268,9 +261,9 @@ public void reauthorizeWhenRequestScopeParameterThenMappedToContext() { this.request.addParameter(OAuth2ParameterNames.SCOPE, "read write"); - OAuth2ReauthorizeRequest reauthorizeRequest = new OAuth2ReauthorizeRequest( + OAuth2AuthorizeRequest reauthorizeRequest = new OAuth2AuthorizeRequest( this.authorizedClient, this.principal, this.request, this.response); - OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.reauthorize(reauthorizeRequest); + OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(reauthorizeRequest); verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizeRequestTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizeRequestTests.java index 1a2c114df5c..6d7e687fcdd 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizeRequestTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizeRequestTests.java @@ -20,8 +20,11 @@ import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; +import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -34,16 +37,26 @@ public class OAuth2AuthorizeRequestTests { private ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); private Authentication principal = new TestingAuthenticationToken("principal", "password"); + private OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( + this.clientRegistration, this.principal.getName(), + TestOAuth2AccessTokens.scopes("read", "write"), TestOAuth2RefreshTokens.refreshToken()); private MockHttpServletRequest servletRequest = new MockHttpServletRequest(); private MockHttpServletResponse servletResponse = new MockHttpServletResponse(); @Test public void constructorWhenClientRegistrationIdIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new OAuth2AuthorizeRequest(null, this.principal, this.servletRequest, this.servletResponse)) + assertThatThrownBy(() -> new OAuth2AuthorizeRequest((String) null, this.principal, this.servletRequest, this.servletResponse)) .isInstanceOf(IllegalArgumentException.class) .hasMessage("clientRegistrationId cannot be empty"); } + @Test + public void constructorWhenAuthorizedClientIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new OAuth2AuthorizeRequest((OAuth2AuthorizedClient) null, this.principal, this.servletRequest, this.servletResponse)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("authorizedClient cannot be null"); + } + @Test public void constructorWhenPrincipalIsNullThenThrowIllegalArgumentException() { assertThatThrownBy(() -> new OAuth2AuthorizeRequest(this.clientRegistration.getRegistrationId(), null, this.servletRequest, this.servletResponse)) @@ -66,11 +79,24 @@ public void constructorWhenServletResponseIsNullThenThrowIllegalArgumentExceptio } @Test - public void constructorWhenAllValuesProvidedThenAllValuesAreSet() { + public void constructorClientRegistrationIdWhenAllValuesProvidedThenAllValuesAreSet() { OAuth2AuthorizeRequest authorizeRequest = new OAuth2AuthorizeRequest( this.clientRegistration.getRegistrationId(), this.principal, this.servletRequest, this.servletResponse); assertThat(authorizeRequest.getClientRegistrationId()).isEqualTo(this.clientRegistration.getRegistrationId()); + assertThat(authorizeRequest.getAuthorizedClient()).isNull(); + assertThat(authorizeRequest.getPrincipal()).isEqualTo(this.principal); + assertThat(authorizeRequest.getServletRequest()).isEqualTo(this.servletRequest); + assertThat(authorizeRequest.getServletResponse()).isEqualTo(this.servletResponse); + } + + @Test + public void constructorAuthorizedClientWhenAllValuesProvidedThenAllValuesAreSet() { + OAuth2AuthorizeRequest authorizeRequest = new OAuth2AuthorizeRequest( + this.authorizedClient, this.principal, this.servletRequest, this.servletResponse); + + assertThat(authorizeRequest.getClientRegistrationId()).isEqualTo(this.authorizedClient.getClientRegistration().getRegistrationId()); + assertThat(authorizeRequest.getAuthorizedClient()).isEqualTo(this.authorizedClient); assertThat(authorizeRequest.getPrincipal()).isEqualTo(this.principal); assertThat(authorizeRequest.getServletRequest()).isEqualTo(this.servletRequest); assertThat(authorizeRequest.getServletResponse()).isEqualTo(this.servletResponse); diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2ReauthorizeRequestTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2ReauthorizeRequestTests.java deleted file mode 100644 index a82d9170ba5..00000000000 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2ReauthorizeRequestTests.java +++ /dev/null @@ -1,83 +0,0 @@ -/* - * Copyright 2002-2019 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.springframework.security.oauth2.client.web; - -import org.junit.Test; -import org.springframework.mock.web.MockHttpServletRequest; -import org.springframework.mock.web.MockHttpServletResponse; -import org.springframework.security.authentication.TestingAuthenticationToken; -import org.springframework.security.core.Authentication; -import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; -import org.springframework.security.oauth2.client.registration.ClientRegistration; -import org.springframework.security.oauth2.client.registration.TestClientRegistrations; -import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; - -/** - * Tests for {@link OAuth2ReauthorizeRequest}. - * - * @author Joe Grandja - */ -public class OAuth2ReauthorizeRequestTests { - private ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); - private Authentication principal = new TestingAuthenticationToken("principal", "password"); - private OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( - this.clientRegistration, this.principal.getName(), TestOAuth2AccessTokens.noScopes()); - private MockHttpServletRequest servletRequest = new MockHttpServletRequest(); - private MockHttpServletResponse servletResponse = new MockHttpServletResponse(); - - @Test - public void constructorWhenAuthorizedClientIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new OAuth2ReauthorizeRequest(null, this.principal, this.servletRequest, this.servletResponse)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("authorizedClient cannot be null"); - } - - @Test - public void constructorWhenPrincipalIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new OAuth2ReauthorizeRequest(this.authorizedClient, null, this.servletRequest, this.servletResponse)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("principal cannot be null"); - } - - @Test - public void constructorWhenServletRequestIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new OAuth2ReauthorizeRequest(this.authorizedClient, this.principal, null, this.servletResponse)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("servletRequest cannot be null"); - } - - @Test - public void constructorWhenServletResponseIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new OAuth2ReauthorizeRequest(this.authorizedClient, this.principal, this.servletRequest, null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("servletResponse cannot be null"); - } - - @Test - public void constructorWhenAllValuesProvidedThenAllValuesAreSet() { - OAuth2ReauthorizeRequest reauthorizeRequest = new OAuth2ReauthorizeRequest( - this.authorizedClient, this.principal, this.servletRequest, this.servletResponse); - - assertThat(reauthorizeRequest.getAuthorizedClient()).isEqualTo(this.authorizedClient); - assertThat(reauthorizeRequest.getClientRegistrationId()).isEqualTo(this.clientRegistration.getRegistrationId()); - assertThat(reauthorizeRequest.getPrincipal()).isEqualTo(this.principal); - assertThat(reauthorizeRequest.getServletRequest()).isEqualTo(this.servletRequest); - assertThat(reauthorizeRequest.getServletResponse()).isEqualTo(this.servletResponse); - } -}