diff --git a/src/main/java/org/tarantool/TarantoolClientImpl.java b/src/main/java/org/tarantool/TarantoolClientImpl.java index 5256021e..c1d40b79 100644 --- a/src/main/java/org/tarantool/TarantoolClientImpl.java +++ b/src/main/java/org/tarantool/TarantoolClientImpl.java @@ -22,7 +22,6 @@ import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.locks.Condition; -import java.util.concurrent.locks.LockSupport; import java.util.concurrent.locks.ReentrantLock; public class TarantoolClientImpl extends TarantoolBase> implements TarantoolClient { @@ -72,10 +71,12 @@ public class TarantoolClientImpl extends TarantoolBase> implements Tar @Override public void run() { while (!Thread.currentThread().isInterrupted()) { - if (state.compareAndSet(StateHelper.RECONNECT, 0)) { - reconnect(0, thumbstone); + reconnect(0, thumbstone); + try { + state.awaitReconnection(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); } - LockSupport.park(state); } } }); @@ -146,14 +147,10 @@ protected void connect(final SocketChannel channel) throws Exception { TarantoolGreeting greeting = ProtoUtils.connect(channel, config.username, config.password); this.serverVersion = greeting.getServerVersion(); } catch (IOException e) { - try { - channel.close(); - } catch (IOException ignored) { - // No-op - } - + closeChannel(channel); throw new CommunicationException("Couldn't connect to tarantool", e); } + channel.configureBlocking(false); this.channel = channel; this.readChannel = new ReadableViaSelectorChannel(channel); @@ -169,44 +166,42 @@ protected void connect(final SocketChannel channel) throws Exception { } protected void startThreads(String threadName) throws InterruptedException { - final CountDownLatch init = new CountDownLatch(2); - reader = new Thread(new Runnable() { - @Override - public void run() { - init.countDown(); - if (state.acquire(StateHelper.READING)) { - try { - readThread(); - } finally { - state.release(StateHelper.READING); - if (state.compareAndSet(0, StateHelper.RECONNECT)) { - LockSupport.unpark(connector); - } + final CountDownLatch ioThreadStarted = new CountDownLatch(2); + final AtomicInteger leftIoThreads = new AtomicInteger(2); + reader = new Thread(() -> { + ioThreadStarted.countDown(); + if (state.acquire(StateHelper.READING)) { + try { + readThread(); + } finally { + state.release(StateHelper.READING); + // only last of two IO-threads can signal for reconnection + if (leftIoThreads.decrementAndGet() == 0) { + state.trySignalForReconnection(); } } } }); - writer = new Thread(new Runnable() { - @Override - public void run() { - init.countDown(); - if (state.acquire(StateHelper.WRITING)) { - try { - writeThread(); - } finally { - state.release(StateHelper.WRITING); - if (state.compareAndSet(0, StateHelper.RECONNECT)) { - LockSupport.unpark(connector); - } + writer = new Thread(() -> { + ioThreadStarted.countDown(); + if (state.acquire(StateHelper.WRITING)) { + try { + writeThread(); + } finally { + state.release(StateHelper.WRITING); + // only last of two IO-threads can signal for reconnection + if (leftIoThreads.decrementAndGet() == 0) { + state.trySignalForReconnection(); } } } }); + state.release(StateHelper.RECONNECT); configureThreads(threadName); reader.start(); writer.start(); - init.await(); + ioThreadStarted.await(); } protected void configureThreads(String threadName) { @@ -356,25 +351,21 @@ private boolean directWrite(ByteBuffer buffer) throws InterruptedException, IOEx } protected void readThread() { - try { - while (!Thread.currentThread().isInterrupted()) { - try { - TarantoolPacket packet = ProtoUtils.readPacket(readChannel); + while (!Thread.currentThread().isInterrupted()) { + try { + TarantoolPacket packet = ProtoUtils.readPacket(readChannel); - Map headers = packet.getHeaders(); + Map headers = packet.getHeaders(); - Long syncId = (Long) headers.get(Key.SYNC.getId()); - TarantoolOp future = futures.remove(syncId); - stats.received++; - wait.decrementAndGet(); - complete(packet, future); - } catch (Exception e) { - die("Cant read answer", e); - return; - } + Long syncId = (Long) headers.get(Key.SYNC.getId()); + TarantoolOp future = futures.remove(syncId); + stats.received++; + wait.decrementAndGet(); + complete(packet, future); + } catch (Exception e) { + die("Cant read answer", e); + return; } - } catch (Exception e) { - die("Cant init thread", e); } } @@ -489,7 +480,7 @@ protected void stopIO() { try { readChannel.close(); // also closes this.channel } catch (IOException ignored) { - // No-op + // no-op } } closeChannel(channel); @@ -639,6 +630,7 @@ public TarantoolClientStats getStats() { * Manages state changes. */ protected final class StateHelper { + static final int UNINITIALIZED = 0; static final int READING = 1; static final int WRITING = 2; static final int ALIVE = READING | WRITING; @@ -648,10 +640,22 @@ protected final class StateHelper { private final AtomicInteger state; private final AtomicReference nextAliveLatch = - new AtomicReference(new CountDownLatch(1)); + new AtomicReference<>(new CountDownLatch(1)); private final CountDownLatch closedLatch = new CountDownLatch(1); + /** + * The condition variable to signal a reconnection is needed from reader / + * writer threads and waiting for that signal from the reconnection thread. + * + * The lock variable to access this condition. + * + * @see #awaitReconnection() + * @see #trySignalForReconnection() + */ + protected final ReentrantLock connectorLock = new ReentrantLock(); + protected final Condition reconnectRequired = connectorLock.newCondition(); + protected StateHelper(int state) { this.state = new AtomicInteger(state); } @@ -660,30 +664,51 @@ protected int getState() { return state.get(); } + /** + * Set CLOSED state, drop RECONNECT state. + */ protected boolean close() { for (; ; ) { - int st = getState(); - if ((st & CLOSED) == CLOSED) { + int currentState = getState(); + + /* CLOSED is the terminal state. */ + if ((currentState & CLOSED) == CLOSED) { return false; } - if (compareAndSet(st, (st & ~RECONNECT) | CLOSED)) { + + /* Drop RECONNECT, set CLOSED. */ + if (compareAndSet(currentState, (currentState & ~RECONNECT) | CLOSED)) { return true; } } } + /** + * Move from a current state to a give one. + * + * Some moves are forbidden. + */ protected boolean acquire(int mask) { for (; ; ) { - int st = getState(); - if ((st & CLOSED) == CLOSED) { + int currentState = getState(); + + /* CLOSED is the terminal state. */ + if ((currentState & CLOSED) == CLOSED) { + return false; + } + + /* Don't move to READING, WRITING or ALIVE from RECONNECT. */ + if ((currentState & RECONNECT) > mask) { return false; } - if ((st & mask) != 0) { + /* Cannot move from a state to the same state. */ + if ((currentState & mask) != 0) { throw new IllegalStateException("State is already " + mask); } - if (compareAndSet(st, st | mask)) { + /* Set acquired state. */ + if (compareAndSet(currentState, currentState | mask)) { return true; } } @@ -691,8 +716,8 @@ protected boolean acquire(int mask) { protected void release(int mask) { for (; ; ) { - int st = getState(); - if (compareAndSet(st, st & ~mask)) { + int currentState = getState(); + if (compareAndSet(currentState, currentState & ~mask)) { return; } } @@ -713,10 +738,18 @@ protected boolean compareAndSet(int expect, int update) { return true; } + /** + * Reconnection uses another way to await state via receiving a signal + * instead of latches. + */ protected void awaitState(int state) throws InterruptedException { - CountDownLatch latch = getStateLatch(state); - if (latch != null) { - latch.await(); + if (state == RECONNECT) { + awaitReconnection(); + } else { + CountDownLatch latch = getStateLatch(state); + if (latch != null) { + latch.await(); + } } } @@ -740,6 +773,38 @@ private CountDownLatch getStateLatch(int state) { } return null; } + + /** + * Blocks until a reconnection signal will be received. + * + * @see #trySignalForReconnection() + */ + private void awaitReconnection() throws InterruptedException { + connectorLock.lock(); + try { + while (getState() != StateHelper.RECONNECT) { + reconnectRequired.await(); + } + } finally { + connectorLock.unlock(); + } + } + + /** + * Signals to the connector that reconnection process can be performed. + * + * @see #awaitReconnection() + */ + private void trySignalForReconnection() { + if (compareAndSet(StateHelper.UNINITIALIZED, StateHelper.RECONNECT)) { + connectorLock.lock(); + try { + reconnectRequired.signal(); + } finally { + connectorLock.unlock(); + } + } + } } protected static class TarantoolOp extends CompletableFuture { diff --git a/src/test/java/org/tarantool/ClientReconnectIT.java b/src/test/java/org/tarantool/ClientReconnectIT.java index f4300c1d..a38875bc 100644 --- a/src/test/java/org/tarantool/ClientReconnectIT.java +++ b/src/test/java/org/tarantool/ClientReconnectIT.java @@ -231,7 +231,8 @@ public void testLongParallelCloseReconnects() { SocketChannelProvider provider = new TestSocketChannelProvider(host, port, RESTART_TIMEOUT).setSoLinger(0); - final AtomicReferenceArray clients = new AtomicReferenceArray(numClients); + final AtomicReferenceArray clients = + new AtomicReferenceArray<>(numClients); for (int idx = 0; idx < clients.length(); idx++) { clients.set(idx, makeClient(provider)); @@ -249,9 +250,7 @@ public void testLongParallelCloseReconnects() { @Override public void run() { while (!Thread.currentThread().isInterrupted() && deadline > System.currentTimeMillis()) { - int idx = rnd.nextInt(clients.length()); - try { TarantoolClient cli = clients.get(idx); @@ -300,7 +299,7 @@ public void run() { // Wait for all threads to finish. try { - assertTrue(latch.await(RESTART_TIMEOUT, TimeUnit.MILLISECONDS)); + assertTrue(latch.await(RESTART_TIMEOUT * 2, TimeUnit.MILLISECONDS)); } catch (InterruptedException e) { fail(e); }