diff --git a/spring-integration-ip/src/main/java/org/springframework/integration/ip/tcp/TcpOutboundGateway.java b/spring-integration-ip/src/main/java/org/springframework/integration/ip/tcp/TcpOutboundGateway.java
index 13c14c14497..8f7c9f36dc1 100644
--- a/spring-integration-ip/src/main/java/org/springframework/integration/ip/tcp/TcpOutboundGateway.java
+++ b/spring-integration-ip/src/main/java/org/springframework/integration/ip/tcp/TcpOutboundGateway.java
@@ -1,5 +1,5 @@
/*
- * Copyright 2001-2022 the original author or authors.
+ * Copyright 2001-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -22,6 +22,7 @@
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;
@@ -59,7 +60,6 @@
*
* {@link org.springframework.context.Lifecycle} methods delegate to the underlying {@link AbstractConnectionFactory}.
*
- *
* @author Gary Russell
* @author Artem Bilan
*
@@ -223,7 +223,17 @@ protected Object handleRequestMessage(Message> requestMessage) {
this.pendingReplies.put(connectionId, reply);
String connectionIdToLog = connectionId;
logger.debug(() -> "Added pending reply " + connectionIdToLog);
- connection.send(requestMessage);
+ try {
+ connection.send(requestMessage);
+ }
+ catch (Exception ex) {
+ // If it cannot send, then no reply for this connection.
+ // Therefor release resources for subsequent requests.
+ if (async) {
+ cleanUp(haveSemaphore, connection, connectionId);
+ }
+ throw ex;
+ }
if (this.closeStreamAfterSend) {
connection.shutdownOutput();
}
@@ -326,7 +336,7 @@ public boolean onMessage(Message> message) {
if (reply == null) {
if (message instanceof ErrorMessage) {
/*
- * Socket errors are sent here so they can be conveyed to any waiting thread.
+ * Socket errors are sent here, so they can be conveyed to any waiting thread.
* If there's not one, simply ignore.
*/
return false;
@@ -427,7 +437,11 @@ private final class AsyncReply {
private final boolean haveSemaphore;
- private final CompletableFuture> future = new CompletableFuture<>();
+ private final ScheduledFuture> noResponseFuture;
+
+ private final CompletableFuture> future =
+ new CompletableFuture>()
+ .thenApply(this::cancelNoResponseFutureIfAny);
private volatile Message> reply;
@@ -440,13 +454,27 @@ private final class AsyncReply {
this.connection = connection;
this.haveSemaphore = haveSemaphore;
if (async && remoteTimeout > 0) {
- getTaskScheduler()
- .schedule(() -> {
- TcpOutboundGateway.this.pendingReplies.remove(connection.getConnectionId());
- this.future.completeExceptionally(
- new MessageTimeoutException(requestMessage, "Timed out waiting for response"));
- }, Instant.now().plusMillis(remoteTimeout));
+ this.noResponseFuture =
+ getTaskScheduler()
+ .schedule(() -> {
+ if (this.future.completeExceptionally(
+ new MessageTimeoutException(requestMessage,
+ "Timed out waiting for response"))) {
+
+ cleanUp(this.haveSemaphore, this.connection, this.connection.getConnectionId());
+ }
+ }, Instant.now().plusMillis(remoteTimeout));
+ }
+ else {
+ this.noResponseFuture = null;
+ }
+ }
+
+ private Message> cancelNoResponseFutureIfAny(Message> message) {
+ if (this.noResponseFuture != null) {
+ this.noResponseFuture.cancel(true);
}
+ return message;
}
TcpConnection getConnection() {
diff --git a/spring-integration-ip/src/test/java/org/springframework/integration/ip/tcp/TcpOutboundGatewayTests.java b/spring-integration-ip/src/test/java/org/springframework/integration/ip/tcp/TcpOutboundGatewayTests.java
index 6eacc9d0d00..2eb7531287f 100644
--- a/spring-integration-ip/src/test/java/org/springframework/integration/ip/tcp/TcpOutboundGatewayTests.java
+++ b/spring-integration-ip/src/test/java/org/springframework/integration/ip/tcp/TcpOutboundGatewayTests.java
@@ -61,6 +61,7 @@
import org.springframework.integration.ip.tcp.connection.AbstractClientConnectionFactory;
import org.springframework.integration.ip.tcp.connection.CachingClientConnectionFactory;
import org.springframework.integration.ip.tcp.connection.FailoverClientConnectionFactory;
+import org.springframework.integration.ip.tcp.connection.TcpConnection;
import org.springframework.integration.ip.tcp.connection.TcpConnectionSupport;
import org.springframework.integration.ip.tcp.connection.TcpNetClientConnectionFactory;
import org.springframework.integration.ip.tcp.connection.TcpNioClientConnectionFactory;
@@ -80,9 +81,14 @@
import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler;
import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.assertj.core.api.Assertions.catchThrowable;
import static org.assertj.core.api.Assertions.fail;
import static org.awaitility.Awaitility.await;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.BDDMockito.given;
+import static org.mockito.BDDMockito.willReturn;
+import static org.mockito.BDDMockito.willThrow;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
@@ -397,7 +403,7 @@ private void testGoodNetGWTimeoutGuts(AbstractClientConnectionFactory ccf,
Expression remoteTimeoutExpression = Mockito.mock(Expression.class);
- when(remoteTimeoutExpression.getValue(Mockito.any(EvaluationContext.class), Mockito.any(Message.class),
+ when(remoteTimeoutExpression.getValue(any(EvaluationContext.class), any(Message.class),
Mockito.eq(Long.class))).thenReturn(50L, 60000L);
gateway.setRemoteTimeoutExpression(remoteTimeoutExpression);
@@ -488,7 +494,7 @@ void testCachingFailover() throws Exception {
TcpConnectionSupport mockConn1 = makeMockConnection();
when(factory1.getConnection()).thenReturn(mockConn1);
doThrow(new UncheckedIOException(new IOException("fail")))
- .when(mockConn1).send(Mockito.any(Message.class));
+ .when(mockConn1).send(any(Message.class));
AbstractClientConnectionFactory factory2 = new TcpNetClientConnectionFactory("localhost",
serverSocket.get().getLocalPort());
@@ -521,7 +527,7 @@ void testCachingFailover() throws Exception {
assertThat(reply.getPayload()).isEqualTo("bar");
done.set(true);
gateway.stop();
- verify(mockConn1).send(Mockito.any(Message.class));
+ verify(mockConn1).send(any(Message.class));
factory2.stop();
serverSocket.get().close();
}
@@ -571,7 +577,7 @@ void testFailoverCached() throws Exception {
when(factory1.getConnection()).thenReturn(mockConn1);
when(factory1.isSingleUse()).thenReturn(true);
doThrow(new UncheckedIOException(new IOException("fail")))
- .when(mockConn1).send(Mockito.any(Message.class));
+ .when(mockConn1).send(any(Message.class));
CachingClientConnectionFactory cachingFactory1 = new CachingClientConnectionFactory(factory1, 1);
AbstractClientConnectionFactory factory2 = new TcpNetClientConnectionFactory("localhost",
@@ -606,7 +612,7 @@ void testFailoverCached() throws Exception {
assertThat(reply.getPayload()).isEqualTo("bar");
done.set(true);
gateway.stop();
- verify(mockConn1).send(Mockito.any(Message.class));
+ verify(mockConn1).send(any(Message.class));
factory2.stop();
serverSocket.get().close();
}
@@ -1081,4 +1087,37 @@ void testAsyncTimeout() throws Exception {
}
}
+ @Test
+ void semaphoreIsReleasedOnAsyncSendFailure() throws InterruptedException {
+ AbstractClientConnectionFactory ccf = mock(AbstractClientConnectionFactory.class);
+
+ TcpConnection connection = mock(TcpConnectionSupport.class);
+
+ given(connection.getConnectionId()).willReturn("testId");
+ willThrow(new RuntimeException("intentional"))
+ .given(connection)
+ .send(any(Message.class));
+
+ willReturn(connection)
+ .given(ccf)
+ .getConnection();
+
+ TcpOutboundGateway gateway = new TcpOutboundGateway();
+ gateway.setConnectionFactory(ccf);
+ gateway.setAsync(true);
+ gateway.setBeanFactory(mock(BeanFactory.class));
+ gateway.setRemoteTimeout(-1);
+ gateway.afterPropertiesSet();
+
+ assertThatExceptionOfType(MessageHandlingException.class)
+ .isThrownBy(() -> gateway.handleMessage(new GenericMessage<>("Test1")))
+ .withCauseExactlyInstanceOf(RuntimeException.class)
+ .withStackTraceContaining("intentional");
+
+ assertThatExceptionOfType(MessageHandlingException.class)
+ .isThrownBy(() -> gateway.handleMessage(new GenericMessage<>("Test2")))
+ .withCauseExactlyInstanceOf(RuntimeException.class)
+ .withStackTraceContaining("intentional");
+ }
+
}