Skip to content

Commit 1472e97

Browse files
committed
Update SpringConfiguration to support beans by type
Issue: SPR-10605
1 parent f0dda0e commit 1472e97

File tree

3 files changed

+59
-27
lines changed

3 files changed

+59
-27
lines changed

spring-messaging/src/test/java/org/springframework/messaging/simp/JettyTestServer.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import org.springframework.util.SocketUtils;
2323
import org.springframework.web.context.WebApplicationContext;
2424
import org.springframework.web.servlet.DispatcherServlet;
25-
import org.springframework.web.socket.TestServer;
2625

2726

2827
/**

spring-websocket/src/main/java/org/springframework/web/socket/server/endpoint/SpringConfigurator.java

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616

1717
package org.springframework.web.socket.server.endpoint;
1818

19+
import java.util.Map;
20+
import java.util.concurrent.ConcurrentHashMap;
21+
1922
import javax.websocket.server.ServerEndpoint;
2023
import javax.websocket.server.ServerEndpointConfig.Configurator;
2124

@@ -24,6 +27,7 @@
2427
import org.springframework.core.annotation.AnnotationUtils;
2528
import org.springframework.stereotype.Component;
2629
import org.springframework.util.ClassUtils;
30+
import org.springframework.util.ObjectUtils;
2731
import org.springframework.web.context.ContextLoader;
2832
import org.springframework.web.context.WebApplicationContext;
2933

@@ -48,7 +52,13 @@ public class SpringConfigurator extends Configurator {
4852

4953
private static Log logger = LogFactory.getLog(SpringConfigurator.class);
5054

55+
private static final Map<String, Map<Class<?>, String>> cache =
56+
new ConcurrentHashMap<String, Map<Class<?>, String>>();
57+
58+
private static final String NO_VALUE = ObjectUtils.identityToString(new Object());
5159

60+
61+
@SuppressWarnings("unchecked")
5262
@Override
5363
public <T> T getEndpointInstance(Class<T> endpointClass) throws InstantiationException {
5464

@@ -77,10 +87,44 @@ public <T> T getEndpointInstance(Class<T> endpointClass) throws InstantiationExc
7787
return endpoint;
7888
}
7989

90+
beanName = getBeanNameByType(wac, endpointClass);
91+
if (beanName != null) {
92+
return (T) wac.getBean(beanName);
93+
}
94+
8095
if (logger.isTraceEnabled()) {
8196
logger.trace("Creating new @ServerEndpoint instance of type " + endpointClass);
8297
}
8398
return wac.getAutowireCapableBeanFactory().createBean(endpointClass);
8499
}
85100

101+
private String getBeanNameByType(WebApplicationContext wac, Class<?> endpointClass) {
102+
103+
String wacId = wac.getId();
104+
105+
Map<Class<?>, String> beanNamesByType = cache.get(wacId);
106+
if (beanNamesByType == null) {
107+
beanNamesByType = new ConcurrentHashMap<Class<?>, String>();
108+
cache.put(wacId, beanNamesByType);
109+
}
110+
111+
if (!beanNamesByType.containsKey(endpointClass)) {
112+
String[] names = wac.getBeanNamesForType(endpointClass);
113+
if (names.length == 1) {
114+
beanNamesByType.put(endpointClass, names[0]);
115+
}
116+
else {
117+
beanNamesByType.put(endpointClass, NO_VALUE);
118+
if (names.length > 1) {
119+
String message = "Found multiple @ServerEndpoint's of type " + endpointClass + ", names=" + names;
120+
logger.error(message);
121+
throw new IllegalStateException(message);
122+
}
123+
}
124+
}
125+
126+
String beanName = beanNamesByType.get(endpointClass);
127+
return NO_VALUE.equals(beanName) ? null : beanName;
128+
}
129+
86130
}

spring-websocket/src/test/java/org/springframework/web/socket/server/endpoint/SpringConfiguratorTests.java

Lines changed: 15 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,7 @@
1616

1717
package org.springframework.web.socket.server.endpoint;
1818

19-
import javax.websocket.Endpoint;
20-
import javax.websocket.EndpointConfig;
21-
import javax.websocket.Session;
19+
import javax.websocket.server.ServerEndpoint;
2220

2321
import org.junit.After;
2422
import org.junit.Before;
@@ -65,22 +63,22 @@ public void destroy() {
6563

6664

6765
@Test
68-
public void getEndpointInstancePerConnection() throws Exception {
66+
public void getEndpointPerConnection() throws Exception {
6967
PerConnectionEchoEndpoint endpoint = this.configurator.getEndpointInstance(PerConnectionEchoEndpoint.class);
7068
assertNotNull(endpoint);
7169
}
7270

7371
@Test
74-
public void getEndpointInstanceSingletonByType() throws Exception {
72+
public void getEndpointSingletonByType() throws Exception {
7573
EchoEndpoint expected = this.webAppContext.getBean(EchoEndpoint.class);
7674
EchoEndpoint actual = this.configurator.getEndpointInstance(EchoEndpoint.class);
7775
assertSame(expected, actual);
7876
}
7977

8078
@Test
81-
public void getEndpointInstanceSingletonByComponentName() throws Exception {
82-
AnotherEchoEndpoint expected = this.webAppContext.getBean(AnotherEchoEndpoint.class);
83-
AnotherEchoEndpoint actual = this.configurator.getEndpointInstance(AnotherEchoEndpoint.class);
79+
public void getEndpointSingletonByComponentName() throws Exception {
80+
ComponentEchoEndpoint expected = this.webAppContext.getBean(ComponentEchoEndpoint.class);
81+
ComponentEchoEndpoint actual = this.configurator.getEndpointInstance(ComponentEchoEndpoint.class);
8482
assertSame(expected, actual);
8583
}
8684

@@ -90,7 +88,7 @@ public void getEndpointInstanceSingletonByComponentName() throws Exception {
9088
static class Config {
9189

9290
@Bean
93-
public EchoEndpoint echoEndpoint() {
91+
public EchoEndpoint javaConfigEndpoint() {
9492
return new EchoEndpoint(echoService());
9593
}
9694

@@ -100,7 +98,8 @@ public EchoService echoService() {
10098
}
10199
}
102100

103-
private static class EchoEndpoint extends Endpoint {
101+
@ServerEndpoint("/echo")
102+
private static class EchoEndpoint {
104103

105104
@SuppressWarnings("unused")
106105
private final EchoService service;
@@ -109,29 +108,23 @@ private static class EchoEndpoint extends Endpoint {
109108
public EchoEndpoint(EchoService service) {
110109
this.service = service;
111110
}
112-
113-
@Override
114-
public void onOpen(Session session, EndpointConfig config) {
115-
}
116111
}
117112

118-
@Component("myEchoEndpoint")
119-
private static class AnotherEchoEndpoint extends Endpoint {
113+
@Component("myComponentEchoEndpoint")
114+
@ServerEndpoint("/echo")
115+
private static class ComponentEchoEndpoint {
120116

121117
@SuppressWarnings("unused")
122118
private final EchoService service;
123119

124120
@Autowired
125-
public AnotherEchoEndpoint(EchoService service) {
121+
public ComponentEchoEndpoint(EchoService service) {
126122
this.service = service;
127123
}
128-
129-
@Override
130-
public void onOpen(Session session, EndpointConfig config) {
131-
}
132124
}
133125

134-
private static class PerConnectionEchoEndpoint extends Endpoint {
126+
@ServerEndpoint("/echo")
127+
private static class PerConnectionEchoEndpoint {
135128

136129
@SuppressWarnings("unused")
137130
private final EchoService service;
@@ -140,10 +133,6 @@ private static class PerConnectionEchoEndpoint extends Endpoint {
140133
public PerConnectionEchoEndpoint(EchoService service) {
141134
this.service = service;
142135
}
143-
144-
@Override
145-
public void onOpen(Session session, EndpointConfig config) {
146-
}
147136
}
148137

149138
private static class EchoService { }

0 commit comments

Comments
 (0)