From cc36724aa7709ad38e67f9256e0307420f41777d Mon Sep 17 00:00:00 2001 From: nicktorwald Date: Fri, 29 Mar 2019 19:09:30 +0700 Subject: [PATCH] Race condition in TarantoolClientImpl A state of a client is a set of the following flags: {READING, WRITING, RECONNECT, CLOSED}. Let's name a state when no flags are set UNINITIALIZED. A reader thread sets READING, performs reading until an error or an interruption, drops READING and tries to trigger reconnection (if a state allows, see below). A writer do quite same things, but with the WRITING flag. The key point here is that a reconnection is triggered from a reader/writer thread and only when certain conditions are met. The prerequisite to set RECONNECT and signal (unpark) a connector thread is that a client has UNINITIALIZED state. There are several problems here: - Say, a reader stalls a bit after dropping READING, then a writer drops WRITING and trigger reconnection. Then reader wokes up and set RECONNECT again. - Calling unpark() N times for a connector thread when it is alive doesn't lead to skipping next N park() calls, so the problem above is not just about extra reconnection, but lead the connector thread to be stuck. - Say, a reader stalls just before setting READING. A writer is hit by an IO error and triggers reconnection (set RECONNECT, unpark connector). Then the reader wakes up and set READING+RECONNECT state that disallows a connector thread to proceed further (it expects pure RECONNECT). Even when the reader drops READING it will not wake up (unpark) the connector thread, because RECONNECT was already set (state is not UNINITIALIZED). This commit introduces several changes that eliminate the problems above: - Use ReentrantLock + Condition instead of park() / unpark() to never miss signals to reconnect, does not matter whether a connector is parked. - Ensure a reader and a writer threads from one generation (that are created on the same reconnection iteration) triggers reconnection once. - Hold RECONNECT state most of time a connector works (while acquiring a new socket, connecting and reading Tarantool greeting) and prevent to set READING/WRITING while RECONNECT is set. - Ensure a new reconnection iteration will start only after old reader and old writer threads exit (because we cannot receive a reconnection signal until we send it). Fixes: #142 Affects: #34, #136 --- .../org/tarantool/TarantoolClientImpl.java | 197 ++++++++++++------ .../java/org/tarantool/ClientReconnectIT.java | 7 +- 2 files changed, 134 insertions(+), 70 deletions(-) diff --git a/src/main/java/org/tarantool/TarantoolClientImpl.java b/src/main/java/org/tarantool/TarantoolClientImpl.java index 5256021e..c1d40b79 100644 --- a/src/main/java/org/tarantool/TarantoolClientImpl.java +++ b/src/main/java/org/tarantool/TarantoolClientImpl.java @@ -22,7 +22,6 @@ import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.locks.Condition; -import java.util.concurrent.locks.LockSupport; import java.util.concurrent.locks.ReentrantLock; public class TarantoolClientImpl extends TarantoolBase> implements TarantoolClient { @@ -72,10 +71,12 @@ public class TarantoolClientImpl extends TarantoolBase> implements Tar @Override public void run() { while (!Thread.currentThread().isInterrupted()) { - if (state.compareAndSet(StateHelper.RECONNECT, 0)) { - reconnect(0, thumbstone); + reconnect(0, thumbstone); + try { + state.awaitReconnection(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); } - LockSupport.park(state); } } }); @@ -146,14 +147,10 @@ protected void connect(final SocketChannel channel) throws Exception { TarantoolGreeting greeting = ProtoUtils.connect(channel, config.username, config.password); this.serverVersion = greeting.getServerVersion(); } catch (IOException e) { - try { - channel.close(); - } catch (IOException ignored) { - // No-op - } - + closeChannel(channel); throw new CommunicationException("Couldn't connect to tarantool", e); } + channel.configureBlocking(false); this.channel = channel; this.readChannel = new ReadableViaSelectorChannel(channel); @@ -169,44 +166,42 @@ protected void connect(final SocketChannel channel) throws Exception { } protected void startThreads(String threadName) throws InterruptedException { - final CountDownLatch init = new CountDownLatch(2); - reader = new Thread(new Runnable() { - @Override - public void run() { - init.countDown(); - if (state.acquire(StateHelper.READING)) { - try { - readThread(); - } finally { - state.release(StateHelper.READING); - if (state.compareAndSet(0, StateHelper.RECONNECT)) { - LockSupport.unpark(connector); - } + final CountDownLatch ioThreadStarted = new CountDownLatch(2); + final AtomicInteger leftIoThreads = new AtomicInteger(2); + reader = new Thread(() -> { + ioThreadStarted.countDown(); + if (state.acquire(StateHelper.READING)) { + try { + readThread(); + } finally { + state.release(StateHelper.READING); + // only last of two IO-threads can signal for reconnection + if (leftIoThreads.decrementAndGet() == 0) { + state.trySignalForReconnection(); } } } }); - writer = new Thread(new Runnable() { - @Override - public void run() { - init.countDown(); - if (state.acquire(StateHelper.WRITING)) { - try { - writeThread(); - } finally { - state.release(StateHelper.WRITING); - if (state.compareAndSet(0, StateHelper.RECONNECT)) { - LockSupport.unpark(connector); - } + writer = new Thread(() -> { + ioThreadStarted.countDown(); + if (state.acquire(StateHelper.WRITING)) { + try { + writeThread(); + } finally { + state.release(StateHelper.WRITING); + // only last of two IO-threads can signal for reconnection + if (leftIoThreads.decrementAndGet() == 0) { + state.trySignalForReconnection(); } } } }); + state.release(StateHelper.RECONNECT); configureThreads(threadName); reader.start(); writer.start(); - init.await(); + ioThreadStarted.await(); } protected void configureThreads(String threadName) { @@ -356,25 +351,21 @@ private boolean directWrite(ByteBuffer buffer) throws InterruptedException, IOEx } protected void readThread() { - try { - while (!Thread.currentThread().isInterrupted()) { - try { - TarantoolPacket packet = ProtoUtils.readPacket(readChannel); + while (!Thread.currentThread().isInterrupted()) { + try { + TarantoolPacket packet = ProtoUtils.readPacket(readChannel); - Map headers = packet.getHeaders(); + Map headers = packet.getHeaders(); - Long syncId = (Long) headers.get(Key.SYNC.getId()); - TarantoolOp future = futures.remove(syncId); - stats.received++; - wait.decrementAndGet(); - complete(packet, future); - } catch (Exception e) { - die("Cant read answer", e); - return; - } + Long syncId = (Long) headers.get(Key.SYNC.getId()); + TarantoolOp future = futures.remove(syncId); + stats.received++; + wait.decrementAndGet(); + complete(packet, future); + } catch (Exception e) { + die("Cant read answer", e); + return; } - } catch (Exception e) { - die("Cant init thread", e); } } @@ -489,7 +480,7 @@ protected void stopIO() { try { readChannel.close(); // also closes this.channel } catch (IOException ignored) { - // No-op + // no-op } } closeChannel(channel); @@ -639,6 +630,7 @@ public TarantoolClientStats getStats() { * Manages state changes. */ protected final class StateHelper { + static final int UNINITIALIZED = 0; static final int READING = 1; static final int WRITING = 2; static final int ALIVE = READING | WRITING; @@ -648,10 +640,22 @@ protected final class StateHelper { private final AtomicInteger state; private final AtomicReference nextAliveLatch = - new AtomicReference(new CountDownLatch(1)); + new AtomicReference<>(new CountDownLatch(1)); private final CountDownLatch closedLatch = new CountDownLatch(1); + /** + * The condition variable to signal a reconnection is needed from reader / + * writer threads and waiting for that signal from the reconnection thread. + * + * The lock variable to access this condition. + * + * @see #awaitReconnection() + * @see #trySignalForReconnection() + */ + protected final ReentrantLock connectorLock = new ReentrantLock(); + protected final Condition reconnectRequired = connectorLock.newCondition(); + protected StateHelper(int state) { this.state = new AtomicInteger(state); } @@ -660,30 +664,51 @@ protected int getState() { return state.get(); } + /** + * Set CLOSED state, drop RECONNECT state. + */ protected boolean close() { for (; ; ) { - int st = getState(); - if ((st & CLOSED) == CLOSED) { + int currentState = getState(); + + /* CLOSED is the terminal state. */ + if ((currentState & CLOSED) == CLOSED) { return false; } - if (compareAndSet(st, (st & ~RECONNECT) | CLOSED)) { + + /* Drop RECONNECT, set CLOSED. */ + if (compareAndSet(currentState, (currentState & ~RECONNECT) | CLOSED)) { return true; } } } + /** + * Move from a current state to a give one. + * + * Some moves are forbidden. + */ protected boolean acquire(int mask) { for (; ; ) { - int st = getState(); - if ((st & CLOSED) == CLOSED) { + int currentState = getState(); + + /* CLOSED is the terminal state. */ + if ((currentState & CLOSED) == CLOSED) { + return false; + } + + /* Don't move to READING, WRITING or ALIVE from RECONNECT. */ + if ((currentState & RECONNECT) > mask) { return false; } - if ((st & mask) != 0) { + /* Cannot move from a state to the same state. */ + if ((currentState & mask) != 0) { throw new IllegalStateException("State is already " + mask); } - if (compareAndSet(st, st | mask)) { + /* Set acquired state. */ + if (compareAndSet(currentState, currentState | mask)) { return true; } } @@ -691,8 +716,8 @@ protected boolean acquire(int mask) { protected void release(int mask) { for (; ; ) { - int st = getState(); - if (compareAndSet(st, st & ~mask)) { + int currentState = getState(); + if (compareAndSet(currentState, currentState & ~mask)) { return; } } @@ -713,10 +738,18 @@ protected boolean compareAndSet(int expect, int update) { return true; } + /** + * Reconnection uses another way to await state via receiving a signal + * instead of latches. + */ protected void awaitState(int state) throws InterruptedException { - CountDownLatch latch = getStateLatch(state); - if (latch != null) { - latch.await(); + if (state == RECONNECT) { + awaitReconnection(); + } else { + CountDownLatch latch = getStateLatch(state); + if (latch != null) { + latch.await(); + } } } @@ -740,6 +773,38 @@ private CountDownLatch getStateLatch(int state) { } return null; } + + /** + * Blocks until a reconnection signal will be received. + * + * @see #trySignalForReconnection() + */ + private void awaitReconnection() throws InterruptedException { + connectorLock.lock(); + try { + while (getState() != StateHelper.RECONNECT) { + reconnectRequired.await(); + } + } finally { + connectorLock.unlock(); + } + } + + /** + * Signals to the connector that reconnection process can be performed. + * + * @see #awaitReconnection() + */ + private void trySignalForReconnection() { + if (compareAndSet(StateHelper.UNINITIALIZED, StateHelper.RECONNECT)) { + connectorLock.lock(); + try { + reconnectRequired.signal(); + } finally { + connectorLock.unlock(); + } + } + } } protected static class TarantoolOp extends CompletableFuture { diff --git a/src/test/java/org/tarantool/ClientReconnectIT.java b/src/test/java/org/tarantool/ClientReconnectIT.java index f4300c1d..a38875bc 100644 --- a/src/test/java/org/tarantool/ClientReconnectIT.java +++ b/src/test/java/org/tarantool/ClientReconnectIT.java @@ -231,7 +231,8 @@ public void testLongParallelCloseReconnects() { SocketChannelProvider provider = new TestSocketChannelProvider(host, port, RESTART_TIMEOUT).setSoLinger(0); - final AtomicReferenceArray clients = new AtomicReferenceArray(numClients); + final AtomicReferenceArray clients = + new AtomicReferenceArray<>(numClients); for (int idx = 0; idx < clients.length(); idx++) { clients.set(idx, makeClient(provider)); @@ -249,9 +250,7 @@ public void testLongParallelCloseReconnects() { @Override public void run() { while (!Thread.currentThread().isInterrupted() && deadline > System.currentTimeMillis()) { - int idx = rnd.nextInt(clients.length()); - try { TarantoolClient cli = clients.get(idx); @@ -300,7 +299,7 @@ public void run() { // Wait for all threads to finish. try { - assertTrue(latch.await(RESTART_TIMEOUT, TimeUnit.MILLISECONDS)); + assertTrue(latch.await(RESTART_TIMEOUT * 2, TimeUnit.MILLISECONDS)); } catch (InterruptedException e) { fail(e); }