Skip to content

Commit c100e62

Browse files
committed
Add refresh_token OAuth2AuthorizedClientProvider
1 parent e3875ed commit c100e62

File tree

2 files changed

+324
-0
lines changed

2 files changed

+324
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
/*
2+
* Copyright 2002-2019 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.security.oauth2.client;
17+
18+
import org.springframework.lang.Nullable;
19+
import org.springframework.security.oauth2.client.endpoint.DefaultRefreshTokenTokenResponseClient;
20+
import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient;
21+
import org.springframework.security.oauth2.client.endpoint.OAuth2RefreshTokenGrantRequest;
22+
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
23+
import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
24+
import org.springframework.security.oauth2.core.AuthorizationGrantType;
25+
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
26+
import org.springframework.util.Assert;
27+
28+
import javax.servlet.http.HttpServletRequest;
29+
import javax.servlet.http.HttpServletResponse;
30+
import java.util.Set;
31+
32+
/**
33+
* An implementation of an {@link OAuth2AuthorizedClientProvider}
34+
* for the {@link AuthorizationGrantType#REFRESH_TOKEN refresh_token} grant.
35+
*
36+
* @author Joe Grandja
37+
* @since 5.2
38+
* @see OAuth2AuthorizedClientProvider
39+
* @see DefaultRefreshTokenTokenResponseClient
40+
*/
41+
public final class RefreshTokenOAuth2AuthorizedClientProvider implements OAuth2AuthorizedClientProvider {
42+
private static final String HTTP_SERVLET_REQUEST_ATTR_NAME = HttpServletRequest.class.getName();
43+
private static final String HTTP_SERVLET_RESPONSE_ATTR_NAME = HttpServletResponse.class.getName();
44+
private static final String SCOPE_ATTR_NAME = "SCOPE";
45+
private final ClientRegistrationRepository clientRegistrationRepository;
46+
private final OAuth2AuthorizedClientRepository authorizedClientRepository;
47+
private OAuth2AccessTokenResponseClient<OAuth2RefreshTokenGrantRequest> accessTokenResponseClient =
48+
new DefaultRefreshTokenTokenResponseClient();
49+
50+
/**
51+
* Constructs a {@code RefreshTokenOAuth2AuthorizedClientProvider} using the provided parameters.
52+
*
53+
* @param clientRegistrationRepository the repository of client registrations
54+
* @param authorizedClientRepository the repository of authorized clients
55+
*/
56+
public RefreshTokenOAuth2AuthorizedClientProvider(ClientRegistrationRepository clientRegistrationRepository,
57+
OAuth2AuthorizedClientRepository authorizedClientRepository) {
58+
Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null");
59+
Assert.notNull(authorizedClientRepository, "authorizedClientRepository cannot be null");
60+
this.clientRegistrationRepository = clientRegistrationRepository;
61+
this.authorizedClientRepository = authorizedClientRepository;
62+
}
63+
64+
@Override
65+
@Nullable
66+
public OAuth2AuthorizedClient authorize(OAuth2AuthorizationContext context) {
67+
Assert.notNull(context, "context cannot be null");
68+
if (!context.reauthorizationRequired() || context.getAuthorizedClient().getRefreshToken() == null) {
69+
return null;
70+
}
71+
72+
HttpServletRequest request = context.getAttribute(HTTP_SERVLET_REQUEST_ATTR_NAME);
73+
HttpServletResponse response = context.getAttribute(HTTP_SERVLET_RESPONSE_ATTR_NAME);
74+
Assert.notNull(request, "context.HttpServletRequest cannot be null");
75+
Assert.notNull(response, "context.HttpServletResponse cannot be null");
76+
77+
Object scopesObj = context.getAttribute(SCOPE_ATTR_NAME);
78+
Set<String> scopes = null;
79+
if (scopesObj != null) {
80+
Assert.isTrue(scopesObj instanceof Set, "The '" + SCOPE_ATTR_NAME + "' attribute must be of type " + Set.class.getName());
81+
scopes = (Set<String>) scopesObj;
82+
}
83+
84+
OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest =
85+
new OAuth2RefreshTokenGrantRequest(context.getAuthorizedClient(), scopes);
86+
OAuth2AccessTokenResponse tokenResponse =
87+
this.accessTokenResponseClient.getTokenResponse(refreshTokenGrantRequest);
88+
89+
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(
90+
context.getClientRegistration(),
91+
context.getPrincipal().getName(),
92+
tokenResponse.getAccessToken(),
93+
tokenResponse.getRefreshToken());
94+
95+
this.authorizedClientRepository.saveAuthorizedClient(
96+
authorizedClient,
97+
context.getPrincipal(),
98+
request,
99+
response);
100+
101+
return authorizedClient;
102+
}
103+
104+
/**
105+
* Sets the client used when requesting an access token credential at the Token Endpoint for the {@code refresh_token} grant.
106+
*
107+
* @param accessTokenResponseClient the client used when requesting an access token credential at the Token Endpoint for the {@code refresh_token} grant
108+
*/
109+
public void setAccessTokenResponseClient(OAuth2AccessTokenResponseClient<OAuth2RefreshTokenGrantRequest> accessTokenResponseClient) {
110+
Assert.notNull(accessTokenResponseClient, "accessTokenResponseClient cannot be null");
111+
this.accessTokenResponseClient = accessTokenResponseClient;
112+
}
113+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
/*
2+
* Copyright 2002-2019 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.security.oauth2.client;
17+
18+
import org.junit.Before;
19+
import org.junit.Test;
20+
import org.mockito.ArgumentCaptor;
21+
import org.springframework.mock.web.MockHttpServletRequest;
22+
import org.springframework.mock.web.MockHttpServletResponse;
23+
import org.springframework.security.authentication.TestingAuthenticationToken;
24+
import org.springframework.security.core.Authentication;
25+
import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient;
26+
import org.springframework.security.oauth2.client.endpoint.OAuth2RefreshTokenGrantRequest;
27+
import org.springframework.security.oauth2.client.registration.ClientRegistration;
28+
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
29+
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
30+
import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
31+
import org.springframework.security.oauth2.core.TestOAuth2AccessTokens;
32+
import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens;
33+
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
34+
import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses;
35+
36+
import javax.servlet.http.HttpServletRequest;
37+
import javax.servlet.http.HttpServletResponse;
38+
import java.util.Collections;
39+
import java.util.Set;
40+
41+
import static org.assertj.core.api.Assertions.assertThat;
42+
import static org.assertj.core.api.Assertions.assertThatThrownBy;
43+
import static org.mockito.ArgumentMatchers.any;
44+
import static org.mockito.ArgumentMatchers.eq;
45+
import static org.mockito.Mockito.*;
46+
47+
/**
48+
* Tests for {@link RefreshTokenOAuth2AuthorizedClientProvider}.
49+
*
50+
* @author Joe Grandja
51+
*/
52+
public class RefreshTokenOAuth2AuthorizedClientProviderTests {
53+
private ClientRegistrationRepository clientRegistrationRepository;
54+
private OAuth2AuthorizedClientRepository authorizedClientRepository;
55+
private RefreshTokenOAuth2AuthorizedClientProvider authorizedClientProvider;
56+
private OAuth2AccessTokenResponseClient<OAuth2RefreshTokenGrantRequest> accessTokenResponseClient;
57+
private ClientRegistration clientRegistration;
58+
private Authentication principal;
59+
private OAuth2AuthorizedClient authorizedClient;
60+
61+
@Before
62+
public void setup() {
63+
this.clientRegistrationRepository = mock(ClientRegistrationRepository.class);
64+
this.authorizedClientRepository = mock(OAuth2AuthorizedClientRepository.class);
65+
this.authorizedClientProvider = new RefreshTokenOAuth2AuthorizedClientProvider(
66+
this.clientRegistrationRepository, this.authorizedClientRepository);
67+
this.accessTokenResponseClient = mock(OAuth2AccessTokenResponseClient.class);
68+
this.authorizedClientProvider.setAccessTokenResponseClient(this.accessTokenResponseClient);
69+
this.clientRegistration = TestClientRegistrations.clientRegistration().build();
70+
this.principal = new TestingAuthenticationToken("principal", "password");
71+
this.authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, this.principal.getName(),
72+
TestOAuth2AccessTokens.scopes("read", "write"), TestOAuth2RefreshTokens.refreshToken());
73+
}
74+
75+
@Test
76+
public void constructorWhenClientRegistrationRepositoryIsNullThenThrowIllegalArgumentException() {
77+
assertThatThrownBy(() -> new RefreshTokenOAuth2AuthorizedClientProvider(null, this.authorizedClientRepository))
78+
.isInstanceOf(IllegalArgumentException.class)
79+
.hasMessage("clientRegistrationRepository cannot be null");
80+
}
81+
82+
@Test
83+
public void constructorWhenOAuth2AuthorizedClientRepositoryIsNullThenThrowIllegalArgumentException() {
84+
assertThatThrownBy(() -> new RefreshTokenOAuth2AuthorizedClientProvider(this.clientRegistrationRepository, null))
85+
.isInstanceOf(IllegalArgumentException.class)
86+
.hasMessage("authorizedClientRepository cannot be null");
87+
}
88+
89+
@Test
90+
public void setAccessTokenResponseClientWhenClientIsNullThenThrowIllegalArgumentException() {
91+
assertThatThrownBy(() -> this.authorizedClientProvider.setAccessTokenResponseClient(null))
92+
.isInstanceOf(IllegalArgumentException.class)
93+
.hasMessage("accessTokenResponseClient cannot be null");
94+
}
95+
96+
@Test
97+
public void authorizeWhenContextIsNullThenThrowIllegalArgumentException() {
98+
assertThatThrownBy(() -> this.authorizedClientProvider.authorize(null))
99+
.isInstanceOf(IllegalArgumentException.class)
100+
.hasMessage("context cannot be null");
101+
}
102+
103+
@Test
104+
public void authorizeWhenNotAuthorizedThenUnableToReauthorize() {
105+
OAuth2AuthorizationContext authorizationContext =
106+
OAuth2AuthorizationContext.authorize(this.clientRegistration).principal(this.principal).build();
107+
assertThat(this.authorizedClientProvider.authorize(authorizationContext)).isNull();
108+
}
109+
110+
@Test
111+
public void authorizeWhenAuthorizedAndRefreshTokenIsNullThenUnableToReauthorize() {
112+
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(
113+
this.clientRegistration, this.principal.getName(), this.authorizedClient.getAccessToken());
114+
OAuth2AuthorizationContext authorizationContext =
115+
OAuth2AuthorizationContext.reauthorize(authorizedClient).principal(this.principal).build();
116+
assertThat(this.authorizedClientProvider.authorize(authorizationContext)).isNull();
117+
}
118+
119+
@Test
120+
public void authorizeWhenHttpServletRequestIsNullThenThrowIllegalArgumentException() {
121+
OAuth2AuthorizationContext authorizationContext =
122+
OAuth2AuthorizationContext.reauthorize(this.authorizedClient).principal(this.principal).build();
123+
assertThatThrownBy(() -> this.authorizedClientProvider.authorize(authorizationContext))
124+
.isInstanceOf(IllegalArgumentException.class)
125+
.hasMessage("context.HttpServletRequest cannot be null");
126+
}
127+
128+
@Test
129+
public void authorizeWhenHttpServletResponseIsNullThenThrowIllegalArgumentException() {
130+
OAuth2AuthorizationContext authorizationContext =
131+
OAuth2AuthorizationContext.reauthorize(this.authorizedClient)
132+
.principal(this.principal)
133+
.attribute(HttpServletRequest.class.getName(), new MockHttpServletRequest())
134+
.build();
135+
assertThatThrownBy(() -> this.authorizedClientProvider.authorize(authorizationContext))
136+
.isInstanceOf(IllegalArgumentException.class)
137+
.hasMessage("context.HttpServletResponse cannot be null");
138+
}
139+
140+
@Test
141+
public void authorizeWhenAuthorizedWithRefreshTokenThenReauthorize() {
142+
OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse()
143+
.refreshToken("new-refresh-token")
144+
.build();
145+
when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse);
146+
147+
OAuth2AuthorizationContext authorizationContext =
148+
OAuth2AuthorizationContext.reauthorize(this.authorizedClient)
149+
.principal(this.principal)
150+
.attribute(HttpServletRequest.class.getName(), new MockHttpServletRequest())
151+
.attribute(HttpServletResponse.class.getName(), new MockHttpServletResponse())
152+
.build();
153+
154+
OAuth2AuthorizedClient authorizedClient = this.authorizedClientProvider.authorize(authorizationContext);
155+
156+
assertThat(authorizedClient.getClientRegistration()).isSameAs(this.clientRegistration);
157+
assertThat(authorizedClient.getPrincipalName()).isEqualTo(this.principal.getName());
158+
assertThat(authorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken());
159+
assertThat(authorizedClient.getRefreshToken()).isEqualTo(accessTokenResponse.getRefreshToken());
160+
verify(this.authorizedClientRepository).saveAuthorizedClient(
161+
eq(authorizedClient), eq(this.principal),
162+
any(HttpServletRequest.class), any(HttpServletResponse.class));
163+
}
164+
165+
@Test
166+
public void authorizeWhenAuthorizedAndScopeProvidedThenScopeRequested() {
167+
OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse()
168+
.refreshToken("new-refresh-token")
169+
.build();
170+
when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse);
171+
172+
Set<String> scope = Collections.singleton("read");
173+
174+
OAuth2AuthorizationContext authorizationContext =
175+
OAuth2AuthorizationContext.reauthorize(this.authorizedClient)
176+
.principal(this.principal)
177+
.attribute(HttpServletRequest.class.getName(), new MockHttpServletRequest())
178+
.attribute(HttpServletResponse.class.getName(), new MockHttpServletResponse())
179+
.attribute("SCOPE", scope)
180+
.build();
181+
182+
this.authorizedClientProvider.authorize(authorizationContext);
183+
184+
ArgumentCaptor<OAuth2RefreshTokenGrantRequest> refreshTokenGrantRequestArgCaptor =
185+
ArgumentCaptor.forClass(OAuth2RefreshTokenGrantRequest.class);
186+
verify(this.accessTokenResponseClient).getTokenResponse(refreshTokenGrantRequestArgCaptor.capture());
187+
assertThat(refreshTokenGrantRequestArgCaptor.getValue().getScopes()).isEqualTo(scope);
188+
}
189+
190+
@Test
191+
public void authorizeWhenAuthorizedAndInvalidScopeProvidedThenThrowIllegalArgumentException() {
192+
OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse()
193+
.refreshToken("new-refresh-token")
194+
.build();
195+
when(this.accessTokenResponseClient.getTokenResponse(any())).thenReturn(accessTokenResponse);
196+
197+
String scope = "read";
198+
199+
OAuth2AuthorizationContext authorizationContext =
200+
OAuth2AuthorizationContext.reauthorize(this.authorizedClient)
201+
.principal(this.principal)
202+
.attribute(HttpServletRequest.class.getName(), new MockHttpServletRequest())
203+
.attribute(HttpServletResponse.class.getName(), new MockHttpServletResponse())
204+
.attribute("SCOPE", scope)
205+
.build();
206+
207+
assertThatThrownBy(() -> this.authorizedClientProvider.authorize(authorizationContext))
208+
.isInstanceOf(IllegalArgumentException.class)
209+
.hasMessage("The 'SCOPE' attribute must be of type " + Set.class.getName());
210+
}
211+
}

0 commit comments

Comments
 (0)