Skip to content

Commit 750cb73

Browse files
committed
Introduce single-value request predicates
This commit introduces new HTTP method, Content-Type, and Accept header request predicates that handle single values. Previously, these predicates were always dealt with as single-value collections, which introduced computational overhead. Closes gh-32244
1 parent 5851cdc commit 750cb73

File tree

4 files changed

+372
-145
lines changed

4 files changed

+372
-145
lines changed

spring-webflux/src/main/java/org/springframework/web/reactive/function/server/RequestPredicates.java

Lines changed: 144 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@
4747
import org.springframework.http.server.PathContainer;
4848
import org.springframework.http.server.RequestPath;
4949
import org.springframework.http.server.reactive.ServerHttpRequest;
50-
import org.springframework.lang.NonNull;
5150
import org.springframework.lang.Nullable;
5251
import org.springframework.util.Assert;
5352
import org.springframework.util.CollectionUtils;
@@ -90,7 +89,8 @@ public static RequestPredicate all() {
9089
* @return a predicate that tests against the given HTTP method
9190
*/
9291
public static RequestPredicate method(HttpMethod httpMethod) {
93-
return new HttpMethodPredicate(httpMethod);
92+
Assert.notNull(httpMethod, "HttpMethod must not be null");
93+
return new SingleHttpMethodPredicate(httpMethod);
9494
}
9595

9696
/**
@@ -101,7 +101,13 @@ public static RequestPredicate method(HttpMethod httpMethod) {
101101
* @since 5.1
102102
*/
103103
public static RequestPredicate methods(HttpMethod... httpMethods) {
104-
return new HttpMethodPredicate(httpMethods);
104+
Assert.notEmpty(httpMethods, "HttpMethods must not be empty");
105+
if (httpMethods.length == 1) {
106+
return new SingleHttpMethodPredicate(httpMethods[0]);
107+
}
108+
else {
109+
return new MultipleHttpMethodsPredicate(httpMethods);
110+
}
105111
}
106112

107113
/**
@@ -151,7 +157,12 @@ public static RequestPredicate headers(Predicate<ServerRequest.Headers> headersP
151157
*/
152158
public static RequestPredicate contentType(MediaType... mediaTypes) {
153159
Assert.notEmpty(mediaTypes, "'mediaTypes' must not be empty");
154-
return new ContentTypePredicate(mediaTypes);
160+
if (mediaTypes.length == 1) {
161+
return new SingleContentTypePredicate(mediaTypes[0]);
162+
}
163+
else {
164+
return new MultipleContentTypesPredicate(mediaTypes);
165+
}
155166
}
156167

157168
/**
@@ -163,7 +174,12 @@ public static RequestPredicate contentType(MediaType... mediaTypes) {
163174
*/
164175
public static RequestPredicate accept(MediaType... mediaTypes) {
165176
Assert.notEmpty(mediaTypes, "'mediaTypes' must not be empty");
166-
return new AcceptPredicate(mediaTypes);
177+
if (mediaTypes.length == 1) {
178+
return new SingleAcceptPredicate(mediaTypes[0]);
179+
}
180+
else {
181+
return new MultipleAcceptsPredicate(mediaTypes);
182+
}
167183
}
168184

169185
/**
@@ -529,29 +545,23 @@ public boolean modifiesAttributes() {
529545
}
530546

531547

532-
private static class HttpMethodPredicate implements RequestPredicate {
533-
534-
private final Set<HttpMethod> httpMethods;
548+
private static class SingleHttpMethodPredicate implements RequestPredicate {
535549

536-
public HttpMethodPredicate(HttpMethod httpMethod) {
537-
Assert.notNull(httpMethod, "HttpMethod must not be null");
538-
this.httpMethods = Set.of(httpMethod);
539-
}
550+
private final HttpMethod httpMethod;
540551

541-
public HttpMethodPredicate(HttpMethod... httpMethods) {
542-
Assert.notEmpty(httpMethods, "HttpMethods must not be empty");
543-
this.httpMethods = new LinkedHashSet<>(Arrays.asList(httpMethods));
552+
public SingleHttpMethodPredicate(HttpMethod httpMethod) {
553+
this.httpMethod = httpMethod;
544554
}
545555

546556
@Override
547557
public boolean test(ServerRequest request) {
548558
HttpMethod method = method(request);
549-
boolean match = this.httpMethods.contains(method);
550-
traceMatch("Method", this.httpMethods, method, match);
559+
boolean match = this.httpMethod.equals(method);
560+
traceMatch("Method", this.httpMethod, method, match);
551561
return match;
552562
}
553563

554-
private static HttpMethod method(ServerRequest request) {
564+
static HttpMethod method(ServerRequest request) {
555565
if (CorsUtils.isPreFlightRequest(request.exchange().getRequest())) {
556566
String accessControlRequestMethod =
557567
request.headers().firstHeader(HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD);
@@ -562,19 +572,42 @@ private static HttpMethod method(ServerRequest request) {
562572
return request.method();
563573
}
564574

575+
@Override
576+
public void accept(Visitor visitor) {
577+
visitor.method(Set.of(this.httpMethod));
578+
}
579+
580+
@Override
581+
public String toString() {
582+
return this.httpMethod.toString();
583+
}
584+
}
585+
586+
587+
private static class MultipleHttpMethodsPredicate implements RequestPredicate {
588+
589+
private final Set<HttpMethod> httpMethods;
590+
591+
public MultipleHttpMethodsPredicate(HttpMethod[] httpMethods) {
592+
this.httpMethods = new LinkedHashSet<>(Arrays.asList(httpMethods));
593+
}
594+
595+
@Override
596+
public boolean test(ServerRequest request) {
597+
HttpMethod method = SingleHttpMethodPredicate.method(request);
598+
boolean match = this.httpMethods.contains(method);
599+
traceMatch("Method", this.httpMethods, method, match);
600+
return match;
601+
}
602+
565603
@Override
566604
public void accept(Visitor visitor) {
567605
visitor.method(Collections.unmodifiableSet(this.httpMethods));
568606
}
569607

570608
@Override
571609
public String toString() {
572-
if (this.httpMethods.size() == 1) {
573-
return this.httpMethods.iterator().next().toString();
574-
}
575-
else {
576-
return this.httpMethods.toString();
577-
}
610+
return this.httpMethods.toString();
578611
}
579612
}
580613

@@ -669,20 +702,46 @@ public String toString() {
669702
}
670703

671704

672-
private static class ContentTypePredicate extends HeadersPredicate {
705+
private static class SingleContentTypePredicate extends HeadersPredicate {
673706

674-
private final Set<MediaType> mediaTypes;
707+
private final MediaType mediaType;
675708

676-
public ContentTypePredicate(MediaType... mediaTypes) {
677-
this(Set.of(mediaTypes));
709+
public SingleContentTypePredicate(MediaType mediaType) {
710+
super(headers -> {
711+
MediaType contentType = headers.contentType().orElse(MediaType.APPLICATION_OCTET_STREAM);
712+
boolean match = mediaType.includes(contentType);
713+
traceMatch("Content-Type", mediaType, contentType, match);
714+
return match;
715+
});
716+
this.mediaType = mediaType;
717+
}
718+
719+
@Override
720+
public void accept(Visitor visitor) {
721+
visitor.header(HttpHeaders.CONTENT_TYPE, this.mediaType.toString());
678722
}
679723

680-
private ContentTypePredicate(Set<MediaType> mediaTypes) {
724+
@Override
725+
public String toString() {
726+
return "Content-Type: " + this.mediaType;
727+
}
728+
}
729+
730+
731+
private static class MultipleContentTypesPredicate extends HeadersPredicate {
732+
733+
private final MediaType[] mediaTypes;
734+
735+
public MultipleContentTypesPredicate(MediaType[] mediaTypes) {
681736
super(headers -> {
682-
MediaType contentType =
683-
headers.contentType().orElse(MediaType.APPLICATION_OCTET_STREAM);
684-
boolean match = mediaTypes.stream()
685-
.anyMatch(mediaType -> mediaType.includes(contentType));
737+
MediaType contentType = headers.contentType().orElse(MediaType.APPLICATION_OCTET_STREAM);
738+
boolean match = false;
739+
for (MediaType mediaType : mediaTypes) {
740+
if (mediaType.includes(contentType)) {
741+
match = true;
742+
break;
743+
}
744+
}
686745
traceMatch("Content-Type", mediaTypes, contentType, match);
687746
return match;
688747
});
@@ -691,44 +750,37 @@ private ContentTypePredicate(Set<MediaType> mediaTypes) {
691750

692751
@Override
693752
public void accept(Visitor visitor) {
694-
visitor.header(HttpHeaders.CONTENT_TYPE,
695-
(this.mediaTypes.size() == 1) ?
696-
this.mediaTypes.iterator().next().toString() :
697-
this.mediaTypes.toString());
753+
visitor.header(HttpHeaders.CONTENT_TYPE, Arrays.toString(this.mediaTypes));
698754
}
699755

700756
@Override
701757
public String toString() {
702-
return String.format("Content-Type: %s",
703-
(this.mediaTypes.size() == 1) ?
704-
this.mediaTypes.iterator().next().toString() :
705-
this.mediaTypes.toString());
758+
return "Content-Type: " + Arrays.toString(this.mediaTypes);
706759
}
707760
}
708761

709762

710-
private static class AcceptPredicate extends HeadersPredicate {
711-
712-
private final Set<MediaType> mediaTypes;
763+
private static class SingleAcceptPredicate extends HeadersPredicate {
713764

714-
public AcceptPredicate(MediaType... mediaTypes) {
715-
this(Set.of(mediaTypes));
716-
}
765+
private final MediaType mediaType;
717766

718-
private AcceptPredicate(Set<MediaType> mediaTypes) {
767+
public SingleAcceptPredicate(MediaType mediaType) {
719768
super(headers -> {
720769
List<MediaType> acceptedMediaTypes = acceptedMediaTypes(headers);
721-
boolean match = acceptedMediaTypes.stream()
722-
.anyMatch(acceptedMediaType -> mediaTypes.stream()
723-
.anyMatch(acceptedMediaType::isCompatibleWith));
724-
traceMatch("Accept", mediaTypes, acceptedMediaTypes, match);
770+
boolean match = false;
771+
for (MediaType acceptedMediaType : acceptedMediaTypes) {
772+
if (acceptedMediaType.isCompatibleWith(mediaType)) {
773+
match = true;
774+
break;
775+
}
776+
}
777+
traceMatch("Accept", mediaType, acceptedMediaTypes, match);
725778
return match;
726779
});
727-
this.mediaTypes = mediaTypes;
780+
this.mediaType = mediaType;
728781
}
729782

730-
@NonNull
731-
private static List<MediaType> acceptedMediaTypes(ServerRequest.Headers headers) {
783+
static List<MediaType> acceptedMediaTypes(ServerRequest.Headers headers) {
732784
List<MediaType> acceptedMediaTypes = headers.accept();
733785
if (acceptedMediaTypes.isEmpty()) {
734786
acceptedMediaTypes = Collections.singletonList(MediaType.ALL);
@@ -741,18 +793,47 @@ private static List<MediaType> acceptedMediaTypes(ServerRequest.Headers headers)
741793

742794
@Override
743795
public void accept(Visitor visitor) {
744-
visitor.header(HttpHeaders.ACCEPT,
745-
(this.mediaTypes.size() == 1) ?
746-
this.mediaTypes.iterator().next().toString() :
747-
this.mediaTypes.toString());
796+
visitor.header(HttpHeaders.ACCEPT, this.mediaType.toString());
797+
}
798+
799+
@Override
800+
public String toString() {
801+
return "Accept: " + this.mediaType;
802+
}
803+
}
804+
805+
806+
private static class MultipleAcceptsPredicate extends HeadersPredicate {
807+
808+
private final MediaType[] mediaTypes;
809+
810+
public MultipleAcceptsPredicate(MediaType[] mediaTypes) {
811+
super(headers -> {
812+
List<MediaType> acceptedMediaTypes = SingleAcceptPredicate.acceptedMediaTypes(headers);
813+
boolean match = false;
814+
outer:
815+
for (MediaType acceptedMediaType : acceptedMediaTypes) {
816+
for (MediaType mediaType : mediaTypes) {
817+
if (acceptedMediaType.isCompatibleWith(mediaType)) {
818+
match = true;
819+
break outer;
820+
}
821+
}
822+
}
823+
traceMatch("Accept", mediaTypes, acceptedMediaTypes, match);
824+
return match;
825+
});
826+
this.mediaTypes = mediaTypes;
827+
}
828+
829+
@Override
830+
public void accept(Visitor visitor) {
831+
visitor.header(HttpHeaders.ACCEPT, Arrays.toString(this.mediaTypes));
748832
}
749833

750834
@Override
751835
public String toString() {
752-
return String.format("Accept: %s",
753-
(this.mediaTypes.size() == 1) ?
754-
this.mediaTypes.iterator().next().toString() :
755-
this.mediaTypes.toString());
836+
return "Accept: " + Arrays.toString(this.mediaTypes);
756837
}
757838
}
758839

spring-webflux/src/test/java/org/springframework/web/reactive/function/server/RequestPredicatesTests.java

Lines changed: 50 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -219,11 +219,10 @@ void headersCors() {
219219

220220

221221
@Test
222-
void contentType() {
223-
MediaType json = MediaType.APPLICATION_JSON;
224-
RequestPredicate predicate = RequestPredicates.contentType(json);
222+
void singleContentType() {
223+
RequestPredicate predicate = RequestPredicates.contentType(MediaType.APPLICATION_JSON);
225224
MockServerHttpRequest mockRequest = MockServerHttpRequest.get("https://example.com")
226-
.header(HttpHeaders.CONTENT_TYPE, json.toString())
225+
.header(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE)
227226
.build();
228227
ServerRequest request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList());
229228
assertThat(predicate.test(request)).isTrue();
@@ -236,15 +235,58 @@ void contentType() {
236235
}
237236

238237
@Test
239-
void accept() {
240-
MediaType json = MediaType.APPLICATION_JSON;
241-
RequestPredicate predicate = RequestPredicates.accept(json);
238+
void multipleContentTypes() {
239+
RequestPredicate predicate = RequestPredicates.contentType(MediaType.APPLICATION_JSON, MediaType.TEXT_PLAIN);
242240
MockServerHttpRequest mockRequest = MockServerHttpRequest.get("https://example.com")
243-
.header(HttpHeaders.ACCEPT, json.toString())
241+
.header(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE)
244242
.build();
245243
ServerRequest request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList());
246244
assertThat(predicate.test(request)).isTrue();
247245

246+
mockRequest = MockServerHttpRequest.get("https://example.com")
247+
.header(HttpHeaders.CONTENT_TYPE, MediaType.TEXT_PLAIN_VALUE)
248+
.build();
249+
request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList());
250+
assertThat(predicate.test(request)).isTrue();
251+
252+
mockRequest = MockServerHttpRequest.get("https://example.com")
253+
.header(HttpHeaders.CONTENT_TYPE, "foo/bar")
254+
.build();
255+
request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList());
256+
assertThat(predicate.test(request)).isFalse();
257+
}
258+
259+
@Test
260+
void singleAccept() {
261+
RequestPredicate predicate = RequestPredicates.accept(MediaType.APPLICATION_JSON, MediaType.TEXT_PLAIN);
262+
MockServerHttpRequest mockRequest = MockServerHttpRequest.get("https://example.com")
263+
.header(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE)
264+
.build();
265+
ServerRequest request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList());
266+
assertThat(predicate.test(request)).isTrue();
267+
268+
mockRequest = MockServerHttpRequest.get("https://example.com")
269+
.header(HttpHeaders.ACCEPT, "foo/bar")
270+
.build();
271+
request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList());
272+
assertThat(predicate.test(request)).isFalse();
273+
}
274+
275+
@Test
276+
void multipleAccepts() {
277+
RequestPredicate predicate = RequestPredicates.accept(MediaType.APPLICATION_JSON, MediaType.TEXT_PLAIN);
278+
MockServerHttpRequest mockRequest = MockServerHttpRequest.get("https://example.com")
279+
.header(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE)
280+
.build();
281+
ServerRequest request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList());
282+
assertThat(predicate.test(request)).isTrue();
283+
284+
mockRequest = MockServerHttpRequest.get("https://example.com")
285+
.header(HttpHeaders.ACCEPT, MediaType.TEXT_PLAIN_VALUE)
286+
.build();
287+
request = new DefaultServerRequest(MockServerWebExchange.from(mockRequest), Collections.emptyList());
288+
assertThat(predicate.test(request)).isTrue();
289+
248290
mockRequest = MockServerHttpRequest.get("https://example.com")
249291
.header(HttpHeaders.ACCEPT, "foo/bar")
250292
.build();

0 commit comments

Comments
 (0)