diff --git a/spring-integration-ip/src/main/java/org/springframework/integration/ip/tcp/TcpOutboundGateway.java b/spring-integration-ip/src/main/java/org/springframework/integration/ip/tcp/TcpOutboundGateway.java index 13c14c14497..8f7c9f36dc1 100644 --- a/spring-integration-ip/src/main/java/org/springframework/integration/ip/tcp/TcpOutboundGateway.java +++ b/spring-integration-ip/src/main/java/org/springframework/integration/ip/tcp/TcpOutboundGateway.java @@ -1,5 +1,5 @@ /* - * Copyright 2001-2022 the original author or authors. + * Copyright 2001-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,6 +22,7 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ScheduledFuture; import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; @@ -59,7 +60,6 @@ *

* {@link org.springframework.context.Lifecycle} methods delegate to the underlying {@link AbstractConnectionFactory}. * - * * @author Gary Russell * @author Artem Bilan * @@ -223,7 +223,17 @@ protected Object handleRequestMessage(Message requestMessage) { this.pendingReplies.put(connectionId, reply); String connectionIdToLog = connectionId; logger.debug(() -> "Added pending reply " + connectionIdToLog); - connection.send(requestMessage); + try { + connection.send(requestMessage); + } + catch (Exception ex) { + // If it cannot send, then no reply for this connection. + // Therefor release resources for subsequent requests. + if (async) { + cleanUp(haveSemaphore, connection, connectionId); + } + throw ex; + } if (this.closeStreamAfterSend) { connection.shutdownOutput(); } @@ -326,7 +336,7 @@ public boolean onMessage(Message message) { if (reply == null) { if (message instanceof ErrorMessage) { /* - * Socket errors are sent here so they can be conveyed to any waiting thread. + * Socket errors are sent here, so they can be conveyed to any waiting thread. * If there's not one, simply ignore. */ return false; @@ -427,7 +437,11 @@ private final class AsyncReply { private final boolean haveSemaphore; - private final CompletableFuture> future = new CompletableFuture<>(); + private final ScheduledFuture noResponseFuture; + + private final CompletableFuture> future = + new CompletableFuture>() + .thenApply(this::cancelNoResponseFutureIfAny); private volatile Message reply; @@ -440,13 +454,27 @@ private final class AsyncReply { this.connection = connection; this.haveSemaphore = haveSemaphore; if (async && remoteTimeout > 0) { - getTaskScheduler() - .schedule(() -> { - TcpOutboundGateway.this.pendingReplies.remove(connection.getConnectionId()); - this.future.completeExceptionally( - new MessageTimeoutException(requestMessage, "Timed out waiting for response")); - }, Instant.now().plusMillis(remoteTimeout)); + this.noResponseFuture = + getTaskScheduler() + .schedule(() -> { + if (this.future.completeExceptionally( + new MessageTimeoutException(requestMessage, + "Timed out waiting for response"))) { + + cleanUp(this.haveSemaphore, this.connection, this.connection.getConnectionId()); + } + }, Instant.now().plusMillis(remoteTimeout)); + } + else { + this.noResponseFuture = null; + } + } + + private Message cancelNoResponseFutureIfAny(Message message) { + if (this.noResponseFuture != null) { + this.noResponseFuture.cancel(true); } + return message; } TcpConnection getConnection() { diff --git a/spring-integration-ip/src/test/java/org/springframework/integration/ip/tcp/TcpOutboundGatewayTests.java b/spring-integration-ip/src/test/java/org/springframework/integration/ip/tcp/TcpOutboundGatewayTests.java index 6eacc9d0d00..2eb7531287f 100644 --- a/spring-integration-ip/src/test/java/org/springframework/integration/ip/tcp/TcpOutboundGatewayTests.java +++ b/spring-integration-ip/src/test/java/org/springframework/integration/ip/tcp/TcpOutboundGatewayTests.java @@ -61,6 +61,7 @@ import org.springframework.integration.ip.tcp.connection.AbstractClientConnectionFactory; import org.springframework.integration.ip.tcp.connection.CachingClientConnectionFactory; import org.springframework.integration.ip.tcp.connection.FailoverClientConnectionFactory; +import org.springframework.integration.ip.tcp.connection.TcpConnection; import org.springframework.integration.ip.tcp.connection.TcpConnectionSupport; import org.springframework.integration.ip.tcp.connection.TcpNetClientConnectionFactory; import org.springframework.integration.ip.tcp.connection.TcpNioClientConnectionFactory; @@ -80,9 +81,14 @@ import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.Assertions.catchThrowable; import static org.assertj.core.api.Assertions.fail; import static org.awaitility.Awaitility.await; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; +import static org.mockito.BDDMockito.willReturn; +import static org.mockito.BDDMockito.willThrow; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; @@ -397,7 +403,7 @@ private void testGoodNetGWTimeoutGuts(AbstractClientConnectionFactory ccf, Expression remoteTimeoutExpression = Mockito.mock(Expression.class); - when(remoteTimeoutExpression.getValue(Mockito.any(EvaluationContext.class), Mockito.any(Message.class), + when(remoteTimeoutExpression.getValue(any(EvaluationContext.class), any(Message.class), Mockito.eq(Long.class))).thenReturn(50L, 60000L); gateway.setRemoteTimeoutExpression(remoteTimeoutExpression); @@ -488,7 +494,7 @@ void testCachingFailover() throws Exception { TcpConnectionSupport mockConn1 = makeMockConnection(); when(factory1.getConnection()).thenReturn(mockConn1); doThrow(new UncheckedIOException(new IOException("fail"))) - .when(mockConn1).send(Mockito.any(Message.class)); + .when(mockConn1).send(any(Message.class)); AbstractClientConnectionFactory factory2 = new TcpNetClientConnectionFactory("localhost", serverSocket.get().getLocalPort()); @@ -521,7 +527,7 @@ void testCachingFailover() throws Exception { assertThat(reply.getPayload()).isEqualTo("bar"); done.set(true); gateway.stop(); - verify(mockConn1).send(Mockito.any(Message.class)); + verify(mockConn1).send(any(Message.class)); factory2.stop(); serverSocket.get().close(); } @@ -571,7 +577,7 @@ void testFailoverCached() throws Exception { when(factory1.getConnection()).thenReturn(mockConn1); when(factory1.isSingleUse()).thenReturn(true); doThrow(new UncheckedIOException(new IOException("fail"))) - .when(mockConn1).send(Mockito.any(Message.class)); + .when(mockConn1).send(any(Message.class)); CachingClientConnectionFactory cachingFactory1 = new CachingClientConnectionFactory(factory1, 1); AbstractClientConnectionFactory factory2 = new TcpNetClientConnectionFactory("localhost", @@ -606,7 +612,7 @@ void testFailoverCached() throws Exception { assertThat(reply.getPayload()).isEqualTo("bar"); done.set(true); gateway.stop(); - verify(mockConn1).send(Mockito.any(Message.class)); + verify(mockConn1).send(any(Message.class)); factory2.stop(); serverSocket.get().close(); } @@ -1081,4 +1087,37 @@ void testAsyncTimeout() throws Exception { } } + @Test + void semaphoreIsReleasedOnAsyncSendFailure() throws InterruptedException { + AbstractClientConnectionFactory ccf = mock(AbstractClientConnectionFactory.class); + + TcpConnection connection = mock(TcpConnectionSupport.class); + + given(connection.getConnectionId()).willReturn("testId"); + willThrow(new RuntimeException("intentional")) + .given(connection) + .send(any(Message.class)); + + willReturn(connection) + .given(ccf) + .getConnection(); + + TcpOutboundGateway gateway = new TcpOutboundGateway(); + gateway.setConnectionFactory(ccf); + gateway.setAsync(true); + gateway.setBeanFactory(mock(BeanFactory.class)); + gateway.setRemoteTimeout(-1); + gateway.afterPropertiesSet(); + + assertThatExceptionOfType(MessageHandlingException.class) + .isThrownBy(() -> gateway.handleMessage(new GenericMessage<>("Test1"))) + .withCauseExactlyInstanceOf(RuntimeException.class) + .withStackTraceContaining("intentional"); + + assertThatExceptionOfType(MessageHandlingException.class) + .isThrownBy(() -> gateway.handleMessage(new GenericMessage<>("Test2"))) + .withCauseExactlyInstanceOf(RuntimeException.class) + .withStackTraceContaining("intentional"); + } + }