From ed99a9bf734544074f20ffb536bbabc90ee2c206 Mon Sep 17 00:00:00 2001 From: nicktorwald Date: Tue, 10 Sep 2019 17:51:13 +0700 Subject: [PATCH] Propagate a client thumbstone when a client creation fails Closes: #30 --- .../RoundRobinSocketProviderImpl.java | 8 +- .../org/tarantool/TarantoolClientImpl.java | 10 +- .../protocol/ReadableViaSelectorChannel.java | 15 ++- .../tarantool/ClientReconnectClusterIT.java | 70 +++++++++++++ .../java/org/tarantool/ClientReconnectIT.java | 99 ++++++++++++++++++- src/test/java/org/tarantool/TestUtils.java | 65 ++++++++++++ 6 files changed, 251 insertions(+), 16 deletions(-) diff --git a/src/main/java/org/tarantool/RoundRobinSocketProviderImpl.java b/src/main/java/org/tarantool/RoundRobinSocketProviderImpl.java index 1107799c..804c869b 100644 --- a/src/main/java/org/tarantool/RoundRobinSocketProviderImpl.java +++ b/src/main/java/org/tarantool/RoundRobinSocketProviderImpl.java @@ -140,7 +140,7 @@ protected InetSocketAddress getLastObtainedAddress() { @Override protected SocketChannel makeAttempt(int retryNumber, Throwable lastError) throws IOException { if (retryNumber > getAddressCount()) { - throwFatalError("No more connection addresses are left."); + throwFatalError("No more connection addresses are left.", lastError); } int retriesLimit = getRetriesLimit(); @@ -165,7 +165,7 @@ protected SocketChannel makeAttempt(int retryNumber, Throwable lastError) throws @Override public void setRetriesLimit(int retriesLimit) { if (retriesLimit == 0) { - throwFatalError("Retries count should be at least 1 or more"); + throwFatalError("Retries count should be at least 1 or more", null); } super.setRetriesLimit(retriesLimit); } @@ -212,8 +212,8 @@ public void refreshAddresses(Collection addresses) { updateAddressList(addresses); } - private void throwFatalError(String message) { - throw new CommunicationException(message); + private void throwFatalError(String message, Throwable lastError) { + throw new CommunicationException(message, lastError); } } diff --git a/src/main/java/org/tarantool/TarantoolClientImpl.java b/src/main/java/org/tarantool/TarantoolClientImpl.java index dcb7a8ba..bad3158c 100644 --- a/src/main/java/org/tarantool/TarantoolClientImpl.java +++ b/src/main/java/org/tarantool/TarantoolClientImpl.java @@ -28,9 +28,6 @@ public class TarantoolClientImpl extends TarantoolBase> implements TarantoolClient { - public static final CommunicationException NOT_INIT_EXCEPTION - = new CommunicationException("Not connected, initializing connection"); - protected TarantoolClientConfig config; protected long operationTimeout; @@ -101,7 +98,6 @@ public TarantoolClientImpl(SocketChannelProvider socketProvider, TarantoolClient } private void initClient(SocketChannelProvider socketProvider, TarantoolClientConfig config) { - this.thumbstone = NOT_INIT_EXCEPTION; this.config = config; this.initialRequestSize = config.defaultRequestSize; this.operationTimeout = config.operationExpiryTimeMillis; @@ -130,8 +126,8 @@ private void startConnector(long initTimeoutMillis) { CommunicationException e = new CommunicationException( initTimeoutMillis + "ms is exceeded when waiting for client initialization. " + - "You could configure init timeout in TarantoolConfig" - ); + "You could configure init timeout in TarantoolConfig", + thumbstone); close(e); throw e; @@ -147,7 +143,7 @@ protected void reconnect(Throwable lastError) { int retryNumber = 0; while (!Thread.currentThread().isInterrupted()) { try { - channel = socketProvider.get(retryNumber++, lastError == NOT_INIT_EXCEPTION ? null : lastError); + channel = socketProvider.get(retryNumber++, lastError); } catch (Exception e) { closeChannel(channel); lastError = e; diff --git a/src/main/java/org/tarantool/protocol/ReadableViaSelectorChannel.java b/src/main/java/org/tarantool/protocol/ReadableViaSelectorChannel.java index f6565ba5..3f797ca6 100644 --- a/src/main/java/org/tarantool/protocol/ReadableViaSelectorChannel.java +++ b/src/main/java/org/tarantool/protocol/ReadableViaSelectorChannel.java @@ -37,14 +37,14 @@ public int read(ByteBuffer buffer) throws IOException { count = n = channel.read(buffer); if (n < 0) { - throw new CommunicationException("Channel read failed " + n); + throw new CommunicationException("Channel read failed " + formatReadBytes(n)); } while (buffer.remaining() > 0) { selector.select(); n = channel.read(buffer); if (n < 0) { - throw new CommunicationException("Channel read failed: " + n); + throw new CommunicationException("Channel read failed: " + formatReadBytes(n)); } count += n; } @@ -61,4 +61,15 @@ public void close() throws IOException { selector.close(); channel.close(); } + + /** + * Formats the bytes count to a human readable message. + * + * @param bytes number of bytes + * + * @return formatted message + */ + private String formatReadBytes(int bytes) { + return bytes < 0 ? "EOF" : bytes + " bytes"; + } } diff --git a/src/test/java/org/tarantool/ClientReconnectClusterIT.java b/src/test/java/org/tarantool/ClientReconnectClusterIT.java index c0228874..f46031fa 100644 --- a/src/test/java/org/tarantool/ClientReconnectClusterIT.java +++ b/src/test/java/org/tarantool/ClientReconnectClusterIT.java @@ -2,6 +2,8 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.tarantool.TestUtils.findCause; import static org.tarantool.TestUtils.makeDefaultClusterClientConfig; import static org.tarantool.TestUtils.makeDiscoveryFunction; @@ -13,6 +15,7 @@ import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.Test; +import java.net.ConnectException; import java.time.Duration; import java.util.Arrays; import java.util.Collections; @@ -393,6 +396,73 @@ void testDelayFunctionResultFetch() { expectDisconnected(client, spaceId, pkId); } + @Test + void testRoundRobinSocketProviderRefusedByFakeReason() { + stopInstancesAndAwait(SRV1); + stopInstancesAndAwait(SRV2); + stopInstancesAndAwait(SRV3); + + RuntimeException error = new RuntimeException("Fake error"); + TarantoolClusterClientConfig config = makeDefaultClusterClientConfig(); + config.initTimeoutMillis = 1000; + Throwable exception = assertThrows( + CommunicationException.class, + () -> { + new TarantoolClusterClient( + config, + TestUtils.wrapByErroredProvider(new RoundRobinSocketProviderImpl( + "localhost:" + PORTS[0], + "localhost:" + PORTS[1], + "localhost:" + PORTS[2] + ), error) + ); + } + ); + assertTrue(findCause(exception, error)); + } + + @Test + void testRoundRobinSocketProviderRefused() { + stopInstancesAndAwait(SRV1); + stopInstancesAndAwait(SRV2); + stopInstancesAndAwait(SRV3); + + TarantoolClusterClientConfig config = makeDefaultClusterClientConfig(); + config.initTimeoutMillis = 1000; + Throwable exception = assertThrows( + CommunicationException.class, + () -> { + new TarantoolClusterClient( + config, + new RoundRobinSocketProviderImpl("localhost:" + PORTS[0]) + ); + } + ); + assertTrue(findCause(exception, ConnectException.class)); + } + + @Test + void testRoundRobinSocketProviderRefusedAfterConnect() { + final TarantoolClientImpl client = makeClusterClient( + "localhost:" + PORTS[0], + "localhost:" + PORTS[1], + "localhost:" + PORTS[2] + ); + + client.ping(); + stopInstancesAndAwait(SRV1); + + client.ping(); + stopInstancesAndAwait(SRV2); + + client.ping(); + stopInstancesAndAwait(SRV3); + + CommunicationException exception = assertThrows(CommunicationException.class, client::ping); + Throwable origin = exception.getCause(); + assertEquals(origin, client.getThumbstone()); + } + private void tryAwait(CyclicBarrier barrier) { try { barrier.await(6000, TimeUnit.MILLISECONDS); diff --git a/src/test/java/org/tarantool/ClientReconnectIT.java b/src/test/java/org/tarantool/ClientReconnectIT.java index f644fc8f..af2afc86 100644 --- a/src/test/java/org/tarantool/ClientReconnectIT.java +++ b/src/test/java/org/tarantool/ClientReconnectIT.java @@ -7,6 +7,7 @@ import static org.junit.jupiter.api.Assertions.assertTimeoutPreemptively; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.fail; +import static org.tarantool.TestUtils.findCause; import static org.tarantool.TestUtils.makeDefaultClientConfig; import static org.tarantool.TestUtils.makeTestClient; @@ -16,6 +17,7 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.function.Executable; +import java.net.ConnectException; import java.nio.channels.SocketChannel; import java.time.Duration; import java.util.Collections; @@ -27,6 +29,7 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.atomic.AtomicReferenceArray; import java.util.concurrent.locks.LockSupport; @@ -377,7 +380,7 @@ public void run() { /** * Verify that we don't exceed a file descriptor limit (and so likely don't * leak file descriptors) when trying to connect to an existing node with - * wrong authentification credentials. + * wrong authentication credentials. *

