Skip to content

Commit 1aef041

Browse files
artembilangaryrussell
authored andcommitted
GH-3993: Fix async race condition in TcpOutGateway (#3995)
* GH-3993: Fix async race condition in TcpOutGateway Fixes #3993 When `TcpOutboundGateway` is in an `async` mode and `CCF` is configured not for `singleUse` an `Semaphore` around an obtained `TcpConnection` is involved. If we fail on `TcpConnection.send()`, resources are not clean up, including the mentioned `Semaphore`: in async mode this happens only when we receive a reply. * Catch an exception on the `TcpConnection.send()` and perform `cleanUp()` in async mode. * Add `cleanUp()` into a scheduled task from the `TcpOutboundGateway.AsyncReply` when no reply arrives in time. * Optimize `TcpOutboundGateway.AsyncReply` behavior to cancel no-reply scheduled task when reply arrives into a `CompletableFuture` **Cherry-pick to `5.5.x`** * * Call `cleanUp()` from no response scheduled task only if `future.completeExceptionally()` is `true`
1 parent 0cce72a commit 1aef041

File tree

2 files changed

+85
-16
lines changed

2 files changed

+85
-16
lines changed

spring-integration-ip/src/main/java/org/springframework/integration/ip/tcp/TcpOutboundGateway.java

Lines changed: 41 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2001-2021 the original author or authors.
2+
* Copyright 2001-2023 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -17,10 +17,12 @@
1717
package org.springframework.integration.ip.tcp;
1818

1919
import java.io.IOException;
20-
import java.util.Date;
20+
import java.time.Instant;
2121
import java.util.Map;
22+
import java.util.concurrent.CompletableFuture;
2223
import java.util.concurrent.ConcurrentHashMap;
2324
import java.util.concurrent.CountDownLatch;
25+
import java.util.concurrent.ScheduledFuture;
2426
import java.util.concurrent.Semaphore;
2527
import java.util.concurrent.TimeUnit;
2628

@@ -59,7 +61,6 @@
5961
* <p>
6062
* {@link org.springframework.context.Lifecycle} methods delegate to the underlying {@link AbstractConnectionFactory}.
6163
*
62-
*
6364
* @author Gary Russell
6465
* @author Artem Bilan
6566
*
@@ -223,7 +224,17 @@ protected Object handleRequestMessage(Message<?> requestMessage) {
223224
this.pendingReplies.put(connectionId, reply);
224225
String connectionIdToLog = connectionId;
225226
logger.debug(() -> "Added pending reply " + connectionIdToLog);
226-
connection.send(requestMessage);
227+
try {
228+
connection.send(requestMessage);
229+
}
230+
catch (Exception ex) {
231+
// If it cannot send, then no reply for this connection.
232+
// Therefor release resources for subsequent requests.
233+
if (async) {
234+
cleanUp(haveSemaphore, connection, connectionId);
235+
}
236+
throw ex;
237+
}
227238
if (this.closeStreamAfterSend) {
228239
connection.shutdownOutput();
229240
}
@@ -326,7 +337,7 @@ public boolean onMessage(Message<?> message) {
326337
if (reply == null) {
327338
if (message instanceof ErrorMessage) {
328339
/*
329-
* Socket errors are sent here so they can be conveyed to any waiting thread.
340+
* Socket errors are sent here, so they can be conveyed to any waiting thread.
330341
* If there's not one, simply ignore.
331342
*/
332343
return false;
@@ -427,7 +438,11 @@ private final class AsyncReply {
427438

428439
private final boolean haveSemaphore;
429440

430-
private final SettableListenableFuture<Message<?>> future = new SettableListenableFuture<>();
441+
private final ScheduledFuture<?> noResponseFuture;
442+
443+
private final CompletableFuture<Message<?>> future =
444+
new CompletableFuture<Message<?>>()
445+
.thenApply(this::cancelNoResponseFutureIfAny);
431446

432447
private volatile Message<?> reply;
433448

@@ -440,12 +455,27 @@ private final class AsyncReply {
440455
this.connection = connection;
441456
this.haveSemaphore = haveSemaphore;
442457
if (async && remoteTimeout > 0) {
443-
getTaskScheduler().schedule(() -> {
444-
TcpOutboundGateway.this.pendingReplies.remove(connection.getConnectionId());
445-
this.future.setException(
446-
new MessageTimeoutException(requestMessage, "Timed out waiting for response"));
447-
}, new Date(System.currentTimeMillis() + remoteTimeout));
458+
this.noResponseFuture =
459+
getTaskScheduler()
460+
.schedule(() -> {
461+
if (this.future.completeExceptionally(
462+
new MessageTimeoutException(requestMessage,
463+
"Timed out waiting for response"))) {
464+
465+
cleanUp(this.haveSemaphore, this.connection, this.connection.getConnectionId());
466+
}
467+
}, Instant.now().plusMillis(remoteTimeout));
468+
}
469+
else {
470+
this.noResponseFuture = null;
471+
}
472+
}
473+
474+
private Message<?> cancelNoResponseFutureIfAny(Message<?> message) {
475+
if (this.noResponseFuture != null) {
476+
this.noResponseFuture.cancel(true);
448477
}
478+
return message;
449479
}
450480

451481
TcpConnection getConnection() {

spring-integration-ip/src/test/java/org/springframework/integration/ip/tcp/TcpOutboundGatewayTests.java

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
import org.springframework.integration.ip.tcp.connection.AbstractClientConnectionFactory;
6262
import org.springframework.integration.ip.tcp.connection.CachingClientConnectionFactory;
6363
import org.springframework.integration.ip.tcp.connection.FailoverClientConnectionFactory;
64+
import org.springframework.integration.ip.tcp.connection.TcpConnection;
6465
import org.springframework.integration.ip.tcp.connection.TcpConnectionSupport;
6566
import org.springframework.integration.ip.tcp.connection.TcpNetClientConnectionFactory;
6667
import org.springframework.integration.ip.tcp.connection.TcpNioClientConnectionFactory;
@@ -80,9 +81,14 @@
8081
import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler;
8182

8283
import static org.assertj.core.api.Assertions.assertThat;
84+
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
8385
import static org.assertj.core.api.Assertions.catchThrowable;
8486
import static org.assertj.core.api.Assertions.fail;
8587
import static org.awaitility.Awaitility.await;
88+
import static org.mockito.ArgumentMatchers.any;
89+
import static org.mockito.BDDMockito.given;
90+
import static org.mockito.BDDMockito.willReturn;
91+
import static org.mockito.BDDMockito.willThrow;
8692
import static org.mockito.Mockito.doThrow;
8793
import static org.mockito.Mockito.mock;
8894
import static org.mockito.Mockito.verify;
@@ -397,7 +403,7 @@ private void testGoodNetGWTimeoutGuts(AbstractClientConnectionFactory ccf,
397403

398404
Expression remoteTimeoutExpression = Mockito.mock(Expression.class);
399405

400-
when(remoteTimeoutExpression.getValue(Mockito.any(EvaluationContext.class), Mockito.any(Message.class),
406+
when(remoteTimeoutExpression.getValue(any(EvaluationContext.class), any(Message.class),
401407
Mockito.eq(Long.class))).thenReturn(50L, 60000L);
402408

403409
gateway.setRemoteTimeoutExpression(remoteTimeoutExpression);
@@ -489,7 +495,7 @@ void testCachingFailover() throws Exception {
489495
TcpConnectionSupport mockConn1 = makeMockConnection();
490496
when(factory1.getConnection()).thenReturn(mockConn1);
491497
doThrow(new UncheckedIOException(new IOException("fail")))
492-
.when(mockConn1).send(Mockito.any(Message.class));
498+
.when(mockConn1).send(any(Message.class));
493499

494500
AbstractClientConnectionFactory factory2 = new TcpNetClientConnectionFactory("localhost",
495501
serverSocket.get().getLocalPort());
@@ -522,7 +528,7 @@ void testCachingFailover() throws Exception {
522528
assertThat(reply.getPayload()).isEqualTo("bar");
523529
done.set(true);
524530
gateway.stop();
525-
verify(mockConn1).send(Mockito.any(Message.class));
531+
verify(mockConn1).send(any(Message.class));
526532
factory2.stop();
527533
serverSocket.get().close();
528534
}
@@ -572,7 +578,7 @@ void testFailoverCached() throws Exception {
572578
when(factory1.getConnection()).thenReturn(mockConn1);
573579
when(factory1.isSingleUse()).thenReturn(true);
574580
doThrow(new UncheckedIOException(new IOException("fail")))
575-
.when(mockConn1).send(Mockito.any(Message.class));
581+
.when(mockConn1).send(any(Message.class));
576582
CachingClientConnectionFactory cachingFactory1 = new CachingClientConnectionFactory(factory1, 1);
577583

578584
AbstractClientConnectionFactory factory2 = new TcpNetClientConnectionFactory("localhost",
@@ -607,7 +613,7 @@ void testFailoverCached() throws Exception {
607613
assertThat(reply.getPayload()).isEqualTo("bar");
608614
done.set(true);
609615
gateway.stop();
610-
verify(mockConn1).send(Mockito.any(Message.class));
616+
verify(mockConn1).send(any(Message.class));
611617
factory2.stop();
612618
serverSocket.get().close();
613619
}
@@ -1080,4 +1086,37 @@ void testAsyncTimeout() throws Exception {
10801086
}
10811087
}
10821088

1089+
@Test
1090+
void semaphoreIsReleasedOnAsyncSendFailure() throws InterruptedException {
1091+
AbstractClientConnectionFactory ccf = mock(AbstractClientConnectionFactory.class);
1092+
1093+
TcpConnection connection = mock(TcpConnectionSupport.class);
1094+
1095+
given(connection.getConnectionId()).willReturn("testId");
1096+
willThrow(new RuntimeException("intentional"))
1097+
.given(connection)
1098+
.send(any(Message.class));
1099+
1100+
willReturn(connection)
1101+
.given(ccf)
1102+
.getConnection();
1103+
1104+
TcpOutboundGateway gateway = new TcpOutboundGateway();
1105+
gateway.setConnectionFactory(ccf);
1106+
gateway.setAsync(true);
1107+
gateway.setBeanFactory(mock(BeanFactory.class));
1108+
gateway.setRemoteTimeout(-1);
1109+
gateway.afterPropertiesSet();
1110+
1111+
assertThatExceptionOfType(MessageHandlingException.class)
1112+
.isThrownBy(() -> gateway.handleMessage(new GenericMessage<>("Test1")))
1113+
.withCauseExactlyInstanceOf(RuntimeException.class)
1114+
.withStackTraceContaining("intentional");
1115+
1116+
assertThatExceptionOfType(MessageHandlingException.class)
1117+
.isThrownBy(() -> gateway.handleMessage(new GenericMessage<>("Test2")))
1118+
.withCauseExactlyInstanceOf(RuntimeException.class)
1119+
.withStackTraceContaining("intentional");
1120+
}
1121+
10831122
}

0 commit comments

Comments
 (0)