diff --git a/driver-core/src/main/com/mongodb/connection/netty/NettyStream.java b/driver-core/src/main/com/mongodb/connection/netty/NettyStream.java
index b8e5e8d0671..2eca8673730 100644
--- a/driver-core/src/main/com/mongodb/connection/netty/NettyStream.java
+++ b/driver-core/src/main/com/mongodb/connection/netty/NettyStream.java
@@ -24,10 +24,12 @@
import com.mongodb.MongoSocketOpenException;
import com.mongodb.MongoSocketReadTimeoutException;
import com.mongodb.ServerAddress;
+import com.mongodb.annotations.ThreadSafe;
import com.mongodb.connection.AsyncCompletionHandler;
import com.mongodb.connection.SocketSettings;
import com.mongodb.connection.SslSettings;
import com.mongodb.connection.Stream;
+import com.mongodb.lang.Nullable;
import io.netty.bootstrap.Bootstrap;
import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.CompositeByteBuf;
@@ -35,16 +37,16 @@
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
-import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
+import io.netty.channel.ChannelPipeline;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.channel.socket.SocketChannel;
import io.netty.handler.ssl.SslHandler;
import io.netty.handler.timeout.ReadTimeoutException;
-import io.netty.util.concurrent.EventExecutor;
import org.bson.ByteBuf;
import javax.net.ssl.SSLContext;
@@ -58,6 +60,8 @@
import java.util.List;
import java.util.Queue;
import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.Future;
+import java.util.concurrent.ScheduledFuture;
import static com.mongodb.internal.connection.SslHelper.enableHostNameVerification;
import static com.mongodb.internal.connection.SslHelper.enableSni;
@@ -65,9 +69,39 @@
/**
* A Stream implementation based on Netty 4.0.
+ * Just like it is for the {@link java.nio.channels.AsynchronousSocketChannel},
+ * concurrent pending1 readers
+ * (whether {@linkplain #read(int, int) synchronous} or {@linkplain #readAsync(int, AsyncCompletionHandler) asynchronous})
+ * are not supported by {@link NettyStream}.
+ * However, this class does not have a fail-fast mechanism checking for such situations.
+ *
+ * 1We cannot simply say that read methods are not allowed be run concurrently because strictly speaking they are allowed,
+ * as explained below.
+ * {@code
+ * NettyStream stream = ...;
+ * stream.readAsync(1, new AsyncCompletionHandler() {//inv1
+ * @Override
+ * public void completed(ByteBuf o) {
+ * stream.readAsync(//inv2
+ * 1, ...);//ret2
+ * }
+ *
+ * @Override
+ * public void failed(Throwable t) {
+ * }
+ * });//ret1
+ * }
+ * Arrows on the diagram below represent happens-before relations.
+ * {@code
+ * int1 -> inv2 -> ret2
+ * \--------> ret1
+ * }
+ * As shown on the diagram, the method {@link #readAsync(int, AsyncCompletionHandler)} runs concurrently with
+ * itself in the example above. However, there are no concurrent pending readers because the second operation
+ * is invoked after the first operation has completed reading despite the method has not returned yet.
*/
final class NettyStream implements Stream {
- private static final String READ_HANDLER_NAME = "ReadTimeoutHandler";
+ private static final byte NO_SCHEDULE_TIME = 0;
private final ServerAddress address;
private final SocketSettings settings;
private final SslSettings sslSettings;
@@ -79,8 +113,19 @@ final class NettyStream implements Stream {
private volatile Channel channel;
private final LinkedList pendingInboundBuffers = new LinkedList();
- private volatile PendingReader pendingReader;
- private volatile Throwable pendingException;
+ /* The fields pendingReader, pendingException are always written/read inside synchronized blocks
+ * that use the same NettyStream object, so they can be plain.*/
+ private PendingReader pendingReader;
+ private Throwable pendingException;
+ /* The fields readTimeoutTask, readTimeoutMillis are each written only in the ChannelInitializer.initChannel method
+ * (in addition to the write of the default value and the write by variable initializers),
+ * and read only when NettyStream users read data, or Netty event loop handles incoming data.
+ * Since actions done by the ChannelInitializer.initChannel method
+ * are ordered (in the happens-before order) before user read actions and before event loop actions that handle incoming data,
+ * these fields can be plain.*/
+ @Nullable
+ private ReadTimeoutTask readTimeoutTask;
+ private long readTimeoutMillis = NO_SCHEDULE_TIME;
NettyStream(final ServerAddress address, final SocketSettings settings, final SslSettings sslSettings, final EventLoopGroup workerGroup,
final Class extends SocketChannel> socketChannelClass, final ByteBufAllocator allocator) {
@@ -135,6 +180,7 @@ private void initializeChannel(final AsyncCompletionHandler handler, final
bootstrap.handler(new ChannelInitializer() {
@Override
public void initChannel(final SocketChannel ch) {
+ ChannelPipeline pipeline = ch.pipeline();
if (sslSettings.isEnabled()) {
SSLEngine engine = getSslContext().createSSLEngine(address.getHost(), address.getPort());
engine.setUseClientMode(true);
@@ -144,13 +190,20 @@ public void initChannel(final SocketChannel ch) {
enableHostNameVerification(sslParameters);
}
engine.setSSLParameters(sslParameters);
- ch.pipeline().addFirst("ssl", new SslHandler(engine, false));
+ pipeline.addFirst("ssl", new SslHandler(engine, false));
}
+
int readTimeout = settings.getReadTimeout(MILLISECONDS);
- if (readTimeout > 0) {
- ch.pipeline().addLast(READ_HANDLER_NAME, new ReadTimeoutHandler(readTimeout));
+ if (readTimeout > NO_SCHEDULE_TIME) {
+ readTimeoutMillis = readTimeout;
+ /* We need at least one handler before (in the inbound evaluation order) the InboundBufferHandler,
+ * so that we can fire exception events (they are inbound events) using its context and the InboundBufferHandler
+ * receives them. SslHandler is not always present, so adding a NOOP handler.*/
+ pipeline.addLast(new ChannelInboundHandlerAdapter());
+ readTimeoutTask = new ReadTimeoutTask(pipeline.lastContext());
}
- ch.pipeline().addLast(new InboundBufferHandler());
+
+ pipeline.addLast(new InboundBufferHandler());
}
});
final ChannelFuture channelFuture = bootstrap.connect(nextAddress);
@@ -193,14 +246,27 @@ public void operationComplete(final ChannelFuture future) throws Exception {
@Override
public void readAsync(final int numBytes, final AsyncCompletionHandler handler) {
- scheduleReadTimeout();
+ readAsync(numBytes, handler, readTimeoutMillis);
+ }
+
+ /**
+ * @param numBytes Must be equal to {@link #pendingReader}{@code .numBytes} when called by a Netty channel handler.
+ * @param handler Must be equal to {@link #pendingReader}{@code .handler} when called by a Netty channel handler.
+ * @param readTimeoutMillis Must be equal to {@link #NO_SCHEDULE_TIME} when called by a Netty channel handler.
+ * Timeouts may be scheduled only by the public read methods. Taking into account that concurrent pending
+ * readers are not allowed, there must not be a situation when threads attempt to schedule a timeout
+ * before the previous one is either cancelled or completed.
+ */
+ private void readAsync(final int numBytes, final AsyncCompletionHandler handler, final long readTimeoutMillis) {
ByteBuf buffer = null;
Throwable exceptionResult = null;
synchronized (this) {
exceptionResult = pendingException;
if (exceptionResult == null) {
if (!hasBytesAvailable(numBytes)) {
- pendingReader = new PendingReader(numBytes, handler);
+ if (pendingReader == null) {//called by a public read method
+ pendingReader = new PendingReader(numBytes, handler, scheduleReadTimeout(readTimeoutTask, readTimeoutMillis));
+ }
} else {
CompositeByteBuf composite = allocator.compositeBuffer(pendingInboundBuffers.size());
int bytesNeeded = numBytes;
@@ -223,13 +289,16 @@ public void readAsync(final int numBytes, final AsyncCompletionHandler
buffer = new NettyByteBuf(composite).flip();
}
}
+ if (!(exceptionResult == null && buffer == null)//the read operation has completed
+ && pendingReader != null) {//we need to clear the pending reader
+ cancel(pendingReader.timeout);
+ this.pendingReader = null;
+ }
}
if (exceptionResult != null) {
- disableReadTimeout();
handler.failed(exceptionResult);
}
if (buffer != null) {
- disableReadTimeout();
handler.completed(buffer);
}
}
@@ -253,14 +322,12 @@ private void handleReadResponse(final io.netty.buffer.ByteBuf buffer, final Thro
} else {
pendingException = t;
}
- if (pendingReader != null) {
- localPendingReader = pendingReader;
- pendingReader = null;
- }
+ localPendingReader = pendingReader;
}
if (localPendingReader != null) {
- readAsync(localPendingReader.numBytes, localPendingReader.handler);
+ //timeouts may be scheduled only by the public read methods
+ readAsync(localPendingReader.numBytes, localPendingReader.handler, NO_SCHEDULE_TIME);
}
}
@@ -336,10 +403,14 @@ public void exceptionCaught(final ChannelHandlerContext ctx, final Throwable t)
private static final class PendingReader {
private final int numBytes;
private final AsyncCompletionHandler handler;
+ @Nullable
+ private final ScheduledFuture> timeout;
- private PendingReader(final int numBytes, final AsyncCompletionHandler handler) {
+ private PendingReader(
+ final int numBytes, final AsyncCompletionHandler handler, @Nullable final ScheduledFuture> timeout) {
this.numBytes = numBytes;
this.handler = handler;
+ this.timeout = timeout;
}
}
@@ -423,44 +494,52 @@ public void operationComplete(final ChannelFuture future) {
}
}
- private void scheduleReadTimeout() {
- adjustTimeout(false);
+ private static void cancel(@Nullable final Future> f) {
+ if (f != null) {
+ f.cancel(false);
+ }
}
- private void disableReadTimeout() {
- adjustTimeout(true);
+ private static long combinedTimeout(final long timeout, final int additionalTimeout) {
+ if (timeout == NO_SCHEDULE_TIME) {
+ return NO_SCHEDULE_TIME;
+ } else {
+ return Math.addExact(timeout, additionalTimeout);
+ }
}
- private void adjustTimeout(final boolean disable) {
- ChannelHandler timeoutHandler = channel.pipeline().get(READ_HANDLER_NAME);
- if (timeoutHandler != null) {
- final ReadTimeoutHandler readTimeoutHandler = (ReadTimeoutHandler) timeoutHandler;
- final ChannelHandlerContext handlerContext = channel.pipeline().context(timeoutHandler);
- EventExecutor executor = handlerContext.executor();
+ private static ScheduledFuture> scheduleReadTimeout(@Nullable final ReadTimeoutTask readTimeoutTask, final long timeoutMillis) {
+ if (timeoutMillis == NO_SCHEDULE_TIME) {
+ return null;
+ } else {
+ //assert readTimeoutTask != null : "readTimeoutTask must be initialized if read timeouts are enabled";
+ return readTimeoutTask.schedule(timeoutMillis);
+ }
+ }
- if (disable) {
- if (executor.inEventLoop()) {
- readTimeoutHandler.removeTimeout(handlerContext);
- } else {
- executor.submit(new Runnable() {
- @Override
- public void run() {
- readTimeoutHandler.removeTimeout(handlerContext);
- }
- });
- }
- } else {
- if (executor.inEventLoop()) {
- readTimeoutHandler.scheduleTimeout(handlerContext);
- } else {
- executor.submit(new Runnable() {
- @Override
- public void run() {
- readTimeoutHandler.scheduleTimeout(handlerContext);
- }
- });
- }
+ @ThreadSafe
+ private static final class ReadTimeoutTask implements Runnable {
+ private final ChannelHandlerContext ctx;
+
+ private ReadTimeoutTask(final ChannelHandlerContext timeoutChannelHandlerContext) {
+ ctx = timeoutChannelHandlerContext;
+ }
+
+ @Override
+ public void run() {
+ try {
+ if (ctx.channel().isOpen()) {
+ ctx.fireExceptionCaught(ReadTimeoutException.INSTANCE);
+ ctx.close();
}
+ } catch (final Throwable t) {
+ ctx.fireExceptionCaught(t);
}
+ }
+
+ private ScheduledFuture> schedule(final long timeoutMillis) {
+ //assert timeoutMillis > 0 : timeoutMillis;
+ return ctx.executor().schedule(this, timeoutMillis, MILLISECONDS);
+ }
}
}
diff --git a/driver-core/src/main/com/mongodb/connection/netty/ReadTimeoutHandler.java b/driver-core/src/main/com/mongodb/connection/netty/ReadTimeoutHandler.java
deleted file mode 100644
index 4b4533f3cde..00000000000
--- a/driver-core/src/main/com/mongodb/connection/netty/ReadTimeoutHandler.java
+++ /dev/null
@@ -1,78 +0,0 @@
-/*
- * Copyright 2008-present MongoDB, Inc.
- * Copyright 2012 The Netty Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package com.mongodb.connection.netty;
-
-import io.netty.channel.ChannelHandlerContext;
-import io.netty.channel.ChannelInboundHandlerAdapter;
-import io.netty.handler.timeout.ReadTimeoutException;
-
-import java.util.concurrent.ScheduledFuture;
-import java.util.concurrent.TimeUnit;
-
-import static com.mongodb.assertions.Assertions.isTrue;
-import static com.mongodb.assertions.Assertions.isTrueArgument;
-
-/**
- * Passes a {@link ReadTimeoutException} if the time between a {@link #scheduleTimeout} and {@link #removeTimeout} is longer than the set
- * timeout.
- */
-final class ReadTimeoutHandler extends ChannelInboundHandlerAdapter {
- private final long readTimeout;
- private volatile ScheduledFuture> timeout;
-
- ReadTimeoutHandler(final long readTimeout) {
- isTrueArgument("readTimeout must be greater than zero.", readTimeout > 0);
- this.readTimeout = readTimeout;
- }
-
- void scheduleTimeout(final ChannelHandlerContext ctx) {
- isTrue("Handler called from the eventLoop", ctx.channel().eventLoop().inEventLoop());
- if (timeout == null) {
- timeout = ctx.executor().schedule(new ReadTimeoutTask(ctx), readTimeout, TimeUnit.MILLISECONDS);
- }
- }
-
- void removeTimeout(final ChannelHandlerContext ctx) {
- isTrue("Handler called from the eventLoop", ctx.channel().eventLoop().inEventLoop());
- if (timeout != null) {
- timeout.cancel(false);
- timeout = null;
- }
- }
-
- private static final class ReadTimeoutTask implements Runnable {
-
- private final ChannelHandlerContext ctx;
-
- ReadTimeoutTask(final ChannelHandlerContext ctx) {
- this.ctx = ctx;
- }
-
- @Override
- public void run() {
- if (ctx.channel().isOpen()) {
- try {
- ctx.fireExceptionCaught(ReadTimeoutException.INSTANCE);
- ctx.close();
- } catch (Throwable t) {
- ctx.fireExceptionCaught(t);
- }
- }
- }
- }
-}