Skip to content

Commit ebba374

Browse files
authored
GH-3993: Fix async race condition in TcpOutGateway (#3995)
* GH-3993: Fix async race condition in TcpOutGateway Fixes #3993 When `TcpOutboundGateway` is in an `async` mode and `CCF` is configured not for `singleUse` an `Semaphore` around an obtained `TcpConnection` is involved. If we fail on `TcpConnection.send()`, resources are not clean up, including the mentioned `Semaphore`: in async mode this happens only when we receive a reply. * Catch an exception on the `TcpConnection.send()` and perform `cleanUp()` in async mode. * Add `cleanUp()` into a scheduled task from the `TcpOutboundGateway.AsyncReply` when no reply arrives in time. * Optimize `TcpOutboundGateway.AsyncReply` behavior to cancel no-reply scheduled task when reply arrives into a `CompletableFuture` **Cherry-pick to `5.5.x`** * * Call `cleanUp()` from no response scheduled task only if `future.completeExceptionally()` is `true`
1 parent f8c0779 commit ebba374

File tree

2 files changed

+83
-16
lines changed

2 files changed

+83
-16
lines changed

spring-integration-ip/src/main/java/org/springframework/integration/ip/tcp/TcpOutboundGateway.java

Lines changed: 39 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2001-2022 the original author or authors.
2+
* Copyright 2001-2023 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -22,6 +22,7 @@
2222
import java.util.concurrent.CompletableFuture;
2323
import java.util.concurrent.ConcurrentHashMap;
2424
import java.util.concurrent.CountDownLatch;
25+
import java.util.concurrent.ScheduledFuture;
2526
import java.util.concurrent.Semaphore;
2627
import java.util.concurrent.TimeUnit;
2728

@@ -59,7 +60,6 @@
5960
* <p>
6061
* {@link org.springframework.context.Lifecycle} methods delegate to the underlying {@link AbstractConnectionFactory}.
6162
*
62-
*
6363
* @author Gary Russell
6464
* @author Artem Bilan
6565
*
@@ -223,7 +223,17 @@ protected Object handleRequestMessage(Message<?> requestMessage) {
223223
this.pendingReplies.put(connectionId, reply);
224224
String connectionIdToLog = connectionId;
225225
logger.debug(() -> "Added pending reply " + connectionIdToLog);
226-
connection.send(requestMessage);
226+
try {
227+
connection.send(requestMessage);
228+
}
229+
catch (Exception ex) {
230+
// If it cannot send, then no reply for this connection.
231+
// Therefor release resources for subsequent requests.
232+
if (async) {
233+
cleanUp(haveSemaphore, connection, connectionId);
234+
}
235+
throw ex;
236+
}
227237
if (this.closeStreamAfterSend) {
228238
connection.shutdownOutput();
229239
}
@@ -326,7 +336,7 @@ public boolean onMessage(Message<?> message) {
326336
if (reply == null) {
327337
if (message instanceof ErrorMessage) {
328338
/*
329-
* Socket errors are sent here so they can be conveyed to any waiting thread.
339+
* Socket errors are sent here, so they can be conveyed to any waiting thread.
330340
* If there's not one, simply ignore.
331341
*/
332342
return false;
@@ -427,7 +437,11 @@ private final class AsyncReply {
427437

428438
private final boolean haveSemaphore;
429439

430-
private final CompletableFuture<Message<?>> future = new CompletableFuture<>();
440+
private final ScheduledFuture<?> noResponseFuture;
441+
442+
private final CompletableFuture<Message<?>> future =
443+
new CompletableFuture<Message<?>>()
444+
.thenApply(this::cancelNoResponseFutureIfAny);
431445

432446
private volatile Message<?> reply;
433447

@@ -440,13 +454,27 @@ private final class AsyncReply {
440454
this.connection = connection;
441455
this.haveSemaphore = haveSemaphore;
442456
if (async && remoteTimeout > 0) {
443-
getTaskScheduler()
444-
.schedule(() -> {
445-
TcpOutboundGateway.this.pendingReplies.remove(connection.getConnectionId());
446-
this.future.completeExceptionally(
447-
new MessageTimeoutException(requestMessage, "Timed out waiting for response"));
448-
}, Instant.now().plusMillis(remoteTimeout));
457+
this.noResponseFuture =
458+
getTaskScheduler()
459+
.schedule(() -> {
460+
if (this.future.completeExceptionally(
461+
new MessageTimeoutException(requestMessage,
462+
"Timed out waiting for response"))) {
463+
464+
cleanUp(this.haveSemaphore, this.connection, this.connection.getConnectionId());
465+
}
466+
}, Instant.now().plusMillis(remoteTimeout));
467+
}
468+
else {
469+
this.noResponseFuture = null;
470+
}
471+
}
472+
473+
private Message<?> cancelNoResponseFutureIfAny(Message<?> message) {
474+
if (this.noResponseFuture != null) {
475+
this.noResponseFuture.cancel(true);
449476
}
477+
return message;
450478
}
451479

452480
TcpConnection getConnection() {

spring-integration-ip/src/test/java/org/springframework/integration/ip/tcp/TcpOutboundGatewayTests.java

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
import org.springframework.integration.ip.tcp.connection.AbstractClientConnectionFactory;
6262
import org.springframework.integration.ip.tcp.connection.CachingClientConnectionFactory;
6363
import org.springframework.integration.ip.tcp.connection.FailoverClientConnectionFactory;
64+
import org.springframework.integration.ip.tcp.connection.TcpConnection;
6465
import org.springframework.integration.ip.tcp.connection.TcpConnectionSupport;
6566
import org.springframework.integration.ip.tcp.connection.TcpNetClientConnectionFactory;
6667
import org.springframework.integration.ip.tcp.connection.TcpNioClientConnectionFactory;
@@ -80,9 +81,14 @@
8081
import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler;
8182

8283
import static org.assertj.core.api.Assertions.assertThat;
84+
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
8385
import static org.assertj.core.api.Assertions.catchThrowable;
8486
import static org.assertj.core.api.Assertions.fail;
8587
import static org.awaitility.Awaitility.await;
88+
import static org.mockito.ArgumentMatchers.any;
89+
import static org.mockito.BDDMockito.given;
90+
import static org.mockito.BDDMockito.willReturn;
91+
import static org.mockito.BDDMockito.willThrow;
8692
import static org.mockito.Mockito.doThrow;
8793
import static org.mockito.Mockito.mock;
8894
import static org.mockito.Mockito.verify;
@@ -397,7 +403,7 @@ private void testGoodNetGWTimeoutGuts(AbstractClientConnectionFactory ccf,
397403

398404
Expression remoteTimeoutExpression = Mockito.mock(Expression.class);
399405

400-
when(remoteTimeoutExpression.getValue(Mockito.any(EvaluationContext.class), Mockito.any(Message.class),
406+
when(remoteTimeoutExpression.getValue(any(EvaluationContext.class), any(Message.class),
401407
Mockito.eq(Long.class))).thenReturn(50L, 60000L);
402408

403409
gateway.setRemoteTimeoutExpression(remoteTimeoutExpression);
@@ -488,7 +494,7 @@ void testCachingFailover() throws Exception {
488494
TcpConnectionSupport mockConn1 = makeMockConnection();
489495
when(factory1.getConnection()).thenReturn(mockConn1);
490496
doThrow(new UncheckedIOException(new IOException("fail")))
491-
.when(mockConn1).send(Mockito.any(Message.class));
497+
.when(mockConn1).send(any(Message.class));
492498

493499
AbstractClientConnectionFactory factory2 = new TcpNetClientConnectionFactory("localhost",
494500
serverSocket.get().getLocalPort());
@@ -521,7 +527,7 @@ void testCachingFailover() throws Exception {
521527
assertThat(reply.getPayload()).isEqualTo("bar");
522528
done.set(true);
523529
gateway.stop();
524-
verify(mockConn1).send(Mockito.any(Message.class));
530+
verify(mockConn1).send(any(Message.class));
525531
factory2.stop();
526532
serverSocket.get().close();
527533
}
@@ -571,7 +577,7 @@ void testFailoverCached() throws Exception {
571577
when(factory1.getConnection()).thenReturn(mockConn1);
572578
when(factory1.isSingleUse()).thenReturn(true);
573579
doThrow(new UncheckedIOException(new IOException("fail")))
574-
.when(mockConn1).send(Mockito.any(Message.class));
580+
.when(mockConn1).send(any(Message.class));
575581
CachingClientConnectionFactory cachingFactory1 = new CachingClientConnectionFactory(factory1, 1);
576582

577583
AbstractClientConnectionFactory factory2 = new TcpNetClientConnectionFactory("localhost",
@@ -606,7 +612,7 @@ void testFailoverCached() throws Exception {
606612
assertThat(reply.getPayload()).isEqualTo("bar");
607613
done.set(true);
608614
gateway.stop();
609-
verify(mockConn1).send(Mockito.any(Message.class));
615+
verify(mockConn1).send(any(Message.class));
610616
factory2.stop();
611617
serverSocket.get().close();
612618
}
@@ -1081,4 +1087,37 @@ void testAsyncTimeout() throws Exception {
10811087
}
10821088
}
10831089

1090+
@Test
1091+
void semaphoreIsReleasedOnAsyncSendFailure() throws InterruptedException {
1092+
AbstractClientConnectionFactory ccf = mock(AbstractClientConnectionFactory.class);
1093+
1094+
TcpConnection connection = mock(TcpConnectionSupport.class);
1095+
1096+
given(connection.getConnectionId()).willReturn("testId");
1097+
willThrow(new RuntimeException("intentional"))
1098+
.given(connection)
1099+
.send(any(Message.class));
1100+
1101+
willReturn(connection)
1102+
.given(ccf)
1103+
.getConnection();
1104+
1105+
TcpOutboundGateway gateway = new TcpOutboundGateway();
1106+
gateway.setConnectionFactory(ccf);
1107+
gateway.setAsync(true);
1108+
gateway.setBeanFactory(mock(BeanFactory.class));
1109+
gateway.setRemoteTimeout(-1);
1110+
gateway.afterPropertiesSet();
1111+
1112+
assertThatExceptionOfType(MessageHandlingException.class)
1113+
.isThrownBy(() -> gateway.handleMessage(new GenericMessage<>("Test1")))
1114+
.withCauseExactlyInstanceOf(RuntimeException.class)
1115+
.withStackTraceContaining("intentional");
1116+
1117+
assertThatExceptionOfType(MessageHandlingException.class)
1118+
.isThrownBy(() -> gateway.handleMessage(new GenericMessage<>("Test2")))
1119+
.withCauseExactlyInstanceOf(RuntimeException.class)
1120+
.withStackTraceContaining("intentional");
1121+
}
1122+
10841123
}

0 commit comments

Comments
 (0)