diff --git a/driver-core/src/main/com/mongodb/annotations/ThreadSafe.java b/driver-core/src/main/com/mongodb/annotations/ThreadSafe.java index e07f8dc8c1e..209456cbd32 100644 --- a/driver-core/src/main/com/mongodb/annotations/ThreadSafe.java +++ b/driver-core/src/main/com/mongodb/annotations/ThreadSafe.java @@ -18,12 +18,12 @@ /** - * The class to which this annotation is applied is thread-safe. This means that no sequences of accesses (reads and writes to public - * fields, calls to public methods) may put the object into an invalid state, regardless of the interleaving of those actions by the + * The class or method to which this annotation is applied is thread-safe. This means that no sequences of accesses (reads and writes to + * public fields, calls to public methods) may put the object into an invalid state, regardless of the interleaving of those actions by the * runtime, and without requiring any additional synchronization or coordination on the part of the caller. */ @Documented -@Target(ElementType.TYPE) +@Target({ElementType.TYPE, ElementType.METHOD}) @Retention(RetentionPolicy.RUNTIME) public @interface ThreadSafe { } diff --git a/driver-core/src/main/com/mongodb/connection/netty/NettyStream.java b/driver-core/src/main/com/mongodb/connection/netty/NettyStream.java index 948da9162ec..4549249002d 100644 --- a/driver-core/src/main/com/mongodb/connection/netty/NettyStream.java +++ b/driver-core/src/main/com/mongodb/connection/netty/NettyStream.java @@ -44,7 +44,6 @@ import io.netty.channel.socket.SocketChannel; import io.netty.handler.ssl.SslHandler; import io.netty.handler.timeout.ReadTimeoutException; -import io.netty.util.concurrent.EventExecutor; import org.bson.ByteBuf; import javax.net.ssl.SSLContext; @@ -59,15 +58,47 @@ import java.util.Queue; import java.util.concurrent.CountDownLatch; +import static com.mongodb.assertions.Assertions.isTrueArgument; import static com.mongodb.internal.connection.SslHelper.enableHostNameVerification; import static com.mongodb.internal.connection.SslHelper.enableSni; import static java.util.concurrent.TimeUnit.MILLISECONDS; /** * A Stream implementation based on Netty 4.0. + * Just like it is for the {@link java.nio.channels.AsynchronousSocketChannel}, + * concurrent pending1 readers + * (whether {@linkplain #read(int, int) synchronous} or {@linkplain #readAsync(int, AsyncCompletionHandler, int) asynchronous}) + * are not supported by {@link NettyStream}. + * However, this class does not have a fail-fast mechanism checking for such situations. + *
+ * 1We cannot simply say that read methods are not allowed be run concurrently because strictly speaking they are allowed, + * as explained below. + *
{@code
+ * NettyStream stream = ...;
+ * stream.readAsync(1, new AsyncCompletionHandler() {//inv1
+ *  @Override
+ *  public void completed(ByteBuf o) {
+ *      stream.readAsync(//inv2
+ *              1, ...);//ret2
+ *  }
+ *
+ *  @Override
+ *  public void failed(Throwable t) {
+ *  }
+ * });//ret1
+ * }
+ * Arrows on the diagram below represent happens-before relations. + *
{@code
+ * int1 -> inv2 -> ret2
+ *      \--------> ret1
+ * }
+ * As shown on the diagram, the method {@link #readAsync(int, AsyncCompletionHandler)} runs concurrently with + * itself in the example above. However, there are no concurrent pending readers because the second operation + * is invoked after the first operation has completed reading despite the method has not returned yet. */ final class NettyStream implements Stream { private static final String READ_HANDLER_NAME = "ReadTimeoutHandler"; + private static final int NO_SCHEDULE_TIMEOUT = -1; private final ServerAddress address; private final SocketSettings settings; private final SslSettings sslSettings; @@ -79,8 +110,8 @@ final class NettyStream implements Stream { private volatile Channel channel; private final LinkedList pendingInboundBuffers = new LinkedList(); - private volatile PendingReader pendingReader; - private volatile Throwable pendingException; + private PendingReader pendingReader; + private Throwable pendingException; NettyStream(final ServerAddress address, final SocketSettings settings, final SslSettings sslSettings, final EventLoopGroup workerGroup, final Class socketChannelClass, final ByteBufAllocator allocator) { @@ -185,6 +216,7 @@ public boolean supportsAdditionalTimeout() { @Override public ByteBuf read(final int numBytes, final int additionalTimeout) throws IOException { + isTrueArgument("additionalTimeout must not be negative", additionalTimeout >= 0); FutureAsyncCompletionHandler future = new FutureAsyncCompletionHandler(); readAsync(numBytes, future, additionalTimeout); return future.get(); @@ -214,6 +246,12 @@ public void readAsync(final int numBytes, final AsyncCompletionHandler readAsync(numBytes, handler, 0); } + /** + * @param additionalTimeout Must be equal to {@link #NO_SCHEDULE_TIMEOUT} when the method is called by a Netty channel handler. + * A timeout is scheduled only by the public read methods. Taking into account that concurrent pending readers + * are not allowed, there must not be a situation when threads attempt to schedule a timeout + * before the previous one is removed or completed. + */ private void readAsync(final int numBytes, final AsyncCompletionHandler handler, final int additionalTimeout) { scheduleReadTimeout(additionalTimeout); ByteBuf buffer = null; @@ -282,7 +320,8 @@ private void handleReadResponse(final io.netty.buffer.ByteBuf buffer, final Thro } if (localPendingReader != null) { - readAsync(localPendingReader.numBytes, localPendingReader.handler); + //if there is a pending reader, then the reader has scheduled a timeout and we should not attempt to schedule another one + readAsync(localPendingReader.numBytes, localPendingReader.handler, NO_SCHEDULE_TIMEOUT); } } @@ -446,6 +485,9 @@ public void operationComplete(final ChannelFuture future) { } private void scheduleReadTimeout(final int additionalTimeout) { + if (additionalTimeout == NO_SCHEDULE_TIMEOUT) { + return; + } adjustTimeout(false, additionalTimeout); } @@ -460,31 +502,10 @@ private void adjustTimeout(final boolean disable, final int additionalTimeout) { ChannelHandler timeoutHandler = channel.pipeline().get(READ_HANDLER_NAME); if (timeoutHandler != null) { final ReadTimeoutHandler readTimeoutHandler = (ReadTimeoutHandler) timeoutHandler; - final ChannelHandlerContext handlerContext = channel.pipeline().context(timeoutHandler); - EventExecutor executor = handlerContext.executor(); - if (disable) { - if (executor.inEventLoop()) { - readTimeoutHandler.removeTimeout(handlerContext); - } else { - executor.submit(new Runnable() { - @Override - public void run() { - readTimeoutHandler.removeTimeout(handlerContext); - } - }); - } + readTimeoutHandler.removeTimeout(); } else { - if (executor.inEventLoop()) { - readTimeoutHandler.scheduleTimeout(handlerContext, additionalTimeout); - } else { - executor.submit(new Runnable() { - @Override - public void run() { - readTimeoutHandler.scheduleTimeout(handlerContext, additionalTimeout); - } - }); - } + readTimeoutHandler.scheduleTimeout(additionalTimeout); } } } diff --git a/driver-core/src/main/com/mongodb/connection/netty/ReadTimeoutHandler.java b/driver-core/src/main/com/mongodb/connection/netty/ReadTimeoutHandler.java index 824c8f7d6a3..436e395ed6e 100644 --- a/driver-core/src/main/com/mongodb/connection/netty/ReadTimeoutHandler.java +++ b/driver-core/src/main/com/mongodb/connection/netty/ReadTimeoutHandler.java @@ -17,62 +17,180 @@ package com.mongodb.connection.netty; +import com.mongodb.annotations.NotThreadSafe; +import com.mongodb.annotations.ThreadSafe; +import com.mongodb.lang.Nullable; +import io.netty.channel.Channel; import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandler; import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.handler.timeout.IdleStateHandler; import io.netty.handler.timeout.ReadTimeoutException; import java.util.concurrent.ScheduledFuture; import java.util.concurrent.TimeUnit; +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.StampedLock; -import static com.mongodb.assertions.Assertions.isTrue; import static com.mongodb.assertions.Assertions.isTrueArgument; /** - * Passes a {@link ReadTimeoutException} if the time between a {@link #scheduleTimeout} and {@link #removeTimeout} is longer than the set - * timeout. + * This {@link ChannelInboundHandler} allows {@linkplain #scheduleTimeout(int) scheduling} and {@linkplain #removeTimeout() removing} + * timeouts. A timeout is a delayed task that {@linkplain ChannelHandlerContext#fireExceptionCaught(Throwable) fires} + * {@link ReadTimeoutException#INSTANCE} and {@linkplain ChannelHandlerContext#close() closes} the channel; + * this can be prevented by removing the timeout. + *

+ * This class guarantees that there are no concurrent timeouts scheduled for a channel. + * Note that despite instances of this class are not thread-safe + * (only {@linkplain io.netty.channel.ChannelHandler.Sharable sharable} handlers must be thread-safe), + * methods {@link #scheduleTimeout(int)} and {@link #removeTimeout()} are linearizable. + *

+ * The Netty-related lifecycle management in this class is inspired by the {@link IdleStateHandler}. + * See the channel state model + * for additional details. */ +@NotThreadSafe final class ReadTimeoutHandler extends ChannelInboundHandlerAdapter { - private final long readTimeout; - private volatile ScheduledFuture timeout; + private final long readTimeoutMillis; + private final Lock nonreentrantLock; + @Nullable + private ChannelHandlerContext ctx; + @Nullable + private ScheduledFuture timeout; - ReadTimeoutHandler(final long readTimeout) { - isTrueArgument("readTimeout must be greater than zero.", readTimeout > 0); - this.readTimeout = readTimeout; + ReadTimeoutHandler(final long readTimeoutMillis) { + isTrueArgument("readTimeoutMillis must be positive", readTimeoutMillis > 0); + this.readTimeoutMillis = readTimeoutMillis; + nonreentrantLock = new StampedLock().asWriteLock(); } - void scheduleTimeout(final ChannelHandlerContext ctx, final int additionalTimeout) { - isTrue("Handler called from the eventLoop", ctx.channel().eventLoop().inEventLoop()); - if (timeout == null) { - timeout = ctx.executor().schedule(new ReadTimeoutTask(ctx), readTimeout + additionalTimeout, TimeUnit.MILLISECONDS); + private void register(final ChannelHandlerContext context) { + nonreentrantLock.lock(); + try { + final ChannelHandlerContext ctx = this.ctx; + if (ctx == context) { + return; + } + assert ctx == null : "Attempted to register a context before a previous one is deregistered"; + this.ctx = context; + } finally { + nonreentrantLock.unlock(); + } + } + + private void deregister() { + nonreentrantLock.lock(); + try { + unsynchronizedCancel(); + ctx = null; + } finally { + nonreentrantLock.unlock(); } } - void removeTimeout(final ChannelHandlerContext ctx) { - isTrue("Handler called from the eventLoop", ctx.channel().eventLoop().inEventLoop()); + private void unsynchronizedCancel() { + final ScheduledFuture timeout = this.timeout; if (timeout != null) { timeout.cancel(false); - timeout = null; + this.timeout = null; } } - private static final class ReadTimeoutTask implements Runnable { + @Override + public void channelActive(final ChannelHandlerContext ctx) throws Exception { + /* This method is invoked only if the handler is added to a channel pipeline before the channelActive event is fired. + * Because of this fact we also need to monitor the handlerAdded event.*/ + register(ctx); + super.channelActive(ctx); + } - private final ChannelHandlerContext ctx; + @Override + public void handlerAdded(final ChannelHandlerContext ctx) throws Exception { + final Channel channel = ctx.channel(); + if (channel.isActive()//the channelActive event has already been fired and our channelActive method will not be called + /* Check that the channel is registered with an event loop. + * If it is not the case, then our channelRegistered method calls the register method.*/ + && channel.isRegistered()) { + register(ctx); + } else { + /* The channelActive event has not been fired. When it is fired, our channelActive method will be called + * and we will call the register method there.*/ + } + super.handlerAdded(ctx); + } - ReadTimeoutTask(final ChannelHandlerContext ctx) { - this.ctx = ctx; + @Override + public void channelRegistered(final ChannelHandlerContext ctx) throws Exception { + if (ctx.channel().isActive()) {//the channelActive event has already been fired and our channelActive method will not be called + register(ctx); } + super.channelRegistered(ctx); + } + + @Override + public void channelInactive(final ChannelHandlerContext ctx) throws Exception { + deregister(); + super.channelInactive(ctx); + } + + @Override + public void handlerRemoved(final ChannelHandlerContext ctx) throws Exception { + deregister(); + super.handlerRemoved(ctx); + } + + @Override + public void channelUnregistered(final ChannelHandlerContext ctx) throws Exception { + deregister(); + super.channelUnregistered(ctx); + } - @Override - public void run() { - if (ctx.channel().isOpen()) { + /** + * Schedules a new timeout. + * A timeout must be {@linkplain #removeTimeout() removed} before another one is allowed to be scheduled. + */ + @ThreadSafe + void scheduleTimeout(final int additionalTimeoutMillis) { + isTrueArgument("additionalTimeoutMillis must not be negative", additionalTimeoutMillis >= 0); + nonreentrantLock.lock(); + try { + final ChannelHandlerContext ctx = this.ctx; + if (ctx == null) {//no context is registered + return; + } + final ScheduledFuture timeout = this.timeout; + assert timeout == null || timeout.isDone() : "Attempted to schedule a timeout before the previous one is removed or completed"; + this.timeout = ctx.executor().schedule(() -> { try { - ctx.fireExceptionCaught(ReadTimeoutException.INSTANCE); - ctx.close(); - } catch (Throwable t) { + fireTimeoutException(ctx); + } catch (final Throwable t) { ctx.fireExceptionCaught(t); } - } + }, readTimeoutMillis + additionalTimeoutMillis, TimeUnit.MILLISECONDS); + } finally { + nonreentrantLock.unlock(); + } + } + + /** + * Either removes the previously {@linkplain #scheduleTimeout(int) scheduled} timeout, or does nothing. + * After removing a timeout, another one may be scheduled. + */ + @ThreadSafe + void removeTimeout() { + nonreentrantLock.lock(); + try { + unsynchronizedCancel(); + } finally { + nonreentrantLock.unlock(); + } + } + + private static void fireTimeoutException(final ChannelHandlerContext ctx) { + if (!ctx.channel().isOpen()) { + return; } + ctx.fireExceptionCaught(ReadTimeoutException.INSTANCE); + ctx.close(); } }