From c9e3b12d4fa3ace1274d3418a1db68be56f54a4d Mon Sep 17 00:00:00 2001 From: ray0052 Date: Tue, 2 Aug 2016 13:36:01 -0600 Subject: [PATCH] Updated DefaultCsrfToken to protected against BREACH attacks --- .../WebMvcSecurityConfigurationTests.java | 4 +- ...SocketMessageBrokerConfigurerDocTests.java | 2 +- ...WebSocketMessageBrokerConfigurerTests.java | 2 +- .../web/csrf/CsrfChannelInterceptor.java | 2 +- .../web/csrf/CsrfChannelInterceptorTests.java | 4 +- .../CsrfTokenHandshakeInterceptorTests.java | 4 +- ...yMockMvcRequestBuildersFormLoginTests.java | 6 +- ...MockMvcRequestBuildersFormLogoutTests.java | 6 +- .../web/csrf/CookieCsrfTokenRepository.java | 11 +-- .../security/web/csrf/CsrfFilter.java | 2 +- .../security/web/csrf/CsrfToken.java | 8 ++- .../security/web/csrf/DefaultCsrfToken.java | 68 +++++++++++++++---- .../csrf/HttpSessionCsrfTokenRepository.java | 8 +-- .../web/csrf/LazyCsrfTokenRepository.java | 5 ++ .../csrf/CookieCsrfTokenRepositoryTests.java | 4 +- .../csrf/CsrfAuthenticationStrategyTests.java | 6 +- .../security/web/csrf/CsrfFilterTests.java | 10 +-- .../web/csrf/DefaultCsrfTokenTests.java | 31 +++++---- .../HttpSessionCsrfTokenRepositoryTests.java | 4 +- .../csrf/LazyCsrfTokenRepositoryTests.java | 4 +- .../CsrfTokenArgumentResolverTests.java | 2 +- .../CsrfRequestDataValueProcessorTests.java | 4 +- 22 files changed, 122 insertions(+), 75 deletions(-) diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configuration/WebMvcSecurityConfigurationTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configuration/WebMvcSecurityConfigurationTests.java index 01c071d2ccf..f154652b604 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configuration/WebMvcSecurityConfigurationTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configuration/WebMvcSecurityConfigurationTests.java @@ -89,7 +89,7 @@ public void deprecatedAuthenticationPrincipalResolved() throws Exception { @Test public void csrfToken() throws Exception { - CsrfToken csrfToken = new DefaultCsrfToken("headerName", "paramName", "token"); + CsrfToken csrfToken = new DefaultCsrfToken("headerName", "paramName"); MockHttpServletRequestBuilder request = get("/csrf").requestAttr( CsrfToken.class.getName(), csrfToken); @@ -132,4 +132,4 @@ public TestController testController() { } } -} \ No newline at end of file +} diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/socket/AbstractSecurityWebSocketMessageBrokerConfigurerDocTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/socket/AbstractSecurityWebSocketMessageBrokerConfigurerDocTests.java index 7a8b40fa323..d79d9b285c2 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/socket/AbstractSecurityWebSocketMessageBrokerConfigurerDocTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/socket/AbstractSecurityWebSocketMessageBrokerConfigurerDocTests.java @@ -60,7 +60,7 @@ public class AbstractSecurityWebSocketMessageBrokerConfigurerDocTests { @Before public void setup() { - token = new DefaultCsrfToken("header", "param", "token"); + token = new DefaultCsrfToken("header", "param"); sessionAttr = "sessionAttr"; messageUser = new TestingAuthenticationToken("user", "pass", "ROLE_USER"); } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/socket/AbstractSecurityWebSocketMessageBrokerConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/socket/AbstractSecurityWebSocketMessageBrokerConfigurerTests.java index 759fd7dbb1d..9f98d26d6d9 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/socket/AbstractSecurityWebSocketMessageBrokerConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/socket/AbstractSecurityWebSocketMessageBrokerConfigurerTests.java @@ -82,7 +82,7 @@ public class AbstractSecurityWebSocketMessageBrokerConfigurerTests { @Before public void setup() { - token = new DefaultCsrfToken("header", "param", "token"); + token = new DefaultCsrfToken("header", "param"); sessionAttr = "sessionAttr"; messageUser = new TestingAuthenticationToken("user", "pass", "ROLE_USER"); } diff --git a/messaging/src/main/java/org/springframework/security/messaging/web/csrf/CsrfChannelInterceptor.java b/messaging/src/main/java/org/springframework/security/messaging/web/csrf/CsrfChannelInterceptor.java index edfefaf05bf..255079e0dda 100644 --- a/messaging/src/main/java/org/springframework/security/messaging/web/csrf/CsrfChannelInterceptor.java +++ b/messaging/src/main/java/org/springframework/security/messaging/web/csrf/CsrfChannelInterceptor.java @@ -58,7 +58,7 @@ public Message preSend(Message message, MessageChannel channel) { String actualTokenValue = SimpMessageHeaderAccessor.wrap(message) .getFirstNativeHeader(expectedToken.getHeaderName()); - boolean csrfCheckPassed = expectedToken.getToken().equals(actualTokenValue); + boolean csrfCheckPassed = expectedToken.isValid(actualTokenValue); if (csrfCheckPassed) { return message; } diff --git a/messaging/src/test/java/org/springframework/security/messaging/web/csrf/CsrfChannelInterceptorTests.java b/messaging/src/test/java/org/springframework/security/messaging/web/csrf/CsrfChannelInterceptorTests.java index 848e2e47777..5d6b19580b4 100644 --- a/messaging/src/test/java/org/springframework/security/messaging/web/csrf/CsrfChannelInterceptorTests.java +++ b/messaging/src/test/java/org/springframework/security/messaging/web/csrf/CsrfChannelInterceptorTests.java @@ -46,7 +46,7 @@ public class CsrfChannelInterceptorTests { @Before public void setup() { - token = new DefaultCsrfToken("header", "param", "token"); + token = new DefaultCsrfToken("header", "param"); interceptor = new CsrfChannelInterceptor(); messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT); @@ -126,7 +126,7 @@ public void preSendNoToken() { @Test(expected = InvalidCsrfTokenException.class) public void preSendInvalidToken() { messageHeaders.setNativeHeader(token.getHeaderName(), token.getToken() - + "invalid"); + .replaceAll("^.{10}","INVALID000")); interceptor.preSend(message(), channel); } diff --git a/messaging/src/test/java/org/springframework/security/messaging/web/socket/server/CsrfTokenHandshakeInterceptorTests.java b/messaging/src/test/java/org/springframework/security/messaging/web/socket/server/CsrfTokenHandshakeInterceptorTests.java index f2e36b61dc7..38eb9e32422 100644 --- a/messaging/src/test/java/org/springframework/security/messaging/web/socket/server/CsrfTokenHandshakeInterceptorTests.java +++ b/messaging/src/test/java/org/springframework/security/messaging/web/socket/server/CsrfTokenHandshakeInterceptorTests.java @@ -70,7 +70,7 @@ public void beforeHandshakeNoAttribute() throws Exception { @Test public void beforeHandshake() throws Exception { - CsrfToken token = new DefaultCsrfToken("header", "param", "token"); + CsrfToken token = new DefaultCsrfToken("header", "param"); httpRequest.setAttribute(CsrfToken.class.getName(), token); interceptor.beforeHandshake(request, response, wsHandler, attributes); @@ -79,4 +79,4 @@ public void beforeHandshake() throws Exception { assertThat(attributes.values()).containsOnly(token); } -} \ No newline at end of file +} diff --git a/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLoginTests.java b/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLoginTests.java index bddfa3d218c..f62bd68dd39 100644 --- a/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLoginTests.java +++ b/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLoginTests.java @@ -44,8 +44,7 @@ public void defaults() throws Exception { assertThat(request.getParameter("username")).isEqualTo("user"); assertThat(request.getParameter("password")).isEqualTo("password"); assertThat(request.getMethod()).isEqualTo("POST"); - assertThat(request.getParameter(token.getParameterName())) - .isEqualTo(token.getToken()); + assertThat(token.isValid(request.getParameter(token.getParameterName()))); assertThat(request.getRequestURI()).isEqualTo("/login"); assertThat(request.getParameter("_csrf")).isNotNull(); } @@ -61,8 +60,7 @@ public void custom() throws Exception { assertThat(request.getParameter("username")).isEqualTo("admin"); assertThat(request.getParameter("password")).isEqualTo("secret"); assertThat(request.getMethod()).isEqualTo("POST"); - assertThat(request.getParameter(token.getParameterName())) - .isEqualTo(token.getToken()); + assertThat(token.isValid(request.getParameter(token.getParameterName()))); assertThat(request.getRequestURI()).isEqualTo("/login"); } diff --git a/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLogoutTests.java b/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLogoutTests.java index 16b7ebec068..3ee788d0c31 100644 --- a/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLogoutTests.java +++ b/test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLogoutTests.java @@ -40,8 +40,7 @@ public void defaults() throws Exception { CsrfToken token = (CsrfToken) request.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.ATTR_NAME); assertThat(request.getMethod()).isEqualTo("POST"); - assertThat(request.getParameter(token.getParameterName())).isEqualTo( - token.getToken()); + assertThat(token.isValid(request.getParameter(token.getParameterName()))); assertThat(request.getRequestURI()).isEqualTo("/logout"); } @@ -53,8 +52,7 @@ public void custom() throws Exception { CsrfToken token = (CsrfToken) request.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.ATTR_NAME); assertThat(request.getMethod()).isEqualTo("POST"); - assertThat(request.getParameter(token.getParameterName())).isEqualTo( - token.getToken()); + assertThat(token.isValid(request.getParameter(token.getParameterName()))); assertThat(request.getRequestURI()).isEqualTo("/admin/logout"); } diff --git a/web/src/main/java/org/springframework/security/web/csrf/CookieCsrfTokenRepository.java b/web/src/main/java/org/springframework/security/web/csrf/CookieCsrfTokenRepository.java index c275247c729..92106f2fd59 100644 --- a/web/src/main/java/org/springframework/security/web/csrf/CookieCsrfTokenRepository.java +++ b/web/src/main/java/org/springframework/security/web/csrf/CookieCsrfTokenRepository.java @@ -62,8 +62,7 @@ public CookieCsrfTokenRepository() { @Override public CsrfToken generateToken(HttpServletRequest request) { - return new DefaultCsrfToken(this.headerName, this.parameterName, - createNewToken()); + return new DefaultCsrfToken(this.headerName, this.parameterName); } @Override @@ -96,7 +95,7 @@ public CsrfToken loadToken(HttpServletRequest request) { if (!StringUtils.hasLength(token)) { return null; } - return new DefaultCsrfToken(this.headerName, this.parameterName, token); + return new DefaultCsrfToken(this.headerName, this.parameterName); } /** @@ -165,8 +164,4 @@ public static CookieCsrfTokenRepository withHttpOnlyFalse() { result.setCookieHttpOnly(false); return result; } - - private String createNewToken() { - return UUID.randomUUID().toString(); - } -} \ No newline at end of file +} diff --git a/web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java b/web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java index 0a62e303bc3..6d7fb4e1b4e 100644 --- a/web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java +++ b/web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java @@ -105,7 +105,7 @@ protected void doFilterInternal(HttpServletRequest request, if (actualToken == null) { actualToken = request.getParameter(csrfToken.getParameterName()); } - if (!csrfToken.getToken().equals(actualToken)) { + if (!csrfToken.isValid(actualToken)) { if (this.logger.isDebugEnabled()) { this.logger.debug("Invalid CSRF token found for " + UrlUtils.buildFullRequestUrl(request)); diff --git a/web/src/main/java/org/springframework/security/web/csrf/CsrfToken.java b/web/src/main/java/org/springframework/security/web/csrf/CsrfToken.java index 2e9b66a0633..5102a3dfb35 100644 --- a/web/src/main/java/org/springframework/security/web/csrf/CsrfToken.java +++ b/web/src/main/java/org/springframework/security/web/csrf/CsrfToken.java @@ -23,6 +23,7 @@ * @see DefaultCsrfToken * * @author Rob Winch + * @author John Ray * @since 3.2 * */ @@ -49,4 +50,9 @@ public interface CsrfToken extends Serializable { */ String getToken(); -} \ No newline at end of file + /** + * Check if a value returned by a previous call to getToken() matches this token. + * @return true if the token is Valid + */ + boolean isValid(String value); +} diff --git a/web/src/main/java/org/springframework/security/web/csrf/DefaultCsrfToken.java b/web/src/main/java/org/springframework/security/web/csrf/DefaultCsrfToken.java index 432d1181f78..56ddcc32996 100644 --- a/web/src/main/java/org/springframework/security/web/csrf/DefaultCsrfToken.java +++ b/web/src/main/java/org/springframework/security/web/csrf/DefaultCsrfToken.java @@ -15,37 +15,42 @@ */ package org.springframework.security.web.csrf; +import org.springframework.security.crypto.codec.Base64; import org.springframework.util.Assert; +import java.security.SecureRandom; + /** * A CSRF token that is used to protect against CSRF attacks. * * @author Rob Winch + * @author John Ray * @since 3.2 */ @SuppressWarnings("serial") public final class DefaultCsrfToken implements CsrfToken { - private final String token; + private static final int CSRF_VALUE_SIZE = 16; // 128 bit CSRF value + + private static final SecureRandom secureRandom = new SecureRandom(); private final String parameterName; private final String headerName; + private final byte[] csrfToken; + /** * Creates a new instance * @param headerName the HTTP header name to use * @param parameterName the HTTP parameter name to use - * @param token the value of the token (i.e. expected value of the HTTP parameter of - * parametername). */ - public DefaultCsrfToken(String headerName, String parameterName, String token) { + public DefaultCsrfToken(String headerName, String parameterName) { Assert.hasLength(headerName, "headerName cannot be null or empty"); Assert.hasLength(parameterName, "parameterName cannot be null or empty"); - Assert.hasLength(token, "token cannot be null or empty"); this.headerName = headerName; this.parameterName = parameterName; - this.token = token; + csrfToken = secureRandom.generateSeed(CSRF_VALUE_SIZE); } /* @@ -54,7 +59,7 @@ public DefaultCsrfToken(String headerName, String parameterName, String token) { * @see org.springframework.security.web.csrf.CsrfToken#getHeaderName() */ public String getHeaderName() { - return this.headerName; + return headerName; } /* @@ -63,15 +68,54 @@ public String getHeaderName() { * @see org.springframework.security.web.csrf.CsrfToken#getParameterName() */ public String getParameterName() { - return this.parameterName; + return parameterName; } - /* - * (non-Javadoc) + /** + * Get the CSRF token value. Each call to this method will return a unique + * value to defeat a possible BREACH attack. + * + *

The value consists of a 128 bit random mask followed by a 128 bit token + * XORed against the mask. The value is Base64 encoded. * - * @see org.springframework.security.web.csrf.CsrfToken#getToken() + * @return A unique CSRF token. */ public String getToken() { - return this.token; + byte[] encodedToken = new byte[CSRF_VALUE_SIZE*2]; + + byte[] mask = secureRandom.generateSeed(CSRF_VALUE_SIZE); + for (int i=0; i < CSRF_VALUE_SIZE; i++) { + encodedToken[i] = mask[i]; + encodedToken[i+CSRF_VALUE_SIZE] = (byte)(csrfToken[i] ^ mask[i]); + } + + return new String(Base64.encode(encodedToken)); } + + /* + * (non-Javadoc) + * + * @see org.springframework.security.web.csrf.CsrfToken#isValid() + */ + public boolean isValid(String value) { + if ((value == null) || (value.length() == 0)) + return false; + + byte[] encodedToken; + try { + encodedToken = Base64.decode(value.getBytes()); + } catch (IllegalArgumentException e) { + return false; + } + + if (encodedToken.length != (CSRF_VALUE_SIZE*2)) + return false; + + for (int i=0; i < CSRF_VALUE_SIZE; i++) + if (csrfToken[i] != (encodedToken[i] ^ encodedToken[i+CSRF_VALUE_SIZE])) + return false; + + return true; + } + } diff --git a/web/src/main/java/org/springframework/security/web/csrf/HttpSessionCsrfTokenRepository.java b/web/src/main/java/org/springframework/security/web/csrf/HttpSessionCsrfTokenRepository.java index 4e36ebd839a..e79cf46a977 100644 --- a/web/src/main/java/org/springframework/security/web/csrf/HttpSessionCsrfTokenRepository.java +++ b/web/src/main/java/org/springframework/security/web/csrf/HttpSessionCsrfTokenRepository.java @@ -15,8 +15,6 @@ */ package org.springframework.security.web.csrf; -import java.util.UUID; - import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpSession; @@ -87,8 +85,7 @@ public CsrfToken loadToken(HttpServletRequest request) { * servlet .http.HttpServletRequest) */ public CsrfToken generateToken(HttpServletRequest request) { - return new DefaultCsrfToken(this.headerName, this.parameterName, - createNewToken()); + return new DefaultCsrfToken(this.headerName, this.parameterName); } /** @@ -122,7 +119,4 @@ public void setSessionAttributeName(String sessionAttributeName) { this.sessionAttributeName = sessionAttributeName; } - private String createNewToken() { - return UUID.randomUUID().toString(); - } } diff --git a/web/src/main/java/org/springframework/security/web/csrf/LazyCsrfTokenRepository.java b/web/src/main/java/org/springframework/security/web/csrf/LazyCsrfTokenRepository.java index 5d726b7f387..3819a625d5b 100644 --- a/web/src/main/java/org/springframework/security/web/csrf/LazyCsrfTokenRepository.java +++ b/web/src/main/java/org/springframework/security/web/csrf/LazyCsrfTokenRepository.java @@ -129,6 +129,11 @@ public String getToken() { return this.delegate.getToken(); } + @Override + public boolean isValid(String value) { + return delegate.isValid(value); + } + @Override public String toString() { return "SaveOnAccessCsrfToken [delegate=" + this.delegate + "]"; diff --git a/web/src/test/java/org/springframework/security/web/csrf/CookieCsrfTokenRepositoryTests.java b/web/src/test/java/org/springframework/security/web/csrf/CookieCsrfTokenRepositoryTests.java index 6a19774fda0..5342548fd5c 100644 --- a/web/src/test/java/org/springframework/security/web/csrf/CookieCsrfTokenRepositoryTests.java +++ b/web/src/test/java/org/springframework/security/web/csrf/CookieCsrfTokenRepositoryTests.java @@ -82,7 +82,7 @@ public void saveToken() { .isEqualTo(CookieCsrfTokenRepository.DEFAULT_CSRF_COOKIE_NAME); assertThat(tokenCookie.getPath()).isEqualTo(this.request.getContextPath()); assertThat(tokenCookie.getSecure()).isEqualTo(this.request.isSecure()); - assertThat(tokenCookie.getValue()).isEqualTo(token.getToken()); + assertThat(token.isValid(tokenCookie.getValue())); assertThat(tokenCookie.isHttpOnly()).isEqualTo(true); } @@ -204,7 +204,7 @@ public void loadTokenCustom() { assertThat(loadToken).isNotNull(); assertThat(loadToken.getHeaderName()).isEqualTo(headerName); assertThat(loadToken.getParameterName()).isEqualTo(parameterName); - assertThat(loadToken.getToken()).isEqualTo(value); + assertThat(loadToken.isValid(value)); } @Test(expected = IllegalArgumentException.class) diff --git a/web/src/test/java/org/springframework/security/web/csrf/CsrfAuthenticationStrategyTests.java b/web/src/test/java/org/springframework/security/web/csrf/CsrfAuthenticationStrategyTests.java index d68949aaa0d..e727b253854 100644 --- a/web/src/test/java/org/springframework/security/web/csrf/CsrfAuthenticationStrategyTests.java +++ b/web/src/test/java/org/springframework/security/web/csrf/CsrfAuthenticationStrategyTests.java @@ -60,8 +60,8 @@ public void setup() { this.request = new MockHttpServletRequest(); this.request.setAttribute(HttpServletResponse.class.getName(), this.response); this.strategy = new CsrfAuthenticationStrategy(this.csrfTokenRepository); - this.existingToken = new DefaultCsrfToken("_csrf", "_csrf", "1"); - this.generatedToken = new DefaultCsrfToken("_csrf", "_csrf", "2"); + this.existingToken = new DefaultCsrfToken("_csrf", "_csrf"); + this.generatedToken = new DefaultCsrfToken("_csrf", "_csrf"); } @Test(expected = IllegalArgumentException.class) @@ -85,7 +85,7 @@ public void logoutRemovesCsrfTokenAndSavesNew() { // SEC-2404, SEC-2832 CsrfToken tokenInRequest = (CsrfToken) this.request .getAttribute(CsrfToken.class.getName()); - assertThat(tokenInRequest.getToken()).isSameAs(this.generatedToken.getToken()); + assertThat(tokenInRequest.isValid(this.generatedToken.getToken())); assertThat(tokenInRequest.getHeaderName()) .isSameAs(this.generatedToken.getHeaderName()); assertThat(tokenInRequest.getParameterName()) diff --git a/web/src/test/java/org/springframework/security/web/csrf/CsrfFilterTests.java b/web/src/test/java/org/springframework/security/web/csrf/CsrfFilterTests.java index 970276788cd..beddaab3688 100644 --- a/web/src/test/java/org/springframework/security/web/csrf/CsrfFilterTests.java +++ b/web/src/test/java/org/springframework/security/web/csrf/CsrfFilterTests.java @@ -69,7 +69,7 @@ public class CsrfFilterTests { @Before public void setup() { - this.token = new DefaultCsrfToken("headerName", "paramName", "csrfTokenValue"); + this.token = new DefaultCsrfToken("headerName", "paramName"); resetRequestResponse(); this.filter = createCsrfFilter(this.tokenRepository); } @@ -140,7 +140,7 @@ public void doFilterAccessDeniedIncorrectTokenPresent() when(this.requestMatcher.matches(this.request)).thenReturn(true); when(this.tokenRepository.loadToken(this.request)).thenReturn(this.token); this.request.setParameter(this.token.getParameterName(), - this.token.getToken() + " INVALID"); + this.token.getToken().replaceAll("^.{10}","INVALID000")); this.filter.doFilter(this.request, this.response, this.filterChain); @@ -160,7 +160,7 @@ public void doFilterAccessDeniedIncorrectTokenPresentHeader() when(this.requestMatcher.matches(this.request)).thenReturn(true); when(this.tokenRepository.loadToken(this.request)).thenReturn(this.token); this.request.addHeader(this.token.getHeaderName(), - this.token.getToken() + " INVALID"); + this.token.getToken().replaceAll("^.{10}","INVALID000")); this.filter.doFilter(this.request, this.response, this.filterChain); @@ -181,7 +181,7 @@ public void doFilterAccessDeniedIncorrectTokenPresentHeaderPreferredOverParamete when(this.tokenRepository.loadToken(this.request)).thenReturn(this.token); this.request.setParameter(this.token.getParameterName(), this.token.getToken()); this.request.addHeader(this.token.getHeaderName(), - this.token.getToken() + " INVALID"); + this.token.getToken().replaceAll("^.{10}","INVALID000")); this.filter.doFilter(this.request, this.response, this.filterChain); @@ -420,7 +420,7 @@ public CsrfTokenAssert isEqualTo(CsrfToken expected) { assertThat(this.actual.getHeaderName()).isEqualTo(expected.getHeaderName()); assertThat(this.actual.getParameterName()) .isEqualTo(expected.getParameterName()); - assertThat(this.actual.getToken()).isEqualTo(expected.getToken()); + assertThat(this.actual.isValid(expected.getToken())); return this; } } diff --git a/web/src/test/java/org/springframework/security/web/csrf/DefaultCsrfTokenTests.java b/web/src/test/java/org/springframework/security/web/csrf/DefaultCsrfTokenTests.java index c075f5a4cb6..13d0cc24daa 100644 --- a/web/src/test/java/org/springframework/security/web/csrf/DefaultCsrfTokenTests.java +++ b/web/src/test/java/org/springframework/security/web/csrf/DefaultCsrfTokenTests.java @@ -16,6 +16,7 @@ package org.springframework.security.web.csrf; import org.junit.Test; +import static org.assertj.core.api.Assertions.assertThat; /** * @author Rob Winch @@ -24,35 +25,41 @@ public class DefaultCsrfTokenTests { private final String headerName = "headerName"; private final String parameterName = "parameterName"; - private final String tokenValue = "tokenValue"; @Test(expected = IllegalArgumentException.class) public void constructorNullHeaderName() { - new DefaultCsrfToken(null, parameterName, tokenValue); + new DefaultCsrfToken(null, parameterName); } @Test(expected = IllegalArgumentException.class) public void constructorEmptyHeaderName() { - new DefaultCsrfToken("", parameterName, tokenValue); + new DefaultCsrfToken("", parameterName); } @Test(expected = IllegalArgumentException.class) public void constructorNullParameterName() { - new DefaultCsrfToken(headerName, null, tokenValue); + new DefaultCsrfToken(headerName, null); } @Test(expected = IllegalArgumentException.class) public void constructorEmptyParameterName() { - new DefaultCsrfToken(headerName, "", tokenValue); + new DefaultCsrfToken(headerName, ""); } - @Test(expected = IllegalArgumentException.class) - public void constructorNullTokenValue() { - new DefaultCsrfToken(headerName, parameterName, null); - } + @Test + public void testIsValid() { + DefaultCsrfToken token = new DefaultCsrfToken(headerName, parameterName); - @Test(expected = IllegalArgumentException.class) - public void constructorEmptyTokenValue() { - new DefaultCsrfToken(headerName, parameterName, ""); + String value1 = token.getToken(); + assertThat(value1).isNotEmpty(); + String value2 = token.getToken(); + assertThat(value2).isNotEmpty(); + + assertThat(value1).doesNotMatch(value2); + + assertThat(token.isValid(value1)).isTrue(); + assertThat(token.isValid(value2)).isTrue(); + assertThat(token.isValid(value2.replaceAll("^.{10}","INVALID000"))).isFalse(); } + } diff --git a/web/src/test/java/org/springframework/security/web/csrf/HttpSessionCsrfTokenRepositoryTests.java b/web/src/test/java/org/springframework/security/web/csrf/HttpSessionCsrfTokenRepositoryTests.java index 7164a2d20bc..64fd79ecf7e 100644 --- a/web/src/test/java/org/springframework/security/web/csrf/HttpSessionCsrfTokenRepositoryTests.java +++ b/web/src/test/java/org/springframework/security/web/csrf/HttpSessionCsrfTokenRepositoryTests.java @@ -89,7 +89,7 @@ public void loadTokenNullWhenSessionExists() { @Test public void saveToken() { - CsrfToken tokenToSave = new DefaultCsrfToken("123", "abc", "def"); + CsrfToken tokenToSave = new DefaultCsrfToken("123", "abc"); repo.saveToken(tokenToSave, request, response); String attrName = request.getSession().getAttributeNames().nextElement(); @@ -100,7 +100,7 @@ public void saveToken() { @Test public void saveTokenCustomSessionAttribute() { - CsrfToken tokenToSave = new DefaultCsrfToken("123", "abc", "def"); + CsrfToken tokenToSave = new DefaultCsrfToken("123", "abc"); String sessionAttributeName = "custom"; repo.setSessionAttributeName(sessionAttributeName); repo.saveToken(tokenToSave, request, response); diff --git a/web/src/test/java/org/springframework/security/web/csrf/LazyCsrfTokenRepositoryTests.java b/web/src/test/java/org/springframework/security/web/csrf/LazyCsrfTokenRepositoryTests.java index e0d50666cf5..6e66131428b 100644 --- a/web/src/test/java/org/springframework/security/web/csrf/LazyCsrfTokenRepositoryTests.java +++ b/web/src/test/java/org/springframework/security/web/csrf/LazyCsrfTokenRepositoryTests.java @@ -51,7 +51,7 @@ public class LazyCsrfTokenRepositoryTests { @Before public void setup() { - this.token = new DefaultCsrfToken("header", "param", "token"); + this.token = new DefaultCsrfToken("header", "param"); when(this.delegate.generateToken(this.request)).thenReturn(this.token); when(this.request.getAttribute(HttpServletResponse.class.getName())) .thenReturn(this.response); @@ -99,4 +99,4 @@ public void loadTokenDelegates() { verify(this.delegate).loadToken(this.request); } -} \ No newline at end of file +} diff --git a/web/src/test/java/org/springframework/security/web/method/annotation/CsrfTokenArgumentResolverTests.java b/web/src/test/java/org/springframework/security/web/method/annotation/CsrfTokenArgumentResolverTests.java index 9651bf7ee35..550fd01a65f 100644 --- a/web/src/test/java/org/springframework/security/web/method/annotation/CsrfTokenArgumentResolverTests.java +++ b/web/src/test/java/org/springframework/security/web/method/annotation/CsrfTokenArgumentResolverTests.java @@ -55,7 +55,7 @@ public class CsrfTokenArgumentResolverTests { @Before public void setup() { - token = new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "secret"); + token = new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf"); resolver = new CsrfTokenArgumentResolver(); request = new MockHttpServletRequest(); webRequest = new ServletWebRequest(request); diff --git a/web/src/test/java/org/springframework/security/web/servlet/support/csrf/CsrfRequestDataValueProcessorTests.java b/web/src/test/java/org/springframework/security/web/servlet/support/csrf/CsrfRequestDataValueProcessorTests.java index 62d02439cd6..8329e0aff54 100644 --- a/web/src/test/java/org/springframework/security/web/servlet/support/csrf/CsrfRequestDataValueProcessorTests.java +++ b/web/src/test/java/org/springframework/security/web/servlet/support/csrf/CsrfRequestDataValueProcessorTests.java @@ -46,7 +46,7 @@ public void setup() { request = new MockHttpServletRequest(); processor = new CsrfRequestDataValueProcessor(); - token = new DefaultCsrfToken("1", "a", "b"); + token = new DefaultCsrfToken("1", "a"); request.setAttribute(CsrfToken.class.getName(), token); expected.put(token.getParameterName(), token.getToken()); @@ -127,7 +127,7 @@ public void processUrl() { @Test public void createGetExtraHiddenFieldsHasCsrfToken() { - CsrfToken token = new DefaultCsrfToken("1", "a", "b"); + CsrfToken token = new DefaultCsrfToken("1", "a"); request.setAttribute(CsrfToken.class.getName(), token); Map expected = new HashMap(); expected.put(token.getParameterName(), token.getToken());