Skip to content

GH-3993: Fix async race condition in TcpOutGateway #3995

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2001-2022 the original author or authors.
* Copyright 2001-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -22,6 +22,7 @@
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;

Expand Down Expand Up @@ -59,7 +60,6 @@
* <p>
* {@link org.springframework.context.Lifecycle} methods delegate to the underlying {@link AbstractConnectionFactory}.
*
*
* @author Gary Russell
* @author Artem Bilan
*
Expand Down Expand Up @@ -223,7 +223,17 @@ protected Object handleRequestMessage(Message<?> requestMessage) {
this.pendingReplies.put(connectionId, reply);
String connectionIdToLog = connectionId;
logger.debug(() -> "Added pending reply " + connectionIdToLog);
connection.send(requestMessage);
try {
connection.send(requestMessage);
}
catch (Exception ex) {
// If it cannot send, then no reply for this connection.
// Therefor release resources for subsequent requests.
if (async) {
cleanUp(haveSemaphore, connection, connectionId);
}
throw ex;
}
if (this.closeStreamAfterSend) {
connection.shutdownOutput();
}
Expand Down Expand Up @@ -326,7 +336,7 @@ public boolean onMessage(Message<?> message) {
if (reply == null) {
if (message instanceof ErrorMessage) {
/*
* Socket errors are sent here so they can be conveyed to any waiting thread.
* Socket errors are sent here, so they can be conveyed to any waiting thread.
* If there's not one, simply ignore.
*/
return false;
Expand Down Expand Up @@ -427,7 +437,11 @@ private final class AsyncReply {

private final boolean haveSemaphore;

private final CompletableFuture<Message<?>> future = new CompletableFuture<>();
private final ScheduledFuture<?> noResponseFuture;

private final CompletableFuture<Message<?>> future =
new CompletableFuture<Message<?>>()
.thenApply(this::cancelNoResponseFutureIfAny);

private volatile Message<?> reply;

Expand All @@ -440,13 +454,27 @@ private final class AsyncReply {
this.connection = connection;
this.haveSemaphore = haveSemaphore;
if (async && remoteTimeout > 0) {
getTaskScheduler()
.schedule(() -> {
TcpOutboundGateway.this.pendingReplies.remove(connection.getConnectionId());
this.future.completeExceptionally(
new MessageTimeoutException(requestMessage, "Timed out waiting for response"));
}, Instant.now().plusMillis(remoteTimeout));
this.noResponseFuture =
getTaskScheduler()
.schedule(() -> {
if (this.future.completeExceptionally(
new MessageTimeoutException(requestMessage,
"Timed out waiting for response"))) {

cleanUp(this.haveSemaphore, this.connection, this.connection.getConnectionId());
}
}, Instant.now().plusMillis(remoteTimeout));
}
else {
this.noResponseFuture = null;
}
}

private Message<?> cancelNoResponseFutureIfAny(Message<?> message) {
if (this.noResponseFuture != null) {
this.noResponseFuture.cancel(true);
}
return message;
}

TcpConnection getConnection() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
import org.springframework.integration.ip.tcp.connection.AbstractClientConnectionFactory;
import org.springframework.integration.ip.tcp.connection.CachingClientConnectionFactory;
import org.springframework.integration.ip.tcp.connection.FailoverClientConnectionFactory;
import org.springframework.integration.ip.tcp.connection.TcpConnection;
import org.springframework.integration.ip.tcp.connection.TcpConnectionSupport;
import org.springframework.integration.ip.tcp.connection.TcpNetClientConnectionFactory;
import org.springframework.integration.ip.tcp.connection.TcpNioClientConnectionFactory;
Expand All @@ -80,9 +81,14 @@
import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.assertj.core.api.Assertions.catchThrowable;
import static org.assertj.core.api.Assertions.fail;
import static org.awaitility.Awaitility.await;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.BDDMockito.given;
import static org.mockito.BDDMockito.willReturn;
import static org.mockito.BDDMockito.willThrow;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
Expand Down Expand Up @@ -397,7 +403,7 @@ private void testGoodNetGWTimeoutGuts(AbstractClientConnectionFactory ccf,

Expression remoteTimeoutExpression = Mockito.mock(Expression.class);

when(remoteTimeoutExpression.getValue(Mockito.any(EvaluationContext.class), Mockito.any(Message.class),
when(remoteTimeoutExpression.getValue(any(EvaluationContext.class), any(Message.class),
Mockito.eq(Long.class))).thenReturn(50L, 60000L);

gateway.setRemoteTimeoutExpression(remoteTimeoutExpression);
Expand Down Expand Up @@ -488,7 +494,7 @@ void testCachingFailover() throws Exception {
TcpConnectionSupport mockConn1 = makeMockConnection();
when(factory1.getConnection()).thenReturn(mockConn1);
doThrow(new UncheckedIOException(new IOException("fail")))
.when(mockConn1).send(Mockito.any(Message.class));
.when(mockConn1).send(any(Message.class));

AbstractClientConnectionFactory factory2 = new TcpNetClientConnectionFactory("localhost",
serverSocket.get().getLocalPort());
Expand Down Expand Up @@ -521,7 +527,7 @@ void testCachingFailover() throws Exception {
assertThat(reply.getPayload()).isEqualTo("bar");
done.set(true);
gateway.stop();
verify(mockConn1).send(Mockito.any(Message.class));
verify(mockConn1).send(any(Message.class));
factory2.stop();
serverSocket.get().close();
}
Expand Down Expand Up @@ -571,7 +577,7 @@ void testFailoverCached() throws Exception {
when(factory1.getConnection()).thenReturn(mockConn1);
when(factory1.isSingleUse()).thenReturn(true);
doThrow(new UncheckedIOException(new IOException("fail")))
.when(mockConn1).send(Mockito.any(Message.class));
.when(mockConn1).send(any(Message.class));
CachingClientConnectionFactory cachingFactory1 = new CachingClientConnectionFactory(factory1, 1);

AbstractClientConnectionFactory factory2 = new TcpNetClientConnectionFactory("localhost",
Expand Down Expand Up @@ -606,7 +612,7 @@ void testFailoverCached() throws Exception {
assertThat(reply.getPayload()).isEqualTo("bar");
done.set(true);
gateway.stop();
verify(mockConn1).send(Mockito.any(Message.class));
verify(mockConn1).send(any(Message.class));
factory2.stop();
serverSocket.get().close();
}
Expand Down Expand Up @@ -1081,4 +1087,37 @@ void testAsyncTimeout() throws Exception {
}
}

@Test
void semaphoreIsReleasedOnAsyncSendFailure() throws InterruptedException {
AbstractClientConnectionFactory ccf = mock(AbstractClientConnectionFactory.class);

TcpConnection connection = mock(TcpConnectionSupport.class);

given(connection.getConnectionId()).willReturn("testId");
willThrow(new RuntimeException("intentional"))
.given(connection)
.send(any(Message.class));

willReturn(connection)
.given(ccf)
.getConnection();

TcpOutboundGateway gateway = new TcpOutboundGateway();
gateway.setConnectionFactory(ccf);
gateway.setAsync(true);
gateway.setBeanFactory(mock(BeanFactory.class));
gateway.setRemoteTimeout(-1);
gateway.afterPropertiesSet();

assertThatExceptionOfType(MessageHandlingException.class)
.isThrownBy(() -> gateway.handleMessage(new GenericMessage<>("Test1")))
.withCauseExactlyInstanceOf(RuntimeException.class)
.withStackTraceContaining("intentional");

assertThatExceptionOfType(MessageHandlingException.class)
.isThrownBy(() -> gateway.handleMessage(new GenericMessage<>("Test2")))
.withCauseExactlyInstanceOf(RuntimeException.class)
.withStackTraceContaining("intentional");
}

}