Skip to content

Commit 6eb0e9e

Browse files
committed
Unwrap decorated request or response
Closes: gh-23598
1 parent 9db4118 commit 6eb0e9e

File tree

4 files changed

+87
-23
lines changed

4 files changed

+87
-23
lines changed

spring-webflux/src/main/java/org/springframework/web/reactive/socket/server/upgrade/JettyRequestUpgradeStrategy.java

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2018 the original author or authors.
2+
* Copyright 2002-2019 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -18,6 +18,7 @@
1818

1919
import java.io.IOException;
2020
import java.util.function.Supplier;
21+
2122
import javax.servlet.ServletContext;
2223
import javax.servlet.http.HttpServletRequest;
2324
import javax.servlet.http.HttpServletResponse;
@@ -32,7 +33,9 @@
3233
import org.springframework.http.server.reactive.AbstractServerHttpRequest;
3334
import org.springframework.http.server.reactive.AbstractServerHttpResponse;
3435
import org.springframework.http.server.reactive.ServerHttpRequest;
36+
import org.springframework.http.server.reactive.ServerHttpRequestDecorator;
3537
import org.springframework.http.server.reactive.ServerHttpResponse;
38+
import org.springframework.http.server.reactive.ServerHttpResponseDecorator;
3639
import org.springframework.lang.Nullable;
3740
import org.springframework.util.Assert;
3841
import org.springframework.web.reactive.socket.HandshakeInfo;
@@ -144,8 +147,8 @@ public Mono<Void> upgrade(ServerWebExchange exchange, WebSocketHandler handler,
144147
ServerHttpRequest request = exchange.getRequest();
145148
ServerHttpResponse response = exchange.getResponse();
146149

147-
HttpServletRequest servletRequest = getHttpServletRequest(request);
148-
HttpServletResponse servletResponse = getHttpServletResponse(response);
150+
HttpServletRequest servletRequest = getNativeRequest(request);
151+
HttpServletResponse servletResponse = getNativeResponse(response);
149152

150153
HandshakeInfo handshakeInfo = handshakeInfoFactory.get();
151154
DataBufferFactory factory = response.bufferFactory();
@@ -173,14 +176,30 @@ public Mono<Void> upgrade(ServerWebExchange exchange, WebSocketHandler handler,
173176
return Mono.empty();
174177
}
175178

176-
private HttpServletRequest getHttpServletRequest(ServerHttpRequest request) {
177-
Assert.isInstanceOf(AbstractServerHttpRequest.class, request);
178-
return ((AbstractServerHttpRequest) request).getNativeRequest();
179+
private static HttpServletRequest getNativeRequest(ServerHttpRequest request) {
180+
if (request instanceof AbstractServerHttpRequest) {
181+
return ((AbstractServerHttpRequest) request).getNativeRequest();
182+
}
183+
else if (request instanceof ServerHttpRequestDecorator) {
184+
return getNativeRequest(((ServerHttpRequestDecorator) request).getDelegate());
185+
}
186+
else {
187+
throw new IllegalArgumentException(
188+
"Couldn't find HttpServletRequest in " + request.getClass().getName());
189+
}
179190
}
180191

181-
private HttpServletResponse getHttpServletResponse(ServerHttpResponse response) {
182-
Assert.isInstanceOf(AbstractServerHttpResponse.class, response);
183-
return ((AbstractServerHttpResponse) response).getNativeResponse();
192+
private static HttpServletResponse getNativeResponse(ServerHttpResponse response) {
193+
if (response instanceof AbstractServerHttpResponse) {
194+
return ((AbstractServerHttpResponse) response).getNativeResponse();
195+
}
196+
else if (response instanceof ServerHttpResponseDecorator) {
197+
return getNativeResponse(((ServerHttpResponseDecorator) response).getDelegate());
198+
}
199+
else {
200+
throw new IllegalArgumentException(
201+
"Couldn't find HttpServletResponse in " + response.getClass().getName());
202+
}
184203
}
185204

186205
private void startLazily(HttpServletRequest request) {

spring-webflux/src/main/java/org/springframework/web/reactive/socket/server/upgrade/ReactorNettyRequestUpgradeStrategy.java

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2018 the original author or authors.
2+
* Copyright 2002-2019 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -24,6 +24,7 @@
2424
import org.springframework.core.io.buffer.NettyDataBufferFactory;
2525
import org.springframework.http.server.reactive.AbstractServerHttpResponse;
2626
import org.springframework.http.server.reactive.ServerHttpResponse;
27+
import org.springframework.http.server.reactive.ServerHttpResponseDecorator;
2728
import org.springframework.lang.Nullable;
2829
import org.springframework.web.reactive.socket.HandshakeInfo;
2930
import org.springframework.web.reactive.socket.WebSocketHandler;
@@ -72,7 +73,7 @@ public Mono<Void> upgrade(ServerWebExchange exchange, WebSocketHandler handler,
7273
@Nullable String subProtocol, Supplier<HandshakeInfo> handshakeInfoFactory) {
7374

7475
ServerHttpResponse response = exchange.getResponse();
75-
HttpServerResponse reactorResponse = ((AbstractServerHttpResponse) response).getNativeResponse();
76+
HttpServerResponse reactorResponse = getNativeResponse(response);
7677
HandshakeInfo handshakeInfo = handshakeInfoFactory.get();
7778
NettyDataBufferFactory bufferFactory = (NettyDataBufferFactory) response.bufferFactory();
7879

@@ -85,4 +86,17 @@ public Mono<Void> upgrade(ServerWebExchange exchange, WebSocketHandler handler,
8586
});
8687
}
8788

89+
private static HttpServerResponse getNativeResponse(ServerHttpResponse response) {
90+
if (response instanceof AbstractServerHttpResponse) {
91+
return ((AbstractServerHttpResponse) response).getNativeResponse();
92+
}
93+
else if (response instanceof ServerHttpResponseDecorator) {
94+
return getNativeResponse(((ServerHttpResponseDecorator) response).getDelegate());
95+
}
96+
else {
97+
throw new IllegalArgumentException(
98+
"Couldn't find native response in " + response.getClass().getName());
99+
}
100+
}
101+
88102
}

spring-webflux/src/main/java/org/springframework/web/reactive/socket/server/upgrade/TomcatRequestUpgradeStrategy.java

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2018 the original author or authors.
2+
* Copyright 2002-2019 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -19,6 +19,7 @@
1919
import java.io.IOException;
2020
import java.util.Collections;
2121
import java.util.function.Supplier;
22+
2223
import javax.servlet.ServletException;
2324
import javax.servlet.http.HttpServletRequest;
2425
import javax.servlet.http.HttpServletResponse;
@@ -32,7 +33,9 @@
3233
import org.springframework.http.server.reactive.AbstractServerHttpRequest;
3334
import org.springframework.http.server.reactive.AbstractServerHttpResponse;
3435
import org.springframework.http.server.reactive.ServerHttpRequest;
36+
import org.springframework.http.server.reactive.ServerHttpRequestDecorator;
3537
import org.springframework.http.server.reactive.ServerHttpResponse;
38+
import org.springframework.http.server.reactive.ServerHttpResponseDecorator;
3639
import org.springframework.lang.Nullable;
3740
import org.springframework.util.Assert;
3841
import org.springframework.web.reactive.socket.HandshakeInfo;
@@ -130,8 +133,8 @@ public Mono<Void> upgrade(ServerWebExchange exchange, WebSocketHandler handler,
130133
ServerHttpRequest request = exchange.getRequest();
131134
ServerHttpResponse response = exchange.getResponse();
132135

133-
HttpServletRequest servletRequest = getHttpServletRequest(request);
134-
HttpServletResponse servletResponse = getHttpServletResponse(response);
136+
HttpServletRequest servletRequest = getNativeRequest(request);
137+
HttpServletResponse servletResponse = getNativeResponse(response);
135138

136139
HandshakeInfo handshakeInfo = handshakeInfoFactory.get();
137140
DataBufferFactory bufferFactory = response.bufferFactory();
@@ -155,14 +158,30 @@ public Mono<Void> upgrade(ServerWebExchange exchange, WebSocketHandler handler,
155158
return Mono.empty();
156159
}
157160

158-
private HttpServletRequest getHttpServletRequest(ServerHttpRequest request) {
159-
Assert.isInstanceOf(AbstractServerHttpRequest.class, request, "ServletServerHttpRequest required");
160-
return ((AbstractServerHttpRequest) request).getNativeRequest();
161+
private static HttpServletRequest getNativeRequest(ServerHttpRequest request) {
162+
if (request instanceof AbstractServerHttpRequest) {
163+
return ((AbstractServerHttpRequest) request).getNativeRequest();
164+
}
165+
else if (request instanceof ServerHttpRequestDecorator) {
166+
return getNativeRequest(((ServerHttpRequestDecorator) request).getDelegate());
167+
}
168+
else {
169+
throw new IllegalArgumentException(
170+
"Couldn't find HttpServletRequest in " + request.getClass().getName());
171+
}
161172
}
162173

163-
private HttpServletResponse getHttpServletResponse(ServerHttpResponse response) {
164-
Assert.isInstanceOf(AbstractServerHttpResponse.class, response, "ServletServerHttpResponse required");
165-
return ((AbstractServerHttpResponse) response).getNativeResponse();
174+
private static HttpServletResponse getNativeResponse(ServerHttpResponse response) {
175+
if (response instanceof AbstractServerHttpResponse) {
176+
return ((AbstractServerHttpResponse) response).getNativeResponse();
177+
}
178+
else if (response instanceof ServerHttpResponseDecorator) {
179+
return getNativeResponse(((ServerHttpResponseDecorator) response).getDelegate());
180+
}
181+
else {
182+
throw new IllegalArgumentException(
183+
"Couldn't find HttpServletResponse in " + response.getClass().getName());
184+
}
166185
}
167186

168187
private WsServerContainer getContainer(HttpServletRequest request) {

spring-webflux/src/main/java/org/springframework/web/reactive/socket/server/upgrade/UndertowRequestUpgradeStrategy.java

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import org.springframework.core.io.buffer.DataBufferFactory;
3434
import org.springframework.http.server.reactive.AbstractServerHttpRequest;
3535
import org.springframework.http.server.reactive.ServerHttpRequest;
36+
import org.springframework.http.server.reactive.ServerHttpRequestDecorator;
3637
import org.springframework.lang.Nullable;
3738
import org.springframework.util.Assert;
3839
import org.springframework.web.reactive.socket.HandshakeInfo;
@@ -55,9 +56,7 @@ public class UndertowRequestUpgradeStrategy implements RequestUpgradeStrategy {
5556
public Mono<Void> upgrade(ServerWebExchange exchange, WebSocketHandler handler,
5657
@Nullable String subProtocol, Supplier<HandshakeInfo> handshakeInfoFactory) {
5758

58-
ServerHttpRequest request = exchange.getRequest();
59-
Assert.isInstanceOf(AbstractServerHttpRequest.class, request);
60-
HttpServerExchange httpExchange = ((AbstractServerHttpRequest) request).getNativeRequest();
59+
HttpServerExchange httpExchange = getNativeRequest(exchange.getRequest());
6160

6261
Set<String> protocols = (subProtocol != null ? Collections.singleton(subProtocol) : Collections.emptySet());
6362
Hybi13Handshake handshake = new Hybi13Handshake(protocols, false);
@@ -77,6 +76,19 @@ public Mono<Void> upgrade(ServerWebExchange exchange, WebSocketHandler handler,
7776
return Mono.empty();
7877
}
7978

79+
private static HttpServerExchange getNativeRequest(ServerHttpRequest request) {
80+
if (request instanceof AbstractServerHttpRequest) {
81+
return ((AbstractServerHttpRequest) request).getNativeRequest();
82+
}
83+
else if (request instanceof ServerHttpRequestDecorator) {
84+
return getNativeRequest(((ServerHttpRequestDecorator) request).getDelegate());
85+
}
86+
else {
87+
throw new IllegalArgumentException(
88+
"Couldn't find HttpServerExchange in " + request.getClass().getName());
89+
}
90+
}
91+
8092

8193
private class DefaultCallback implements WebSocketConnectionCallback {
8294

0 commit comments

Comments
 (0)