Skip to content

Commit 5a0e42b

Browse files
committed
Add MultiProtocolWebSocketHandler
It makes it possible to deploy multiple WebSocketHandler's to a URL, each supporting a different sub-protocol. Issue: SPR-10786
1 parent 7bb9c63 commit 5a0e42b

21 files changed

+463
-27
lines changed

spring-websocket/src/main/java/org/springframework/web/socket/WebSocketSession.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,11 @@ public interface WebSocketSession {
6161
*/
6262
String getRemoteAddress();
6363

64+
/**
65+
* Return the negotiated sub-protocol or {@code null} if none was specified.
66+
*/
67+
String getAcceptedProtocol();
68+
6469
/**
6570
* Return whether the connection is still open.
6671
*/

spring-websocket/src/main/java/org/springframework/web/socket/adapter/ConfigurableWebSocketSession.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import java.security.Principal;
2121

2222
import org.springframework.web.socket.WebSocketSession;
23+
import org.springframework.web.socket.server.DefaultHandshakeHandler;
2324

2425
/**
2526
* A WebSocketSession with configurable properties.
@@ -37,4 +38,12 @@ public interface ConfigurableWebSocketSession extends WebSocketSession {
3738

3839
void setPrincipal(Principal principal);
3940

41+
/**
42+
* Set the protocol accepted as part of the WebSocket handshake. This property can be
43+
* used when the WebSocket handshake is performed through
44+
* {@link DefaultHandshakeHandler} rather than the underlying WebSocket runtime, or
45+
* when there is no WebSocket handshake (e.g. SockJS HTTP fallback options)
46+
*/
47+
void setAcceptedProtocol(String protocol);
48+
4049
}

spring-websocket/src/main/java/org/springframework/web/socket/adapter/JettyWebSocketSessionAdapter.java

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import java.security.Principal;
2323

2424
import org.eclipse.jetty.websocket.api.Session;
25+
import org.eclipse.jetty.websocket.api.UpgradeResponse;
2526
import org.springframework.util.Assert;
2627
import org.springframework.util.ObjectUtils;
2728
import org.springframework.web.socket.BinaryMessage;
@@ -44,11 +45,20 @@ public class JettyWebSocketSessionAdapter
4445

4546
private Principal principal;
4647

48+
private String protocol;
49+
4750

4851
@Override
4952
public void initSession(Session session) {
5053
Assert.notNull(session, "session must not be null");
5154
this.session = session;
55+
56+
if (this.protocol == null) {
57+
UpgradeResponse response = session.getUpgradeResponse();
58+
if ((response != null) && response.getAcceptedSubProtocol() != null) {
59+
this.protocol = response.getAcceptedSubProtocol();
60+
}
61+
}
5262
}
5363

5464
@Override
@@ -101,6 +111,16 @@ public void setRemoteAddress(String address) {
101111
// ignore
102112
}
103113

