Skip to content

Commit e1f51cb

Browse files
committed
Check both https and wss in forwarded header checks
Closes gh-27097
1 parent 6ec7cff commit e1f51cb

File tree

4 files changed

+18
-12
lines changed

4 files changed

+18
-12
lines changed

spring-web/src/main/java/org/springframework/web/filter/ForwardedHeaderFilter.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2020 the original author or authors.
2+
* Copyright 2002-2021 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -239,7 +239,7 @@ private static class ForwardedHeaderExtractingRequest extends ForwardedHeaderRem
239239
int port = uriComponents.getPort();
240240

241241
this.scheme = uriComponents.getScheme();
242-
this.secure = "https".equals(this.scheme);
242+
this.secure = "https".equals(this.scheme) || "wss".equals(this.scheme);
243243
this.host = uriComponents.getHost();
244244
this.port = (port == -1 ? (this.secure ? 443 : 80) : port);
245245

spring-web/src/main/java/org/springframework/web/util/UriComponentsBuilder.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -882,7 +882,7 @@ else if (isForwardedSslOn(headers)) {
882882
}
883883

884884
if (this.scheme != null && ((this.scheme.equals("http") && "80".equals(this.port)) ||
885-
(this.scheme.equals("https") && "443".equals(this.port)))) {
885+
((this.scheme.equals("https") || this.scheme.equals("wss")) && "443".equals(this.port)))) {
886886
port(null);
887887
}
888888

spring-web/src/test/java/org/springframework/web/filter/ForwardedHeaderFilterTests.java

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030
import org.junit.jupiter.api.BeforeEach;
3131
import org.junit.jupiter.api.Nested;
3232
import org.junit.jupiter.api.Test;
33+
import org.junit.jupiter.params.ParameterizedTest;
34+
import org.junit.jupiter.params.provider.ValueSource;
3335

3436
import org.springframework.web.testfixture.servlet.MockFilterChain;
3537
import org.springframework.web.testfixture.servlet.MockHttpServletRequest;
@@ -102,10 +104,11 @@ public void shouldNotFilter() {
102104
assertThat(this.filter.shouldNotFilter(new MockHttpServletRequest())).isTrue();
103105
}
104106

105-
@Test
106-
public void forwardedRequest() throws Exception {
107+
@ParameterizedTest
108+
@ValueSource(strings = {"https", "wss"})
109+
public void forwardedRequest(String protocol) throws Exception {
107110
this.request.setRequestURI("/mvc-showcase");
108-
this.request.addHeader(X_FORWARDED_PROTO, "https");
111+
this.request.addHeader(X_FORWARDED_PROTO, protocol);
109112
this.request.addHeader(X_FORWARDED_HOST, "84.198.58.199");
110113
this.request.addHeader(X_FORWARDED_PORT, "443");
111114
this.request.addHeader("foo", "bar");
@@ -115,8 +118,8 @@ public void forwardedRequest() throws Exception {
115118
HttpServletRequest actual = (HttpServletRequest) this.filterChain.getRequest();
116119

117120
assertThat(actual).isNotNull();
118-
assertThat(actual.getRequestURL().toString()).isEqualTo("https://84.198.58.199/mvc-showcase");
119-
assertThat(actual.getScheme()).isEqualTo("https");
121+
assertThat(actual.getRequestURL().toString()).isEqualTo(protocol + "://84.198.58.199/mvc-showcase");
122+
assertThat(actual.getScheme()).isEqualTo(protocol);
120123
assertThat(actual.getServerName()).isEqualTo("84.198.58.199");
121124
assertThat(actual.getServerPort()).isEqualTo(443);
122125
assertThat(actual.isSecure()).isTrue();

spring-web/src/test/java/org/springframework/web/util/UriComponentsBuilderTests.java

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
import java.util.function.BiConsumer;
2929

3030
import org.junit.jupiter.api.Test;
31+
import org.junit.jupiter.params.ParameterizedTest;
32+
import org.junit.jupiter.params.provider.ValueSource;
3133

3234
import org.springframework.http.HttpHeaders;
3335
import org.springframework.http.HttpRequest;
@@ -374,10 +376,11 @@ void fromHttpRequest() {
374376
assertThat(result.getQuery()).isEqualTo("a=1");
375377
}
376378

377-
@Test // SPR-12771
378-
void fromHttpRequestResetsPortBeforeSettingIt() {
379+
@ParameterizedTest // gh-17368, gh-27097
380+
@ValueSource(strings = {"https", "wss"})
381+
void fromHttpRequestResetsPortBeforeSettingIt(String protocol) {
379382
MockHttpServletRequest request = new MockHttpServletRequest();
380-
request.addHeader("X-Forwarded-Proto", "https");
383+
request.addHeader("X-Forwarded-Proto", protocol);
381384
request.addHeader("X-Forwarded-Host", "84.198.58.199");
382385
request.addHeader("X-Forwarded-Port", 443);
383386
request.setScheme("http");
@@ -388,7 +391,7 @@ void fromHttpRequestResetsPortBeforeSettingIt() {
388391
HttpRequest httpRequest = new ServletServerHttpRequest(request);
389392
UriComponents result = UriComponentsBuilder.fromHttpRequest(httpRequest).build();
390393

391-
assertThat(result.getScheme()).isEqualTo("https");
394+
assertThat(result.getScheme()).isEqualTo(protocol);
392395
assertThat(result.getHost()).isEqualTo("84.198.58.199");
393396
assertThat(result.getPort()).isEqualTo(-1);
394397
assertThat(result.getPath()).isEqualTo("/rest/mobile/users/1");

0 commit comments

Comments
 (0)