Skip to content

Commit 5a6a1bf

Browse files
committed
Refine requestMatcher Validation Rules
Closes gh-13850
1 parent 914ebd6 commit 5a6a1bf

File tree

6 files changed

+232
-22
lines changed

6 files changed

+232
-22
lines changed

config/src/main/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistry.java

Lines changed: 130 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2022 the original author or authors.
2+
* Copyright 2002-2023 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -26,6 +26,7 @@
2626
import jakarta.servlet.DispatcherType;
2727
import jakarta.servlet.ServletContext;
2828
import jakarta.servlet.ServletRegistration;
29+
import jakarta.servlet.http.HttpServletRequest;
2930

3031
import org.springframework.beans.factory.NoSuchBeanDefinitionException;
3132
import org.springframework.context.ApplicationContext;
@@ -203,11 +204,30 @@ public C requestMatchers(HttpMethod method, String... patterns) {
203204
if (!hasDispatcherServlet(registrations)) {
204205
return requestMatchers(RequestMatchers.antMatchersAsArray(method, patterns));
205206
}
206-
if (registrations.size() > 1) {
207-
String errorMessage = computeErrorMessage(registrations.values());
208-
throw new IllegalArgumentException(errorMessage);
207+
ServletRegistration dispatcherServlet = requireOneRootDispatcherServlet(registrations);
208+
if (dispatcherServlet != null) {
209+
if (registrations.size() == 1) {
210+
return requestMatchers(createMvcMatchers(method, patterns).toArray(RequestMatcher[]::new));
211+
}
212+
List<RequestMatcher> matchers = new ArrayList<>();
213+
for (String pattern : patterns) {
214+
AntPathRequestMatcher ant = new AntPathRequestMatcher(pattern, (method != null) ? method.name() : null);
215+
MvcRequestMatcher mvc = createMvcMatchers(method, pattern).get(0);
216+
matchers.add(new DispatcherServletDelegatingRequestMatcher(ant, mvc, servletContext));
217+
}
218+
return requestMatchers(matchers.toArray(new RequestMatcher[0]));
209219
}
210-
return requestMatchers(createMvcMatchers(method, patterns).toArray(new RequestMatcher[0]));
220+
dispatcherServlet = requireOnlyPathMappedDispatcherServlet(registrations);
221+
if (dispatcherServlet != null) {
222+
String mapping = dispatcherServlet.getMappings().iterator().next();
223+
List<MvcRequestMatcher> matchers = createMvcMatchers(method, patterns);
224+
for (MvcRequestMatcher matcher : matchers) {
225+
matcher.setServletPath(mapping.substring(0, mapping.length() - 2));
226+
}
227+
return requestMatchers(matchers.toArray(new RequestMatcher[0]));
228+
}
229+
String errorMessage = computeErrorMessage(registrations.values());
230+
throw new IllegalArgumentException(errorMessage);
211231
}
212232

213233
private Map<String, ? extends ServletRegistration> mappableServletRegistrations(ServletContext servletContext) {
@@ -225,22 +245,66 @@ private boolean hasDispatcherServlet(Map<String, ? extends ServletRegistration>
225245
if (registrations == null) {
226246
return false;
227247
}
228-
Class<?> dispatcherServlet = ClassUtils.resolveClassName("org.springframework.web.servlet.DispatcherServlet",
229-
null);
230248
for (ServletRegistration registration : registrations.values()) {
231-
try {
232-
Class<?> clazz = Class.forName(registration.getClassName());
233-
if (dispatcherServlet.isAssignableFrom(clazz)) {
234-
return true;
235-
}
236-
}
237-
catch (ClassNotFoundException ex) {
238-
return false;
249+
if (isDispatcherServlet(registration)) {
250+
return true;
239251
}
240252
}
241253
return false;
242254
}
243255

256+
private ServletRegistration requireOneRootDispatcherServlet(
257+
Map<String, ? extends ServletRegistration> registrations) {
258+
ServletRegistration rootDispatcherServlet = null;
259+
for (ServletRegistration registration : registrations.values()) {
260+
if (!isDispatcherServlet(registration)) {
261+
continue;
262+
}
263+
if (registration.getMappings().size() > 1) {
264+
return null;
265+
}
266+
if (!"/".equals(registration.getMappings().iterator().next())) {
267+
return null;
268+
}
269+
rootDispatcherServlet = registration;
270+
}
271+
return rootDispatcherServlet;
272+
}
273+
274+
private ServletRegistration requireOnlyPathMappedDispatcherServlet(
275+
Map<String, ? extends ServletRegistration> registrations) {
276+
ServletRegistration pathDispatcherServlet = null;
277+
for (ServletRegistration registration : registrations.values()) {
278+
if (!isDispatcherServlet(registration)) {
279+
return null;
280+
}
281+
if (registration.getMappings().size() > 1) {
282+
return null;
283+
}
284+
String mapping = registration.getMappings().iterator().next();
285+
if (!mapping.startsWith("/") || !mapping.endsWith("/*")) {
286+
return null;
287+
}
288+
if (pathDispatcherServlet != null) {
289+
return null;
290+
}
291+
pathDispatcherServlet = registration;
292+
}
293+
return pathDispatcherServlet;
294+
}
295+
296+
private boolean isDispatcherServlet(ServletRegistration registration) {
297+
Class<?> dispatcherServlet = ClassUtils.resolveClassName("org.springframework.web.servlet.DispatcherServlet",
298+
null);
299+
try {
300+
Class<?> clazz = Class.forName(registration.getClassName());
301+
return dispatcherServlet.isAssignableFrom(clazz);
302+
}
303+
catch (ClassNotFoundException ex) {
304+
return false;
305+
}
306+
}
307+
244308
private String computeErrorMessage(Collection<? extends ServletRegistration> registrations) {
245309
String template = "This method cannot decide whether these patterns are Spring MVC patterns or not. "
246310
+ "If this endpoint is a Spring MVC endpoint, please use requestMatchers(MvcRequestMatcher); "
@@ -380,4 +444,55 @@ static List<RequestMatcher> regexMatchers(String... regexPatterns) {
380444

381445
}
382446

447+
static class DispatcherServletDelegatingRequestMatcher implements RequestMatcher {
448+
449+
private final AntPathRequestMatcher ant;
450+
451+
private final MvcRequestMatcher mvc;
452+
453+
private final ServletContext servletContext;
454+
455+
DispatcherServletDelegatingRequestMatcher(AntPathRequestMatcher ant, MvcRequestMatcher mvc,
456+
ServletContext servletContext) {
457+
this.ant = ant;
458+
this.mvc = mvc;
459+
this.servletContext = servletContext;
460+
}
461+
462+
@Override
463+
public boolean matches(HttpServletRequest request) {
464+
String name = request.getHttpServletMapping().getServletName();
465+
ServletRegistration registration = this.servletContext.getServletRegistration(name);
466+
Assert.notNull(registration, "Failed to find servlet [" + name + "] in the servlet context");
467+
if (isDispatcherServlet(registration)) {
468+
return this.mvc.matches(request);
469+
}
470+
return this.ant.matches(request);
471+
}
472+
473+
@Override
474+
public MatchResult matcher(HttpServletRequest request) {
475+
String name = request.getHttpServletMapping().getServletName();
476+
ServletRegistration registration = this.servletContext.getServletRegistration(name);
477+
Assert.notNull(registration, "Failed to find servlet [" + name + "] in the servlet context");
478+
if (isDispatcherServlet(registration)) {
479+
return this.mvc.matcher(request);
480+
}
481+
return this.ant.matcher(request);
482+
}
483+
484+
private boolean isDispatcherServlet(ServletRegistration registration) {
485+
Class<?> dispatcherServlet = ClassUtils
486+
.resolveClassName("org.springframework.web.servlet.DispatcherServlet", null);
487+
try {
488+
Class<?> clazz = Class.forName(registration.getClassName());
489+
return dispatcherServlet.isAssignableFrom(clazz);
490+
}
491+
catch (ClassNotFoundException ex) {
492+
return false;
493+
}
494+
}
495+
496+
}
497+
383498
}

config/src/test/java/org/springframework/security/config/MockServletContext.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,11 @@ public ServletRegistration.Dynamic addServlet(@NonNull String servletName, Class
5555
return this.registrations;
5656
}
5757

58+
@Override
59+
public ServletRegistration getServletRegistration(String servletName) {
60+
return this.registrations.get(servletName);
61+
}
62+
5863
private static class MockServletRegistration implements ServletRegistration.Dynamic {
5964

6065
private final String name;
Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,32 +14,32 @@
1414
* limitations under the License.
1515
*/
1616

17-
package org.springframework.security.config.annotation.web.configurers;
17+
package org.springframework.security.config;
1818

1919
import jakarta.servlet.http.HttpServletRequest;
2020
import jakarta.servlet.http.MappingMatch;
2121

2222
import org.springframework.mock.web.MockHttpServletMapping;
2323

24-
final class TestMockHttpServletMappings {
24+
public final class TestMockHttpServletMappings {
2525

2626
private TestMockHttpServletMappings() {
2727

2828
}
2929

30-
static MockHttpServletMapping extension(HttpServletRequest request, String extension) {
30+
public static MockHttpServletMapping extension(HttpServletRequest request, String extension) {
3131
String uri = request.getRequestURI();
3232
String matchValue = uri.substring(0, uri.lastIndexOf(extension));
3333
return new MockHttpServletMapping(matchValue, "*" + extension, "extension", MappingMatch.EXTENSION);
3434
}
3535

36-
static MockHttpServletMapping path(HttpServletRequest request, String path) {
36+
public static MockHttpServletMapping path(HttpServletRequest request, String path) {
3737
String uri = request.getRequestURI();
3838
String matchValue = uri.substring(path.length());
3939
return new MockHttpServletMapping(matchValue, path + "/*", "path", MappingMatch.PATH);
4040
}
4141

42-
static MockHttpServletMapping defaultMapping() {
42+
public static MockHttpServletMapping defaultMapping() {
4343
return new MockHttpServletMapping("", "/", "default", MappingMatch.DEFAULT);
4444
}
4545

config/src/test/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistryTests.java

Lines changed: 90 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2022 the original author or authors.
2+
* Copyright 2002-2023 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -26,8 +26,11 @@
2626
import org.springframework.beans.factory.NoSuchBeanDefinitionException;
2727
import org.springframework.context.ApplicationContext;
2828
import org.springframework.http.HttpMethod;
29+
import org.springframework.mock.web.MockHttpServletRequest;
2930
import org.springframework.security.config.MockServletContext;
31+
import org.springframework.security.config.TestMockHttpServletMappings;
3032
import org.springframework.security.config.annotation.ObjectPostProcessor;
33+
import org.springframework.security.config.annotation.web.AbstractRequestMatcherRegistry.DispatcherServletDelegatingRequestMatcher;
3134
import org.springframework.security.web.servlet.util.matcher.MvcRequestMatcher;
3235
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
3336
import org.springframework.security.web.util.matcher.DispatcherTypeRequestMatcher;
@@ -40,6 +43,9 @@
4043
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
4144
import static org.mockito.BDDMockito.given;
4245
import static org.mockito.Mockito.mock;
46+
import static org.mockito.Mockito.verify;
47+
import static org.mockito.Mockito.verifyNoInteractions;
48+
import static org.mockito.Mockito.verifyNoMoreInteractions;
4349

4450
/**
4551
* Tests for {@link AbstractRequestMatcherRegistry}.
@@ -159,6 +165,8 @@ public void requestMatchersWhenMvcPresentInClassPathAndMvcIntrospectorBeanNotAva
159165
public void requestMatchersWhenNoDispatcherServletThenAntPathRequestMatcherType() {
160166
MockServletContext servletContext = new MockServletContext();
161167
given(this.context.getServletContext()).willReturn(servletContext);
168+
servletContext.addServlet("servletOne", Servlet.class).addMapping("/one");
169+
servletContext.addServlet("servletTwo", Servlet.class).addMapping("/two");
162170
List<RequestMatcher> requestMatchers = this.matcherRegistry.requestMatchers("/**");
163171
assertThat(requestMatchers).isNotEmpty();
164172
assertThat(requestMatchers).hasSize(1);
@@ -170,7 +178,26 @@ public void requestMatchersWhenAmbiguousServletsThenException() {
170178
MockServletContext servletContext = new MockServletContext();
171179
given(this.context.getServletContext()).willReturn(servletContext);
172180
servletContext.addServlet("dispatcherServlet", DispatcherServlet.class).addMapping("/");
173-
servletContext.addServlet("servletTwo", Servlet.class).addMapping("/servlet/**");
181+
servletContext.addServlet("servletTwo", DispatcherServlet.class).addMapping("/servlet/*");
182+
assertThatExceptionOfType(IllegalArgumentException.class)
183+
.isThrownBy(() -> this.matcherRegistry.requestMatchers("/**"));
184+
}
185+
186+
@Test
187+
public void requestMatchersWhenMultipleDispatcherServletMappingsThenException() {
188+
MockServletContext servletContext = new MockServletContext();
189+
given(this.context.getServletContext()).willReturn(servletContext);
190+
servletContext.addServlet("dispatcherServlet", DispatcherServlet.class).addMapping("/", "/mvc/*");
191+
assertThatExceptionOfType(IllegalArgumentException.class)
192+
.isThrownBy(() -> this.matcherRegistry.requestMatchers("/**"));
193+
}
194+
195+
@Test
196+
public void requestMatchersWhenPathDispatcherServletAndOtherServletsThenException() {
197+
MockServletContext servletContext = new MockServletContext();
198+
given(this.context.getServletContext()).willReturn(servletContext);
199+
servletContext.addServlet("dispatcherServlet", DispatcherServlet.class).addMapping("/mvc/*");
200+
servletContext.addServlet("default", Servlet.class).addMapping("/");
174201
assertThatExceptionOfType(IllegalArgumentException.class)
175202
.isThrownBy(() -> this.matcherRegistry.requestMatchers("/**"));
176203
}
@@ -187,6 +214,67 @@ public void requestMatchersWhenUnmappableServletsThenSkips() {
187214
assertThat(requestMatchers.get(0)).isInstanceOf(MvcRequestMatcher.class);
188215
}
189216

217+
@Test
218+
public void requestMatchersWhenOnlyDispatcherServletThenAllows() {
219+
MockServletContext servletContext = new MockServletContext();
220+
given(this.context.getServletContext()).willReturn(servletContext);
221+
servletContext.addServlet("dispatcherServlet", DispatcherServlet.class).addMapping("/mvc/*");
222+
List<RequestMatcher> requestMatchers = this.matcherRegistry.requestMatchers("/**");
223+
assertThat(requestMatchers).hasSize(1);
224+
assertThat(requestMatchers.get(0)).isInstanceOf(MvcRequestMatcher.class);
225+
}
226+
227+
@Test
228+
public void requestMatchersWhenImplicitServletsThenAllows() {
229+
mockMvcIntrospector(true);
230+
MockServletContext servletContext = new MockServletContext();
231+
given(this.context.getServletContext()).willReturn(servletContext);
232+
servletContext.addServlet("defaultServlet", Servlet.class);
233+
servletContext.addServlet("jspServlet", Servlet.class).addMapping("*.jsp", "*.jspx");
234+
servletContext.addServlet("dispatcherServlet", DispatcherServlet.class).addMapping("/");
235+
List<RequestMatcher> requestMatchers = this.matcherRegistry.requestMatchers("/**");
236+
assertThat(requestMatchers).hasSize(1);
237+
assertThat(requestMatchers.get(0)).isInstanceOf(DispatcherServletDelegatingRequestMatcher.class);
238+
}
239+
240+
@Test
241+
public void requestMatchersWhenPathBasedNonDispatcherServletThenAllows() {
242+
MockServletContext servletContext = new MockServletContext();
243+
given(this.context.getServletContext()).willReturn(servletContext);
244+
servletContext.addServlet("path", Servlet.class).addMapping("/services/*");
245+
servletContext.addServlet("default", DispatcherServlet.class).addMapping("/");
246+
List<RequestMatcher> requestMatchers = this.matcherRegistry.requestMatchers("/services/*");
247+
assertThat(requestMatchers).hasSize(1);
248+
assertThat(requestMatchers.get(0)).isInstanceOf(DispatcherServletDelegatingRequestMatcher.class);
249+
MockHttpServletRequest request = new MockHttpServletRequest("GET", "/services/endpoint");
250+
request.setHttpServletMapping(TestMockHttpServletMappings.defaultMapping());
251+
assertThat(requestMatchers.get(0).matcher(request).isMatch()).isTrue();
252+
request.setHttpServletMapping(TestMockHttpServletMappings.path(request, "/services"));
253+
request.setServletPath("/services");
254+
request.setPathInfo("/endpoint");
255+
assertThat(requestMatchers.get(0).matcher(request).isMatch()).isTrue();
256+
}
257+
258+
@Test
259+
public void matchesWhenDispatcherServletThenMvc() {
260+
MockServletContext servletContext = new MockServletContext();
261+
servletContext.addServlet("default", DispatcherServlet.class).addMapping("/");
262+
servletContext.addServlet("path", Servlet.class).addMapping("/services/*");
263+
MvcRequestMatcher mvc = mock(MvcRequestMatcher.class);
264+
AntPathRequestMatcher ant = mock(AntPathRequestMatcher.class);
265+
DispatcherServletDelegatingRequestMatcher requestMatcher = new DispatcherServletDelegatingRequestMatcher(ant,
266+
mvc, servletContext);
267+
MockHttpServletRequest request = new MockHttpServletRequest("GET", "/services/endpoint");
268+
request.setHttpServletMapping(TestMockHttpServletMappings.defaultMapping());
269+
assertThat(requestMatcher.matches(request)).isFalse();
270+
verify(mvc).matches(request);
271+
verifyNoInteractions(ant);
272+
request.setHttpServletMapping(TestMockHttpServletMappings.path(request, "/services"));
273+
assertThat(requestMatcher.matches(request)).isFalse();
274+
verify(ant).matches(request);
275+
verifyNoMoreInteractions(mvc);
276+
}
277+
190278
private void mockMvcIntrospector(boolean isPresent) {
191279
ApplicationContext context = this.matcherRegistry.getApplicationContext();
192280
given(context.containsBean("mvcHandlerMappingIntrospector")).willReturn(isPresent);

config/src/test/java/org/springframework/security/config/annotation/web/configurers/AuthorizeHttpRequestsConfigurerTests.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
import org.springframework.security.authorization.AuthorizationEventPublisher;
3737
import org.springframework.security.authorization.AuthorizationManager;
3838
import org.springframework.security.config.MockServletContext;
39+
import org.springframework.security.config.TestMockHttpServletMappings;
3940
import org.springframework.security.config.annotation.ObjectPostProcessor;
4041
import org.springframework.security.config.annotation.web.AbstractRequestMatcherRegistry;
4142
import org.springframework.security.config.annotation.web.builders.HttpSecurity;

config/src/test/java/org/springframework/security/config/annotation/web/configurers/ServletPatternRequestMatcherTests.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import org.junit.jupiter.api.Test;
2020

2121
import org.springframework.mock.web.MockHttpServletRequest;
22+
import org.springframework.security.config.TestMockHttpServletMappings;
2223

2324
import static org.assertj.core.api.Assertions.assertThat;
2425

0 commit comments

Comments
 (0)