114+
@Override
115+
public String getAcceptedProtocol() {
116+
return this.protocol;
117+
}
118+
119+
@Override
120+
public void setAcceptedProtocol(String protocol) {
121+
this.protocol = protocol;
122+
}
123+
104124
@Override
105125
public boolean isOpen() {
106126
return this.session.isOpen();

spring-websocket/src/main/java/org/springframework/web/socket/adapter/StandardWebSocketSessionAdapter.java

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import javax.websocket.CloseReason.CloseCodes;
2525

2626
import org.springframework.util.Assert;
27+
import org.springframework.util.StringUtils;
2728
import org.springframework.web.socket.BinaryMessage;
2829
import org.springframework.web.socket.CloseStatus;
2930
import org.springframework.web.socket.TextMessage;
@@ -45,11 +46,19 @@ public class StandardWebSocketSessionAdapter extends AbstractWebSocketSesssionAd
4546

4647
private String remoteAddress;
4748

49+
private String protocol;
50+
4851

4952
@Override
5053
public void initSession(javax.websocket.Session session) {
5154
Assert.notNull(session, "session must not be null");
5255
this.session = session;
56+
57+
if (this.protocol == null) {
58+
if (StringUtils.hasText(session.getNegotiatedSubprotocol())) {
59+
this.protocol = session.getNegotiatedSubprotocol();
60+
}
61+
}
5362
}
5463

5564
@Override
@@ -103,6 +112,16 @@ public void setRemoteAddress(String address) {
103112
this.remoteAddress = address;
104113
}
105114

115+
@Override
116+
public String getAcceptedProtocol() {
117+
return this.protocol;
118+
}
119+
120+
@Override
121+
public void setAcceptedProtocol(String protocol) {
122+
this.protocol = protocol;
123+
}
124+
106125
@Override
107126
public boolean isOpen() {
108127
return this.session.isOpen();

spring-websocket/src/main/java/org/springframework/web/socket/client/WebSocketConnectionManager.java

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ public class WebSocketConnectionManager extends ConnectionManagerSupport {
4343

4444
private WebSocketSession webSocketSession;
4545

46-
private final List<String> subProtocols = new ArrayList<String>();
46+
private final List<String> protocols = new ArrayList<String>();
4747

4848
private final boolean syncClientLifecycle;
4949

@@ -67,15 +67,15 @@ protected WebSocketHandler decorateWebSocketHandler(WebSocketHandler handler) {
6767
return new LoggingWebSocketHandlerDecorator(handler);
6868
}
6969

70-
public void setSubProtocols(List<String> subProtocols) {
71-
this.subProtocols.clear();
72-
if (!CollectionUtils.isEmpty(subProtocols)) {
73-
this.subProtocols.addAll(subProtocols);
70+
public void setSupportedProtocols(List<String> protocols) {
71+
this.protocols.clear();
72+
if (!CollectionUtils.isEmpty(protocols)) {
73+
this.protocols.addAll(protocols);
7474
}
7575
}
7676

77-
public List<String> getSubProtocols() {
78-
return this.subProtocols;
77+
public List<String> getSupportedProtocols() {
78+
return this.protocols;
7979
}
8080

8181
@Override
@@ -97,7 +97,7 @@ public void stopInternal() throws Exception {
9797
@Override
9898
protected void openConnection() throws Exception {
9999
HttpHeaders headers = new HttpHeaders();
100-
headers.setSecWebSocketProtocol(this.subProtocols);
100+
headers.setSecWebSocketProtocol(this.protocols);
101101
this.webSocketSession = this.client.doHandshake(this.webSocketHandler, headers, getUri());
102102
}
103103

spring-websocket/src/main/java/org/springframework/web/socket/client/endpoint/EndpointConnectionManager.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,8 @@ public EndpointConnectionManager(Class<? extends Endpoint> endpointClass, String
7474
}
7575

7676

77-
public void setSubProtocols(String... subprotocols) {
78-
this.configBuilder.preferredSubprotocols(Arrays.asList(subprotocols));
77+
public void setSupportedProtocols(String... protocols) {
78+
this.configBuilder.preferredSubprotocols(Arrays.asList(protocols));
7979
}
8080

8181
public void setExtensions(Extension... extensions) {

spring-websocket/src/main/java/org/springframework/web/socket/server/DefaultHandshakeHandler.java

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
import java.security.MessageDigest;
2222
import java.security.NoSuchAlgorithmException;
2323
import java.util.ArrayList;
24-
import java.util.Arrays;
2524
import java.util.Collections;
2625
import java.util.List;
2726

@@ -35,7 +34,6 @@
3534
import org.springframework.http.server.ServerHttpRequest;
3635
import org.springframework.http.server.ServerHttpResponse;
3736
import org.springframework.util.ClassUtils;
38-
import org.springframework.util.CollectionUtils;
3937
import org.springframework.util.StringUtils;
4038
import org.springframework.web.socket.WebSocketHandler;
4139

@@ -55,7 +53,7 @@ public class DefaultHandshakeHandler implements HandshakeHandler {
5553

5654
protected Log logger = LogFactory.getLog(getClass());
5755

58-
private List<String> supportedProtocols = new ArrayList<String>();
56+
private final List<String> supportedProtocols = new ArrayList<String>();
5957

6058
private final RequestUpgradeStrategy requestUpgradeStrategy;
6159

@@ -78,11 +76,22 @@ public DefaultHandshakeHandler(RequestUpgradeStrategy upgradeStrategy) {
7876
this.requestUpgradeStrategy = upgradeStrategy;
7977
}
8078

81-
79+
/**
80+
* Use this property to configure a list of sub-protocols that are supported.
81+
* The first protocol that matches what the client requested is selected.
82+
* If no protocol matches or this property is not configured, then the
83+
* response will not contain a Sec-WebSocket-Protocol header.
84+
*/
8285
public void setSupportedProtocols(String... protocols) {
83-
this.supportedProtocols = Arrays.asList(protocols);
86+
this.supportedProtocols.clear();
87+
for (String protocol : protocols) {
88+
this.supportedProtocols.add(protocol.toLowerCase());
89+
}
8490
}
8591

92+
/**
93+
* Return the list of supported sub-protocols.
94+
*/
8695
public String[] getSupportedProtocols() {
8796
return this.supportedProtocols.toArray(new String[this.supportedProtocols.size()]);
8897
}
@@ -191,9 +200,12 @@ protected boolean isValidOrigin(ServerHttpRequest request) {
191200
}
192201

193202
protected String selectProtocol(List<String> requestedProtocols) {
194-
if (CollectionUtils.isEmpty(requestedProtocols)) {
203+
if (requestedProtocols != null) {
195204
for (String protocol : requestedProtocols) {
196-
if (this.supportedProtocols.contains(protocol)) {
205+
if (this.supportedProtocols.contains(protocol.toLowerCase())) {
206+
if (logger.isDebugEnabled()) {
207+
logger.debug("Selected sub-protocol '" + protocol + "'");
208+
}
197209
return protocol;
198210
}
199211
}

spring-websocket/src/main/java/org/springframework/web/socket/server/endpoint/ServerEndpointRegistration.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ public class ServerEndpointRegistration implements ServerEndpointConfig, BeanFac
6060

6161
private List<Class<? extends Decoder>> decoders = new ArrayList<Class<? extends Decoder>>();
6262

63-
private List<String> subprotocols = new ArrayList<String>();
63+
private List<String> protocols = new ArrayList<String>();
6464

6565
private List<Extension> extensions = new ArrayList<Extension>();
6666

@@ -113,13 +113,13 @@ public Endpoint getEndpoint() {
113113
return (this.endpoint != null) ? this.endpoint : this.endpointProvider.getHandler();
114114
}
115115

116-
public void setSubprotocols(List<String> subprotocols) {
117-
this.subprotocols = subprotocols;
116+
public void setSubprotocols(List<String> protocols) {
117+
this.protocols = protocols;
118118
}
119119

120120
@Override
121121
public List<String> getSubprotocols() {
122-
return this.subprotocols;
122+
return this.protocols;
123123
}
124124

125125
public void setExtensions(List<Extension> extensions) {

spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractStandardUpgradeStrategy.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ public void upgrade(ServerHttpRequest request, ServerHttpResponse response,
4848
String protocol, WebSocketHandler handler) throws IOException, HandshakeFailureException {
4949

5050
StandardWebSocketSessionAdapter session = new StandardWebSocketSessionAdapter();
51-
this.wsSessionInitializer.initialize(request, response, session);
51+
this.wsSessionInitializer.initialize(request, response, protocol, session);
5252
StandardEndpointAdapter endpoint = new StandardEndpointAdapter(handler, session);
5353
upgradeInternal(request, response, protocol, endpoint);
5454
}

spring-websocket/src/main/java/org/springframework/web/socket/server/support/JettyRequestUpgradeStrategy.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ public String[] getSupportedVersions() {
8787

8888
@Override
8989
public void upgrade(ServerHttpRequest request, ServerHttpResponse response,
90-
String selectedProtocol, WebSocketHandler webSocketHandler) throws IOException {
90+
String protocol, WebSocketHandler webSocketHandler) throws IOException {
9191

9292
Assert.isInstanceOf(ServletServerHttpRequest.class, request);
9393
HttpServletRequest servletRequest = ((ServletServerHttpRequest) request).getServletRequest();
@@ -101,7 +101,7 @@ public void upgrade(ServerHttpRequest request, ServerHttpResponse response,
101101
}
102102

103103
JettyWebSocketSessionAdapter session = new JettyWebSocketSessionAdapter();
104-
this.wsSessionInitializer.initialize(request, response, session);
104+
this.wsSessionInitializer.initialize(request, response, protocol, session);
105105
JettyWebSocketListenerAdapter listener = new JettyWebSocketListenerAdapter(webSocketHandler, session);
106106

107107
servletRequest.setAttribute(WEBSOCKET_LISTENER_ATTR_NAME, listener);

spring-websocket/src/main/java/org/springframework/web/socket/server/support/ServerWebSocketSessionInitializer.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,14 @@
3030
*/
3131
public class ServerWebSocketSessionInitializer {
3232

33-
public void initialize(ServerHttpRequest request, ServerHttpResponse response, ConfigurableWebSocketSession session) {
33+
public void initialize(ServerHttpRequest request, ServerHttpResponse response,
34+
String protocol, ConfigurableWebSocketSession session) {
35+
3436
session.setUri(request.getURI());
3537
session.setRemoteHostName(request.getRemoteHostName());
3638
session.setRemoteAddress(request.getRemoteAddress());
3739
session.setPrincipal(request.getPrincipal());
40+
session.setAcceptedProtocol(protocol);
3841
}
3942

4043
}

spring-websocket/src/main/java/org/springframework/web/socket/sockjs/support/DefaultSockJsService.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,8 @@ protected AbstractSockJsSession getSockJsSession(String sessionId, WebSocketHand
245245
}
246246
logger.debug("Creating new session with session id \"" + sessionId + "\"");
247247
session = sessionFactory.createSession(sessionId, handler);
248-
this.sessionInitializer.initialize(request, response, session);
248+
String protocol = null; // TODO: https://github.com/sockjs/sockjs-client/issues/130
249+
this.sessionInitializer.initialize(request, response, protocol, session);
249250
this.sessions.put(sessionId, session);
250251
return session;
251252
}

spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/AbstractHttpReceivingTransportHandler.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,9 @@ public final void handleRequest(ServerHttpRequest request, ServerHttpResponse re
6363
return;
6464
}
6565

66+
// TODO: check "Sec-WebSocket-Protocol" header
67+
// https://github.com/sockjs/sockjs-client/issues/130
68+
6669
handleRequestInternal(request, response, session);
6770
}
6871

spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/AbstractHttpSockJsSession.java

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,32 @@ public abstract class AbstractHttpSockJsSession extends AbstractSockJsSession {
4949

5050
private ServerHttpResponse response;
5151

52+
private String protocol;
53+
5254

5355
public AbstractHttpSockJsSession(String sessionId, SockJsConfiguration config, WebSocketHandler handler) {
5456
super(sessionId, config, handler);
5557
}
5658

5759

60+
/**
61+
* Unlike WebSocket where sub-protocol negotiation is part of the
62+
* initial handshake, in HTTP transports the same negotiation must
63+
* be emulated and the selected protocol set through this setter.
64+
*
65+
* @param protocol the sub-protocol to set
66+
*/
67+
public void setAcceptedProtocol(String protocol) {
68+
this.protocol = protocol;
69+
}
70+
71+
/**
72+
* Return the selected sub-protocol to use.
73+
*/
74+
public String getAcceptedProtocol() {
75+
return this.protocol;
76+
}
77+
5878
public synchronized void setInitialRequest(ServerHttpRequest request, ServerHttpResponse response,
5979
FrameFormat frameFormat) throws TransportErrorException {
6080

spring-websocket/src/main/java/org/springframework/web/socket/sockjs/transport/WebSocketServerSockJsSession.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,20 @@ public WebSocketServerSockJsSession(String sessionId, SockJsConfiguration config
4848
}
4949

5050

51+
@Override
52+
public String getAcceptedProtocol() {
53+
if (this.webSocketSession == null) {
54+
logger.warn("getAcceptedProtocol() invoked before WebSocketSession has been initialized.");
55+
return null;
56+
}
57+
return this.webSocketSession.getAcceptedProtocol();
58+
}
59+
60+
@Override
61+
public void setAcceptedProtocol(String protocol) {
62+
// ignore, webSocketSession should have it
63+
}
64+
5165
public void initWebSocketSession(WebSocketSession session) throws Exception {
5266
this.webSocketSession = session;
5367
try {

0 commit comments

Comments
 (0)