Skip to content

Race condition in TarantoolClientImpl #145

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
197 changes: 131 additions & 66 deletions src/main/java/org/tarantool/TarantoolClientImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.LockSupport;
import java.util.concurrent.locks.ReentrantLock;

public class TarantoolClientImpl extends TarantoolBase<Future<?>> implements TarantoolClient {
Expand Down Expand Up @@ -72,10 +71,12 @@ public class TarantoolClientImpl extends TarantoolBase<Future<?>> implements Tar
@Override
public void run() {
while (!Thread.currentThread().isInterrupted()) {
if (state.compareAndSet(StateHelper.RECONNECT, 0)) {
reconnect(0, thumbstone);
reconnect(0, thumbstone);
try {
state.awaitReconnection();
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
LockSupport.park(state);
}
}
});
Expand Down Expand Up @@ -146,14 +147,10 @@ protected void connect(final SocketChannel channel) throws Exception {
TarantoolGreeting greeting = ProtoUtils.connect(channel, config.username, config.password);
this.serverVersion = greeting.getServerVersion();
} catch (IOException e) {
try {
channel.close();
} catch (IOException ignored) {
// No-op
}

closeChannel(channel);
throw new CommunicationException("Couldn't connect to tarantool", e);
}

channel.configureBlocking(false);
this.channel = channel;
this.readChannel = new ReadableViaSelectorChannel(channel);
Expand All @@ -169,44 +166,42 @@ protected void connect(final SocketChannel channel) throws Exception {
}

protected void startThreads(String threadName) throws InterruptedException {
final CountDownLatch init = new CountDownLatch(2);
reader = new Thread(new Runnable() {
@Override
public void run() {
init.countDown();
if (state.acquire(StateHelper.READING)) {
try {
readThread();
} finally {
state.release(StateHelper.READING);
if (state.compareAndSet(0, StateHelper.RECONNECT)) {
LockSupport.unpark(connector);
}
final CountDownLatch ioThreadStarted = new CountDownLatch(2);
final AtomicInteger leftIoThreads = new AtomicInteger(2);
reader = new Thread(() -> {
ioThreadStarted.countDown();
if (state.acquire(StateHelper.READING)) {
try {
readThread();
} finally {
state.release(StateHelper.READING);
// only last of two IO-threads can signal for reconnection
if (leftIoThreads.decrementAndGet() == 0) {
state.trySignalForReconnection();
}
}
}
});
writer = new Thread(new Runnable() {
@Override
public void run() {
init.countDown();
if (state.acquire(StateHelper.WRITING)) {
try {
writeThread();
} finally {
state.release(StateHelper.WRITING);
if (state.compareAndSet(0, StateHelper.RECONNECT)) {
LockSupport.unpark(connector);
}
writer = new Thread(() -> {
ioThreadStarted.countDown();
if (state.acquire(StateHelper.WRITING)) {
try {
writeThread();
} finally {
state.release(StateHelper.WRITING);
// only last of two IO-threads can signal for reconnection
if (leftIoThreads.decrementAndGet() == 0) {
state.trySignalForReconnection();
}
}
}
});
state.release(StateHelper.RECONNECT);

configureThreads(threadName);
reader.start();
writer.start();
init.await();
ioThreadStarted.await();
}

protected void configureThreads(String threadName) {
Expand Down Expand Up @@ -356,25 +351,21 @@ private boolean directWrite(ByteBuffer buffer) throws InterruptedException, IOEx
}

protected void readThread() {
try {
while (!Thread.currentThread().isInterrupted()) {
try {
TarantoolPacket packet = ProtoUtils.readPacket(readChannel);
while (!Thread.currentThread().isInterrupted()) {
try {
TarantoolPacket packet = ProtoUtils.readPacket(readChannel);

Map<Integer, Object> headers = packet.getHeaders();
Map<Integer, Object> headers = packet.getHeaders();

Long syncId = (Long) headers.get(Key.SYNC.getId());
TarantoolOp<?> future = futures.remove(syncId);
stats.received++;
wait.decrementAndGet();
complete(packet, future);
} catch (Exception e) {
die("Cant read answer", e);
return;
}
Long syncId = (Long) headers.get(Key.SYNC.getId());
TarantoolOp<?> future = futures.remove(syncId);
stats.received++;
wait.decrementAndGet();
complete(packet, future);
} catch (Exception e) {
die("Cant read answer", e);
return;
}
} catch (Exception e) {
die("Cant init thread", e);
}
}

Expand Down Expand Up @@ -489,7 +480,7 @@ protected void stopIO() {
try {
readChannel.close(); // also closes this.channel
} catch (IOException ignored) {
// No-op
// no-op
}
}
closeChannel(channel);
Expand Down Expand Up @@ -639,6 +630,7 @@ public TarantoolClientStats getStats() {
* Manages state changes.
*/
protected final class StateHelper {
static final int UNINITIALIZED = 0;
static final int READING = 1;
static final int WRITING = 2;
static final int ALIVE = READING | WRITING;
Expand All @@ -648,10 +640,22 @@ protected final class StateHelper {
private final AtomicInteger state;

private final AtomicReference<CountDownLatch> nextAliveLatch =
new AtomicReference<CountDownLatch>(new CountDownLatch(1));
new AtomicReference<>(new CountDownLatch(1));

private final CountDownLatch closedLatch = new CountDownLatch(1);

/**
* The condition variable to signal a reconnection is needed from reader /
* writer threads and waiting for that signal from the reconnection thread.
*
* The lock variable to access this condition.
*
* @see #awaitReconnection()
* @see #trySignalForReconnection()
*/
protected final ReentrantLock connectorLock = new ReentrantLock();
protected final Condition reconnectRequired = connectorLock.newCondition();

protected StateHelper(int state) {
this.state = new AtomicInteger(state);
}
Expand All @@ -660,39 +664,60 @@ protected int getState() {
return state.get();
}

/**
* Set CLOSED state, drop RECONNECT state.
*/
protected boolean close() {
for (; ; ) {
int st = getState();
if ((st & CLOSED) == CLOSED) {
int currentState = getState();

/* CLOSED is the terminal state. */
if ((currentState & CLOSED) == CLOSED) {
return false;
}
if (compareAndSet(st, (st & ~RECONNECT) | CLOSED)) {

/* Drop RECONNECT, set CLOSED. */
if (compareAndSet(currentState, (currentState & ~RECONNECT) | CLOSED)) {
return true;
}
}
}

/**
* Move from a current state to a give one.
*
* Some moves are forbidden.
*/
protected boolean acquire(int mask) {
for (; ; ) {
int st = getState();
if ((st & CLOSED) == CLOSED) {
int currentState = getState();

/* CLOSED is the terminal state. */
if ((currentState & CLOSED) == CLOSED) {
return false;
}

/* Don't move to READING, WRITING or ALIVE from RECONNECT. */
if ((currentState & RECONNECT) > mask) {
return false;
}

if ((st & mask) != 0) {
/* Cannot move from a state to the same state. */
if ((currentState & mask) != 0) {
throw new IllegalStateException("State is already " + mask);
}

if (compareAndSet(st, st | mask)) {
/* Set acquired state. */
if (compareAndSet(currentState, currentState | mask)) {
return true;
}
}
}

protected void release(int mask) {
for (; ; ) {
int st = getState();
if (compareAndSet(st, st & ~mask)) {
int currentState = getState();
if (compareAndSet(currentState, currentState & ~mask)) {
return;
}
}
Expand All @@ -713,10 +738,18 @@ protected boolean compareAndSet(int expect, int update) {
return true;
}

/**
* Reconnection uses another way to await state via receiving a signal
* instead of latches.
*/
protected void awaitState(int state) throws InterruptedException {
CountDownLatch latch = getStateLatch(state);
if (latch != null) {
latch.await();
if (state == RECONNECT) {
awaitReconnection();
} else {
CountDownLatch latch = getStateLatch(state);
if (latch != null) {
latch.await();
}
}
}

Expand All @@ -740,6 +773,38 @@ private CountDownLatch getStateLatch(int state) {
}
return null;
}

/**
* Blocks until a reconnection signal will be received.
*
* @see #trySignalForReconnection()
*/
private void awaitReconnection() throws InterruptedException {
connectorLock.lock();
try {
while (getState() != StateHelper.RECONNECT) {
reconnectRequired.await();
}
} finally {
connectorLock.unlock();
}
}

/**
* Signals to the connector that reconnection process can be performed.
*
* @see #awaitReconnection()
*/
private void trySignalForReconnection() {
if (compareAndSet(StateHelper.UNINITIALIZED, StateHelper.RECONNECT)) {
connectorLock.lock();
try {
reconnectRequired.signal();
} finally {
connectorLock.unlock();
}
}
}
}

protected static class TarantoolOp<V> extends CompletableFuture<V> {
Expand Down
7 changes: 3 additions & 4 deletions src/test/java/org/tarantool/ClientReconnectIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,8 @@ public void testLongParallelCloseReconnects() {
SocketChannelProvider provider = new TestSocketChannelProvider(host,
port, RESTART_TIMEOUT).setSoLinger(0);

final AtomicReferenceArray<TarantoolClient> clients = new AtomicReferenceArray<TarantoolClient>(numClients);
final AtomicReferenceArray<TarantoolClient> clients =
new AtomicReferenceArray<>(numClients);

for (int idx = 0; idx < clients.length(); idx++) {
clients.set(idx, makeClient(provider));
Expand All @@ -249,9 +250,7 @@ public void testLongParallelCloseReconnects() {
@Override
public void run() {
while (!Thread.currentThread().isInterrupted() && deadline > System.currentTimeMillis()) {

int idx = rnd.nextInt(clients.length());

try {
TarantoolClient cli = clients.get(idx);

Expand Down Expand Up @@ -300,7 +299,7 @@ public void run() {

// Wait for all threads to finish.
try {
assertTrue(latch.await(RESTART_TIMEOUT, TimeUnit.MILLISECONDS));
assertTrue(latch.await(RESTART_TIMEOUT * 2, TimeUnit.MILLISECONDS));
} catch (InterruptedException e) {
fail(e);
}
Expand Down