diff --git a/driver-core/src/main/com/mongodb/connection/netty/NettyStream.java b/driver-core/src/main/com/mongodb/connection/netty/NettyStream.java index 948da9162ec..a141cd7860b 100644 --- a/driver-core/src/main/com/mongodb/connection/netty/NettyStream.java +++ b/driver-core/src/main/com/mongodb/connection/netty/NettyStream.java @@ -35,7 +35,6 @@ import io.netty.channel.Channel; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFutureListener; -import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInitializer; import io.netty.channel.ChannelOption; @@ -44,7 +43,6 @@ import io.netty.channel.socket.SocketChannel; import io.netty.handler.ssl.SslHandler; import io.netty.handler.timeout.ReadTimeoutException; -import io.netty.util.concurrent.EventExecutor; import org.bson.ByteBuf; import javax.net.ssl.SSLContext; @@ -58,6 +56,8 @@ import java.util.List; import java.util.Queue; import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; import static com.mongodb.internal.connection.SslHelper.enableHostNameVerification; import static com.mongodb.internal.connection.SslHelper.enableSni; @@ -67,7 +67,7 @@ * A Stream implementation based on Netty 4.0. */ final class NettyStream implements Stream { - private static final String READ_HANDLER_NAME = "ReadTimeoutHandler"; + private static final String INBOUND_BUFFER_HANDLER_NAME = "InboundBufferHandler"; private final ServerAddress address; private final SocketSettings settings; private final SslSettings sslSettings; @@ -81,6 +81,7 @@ final class NettyStream implements Stream { private final LinkedList pendingInboundBuffers = new LinkedList(); private volatile PendingReader pendingReader; private volatile Throwable pendingException; + private final int readTimeoutMs; NettyStream(final ServerAddress address, final SocketSettings settings, final SslSettings sslSettings, final EventLoopGroup workerGroup, final Class socketChannelClass, final ByteBufAllocator allocator) { @@ -90,6 +91,7 @@ final class NettyStream implements Stream { this.workerGroup = workerGroup; this.socketChannelClass = socketChannelClass; this.allocator = allocator; + this.readTimeoutMs = settings.getReadTimeout(MILLISECONDS); } @Override @@ -154,11 +156,7 @@ public void initChannel(final SocketChannel ch) { engine.setSSLParameters(sslParameters); ch.pipeline().addFirst("ssl", new SslHandler(engine, false)); } - int readTimeout = settings.getReadTimeout(MILLISECONDS); - if (readTimeout > 0) { - ch.pipeline().addLast(READ_HANDLER_NAME, new ReadTimeoutHandler(readTimeout)); - } - ch.pipeline().addLast(new InboundBufferHandler()); + ch.pipeline().addLast(INBOUND_BUFFER_HANDLER_NAME, new InboundBufferHandler()); } }); final ChannelFuture channelFuture = bootstrap.connect(nextAddress); @@ -186,7 +184,7 @@ public boolean supportsAdditionalTimeout() { @Override public ByteBuf read(final int numBytes, final int additionalTimeout) throws IOException { FutureAsyncCompletionHandler future = new FutureAsyncCompletionHandler(); - readAsync(numBytes, future, additionalTimeout); + readAsync(numBytes, future, additionalTimeout, null); return future.get(); } @@ -211,18 +209,19 @@ public void operationComplete(final ChannelFuture future) throws Exception { @Override public void readAsync(final int numBytes, final AsyncCompletionHandler handler) { - readAsync(numBytes, handler, 0); + readAsync(numBytes, handler, 0, null); } - private void readAsync(final int numBytes, final AsyncCompletionHandler handler, final int additionalTimeout) { - scheduleReadTimeout(additionalTimeout); + private void readAsync(final int numBytes, final AsyncCompletionHandler handler, final int additionalTimeout, + final PendingReader pendingReaderToReuse) { ByteBuf buffer = null; - Throwable exceptionResult = null; + Throwable exceptionResult; synchronized (this) { exceptionResult = pendingException; if (exceptionResult == null) { if (!hasBytesAvailable(numBytes)) { - pendingReader = new PendingReader(numBytes, handler); + pendingReader = (pendingReaderToReuse != null) ? pendingReaderToReuse + : new PendingReader(numBytes, handler, scheduleReadTimeout(additionalTimeout)); } else { CompositeByteBuf composite = allocator.compositeBuffer(pendingInboundBuffers.size()); int bytesNeeded = numBytes; @@ -246,12 +245,17 @@ private void readAsync(final int numBytes, final AsyncCompletionHandler } } } + + if ((exceptionResult != null || buffer != null) + && (pendingReaderToReuse != null) + && (pendingReaderToReuse.timeoutFuture != null)) { + pendingReaderToReuse.timeoutFuture.cancel(false); + } + if (exceptionResult != null) { - disableReadTimeout(); handler.failed(exceptionResult); } if (buffer != null) { - disableReadTimeout(); handler.completed(buffer); } } @@ -282,7 +286,7 @@ private void handleReadResponse(final io.netty.buffer.ByteBuf buffer, final Thro } if (localPendingReader != null) { - readAsync(localPendingReader.numBytes, localPendingReader.handler); + readAsync(localPendingReader.numBytes, localPendingReader.handler, 0, localPendingReader); } } @@ -358,10 +362,12 @@ public void exceptionCaught(final ChannelHandlerContext ctx, final Throwable t) private static final class PendingReader { private final int numBytes; private final AsyncCompletionHandler handler; + private final Future timeoutFuture; - private PendingReader(final int numBytes, final AsyncCompletionHandler handler) { + private PendingReader(final int numBytes, final AsyncCompletionHandler handler, final Future timeoutFuture) { this.numBytes = numBytes; this.handler = handler; + this.timeoutFuture = timeoutFuture; } } @@ -445,47 +451,38 @@ public void operationComplete(final ChannelFuture future) { } } - private void scheduleReadTimeout(final int additionalTimeout) { - adjustTimeout(false, additionalTimeout); - } + private Future scheduleReadTimeout(final int additionalTimeout) { + if (isClosed || readTimeoutMs <= 0) { + return null; + } - private void disableReadTimeout() { - adjustTimeout(true, 0); + final ChannelHandlerContext ctx = channel.pipeline().context(INBOUND_BUFFER_HANDLER_NAME); + if (ctx != null) { + final ReadTimeoutTask task = new ReadTimeoutTask(ctx); + return ctx.executor().schedule(task, readTimeoutMs + additionalTimeout, TimeUnit.MILLISECONDS); + } else { + return null; + } } - private void adjustTimeout(final boolean disable, final int additionalTimeout) { - if (isClosed) { - return; - } - ChannelHandler timeoutHandler = channel.pipeline().get(READ_HANDLER_NAME); - if (timeoutHandler != null) { - final ReadTimeoutHandler readTimeoutHandler = (ReadTimeoutHandler) timeoutHandler; - final ChannelHandlerContext handlerContext = channel.pipeline().context(timeoutHandler); - EventExecutor executor = handlerContext.executor(); - - if (disable) { - if (executor.inEventLoop()) { - readTimeoutHandler.removeTimeout(handlerContext); - } else { - executor.submit(new Runnable() { - @Override - public void run() { - readTimeoutHandler.removeTimeout(handlerContext); - } - }); - } - } else { - if (executor.inEventLoop()) { - readTimeoutHandler.scheduleTimeout(handlerContext, additionalTimeout); - } else { - executor.submit(new Runnable() { - @Override - public void run() { - readTimeoutHandler.scheduleTimeout(handlerContext, additionalTimeout); - } - }); - } + private static final class ReadTimeoutTask implements Runnable { + + private final ChannelHandlerContext ctx; + + ReadTimeoutTask(final ChannelHandlerContext ctx) { + this.ctx = ctx; + } + + @Override + public void run() { + if (ctx.channel().isOpen()) { + try { + ctx.fireExceptionCaught(ReadTimeoutException.INSTANCE); + ctx.close(); + } catch (Throwable t) { + ctx.fireExceptionCaught(t); } } + } } } diff --git a/driver-core/src/main/com/mongodb/connection/netty/ReadTimeoutHandler.java b/driver-core/src/main/com/mongodb/connection/netty/ReadTimeoutHandler.java deleted file mode 100644 index 824c8f7d6a3..00000000000 --- a/driver-core/src/main/com/mongodb/connection/netty/ReadTimeoutHandler.java +++ /dev/null @@ -1,78 +0,0 @@ -/* - * Copyright 2008-present MongoDB, Inc. - * Copyright 2012 The Netty Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.mongodb.connection.netty; - -import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.ChannelInboundHandlerAdapter; -import io.netty.handler.timeout.ReadTimeoutException; - -import java.util.concurrent.ScheduledFuture; -import java.util.concurrent.TimeUnit; - -import static com.mongodb.assertions.Assertions.isTrue; -import static com.mongodb.assertions.Assertions.isTrueArgument; - -/** - * Passes a {@link ReadTimeoutException} if the time between a {@link #scheduleTimeout} and {@link #removeTimeout} is longer than the set - * timeout. - */ -final class ReadTimeoutHandler extends ChannelInboundHandlerAdapter { - private final long readTimeout; - private volatile ScheduledFuture timeout; - - ReadTimeoutHandler(final long readTimeout) { - isTrueArgument("readTimeout must be greater than zero.", readTimeout > 0); - this.readTimeout = readTimeout; - } - - void scheduleTimeout(final ChannelHandlerContext ctx, final int additionalTimeout) { - isTrue("Handler called from the eventLoop", ctx.channel().eventLoop().inEventLoop()); - if (timeout == null) { - timeout = ctx.executor().schedule(new ReadTimeoutTask(ctx), readTimeout + additionalTimeout, TimeUnit.MILLISECONDS); - } - } - - void removeTimeout(final ChannelHandlerContext ctx) { - isTrue("Handler called from the eventLoop", ctx.channel().eventLoop().inEventLoop()); - if (timeout != null) { - timeout.cancel(false); - timeout = null; - } - } - - private static final class ReadTimeoutTask implements Runnable { - - private final ChannelHandlerContext ctx; - - ReadTimeoutTask(final ChannelHandlerContext ctx) { - this.ctx = ctx; - } - - @Override - public void run() { - if (ctx.channel().isOpen()) { - try { - ctx.fireExceptionCaught(ReadTimeoutException.INSTANCE); - ctx.close(); - } catch (Throwable t) { - ctx.fireExceptionCaught(t); - } - } - } - } -}