16
16
17
17
package org .springframework .messaging .simp .broker ;
18
18
19
+ import java .security .Principal ;
19
20
import java .util .Collection ;
21
+ import java .util .Map ;
22
+ import java .util .concurrent .ConcurrentHashMap ;
23
+ import java .util .concurrent .ScheduledFuture ;
20
24
21
25
import org .springframework .messaging .Message ;
22
26
import org .springframework .messaging .MessageChannel ;
27
31
import org .springframework .messaging .support .MessageBuilder ;
28
32
import org .springframework .messaging .support .MessageHeaderAccessor ;
29
33
import org .springframework .messaging .support .MessageHeaderInitializer ;
34
+ import org .springframework .scheduling .TaskScheduler ;
30
35
import org .springframework .util .Assert ;
31
36
import org .springframework .util .MultiValueMap ;
32
37
import org .springframework .util .PathMatcher ;
@@ -43,10 +48,18 @@ public class SimpleBrokerMessageHandler extends AbstractBrokerMessageHandler {
43
48
44
49
private static final byte [] EMPTY_PAYLOAD = new byte [0 ];
45
50
51
+ private final Map <String , SessionInfo > sessions = new ConcurrentHashMap <String , SessionInfo >();
52
+
46
53
private SubscriptionRegistry subscriptionRegistry ;
47
54
48
55
private PathMatcher pathMatcher ;
49
56
57
+ private TaskScheduler taskScheduler ;
58
+
59
+ private long [] heartbeatValue ;
60
+
61
+ private ScheduledFuture <?> heartbeatFuture ;
62
+
50
63
private MessageHeaderInitializer headerInitializer ;
51
64
52
65
@@ -100,6 +113,49 @@ public void setPathMatcher(PathMatcher pathMatcher) {
100
113
initPathMatcherToUse ();
101
114
}
102
115
116
+ /**
117
+ * Configure the {@link org.springframework.scheduling.TaskScheduler} to
118
+ * use for providing heartbeat support. Setting this property also sets the
119
+ * {@link #setHeartbeatValue heartbeatValue} to "10000, 10000".
120
+ * <p>By default this is not set.
121
+ * @since 4.2
122
+ */
123
+ public void setTaskScheduler (TaskScheduler taskScheduler ) {
124
+ Assert .notNull (taskScheduler );
125
+ this .taskScheduler = taskScheduler ;
126
+ if (this .heartbeatValue == null ) {
127
+ this .heartbeatValue = new long [] {10000 , 10000 };
128
+ }
129
+ }
130
+
131
+ /**
132
+ * Return the configured TaskScheduler.
133
+ */
134
+ public TaskScheduler getTaskScheduler () {
135
+ return this .taskScheduler ;
136
+ }
137
+
138
+ /**
139
+ * Configure the value for the heart-beat settings. The first number
140
+ * represents how often the server will write or send a heartbeat.
141
+ * The second is how often the client should write. 0 means no heartbeats.
142
+ * <p>By default this is set to "0, 0" unless the {@link #setTaskScheduler
143
+ * taskScheduler} in which case the default becomes "10000,10000"
144
+ * (in milliseconds).
145
+ * @since 4.2
146
+ */
147
+ public void setHeartbeatValue (long [] heartbeat ) {
148
+ Assert .notNull (heartbeat );
149
+ this .heartbeatValue = heartbeat ;
150
+ }
151
+
152
+ /**
153
+ * The configured value for the heart-beat settings.
154
+ */
155
+ public long [] getHeartbeatValue () {
156
+ return this .heartbeatValue ;
157
+ }
158
+
103
159
/**
104
160
* Configure a {@link MessageHeaderInitializer} to apply to the headers
105
161
* of all messages sent to the client outbound channel.
@@ -120,11 +176,37 @@ public MessageHeaderInitializer getHeaderInitializer() {
120
176
@ Override
121
177
public void startInternal () {
122
178
publishBrokerAvailableEvent ();
179
+ if (getTaskScheduler () != null ) {
180
+ long interval = initHeartbeatTaskDelay ();
181
+ if (interval > 0 ) {
182
+ this .heartbeatFuture = this .taskScheduler .scheduleWithFixedDelay (new HeartbeatTask (), interval );
183
+ }
184
+ }
185
+ else {
186
+ Assert .isTrue (getHeartbeatValue () == null ||
187
+ (getHeartbeatValue ()[0 ] == 0 && getHeartbeatValue ()[1 ] == 0 ),
188
+ "Heartbeat values configured but no TaskScheduler is provided." );
189
+ }
190
+ }
191
+
192
+ private long initHeartbeatTaskDelay () {
193
+ if (getHeartbeatValue () == null ) {
194
+ return 0 ;
195
+ }
196
+ else if (getHeartbeatValue ()[0 ] > 0 && getHeartbeatValue ()[1 ] > 0 ) {
197
+ return Math .min (getHeartbeatValue ()[0 ], getHeartbeatValue ()[1 ]);
198
+ }
199
+ else {
200
+ return (getHeartbeatValue ()[0 ] > 0 ? getHeartbeatValue ()[0 ] : getHeartbeatValue ()[1 ]);
201
+ }
123
202
}
124
203
125
204
@ Override
126
205
public void stopInternal () {
127
206
publishBrokerUnavailableEvent ();
207
+ if (this .heartbeatFuture != null ) {
208
+ this .heartbeatFuture .cancel (true );
209
+ }
128
210
}
129
211
130
212
@ Override
@@ -133,6 +215,9 @@ protected void handleMessageInternal(Message<?> message) {
133
215
SimpMessageType messageType = SimpMessageHeaderAccessor .getMessageType (headers );
134
216
String destination = SimpMessageHeaderAccessor .getDestination (headers );
135
217
String sessionId = SimpMessageHeaderAccessor .getSessionId (headers );
218
+ Principal user = SimpMessageHeaderAccessor .getUser (headers );
219
+
220
+ updateSessionReadTime (sessionId );
136
221
137
222
if (!checkDestinationPrefix (destination )) {
138
223
return ;
@@ -150,23 +235,21 @@ protected void handleMessageInternal(Message<?> message) {
150
235
}
151
236
else if (SimpMessageType .CONNECT .equals (messageType )) {
152
237
logMessage (message );
238
+ long [] clientHeartbeat = SimpMessageHeaderAccessor .getHeartbeat (headers );
239
+ long [] serverHeartbeat = getHeartbeatValue ();
240
+ this .sessions .put (sessionId , new SessionInfo (sessionId , user , clientHeartbeat , serverHeartbeat ));
153
241
SimpMessageHeaderAccessor connectAck = SimpMessageHeaderAccessor .create (SimpMessageType .CONNECT_ACK );
154
242
initHeaders (connectAck );
155
243
connectAck .setSessionId (sessionId );
156
244
connectAck .setUser (SimpMessageHeaderAccessor .getUser (headers ));
157
245
connectAck .setHeader (SimpMessageHeaderAccessor .CONNECT_MESSAGE_HEADER , message );
246
+ connectAck .setHeader (SimpMessageHeaderAccessor .HEART_BEAT_HEADER , serverHeartbeat );
158
247
Message <byte []> messageOut = MessageBuilder .createMessage (EMPTY_PAYLOAD , connectAck .getMessageHeaders ());
159
248
getClientOutboundChannel ().send (messageOut );
160
249
}
161
250
else if (SimpMessageType .DISCONNECT .equals (messageType )) {
162
251
logMessage (message );
163
- this .subscriptionRegistry .unregisterAllSubscriptions (sessionId );
164
- SimpMessageHeaderAccessor disconnectAck = SimpMessageHeaderAccessor .create (SimpMessageType .DISCONNECT_ACK );
165
- initHeaders (disconnectAck );
166
- disconnectAck .setSessionId (sessionId );
167
- disconnectAck .setUser (SimpMessageHeaderAccessor .getUser (headers ));
168
- Message <byte []> messageOut = MessageBuilder .createMessage (EMPTY_PAYLOAD , disconnectAck .getMessageHeaders ());
169
- getClientOutboundChannel ().send (messageOut );
252
+ handleDisconnect (sessionId , user );
170
253
}
171
254
else if (SimpMessageType .SUBSCRIBE .equals (messageType )) {
172
255
logMessage (message );
@@ -178,6 +261,15 @@ else if (SimpMessageType.UNSUBSCRIBE.equals(messageType)) {
178
261
}
179
262
}
180
263
264
+ private void updateSessionReadTime (String sessionId ) {
265
+ if (sessionId != null ) {
266
+ SessionInfo info = this .sessions .get (sessionId );
267
+ if (info != null ) {
268
+ info .setLastReadTime (System .currentTimeMillis ());
269
+ }
270
+ }
271
+ }
272
+
181
273
private void logMessage (Message <?> message ) {
182
274
if (logger .isDebugEnabled ()) {
183
275
SimpMessageHeaderAccessor accessor = MessageHeaderAccessor .getAccessor (message , SimpMessageHeaderAccessor .class );
@@ -192,11 +284,23 @@ private void initHeaders(SimpMessageHeaderAccessor accessor) {
192
284
}
193
285
}
194
286
287
+ private void handleDisconnect (String sessionId , Principal user ) {
288
+ this .sessions .remove (sessionId );
289
+ this .subscriptionRegistry .unregisterAllSubscriptions (sessionId );
290
+ SimpMessageHeaderAccessor accessor = SimpMessageHeaderAccessor .create (SimpMessageType .DISCONNECT_ACK );
291
+ accessor .setSessionId (sessionId );
292
+ accessor .setUser (user );
293
+ initHeaders (accessor );
294
+ Message <byte []> message = MessageBuilder .createMessage (EMPTY_PAYLOAD , accessor .getMessageHeaders ());
295
+ getClientOutboundChannel ().send (message );
296
+ }
297
+
195
298
protected void sendMessageToSubscribers (String destination , Message <?> message ) {
196
299
MultiValueMap <String ,String > subscriptions = this .subscriptionRegistry .findSubscriptions (message );
197
300
if (!subscriptions .isEmpty () && logger .isDebugEnabled ()) {
198
301
logger .debug ("Broadcasting to " + subscriptions .size () + " sessions." );
199
302
}
303
+ long now = System .currentTimeMillis ();
200
304
for (String sessionId : subscriptions .keySet ()) {
201
305
for (String subscriptionId : subscriptions .get (sessionId )) {
202
306
SimpMessageHeaderAccessor headerAccessor = SimpMessageHeaderAccessor .create (SimpMessageType .MESSAGE );
@@ -212,6 +316,12 @@ protected void sendMessageToSubscribers(String destination, Message<?> message)
212
316
catch (Throwable ex ) {
213
317
logger .error ("Failed to send " + message , ex );
214
318
}
319
+ finally {
320
+ SessionInfo info = this .sessions .get (sessionId );
321
+ if (info != null ) {
322
+ info .setLastWriteTime (now );
323
+ }
324
+ }
215
325
}
216
326
}
217
327
}
@@ -221,4 +331,93 @@ public String toString() {
221
331
return "SimpleBroker[" + this .subscriptionRegistry + "]" ;
222
332
}
223
333
334
+
335
+ private static class SessionInfo {
336
+
337
+ /* STOMP spec: receiver SHOULD take into account an error margin */
338
+ private static final long HEARTBEAT_MULTIPLIER = 3 ;
339
+
340
+
341
+ private final String sessiondId ;
342
+
343
+ private final Principal user ;
344
+
345
+ private final long readInterval ;
346
+
347
+ private final long writeInterval ;
348
+
349
+ private volatile long lastReadTime ;
350
+
351
+ private volatile long lastWriteTime ;
352
+
353
+
354
+ public SessionInfo (String sessiondId , Principal user , long [] clientHeartbeat , long [] serverHeartbeat ) {
355
+ this .sessiondId = sessiondId ;
356
+ this .user = user ;
357
+ if (clientHeartbeat != null && serverHeartbeat != null ) {
358
+ this .readInterval = (clientHeartbeat [0 ] > 0 && serverHeartbeat [1 ] > 0 ?
359
+ Math .max (clientHeartbeat [0 ], serverHeartbeat [1 ]) * HEARTBEAT_MULTIPLIER : 0 );
360
+ this .writeInterval = (clientHeartbeat [1 ] > 0 && serverHeartbeat [0 ] > 0 ?
361
+ Math .max (clientHeartbeat [1 ], serverHeartbeat [0 ]) : 0 );
362
+ }
363
+ else {
364
+ this .readInterval = 0 ;
365
+ this .writeInterval = 0 ;
366
+ }
367
+ this .lastReadTime = this .lastWriteTime = System .currentTimeMillis ();
368
+ }
369
+
370
+ public String getSessiondId () {
371
+ return this .sessiondId ;
372
+ }
373
+
374
+ public Principal getUser () {
375
+ return this .user ;
376
+ }
377
+
378
+ public long getReadInterval () {
379
+ return this .readInterval ;
380
+ }
381
+
382
+ public long getWriteInterval () {
383
+ return this .writeInterval ;
384
+ }
385
+
386
+ public long getLastReadTime () {
387
+ return this .lastReadTime ;
388
+ }
389
+
390
+ public void setLastReadTime (long lastReadTime ) {
391
+ this .lastReadTime = lastReadTime ;
392
+ }
393
+
394
+ public long getLastWriteTime () {
395
+ return this .lastWriteTime ;
396
+ }
397
+
398
+ public void setLastWriteTime (long lastWriteTime ) {
399
+ this .lastWriteTime = lastWriteTime ;
400
+ }
401
+ }
402
+
403
+ private class HeartbeatTask implements Runnable {
404
+
405
+ @ Override
406
+ public void run () {
407
+ long now = System .currentTimeMillis ();
408
+ for (SessionInfo info : sessions .values ()) {
409
+ if (info .getReadInterval () > 0 && (now - info .getLastReadTime ()) > info .getReadInterval ()) {
410
+ handleDisconnect (info .getSessiondId (), info .getUser ());
411
+ }
412
+ if (info .getWriteInterval () > 0 && (now - info .getLastWriteTime ()) > info .getWriteInterval ()) {
413
+ SimpMessageHeaderAccessor accessor = SimpMessageHeaderAccessor .create (SimpMessageType .HEARTBEAT );
414
+ accessor .setSessionId (info .getSessiondId ());
415
+ accessor .setUser (info .getUser ());
416
+ initHeaders (accessor );
417
+ MessageHeaders headers = accessor .getMessageHeaders ();
418
+ getClientOutboundChannel ().send (MessageBuilder .createMessage (EMPTY_PAYLOAD , headers ));
419
+ }
420
+ }
421
+ }
422
+ }
224
423
}
0 commit comments