Skip to content

Fix the read timeout implementation in NettyStream #635

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions driver-core/src/main/com/mongodb/annotations/ThreadSafe.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@


/**
* The class to which this annotation is applied is thread-safe. This means that no sequences of accesses (reads and writes to public
* fields, calls to public methods) may put the object into an invalid state, regardless of the interleaving of those actions by the
* The class or method to which this annotation is applied is thread-safe. This means that no sequences of accesses (reads and writes to
* public fields, calls to public methods) may put the object into an invalid state, regardless of the interleaving of those actions by the
* runtime, and without requiring any additional synchronization or coordination on the part of the caller.
*/
@Documented
@Target(ElementType.TYPE)
@Target({ElementType.TYPE, ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
public @interface ThreadSafe {
}
75 changes: 48 additions & 27 deletions driver-core/src/main/com/mongodb/connection/netty/NettyStream.java
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
import io.netty.channel.socket.SocketChannel;
import io.netty.handler.ssl.SslHandler;
import io.netty.handler.timeout.ReadTimeoutException;
import io.netty.util.concurrent.EventExecutor;
import org.bson.ByteBuf;

import javax.net.ssl.SSLContext;
Expand All @@ -59,15 +58,47 @@
import java.util.Queue;
import java.util.concurrent.CountDownLatch;

import static com.mongodb.assertions.Assertions.isTrueArgument;
import static com.mongodb.internal.connection.SslHelper.enableHostNameVerification;
import static com.mongodb.internal.connection.SslHelper.enableSni;
import static java.util.concurrent.TimeUnit.MILLISECONDS;

/**
* A Stream implementation based on Netty 4.0.
* Just like it is for the {@link java.nio.channels.AsynchronousSocketChannel},
* concurrent pending<sup>1</sup> readers
* (whether {@linkplain #read(int, int) synchronous} or {@linkplain #readAsync(int, AsyncCompletionHandler, int) asynchronous})
* are not supported by {@link NettyStream}.
* However, this class does not have a fail-fast mechanism checking for such situations.
* <hr>
* <sup>1</sup>We cannot simply say that read methods are not allowed be run concurrently because strictly speaking they are allowed,
* as explained below.
* <pre>{@code
* NettyStream stream = ...;
* stream.readAsync(1, new AsyncCompletionHandler<ByteBuf>() {//inv1
* @Override
* public void completed(ByteBuf o) {
* stream.readAsync(//inv2
* 1, ...);//ret2
* }
*
* @Override
* public void failed(Throwable t) {
* }
* });//ret1
* }</pre>
* Arrows on the diagram below represent happens-before relations.
* <pre>{@code
* int1 -> inv2 -> ret2
* \--------> ret1
* }</pre>
* As shown on the diagram, the method {@link #readAsync(int, AsyncCompletionHandler)} runs concurrently with
* itself in the example above. However, there are no concurrent pending readers because the second operation
* is invoked after the first operation has completed reading despite the method has not returned yet.
*/
final class NettyStream implements Stream {
private static final String READ_HANDLER_NAME = "ReadTimeoutHandler";
private static final int NO_SCHEDULE_TIMEOUT = -1;
private final ServerAddress address;
private final SocketSettings settings;
private final SslSettings sslSettings;
Expand All @@ -79,8 +110,8 @@ final class NettyStream implements Stream {
private volatile Channel channel;

private final LinkedList<io.netty.buffer.ByteBuf> pendingInboundBuffers = new LinkedList<io.netty.buffer.ByteBuf>();
private volatile PendingReader pendingReader;
private volatile Throwable pendingException;
private PendingReader pendingReader;
private Throwable pendingException;

NettyStream(final ServerAddress address, final SocketSettings settings, final SslSettings sslSettings, final EventLoopGroup workerGroup,
final Class<? extends SocketChannel> socketChannelClass, final ByteBufAllocator allocator) {
Expand Down Expand Up @@ -185,6 +216,7 @@ public boolean supportsAdditionalTimeout() {

@Override
public ByteBuf read(final int numBytes, final int additionalTimeout) throws IOException {
isTrueArgument("additionalTimeout must not be negative", additionalTimeout >= 0);
FutureAsyncCompletionHandler<ByteBuf> future = new FutureAsyncCompletionHandler<ByteBuf>();
readAsync(numBytes, future, additionalTimeout);
return future.get();
Expand Down Expand Up @@ -214,6 +246,12 @@ public void readAsync(final int numBytes, final AsyncCompletionHandler<ByteBuf>
readAsync(numBytes, handler, 0);
}

/**
* @param additionalTimeout Must be equal to {@link #NO_SCHEDULE_TIMEOUT} when the method is called by a Netty channel handler.
* A timeout is scheduled only by the public read methods. Taking into account that concurrent pending readers
* are not allowed, there must not be a situation when threads attempt to schedule a timeout
* before the previous one is removed or completed.
*/
private void readAsync(final int numBytes, final AsyncCompletionHandler<ByteBuf> handler, final int additionalTimeout) {
scheduleReadTimeout(additionalTimeout);
ByteBuf buffer = null;
Expand Down Expand Up @@ -282,7 +320,8 @@ private void handleReadResponse(final io.netty.buffer.ByteBuf buffer, final Thro
}

if (localPendingReader != null) {
readAsync(localPendingReader.numBytes, localPendingReader.handler);
//if there is a pending reader, then the reader has scheduled a timeout and we should not attempt to schedule another one
readAsync(localPendingReader.numBytes, localPendingReader.handler, NO_SCHEDULE_TIMEOUT);
}
}

Expand Down Expand Up @@ -446,6 +485,9 @@ public void operationComplete(final ChannelFuture future) {
}

private void scheduleReadTimeout(final int additionalTimeout) {
if (additionalTimeout == NO_SCHEDULE_TIMEOUT) {
return;
}
adjustTimeout(false, additionalTimeout);
}

Expand All @@ -460,31 +502,10 @@ private void adjustTimeout(final boolean disable, final int additionalTimeout) {
ChannelHandler timeoutHandler = channel.pipeline().get(READ_HANDLER_NAME);
if (timeoutHandler != null) {
final ReadTimeoutHandler readTimeoutHandler = (ReadTimeoutHandler) timeoutHandler;
final ChannelHandlerContext handlerContext = channel.pipeline().context(timeoutHandler);
EventExecutor executor = handlerContext.executor();

if (disable) {
if (executor.inEventLoop()) {
readTimeoutHandler.removeTimeout(handlerContext);
} else {
executor.submit(new Runnable() {
@Override
public void run() {
readTimeoutHandler.removeTimeout(handlerContext);
}
});
}
readTimeoutHandler.removeTimeout();
} else {
if (executor.inEventLoop()) {
readTimeoutHandler.scheduleTimeout(handlerContext, additionalTimeout);
} else {
executor.submit(new Runnable() {
@Override
public void run() {
readTimeoutHandler.scheduleTimeout(handlerContext, additionalTimeout);
}
});
}
readTimeoutHandler.scheduleTimeout(additionalTimeout);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,62 +17,180 @@

package com.mongodb.connection.netty;

import com.mongodb.annotations.NotThreadSafe;
import com.mongodb.annotations.ThreadSafe;
import com.mongodb.lang.Nullable;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandler;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.handler.timeout.IdleStateHandler;
import io.netty.handler.timeout.ReadTimeoutException;

import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.StampedLock;

import static com.mongodb.assertions.Assertions.isTrue;
import static com.mongodb.assertions.Assertions.isTrueArgument;

/**
* Passes a {@link ReadTimeoutException} if the time between a {@link #scheduleTimeout} and {@link #removeTimeout} is longer than the set
* timeout.
* This {@link ChannelInboundHandler} allows {@linkplain #scheduleTimeout(int) scheduling} and {@linkplain #removeTimeout() removing}
* timeouts. A timeout is a delayed task that {@linkplain ChannelHandlerContext#fireExceptionCaught(Throwable) fires}
* {@link ReadTimeoutException#INSTANCE} and {@linkplain ChannelHandlerContext#close() closes} the channel;
* this can be prevented by removing the timeout.
* <p>
* This class guarantees that there are no concurrent timeouts scheduled for a channel.
* Note that despite instances of this class are not thread-safe
* (only {@linkplain io.netty.channel.ChannelHandler.Sharable sharable} handlers must be thread-safe),
* methods {@link #scheduleTimeout(int)} and {@link #removeTimeout()} are linearizable.
* <p>
* The Netty-related lifecycle management in this class is inspired by the {@link IdleStateHandler}.
* See the <a href="https://netty.io/wiki/new-and-noteworthy-in-4.0.html#simplified-channel-state-model">channel state model</a>
* for additional details.
*/
@NotThreadSafe
final class ReadTimeoutHandler extends ChannelInboundHandlerAdapter {
private final long readTimeout;
private volatile ScheduledFuture<?> timeout;
private final long readTimeoutMillis;
private final Lock nonreentrantLock;
@Nullable
private ChannelHandlerContext ctx;
@Nullable
private ScheduledFuture<?> timeout;

ReadTimeoutHandler(final long readTimeout) {
isTrueArgument("readTimeout must be greater than zero.", readTimeout > 0);
this.readTimeout = readTimeout;
ReadTimeoutHandler(final long readTimeoutMillis) {
isTrueArgument("readTimeoutMillis must be positive", readTimeoutMillis > 0);
this.readTimeoutMillis = readTimeoutMillis;
nonreentrantLock = new StampedLock().asWriteLock();
}

void scheduleTimeout(final ChannelHandlerContext ctx, final int additionalTimeout) {
isTrue("Handler called from the eventLoop", ctx.channel().eventLoop().inEventLoop());
if (timeout == null) {
timeout = ctx.executor().schedule(new ReadTimeoutTask(ctx), readTimeout + additionalTimeout, TimeUnit.MILLISECONDS);
private void register(final ChannelHandlerContext context) {
nonreentrantLock.lock();
try {
final ChannelHandlerContext ctx = this.ctx;
if (ctx == context) {
return;
}
assert ctx == null : "Attempted to register a context before a previous one is deregistered";
this.ctx = context;
} finally {
nonreentrantLock.unlock();
}
}

private void deregister() {
nonreentrantLock.lock();
try {
unsynchronizedCancel();
ctx = null;
} finally {
nonreentrantLock.unlock();
}
}

void removeTimeout(final ChannelHandlerContext ctx) {
isTrue("Handler called from the eventLoop", ctx.channel().eventLoop().inEventLoop());
private void unsynchronizedCancel() {
final ScheduledFuture<?> timeout = this.timeout;
if (timeout != null) {
timeout.cancel(false);
timeout = null;
this.timeout = null;
}
}

private static final class ReadTimeoutTask implements Runnable {
@Override
public void channelActive(final ChannelHandlerContext ctx) throws Exception {
/* This method is invoked only if the handler is added to a channel pipeline before the channelActive event is fired.
* Because of this fact we also need to monitor the handlerAdded event.*/
register(ctx);
super.channelActive(ctx);
}

private final ChannelHandlerContext ctx;
@Override
public void handlerAdded(final ChannelHandlerContext ctx) throws Exception {
final Channel channel = ctx.channel();
if (channel.isActive()//the channelActive event has already been fired and our channelActive method will not be called
/* Check that the channel is registered with an event loop.
* If it is not the case, then our channelRegistered method calls the register method.*/
&& channel.isRegistered()) {
register(ctx);
} else {
/* The channelActive event has not been fired. When it is fired, our channelActive method will be called
* and we will call the register method there.*/
}
super.handlerAdded(ctx);
}

ReadTimeoutTask(final ChannelHandlerContext ctx) {
this.ctx = ctx;
@Override
public void channelRegistered(final ChannelHandlerContext ctx) throws Exception {
if (ctx.channel().isActive()) {//the channelActive event has already been fired and our channelActive method will not be called
register(ctx);
}
super.channelRegistered(ctx);
}

@Override
public void channelInactive(final ChannelHandlerContext ctx) throws Exception {
deregister();
super.channelInactive(ctx);
}

@Override
public void handlerRemoved(final ChannelHandlerContext ctx) throws Exception {
deregister();
super.handlerRemoved(ctx);
}

@Override
public void channelUnregistered(final ChannelHandlerContext ctx) throws Exception {
deregister();
super.channelUnregistered(ctx);
}

@Override
public void run() {
if (ctx.channel().isOpen()) {
/**
* Schedules a new timeout.
* A timeout must be {@linkplain #removeTimeout() removed} before another one is allowed to be scheduled.
*/
@ThreadSafe
void scheduleTimeout(final int additionalTimeoutMillis) {
isTrueArgument("additionalTimeoutMillis must not be negative", additionalTimeoutMillis >= 0);
nonreentrantLock.lock();
try {
final ChannelHandlerContext ctx = this.ctx;
if (ctx == null) {//no context is registered
return;
}
final ScheduledFuture<?> timeout = this.timeout;
assert timeout == null || timeout.isDone() : "Attempted to schedule a timeout before the previous one is removed or completed";
this.timeout = ctx.executor().schedule(() -> {
try {
ctx.fireExceptionCaught(ReadTimeoutException.INSTANCE);
ctx.close();
} catch (Throwable t) {
fireTimeoutException(ctx);
} catch (final Throwable t) {
ctx.fireExceptionCaught(t);
}
}
}, readTimeoutMillis + additionalTimeoutMillis, TimeUnit.MILLISECONDS);
} finally {
nonreentrantLock.unlock();
}
}

/**
* Either removes the previously {@linkplain #scheduleTimeout(int) scheduled} timeout, or does nothing.
* After removing a timeout, another one may be scheduled.
*/
@ThreadSafe
void removeTimeout() {
nonreentrantLock.lock();
try {
unsynchronizedCancel();
} finally {
nonreentrantLock.unlock();
}
}

private static void fireTimeoutException(final ChannelHandlerContext ctx) {
if (!ctx.channel().isOpen()) {
return;
}
ctx.fireExceptionCaught(ReadTimeoutException.INSTANCE);
ctx.close();
}
}