* The test sets SO_LINGER to 0 for outgoing connections to avoid producing * many TIME_WAIT sockets, because an available port range can be @@ -412,13 +415,103 @@ public void testReconnectWrongAuth() throws Exception { client.close(); } - private TestSocketChannelProvider makeZeroLingerProvider() { + @Test + void testFirstConnectionRefused() { + RuntimeException error = new RuntimeException("Fake error"); + TarantoolClientConfig config = makeDefaultClientConfig(); + config.initTimeoutMillis = 100; + Throwable exception = assertThrows( + CommunicationException.class, + () -> new TarantoolClientImpl(makeErroredProvider(error), config) + ); + assertTrue(findCause(exception, error)); + } + + @Test + void testConnectionRefusedAfterConnect() { + TarantoolClientImpl client = new TarantoolClientImpl(makeErroredProvider(null), makeDefaultClientConfig()); + client.ping(); + + testHelper.stopInstance(); + CommunicationException exception = assertThrows(CommunicationException.class, client::ping); + + Throwable origin = exception.getCause(); + assertEquals(origin, client.getThumbstone()); + + testHelper.startInstance(); + } + + @Test + void testSocketProviderRefusedByFakeReason() { + TarantoolClientConfig config = makeDefaultClientConfig(); + RuntimeException error = new RuntimeException("Fake error"); + config.initTimeoutMillis = 1000; + + SingleSocketChannelProviderImpl socketProvider = new SingleSocketChannelProviderImpl("localhost:3301"); + + testHelper.stopInstance(); + Throwable exception = assertThrows( + CommunicationException.class, + () -> new TarantoolClientImpl(TestUtils.wrapByErroredProvider(socketProvider, error), config) + ); + testHelper.startInstance(); + assertTrue(findCause(exception, error)); + } + + @Test + void testSingleSocketProviderRefused() { + testHelper.stopInstance(); + + TarantoolClientConfig config = makeDefaultClientConfig(); + config.initTimeoutMillis = 1000; + + SingleSocketChannelProviderImpl socketProvider = new SingleSocketChannelProviderImpl("localhost:3301"); + + Throwable exception = assertThrows( + CommunicationException.class, + () -> new TarantoolClientImpl(socketProvider, config) + ); + testHelper.startInstance(); + assertTrue(findCause(exception, ConnectException.class)); + } + + @Test + void testSingleSocketProviderRefusedAfterConnect() { + TarantoolClientImpl client = new TarantoolClientImpl(socketChannelProvider, makeDefaultClientConfig()); + + client.ping(); + testHelper.stopInstance(); + + CommunicationException exception = assertThrows(CommunicationException.class, client::ping); + Throwable origin = exception.getCause(); + assertEquals(origin, client.getThumbstone()); + + testHelper.startInstance(); + } + + private SocketChannelProvider makeZeroLingerProvider() { return new TestSocketChannelProvider( TarantoolTestHelper.HOST, TarantoolTestHelper.PORT, RESTART_TIMEOUT ).setSoLinger(0); } - TarantoolClient makeClient(SocketChannelProvider provider) { + private SocketChannelProvider makeErroredProvider(RuntimeException error) { + return new SocketChannelProvider() { + private final SocketChannelProvider delegate = makeZeroLingerProvider(); + private AtomicReference errorReference = new AtomicReference<>(error); + + @Override + public SocketChannel get(int retryNumber, Throwable lastError) { + RuntimeException rawError = errorReference.get(); + if (rawError != null) { + throw rawError; + } + return delegate.get(retryNumber, lastError); + } + }; + } + + private TarantoolClient makeClient(SocketChannelProvider provider) { return new TarantoolClientImpl(provider, makeDefaultClientConfig()); } diff --git a/src/test/java/org/tarantool/TestUtils.java b/src/test/java/org/tarantool/TestUtils.java index b65a5de8..cb2b73e6 100644 --- a/src/test/java/org/tarantool/TestUtils.java +++ b/src/test/java/org/tarantool/TestUtils.java @@ -5,6 +5,7 @@ import java.net.InetSocketAddress; import java.net.Socket; import java.net.SocketAddress; +import java.nio.channels.SocketChannel; import java.util.Collection; import java.util.HashMap; import java.util.List; @@ -290,4 +291,68 @@ public static TarantoolClientConfig makeDefaultClientConfig() { return config; } + /** + * Wraps a socket channel provider + * {@link SocketChannelProvider#get(int, Throwable)} method. + * When an error is raised the wrapper substitutes + * this error by the predefined one. The original value is + * still accessible as a cause of the injected error. + * + * @param provider provider to be wrapped + * @param error error to be thrown instead of original + * + * @return wrapped provider + */ + public static SocketChannelProvider wrapByErroredProvider(SocketChannelProvider provider, RuntimeException error) { + return new SocketChannelProvider() { + private final SocketChannelProvider delegate = provider; + + @Override + public SocketChannel get(int retryNumber, Throwable lastError) { + try { + return delegate.get(retryNumber, lastError); + } catch (Exception e) { + error.initCause(e); + throw error; + } + } + }; + } + + /** + * Searches recursively the given cause for a root error. + * + * @param error root error + * @param cause cause to be found + * + * @return {@literal true} if cause is found within a cause chain + */ + public static boolean findCause(Throwable error, Throwable cause) { + while (error.getCause() != null) { + error = error.getCause(); + if (cause.equals(error)) { + return true; + } + } + return false; + } + + /** + * Searches recursively the first cause being the given class type. + * + * @param error root error + * @param type cause class to be found + * + * @return {@literal true} if cause is found within a cause chain + */ + public static boolean findCause(Throwable error, Class type) { + while (error.getCause() != null) { + error = error.getCause(); + if (type == error.getClass()) { + return true; + } + } + return false; + } + }