47
47
import org .springframework .http .server .PathContainer ;
48
48
import org .springframework .http .server .RequestPath ;
49
49
import org .springframework .http .server .reactive .ServerHttpRequest ;
50
- import org .springframework .lang .NonNull ;
51
50
import org .springframework .lang .Nullable ;
52
51
import org .springframework .util .Assert ;
53
52
import org .springframework .util .CollectionUtils ;
@@ -90,7 +89,8 @@ public static RequestPredicate all() {
90
89
* @return a predicate that tests against the given HTTP method
91
90
*/
92
91
public static RequestPredicate method (HttpMethod httpMethod ) {
93
- return new HttpMethodPredicate (httpMethod );
92
+ Assert .notNull (httpMethod , "HttpMethod must not be null" );
93
+ return new SingleHttpMethodPredicate (httpMethod );
94
94
}
95
95
96
96
/**
@@ -101,7 +101,13 @@ public static RequestPredicate method(HttpMethod httpMethod) {
101
101
* @since 5.1
102
102
*/
103
103
public static RequestPredicate methods (HttpMethod ... httpMethods ) {
104
- return new HttpMethodPredicate (httpMethods );
104
+ Assert .notEmpty (httpMethods , "HttpMethods must not be empty" );
105
+ if (httpMethods .length == 1 ) {
106
+ return new SingleHttpMethodPredicate (httpMethods [0 ]);
107
+ }
108
+ else {
109
+ return new MultipleHttpMethodsPredicate (httpMethods );
110
+ }
105
111
}
106
112
107
113
/**
@@ -151,7 +157,12 @@ public static RequestPredicate headers(Predicate<ServerRequest.Headers> headersP
151
157
*/
152
158
public static RequestPredicate contentType (MediaType ... mediaTypes ) {
153
159
Assert .notEmpty (mediaTypes , "'mediaTypes' must not be empty" );
154
- return new ContentTypePredicate (mediaTypes );
160
+ if (mediaTypes .length == 1 ) {
161
+ return new SingleContentTypePredicate (mediaTypes [0 ]);
162
+ }
163
+ else {
164
+ return new MultipleContentTypesPredicate (mediaTypes );
165
+ }
155
166
}
156
167
157
168
/**
@@ -163,7 +174,12 @@ public static RequestPredicate contentType(MediaType... mediaTypes) {
163
174
*/
164
175
public static RequestPredicate accept (MediaType ... mediaTypes ) {
165
176
Assert .notEmpty (mediaTypes , "'mediaTypes' must not be empty" );
166
- return new AcceptPredicate (mediaTypes );
177
+ if (mediaTypes .length == 1 ) {
178
+ return new SingleAcceptPredicate (mediaTypes [0 ]);
179
+ }
180
+ else {
181
+ return new MultipleAcceptsPredicate (mediaTypes );
182
+ }
167
183
}
168
184
169
185
/**
@@ -529,29 +545,23 @@ public boolean modifiesAttributes() {
529
545
}
530
546
531
547
532
- private static class HttpMethodPredicate implements RequestPredicate {
533
-
534
- private final Set <HttpMethod > httpMethods ;
548
+ private static class SingleHttpMethodPredicate implements RequestPredicate {
535
549
536
- public HttpMethodPredicate (HttpMethod httpMethod ) {
537
- Assert .notNull (httpMethod , "HttpMethod must not be null" );
538
- this .httpMethods = Set .of (httpMethod );
539
- }
550
+ private final HttpMethod httpMethod ;
540
551
541
- public HttpMethodPredicate (HttpMethod ... httpMethods ) {
542
- Assert .notEmpty (httpMethods , "HttpMethods must not be empty" );
543
- this .httpMethods = new LinkedHashSet <>(Arrays .asList (httpMethods ));
552
+ public SingleHttpMethodPredicate (HttpMethod httpMethod ) {
553
+ this .httpMethod = httpMethod ;
544
554
}
545
555
546
556
@ Override
547
557
public boolean test (ServerRequest request ) {
548
558
HttpMethod method = method (request );
549
- boolean match = this .httpMethods . contains (method );
550
- traceMatch ("Method" , this .httpMethods , method , match );
559
+ boolean match = this .httpMethod . equals (method );
560
+ traceMatch ("Method" , this .httpMethod , method , match );
551
561
return match ;
552
562
}
553
563
554
- private static HttpMethod method (ServerRequest request ) {
564
+ static HttpMethod method (ServerRequest request ) {
555
565
if (CorsUtils .isPreFlightRequest (request .exchange ().getRequest ())) {
556
566
String accessControlRequestMethod =
557
567
request .headers ().firstHeader (HttpHeaders .ACCESS_CONTROL_REQUEST_METHOD );
@@ -562,19 +572,42 @@ private static HttpMethod method(ServerRequest request) {
562
572
return request .method ();
563
573
}
564
574
575
+ @ Override
576
+ public void accept (Visitor visitor ) {
577
+ visitor .method (Set .of (this .httpMethod ));
578
+ }
579
+
580
+ @ Override
581
+ public String toString () {
582
+ return this .httpMethod .toString ();
583
+ }
584
+ }
585
+
586
+
587
+ private static class MultipleHttpMethodsPredicate implements RequestPredicate {
588
+
589
+ private final Set <HttpMethod > httpMethods ;
590
+
591
+ public MultipleHttpMethodsPredicate (HttpMethod [] httpMethods ) {
592
+ this .httpMethods = new LinkedHashSet <>(Arrays .asList (httpMethods ));
593
+ }
594
+
595
+ @ Override
596
+ public boolean test (ServerRequest request ) {
597
+ HttpMethod method = SingleHttpMethodPredicate .method (request );
598
+ boolean match = this .httpMethods .contains (method );
599
+ traceMatch ("Method" , this .httpMethods , method , match );
600
+ return match ;
601
+ }
602
+
565
603
@ Override
566
604
public void accept (Visitor visitor ) {
567
605
visitor .method (Collections .unmodifiableSet (this .httpMethods ));
568
606
}
569
607
570
608
@ Override
571
609
public String toString () {
572
- if (this .httpMethods .size () == 1 ) {
573
- return this .httpMethods .iterator ().next ().toString ();
574
- }
575
- else {
576
- return this .httpMethods .toString ();
577
- }
610
+ return this .httpMethods .toString ();
578
611
}
579
612
}
580
613
@@ -669,20 +702,46 @@ public String toString() {
669
702
}
670
703
671
704
672
- private static class ContentTypePredicate extends HeadersPredicate {
705
+ private static class SingleContentTypePredicate extends HeadersPredicate {
673
706
674
- private final Set < MediaType > mediaTypes ;
707
+ private final MediaType mediaType ;
675
708
676
- public ContentTypePredicate (MediaType ... mediaTypes ) {
677
- this (Set .of (mediaTypes ));
709
+ public SingleContentTypePredicate (MediaType mediaType ) {
710
+ super (headers -> {
711
+ MediaType contentType = headers .contentType ().orElse (MediaType .APPLICATION_OCTET_STREAM );
712
+ boolean match = mediaType .includes (contentType );
713
+ traceMatch ("Content-Type" , mediaType , contentType , match );
714
+ return match ;
715
+ });
716
+ this .mediaType = mediaType ;
717
+ }
718
+
719
+ @ Override
720
+ public void accept (Visitor visitor ) {
721
+ visitor .header (HttpHeaders .CONTENT_TYPE , this .mediaType .toString ());
678
722
}
679
723
680
- private ContentTypePredicate (Set <MediaType > mediaTypes ) {
724
+ @ Override
725
+ public String toString () {
726
+ return "Content-Type: " + this .mediaType ;
727
+ }
728
+ }
729
+
730
+
731
+ private static class MultipleContentTypesPredicate extends HeadersPredicate {
732
+
733
+ private final MediaType [] mediaTypes ;
734
+
735
+ public MultipleContentTypesPredicate (MediaType [] mediaTypes ) {
681
736
super (headers -> {
682
- MediaType contentType =
683
- headers .contentType ().orElse (MediaType .APPLICATION_OCTET_STREAM );
684
- boolean match = mediaTypes .stream ()
685
- .anyMatch (mediaType -> mediaType .includes (contentType ));
737
+ MediaType contentType = headers .contentType ().orElse (MediaType .APPLICATION_OCTET_STREAM );
738
+ boolean match = false ;
739
+ for (MediaType mediaType : mediaTypes ) {
740
+ if (mediaType .includes (contentType )) {
741
+ match = true ;
742
+ break ;
743
+ }
744
+ }
686
745
traceMatch ("Content-Type" , mediaTypes , contentType , match );
687
746
return match ;
688
747
});
@@ -691,44 +750,37 @@ private ContentTypePredicate(Set<MediaType> mediaTypes) {
691
750
692
751
@ Override
693
752
public void accept (Visitor visitor ) {
694
- visitor .header (HttpHeaders .CONTENT_TYPE ,
695
- (this .mediaTypes .size () == 1 ) ?
696
- this .mediaTypes .iterator ().next ().toString () :
697
- this .mediaTypes .toString ());
753
+ visitor .header (HttpHeaders .CONTENT_TYPE , Arrays .toString (this .mediaTypes ));
698
754
}
699
755
700
756
@ Override
701
757
public String toString () {
702
- return String .format ("Content-Type: %s" ,
703
- (this .mediaTypes .size () == 1 ) ?
704
- this .mediaTypes .iterator ().next ().toString () :
705
- this .mediaTypes .toString ());
758
+ return "Content-Type: " + Arrays .toString (this .mediaTypes );
706
759
}
707
760
}
708
761
709
762
710
- private static class AcceptPredicate extends HeadersPredicate {
711
-
712
- private final Set <MediaType > mediaTypes ;
763
+ private static class SingleAcceptPredicate extends HeadersPredicate {
713
764
714
- public AcceptPredicate (MediaType ... mediaTypes ) {
715
- this (Set .of (mediaTypes ));
716
- }
765
+ private final MediaType mediaType ;
717
766
718
- private AcceptPredicate ( Set < MediaType > mediaTypes ) {
767
+ public SingleAcceptPredicate ( MediaType mediaType ) {
719
768
super (headers -> {
720
769
List <MediaType > acceptedMediaTypes = acceptedMediaTypes (headers );
721
- boolean match = acceptedMediaTypes .stream ()
722
- .anyMatch (acceptedMediaType -> mediaTypes .stream ()
723
- .anyMatch (acceptedMediaType ::isCompatibleWith ));
724
- traceMatch ("Accept" , mediaTypes , acceptedMediaTypes , match );
770
+ boolean match = false ;
771
+ for (MediaType acceptedMediaType : acceptedMediaTypes ) {
772
+ if (acceptedMediaType .isCompatibleWith (mediaType )) {
773
+ match = true ;
774
+ break ;
775
+ }
776
+ }
777
+ traceMatch ("Accept" , mediaType , acceptedMediaTypes , match );
725
778
return match ;
726
779
});
727
- this .mediaTypes = mediaTypes ;
780
+ this .mediaType = mediaType ;
728
781
}
729
782
730
- @ NonNull
731
- private static List <MediaType > acceptedMediaTypes (ServerRequest .Headers headers ) {
783
+ static List <MediaType > acceptedMediaTypes (ServerRequest .Headers headers ) {
732
784
List <MediaType > acceptedMediaTypes = headers .accept ();
733
785
if (acceptedMediaTypes .isEmpty ()) {
734
786
acceptedMediaTypes = Collections .singletonList (MediaType .ALL );
@@ -741,18 +793,47 @@ private static List<MediaType> acceptedMediaTypes(ServerRequest.Headers headers)
741
793
742
794
@ Override
743
795
public void accept (Visitor visitor ) {
744
- visitor .header (HttpHeaders .ACCEPT ,
745
- (this .mediaTypes .size () == 1 ) ?
746
- this .mediaTypes .iterator ().next ().toString () :
747
- this .mediaTypes .toString ());
796
+ visitor .header (HttpHeaders .ACCEPT , this .mediaType .toString ());
797
+ }
798
+
799
+ @ Override
800
+ public String toString () {
801
+ return "Accept: " + this .mediaType ;
802
+ }
803
+ }
804
+
805
+
806
+ private static class MultipleAcceptsPredicate extends HeadersPredicate {
807
+
808
+ private final MediaType [] mediaTypes ;
809
+
810
+ public MultipleAcceptsPredicate (MediaType [] mediaTypes ) {
811
+ super (headers -> {
812
+ List <MediaType > acceptedMediaTypes = SingleAcceptPredicate .acceptedMediaTypes (headers );
813
+ boolean match = false ;
814
+ outer :
815
+ for (MediaType acceptedMediaType : acceptedMediaTypes ) {
816
+ for (MediaType mediaType : mediaTypes ) {
817
+ if (acceptedMediaType .isCompatibleWith (mediaType )) {
818
+ match = true ;
819
+ break outer ;
820
+ }
821
+ }
822
+ }
823
+ traceMatch ("Accept" , mediaTypes , acceptedMediaTypes , match );
824
+ return match ;
825
+ });
826
+ this .mediaTypes = mediaTypes ;
827
+ }
828
+
829
+ @ Override
830
+ public void accept (Visitor visitor ) {
831
+ visitor .header (HttpHeaders .ACCEPT , Arrays .toString (this .mediaTypes ));
748
832
}
749
833
750
834
@ Override
751
835
public String toString () {
752
- return String .format ("Accept: %s" ,
753
- (this .mediaTypes .size () == 1 ) ?
754
- this .mediaTypes .iterator ().next ().toString () :
755
- this .mediaTypes .toString ());
836
+ return "Accept: " + Arrays .toString (this .mediaTypes );
756
837
}
757
838
}
758
839
0 commit comments