17
17
package org .springframework .web .socket .server .support ;
18
18
19
19
import java .io .IOException ;
20
+ import java .lang .reflect .Method ;
20
21
import java .util .Arrays ;
21
22
import java .util .Collections ;
22
23
import java .util .Map ;
26
27
import javax .servlet .http .HttpServletRequest ;
27
28
import javax .servlet .http .HttpServletResponse ;
28
29
import javax .websocket .Endpoint ;
30
+ import javax .websocket .server .ServerEndpointConfig ;
29
31
32
+ import org .apache .tomcat .websocket .server .WsHandshakeRequest ;
33
+ import org .apache .tomcat .websocket .server .WsHttpUpgradeHandler ;
30
34
import org .apache .tomcat .websocket .server .WsServerContainer ;
31
35
import org .springframework .http .server .ServerHttpRequest ;
32
36
import org .springframework .http .server .ServerHttpResponse ;
33
37
import org .springframework .http .server .ServletServerHttpRequest ;
34
38
import org .springframework .http .server .ServletServerHttpResponse ;
35
39
import org .springframework .util .Assert ;
40
+ import org .springframework .util .ReflectionUtils ;
36
41
import org .springframework .web .socket .server .HandshakeFailureException ;
37
42
import org .springframework .web .socket .server .endpoint .ServerEndpointRegistration ;
38
43
@@ -60,6 +65,18 @@ public void upgradeInternal(ServerHttpRequest request, ServerHttpResponse respon
60
65
Assert .isTrue (response instanceof ServletServerHttpResponse );
61
66
HttpServletResponse servletResponse = ((ServletServerHttpResponse ) response ).getServletResponse ();
62
67
68
+ if (hasDoUpgrade ) {
69
+ doUpgrade (servletRequest , servletResponse , acceptedProtocol , endpoint );
70
+ }
71
+ else {
72
+ upgradeTomcat80RC1 (servletRequest , acceptedProtocol , endpoint );
73
+ }
74
+ }
75
+
76
+ private void doUpgrade (HttpServletRequest servletRequest , HttpServletResponse servletResponse ,
77
+ String acceptedProtocol , Endpoint endpoint ) {
78
+
79
+ StringBuffer requestUrl = servletRequest .getRequestURL ();
63
80
String path = servletRequest .getRequestURI (); // shouldn't matter
64
81
Map <String , String > pathParams = Collections .<String , String > emptyMap ();
65
82
@@ -71,11 +88,11 @@ public void upgradeInternal(ServerHttpRequest request, ServerHttpResponse respon
71
88
}
72
89
catch (ServletException ex ) {
73
90
throw new HandshakeFailureException (
74
- "Servlet request failed to upgrade to WebSocket, uri=" + request . getURI () , ex );
91
+ "Servlet request failed to upgrade to WebSocket, uri=" + requestUrl , ex );
75
92
}
76
93
catch (IOException ex ) {
77
94
throw new HandshakeFailureException (
78
- "Response update failed during upgrade to WebSocket, uri=" + request . getURI () , ex );
95
+ "Response update failed during upgrade to WebSocket, uri=" + requestUrl , ex );
79
96
}
80
97
}
81
98
@@ -85,4 +102,36 @@ private WsServerContainer getContainer(HttpServletRequest servletRequest) {
85
102
return (WsServerContainer ) servletContext .getAttribute (attribute );
86
103
}
87
104
88
- }
105
+ // FIXME: Remove this after RC2 is out
106
+
107
+ private void upgradeTomcat80RC1 (HttpServletRequest request , String protocol , Endpoint endpoint ) {
108
+
109
+ WsHttpUpgradeHandler upgradeHandler ;
110
+ try {
111
+ upgradeHandler = request .upgrade (WsHttpUpgradeHandler .class );
112
+ }
113
+ catch (Exception e ) {
114
+ throw new HandshakeFailureException ("Unable to create UpgardeHandler" , e );
115
+ }
116
+
117
+ WsHandshakeRequest webSocketRequest = new WsHandshakeRequest (request );
118
+ try {
119
+ Method method = ReflectionUtils .findMethod (WsHandshakeRequest .class , "finished" );
120
+ ReflectionUtils .makeAccessible (method );
121
+ method .invoke (webSocketRequest );
122
+ }
123
+ catch (Exception ex ) {
124
+ throw new HandshakeFailureException ("Failed to upgrade HttpServletRequest" , ex );
125
+ }
126
+
127
+ ServerEndpointConfig endpointConfig = new ServerEndpointRegistration ("/shouldntmatter" , endpoint );
128
+
129
+ upgradeHandler .preInit (endpoint , endpointConfig , getContainer (request ), webSocketRequest ,
130
+ protocol , Collections .<String , String > emptyMap (), request .isSecure ());
131
+ }
132
+
133
+ private static boolean hasDoUpgrade = (ReflectionUtils .findMethod (WsServerContainer .class ,
134
+ "doUpgrade" , HttpServletRequest .class , HttpServletResponse .class ,
135
+ ServerEndpointConfig .class , Map .class ) != null );
136
+
137
+ }
0 commit comments