From d835d20ed5c5d9cadf06b153b4b4223b50127ec0 Mon Sep 17 00:00:00 2001 From: Valentin Kovalenko Date: Wed, 20 Jan 2021 12:12:25 -0700 Subject: [PATCH] Backport a fix of the read timeout implementation in NettyStream The original PR https://github.com/mongodb/mongo-java-driver/pull/636 JAVA-3920 --- .../mongodb/connection/netty/NettyStream.java | 181 +++++++++++++----- .../connection/netty/ReadTimeoutHandler.java | 78 -------- 2 files changed, 130 insertions(+), 129 deletions(-) delete mode 100644 driver-core/src/main/com/mongodb/connection/netty/ReadTimeoutHandler.java diff --git a/driver-core/src/main/com/mongodb/connection/netty/NettyStream.java b/driver-core/src/main/com/mongodb/connection/netty/NettyStream.java index b8e5e8d0671..2eca8673730 100644 --- a/driver-core/src/main/com/mongodb/connection/netty/NettyStream.java +++ b/driver-core/src/main/com/mongodb/connection/netty/NettyStream.java @@ -24,10 +24,12 @@ import com.mongodb.MongoSocketOpenException; import com.mongodb.MongoSocketReadTimeoutException; import com.mongodb.ServerAddress; +import com.mongodb.annotations.ThreadSafe; import com.mongodb.connection.AsyncCompletionHandler; import com.mongodb.connection.SocketSettings; import com.mongodb.connection.SslSettings; import com.mongodb.connection.Stream; +import com.mongodb.lang.Nullable; import io.netty.bootstrap.Bootstrap; import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.CompositeByteBuf; @@ -35,16 +37,16 @@ import io.netty.channel.Channel; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFutureListener; -import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.channel.ChannelInitializer; import io.netty.channel.ChannelOption; +import io.netty.channel.ChannelPipeline; import io.netty.channel.EventLoopGroup; import io.netty.channel.SimpleChannelInboundHandler; import io.netty.channel.socket.SocketChannel; import io.netty.handler.ssl.SslHandler; import io.netty.handler.timeout.ReadTimeoutException; -import io.netty.util.concurrent.EventExecutor; import org.bson.ByteBuf; import javax.net.ssl.SSLContext; @@ -58,6 +60,8 @@ import java.util.List; import java.util.Queue; import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Future; +import java.util.concurrent.ScheduledFuture; import static com.mongodb.internal.connection.SslHelper.enableHostNameVerification; import static com.mongodb.internal.connection.SslHelper.enableSni; @@ -65,9 +69,39 @@ /** * A Stream implementation based on Netty 4.0. + * Just like it is for the {@link java.nio.channels.AsynchronousSocketChannel}, + * concurrent pending1 readers + * (whether {@linkplain #read(int, int) synchronous} or {@linkplain #readAsync(int, AsyncCompletionHandler) asynchronous}) + * are not supported by {@link NettyStream}. + * However, this class does not have a fail-fast mechanism checking for such situations. + *
+ * 1We cannot simply say that read methods are not allowed be run concurrently because strictly speaking they are allowed, + * as explained below. + *
{@code
+ * NettyStream stream = ...;
+ * stream.readAsync(1, new AsyncCompletionHandler() {//inv1
+ *  @Override
+ *  public void completed(ByteBuf o) {
+ *      stream.readAsync(//inv2
+ *              1, ...);//ret2
+ *  }
+ *
+ *  @Override
+ *  public void failed(Throwable t) {
+ *  }
+ * });//ret1
+ * }
+ * Arrows on the diagram below represent happens-before relations. + *
{@code
+ * int1 -> inv2 -> ret2
+ *      \--------> ret1
+ * }
+ * As shown on the diagram, the method {@link #readAsync(int, AsyncCompletionHandler)} runs concurrently with + * itself in the example above. However, there are no concurrent pending readers because the second operation + * is invoked after the first operation has completed reading despite the method has not returned yet. */ final class NettyStream implements Stream { - private static final String READ_HANDLER_NAME = "ReadTimeoutHandler"; + private static final byte NO_SCHEDULE_TIME = 0; private final ServerAddress address; private final SocketSettings settings; private final SslSettings sslSettings; @@ -79,8 +113,19 @@ final class NettyStream implements Stream { private volatile Channel channel; private final LinkedList pendingInboundBuffers = new LinkedList(); - private volatile PendingReader pendingReader; - private volatile Throwable pendingException; + /* The fields pendingReader, pendingException are always written/read inside synchronized blocks + * that use the same NettyStream object, so they can be plain.*/ + private PendingReader pendingReader; + private Throwable pendingException; + /* The fields readTimeoutTask, readTimeoutMillis are each written only in the ChannelInitializer.initChannel method + * (in addition to the write of the default value and the write by variable initializers), + * and read only when NettyStream users read data, or Netty event loop handles incoming data. + * Since actions done by the ChannelInitializer.initChannel method + * are ordered (in the happens-before order) before user read actions and before event loop actions that handle incoming data, + * these fields can be plain.*/ + @Nullable + private ReadTimeoutTask readTimeoutTask; + private long readTimeoutMillis = NO_SCHEDULE_TIME; NettyStream(final ServerAddress address, final SocketSettings settings, final SslSettings sslSettings, final EventLoopGroup workerGroup, final Class socketChannelClass, final ByteBufAllocator allocator) { @@ -135,6 +180,7 @@ private void initializeChannel(final AsyncCompletionHandler handler, final bootstrap.handler(new ChannelInitializer() { @Override public void initChannel(final SocketChannel ch) { + ChannelPipeline pipeline = ch.pipeline(); if (sslSettings.isEnabled()) { SSLEngine engine = getSslContext().createSSLEngine(address.getHost(), address.getPort()); engine.setUseClientMode(true); @@ -144,13 +190,20 @@ public void initChannel(final SocketChannel ch) { enableHostNameVerification(sslParameters); } engine.setSSLParameters(sslParameters); - ch.pipeline().addFirst("ssl", new SslHandler(engine, false)); + pipeline.addFirst("ssl", new SslHandler(engine, false)); } + int readTimeout = settings.getReadTimeout(MILLISECONDS); - if (readTimeout > 0) { - ch.pipeline().addLast(READ_HANDLER_NAME, new ReadTimeoutHandler(readTimeout)); + if (readTimeout > NO_SCHEDULE_TIME) { + readTimeoutMillis = readTimeout; + /* We need at least one handler before (in the inbound evaluation order) the InboundBufferHandler, + * so that we can fire exception events (they are inbound events) using its context and the InboundBufferHandler + * receives them. SslHandler is not always present, so adding a NOOP handler.*/ + pipeline.addLast(new ChannelInboundHandlerAdapter()); + readTimeoutTask = new ReadTimeoutTask(pipeline.lastContext()); } - ch.pipeline().addLast(new InboundBufferHandler()); + + pipeline.addLast(new InboundBufferHandler()); } }); final ChannelFuture channelFuture = bootstrap.connect(nextAddress); @@ -193,14 +246,27 @@ public void operationComplete(final ChannelFuture future) throws Exception { @Override public void readAsync(final int numBytes, final AsyncCompletionHandler handler) { - scheduleReadTimeout(); + readAsync(numBytes, handler, readTimeoutMillis); + } + + /** + * @param numBytes Must be equal to {@link #pendingReader}{@code .numBytes} when called by a Netty channel handler. + * @param handler Must be equal to {@link #pendingReader}{@code .handler} when called by a Netty channel handler. + * @param readTimeoutMillis Must be equal to {@link #NO_SCHEDULE_TIME} when called by a Netty channel handler. + * Timeouts may be scheduled only by the public read methods. Taking into account that concurrent pending + * readers are not allowed, there must not be a situation when threads attempt to schedule a timeout + * before the previous one is either cancelled or completed. + */ + private void readAsync(final int numBytes, final AsyncCompletionHandler handler, final long readTimeoutMillis) { ByteBuf buffer = null; Throwable exceptionResult = null; synchronized (this) { exceptionResult = pendingException; if (exceptionResult == null) { if (!hasBytesAvailable(numBytes)) { - pendingReader = new PendingReader(numBytes, handler); + if (pendingReader == null) {//called by a public read method + pendingReader = new PendingReader(numBytes, handler, scheduleReadTimeout(readTimeoutTask, readTimeoutMillis)); + } } else { CompositeByteBuf composite = allocator.compositeBuffer(pendingInboundBuffers.size()); int bytesNeeded = numBytes; @@ -223,13 +289,16 @@ public void readAsync(final int numBytes, final AsyncCompletionHandler buffer = new NettyByteBuf(composite).flip(); } } + if (!(exceptionResult == null && buffer == null)//the read operation has completed + && pendingReader != null) {//we need to clear the pending reader + cancel(pendingReader.timeout); + this.pendingReader = null; + } } if (exceptionResult != null) { - disableReadTimeout(); handler.failed(exceptionResult); } if (buffer != null) { - disableReadTimeout(); handler.completed(buffer); } } @@ -253,14 +322,12 @@ private void handleReadResponse(final io.netty.buffer.ByteBuf buffer, final Thro } else { pendingException = t; } - if (pendingReader != null) { - localPendingReader = pendingReader; - pendingReader = null; - } + localPendingReader = pendingReader; } if (localPendingReader != null) { - readAsync(localPendingReader.numBytes, localPendingReader.handler); + //timeouts may be scheduled only by the public read methods + readAsync(localPendingReader.numBytes, localPendingReader.handler, NO_SCHEDULE_TIME); } } @@ -336,10 +403,14 @@ public void exceptionCaught(final ChannelHandlerContext ctx, final Throwable t) private static final class PendingReader { private final int numBytes; private final AsyncCompletionHandler handler; + @Nullable + private final ScheduledFuture timeout; - private PendingReader(final int numBytes, final AsyncCompletionHandler handler) { + private PendingReader( + final int numBytes, final AsyncCompletionHandler handler, @Nullable final ScheduledFuture timeout) { this.numBytes = numBytes; this.handler = handler; + this.timeout = timeout; } } @@ -423,44 +494,52 @@ public void operationComplete(final ChannelFuture future) { } } - private void scheduleReadTimeout() { - adjustTimeout(false); + private static void cancel(@Nullable final Future f) { + if (f != null) { + f.cancel(false); + } } - private void disableReadTimeout() { - adjustTimeout(true); + private static long combinedTimeout(final long timeout, final int additionalTimeout) { + if (timeout == NO_SCHEDULE_TIME) { + return NO_SCHEDULE_TIME; + } else { + return Math.addExact(timeout, additionalTimeout); + } } - private void adjustTimeout(final boolean disable) { - ChannelHandler timeoutHandler = channel.pipeline().get(READ_HANDLER_NAME); - if (timeoutHandler != null) { - final ReadTimeoutHandler readTimeoutHandler = (ReadTimeoutHandler) timeoutHandler; - final ChannelHandlerContext handlerContext = channel.pipeline().context(timeoutHandler); - EventExecutor executor = handlerContext.executor(); + private static ScheduledFuture scheduleReadTimeout(@Nullable final ReadTimeoutTask readTimeoutTask, final long timeoutMillis) { + if (timeoutMillis == NO_SCHEDULE_TIME) { + return null; + } else { + //assert readTimeoutTask != null : "readTimeoutTask must be initialized if read timeouts are enabled"; + return readTimeoutTask.schedule(timeoutMillis); + } + } - if (disable) { - if (executor.inEventLoop()) { - readTimeoutHandler.removeTimeout(handlerContext); - } else { - executor.submit(new Runnable() { - @Override - public void run() { - readTimeoutHandler.removeTimeout(handlerContext); - } - }); - } - } else { - if (executor.inEventLoop()) { - readTimeoutHandler.scheduleTimeout(handlerContext); - } else { - executor.submit(new Runnable() { - @Override - public void run() { - readTimeoutHandler.scheduleTimeout(handlerContext); - } - }); - } + @ThreadSafe + private static final class ReadTimeoutTask implements Runnable { + private final ChannelHandlerContext ctx; + + private ReadTimeoutTask(final ChannelHandlerContext timeoutChannelHandlerContext) { + ctx = timeoutChannelHandlerContext; + } + + @Override + public void run() { + try { + if (ctx.channel().isOpen()) { + ctx.fireExceptionCaught(ReadTimeoutException.INSTANCE); + ctx.close(); } + } catch (final Throwable t) { + ctx.fireExceptionCaught(t); } + } + + private ScheduledFuture schedule(final long timeoutMillis) { + //assert timeoutMillis > 0 : timeoutMillis; + return ctx.executor().schedule(this, timeoutMillis, MILLISECONDS); + } } } diff --git a/driver-core/src/main/com/mongodb/connection/netty/ReadTimeoutHandler.java b/driver-core/src/main/com/mongodb/connection/netty/ReadTimeoutHandler.java deleted file mode 100644 index 4b4533f3cde..00000000000 --- a/driver-core/src/main/com/mongodb/connection/netty/ReadTimeoutHandler.java +++ /dev/null @@ -1,78 +0,0 @@ -/* - * Copyright 2008-present MongoDB, Inc. - * Copyright 2012 The Netty Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.mongodb.connection.netty; - -import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.ChannelInboundHandlerAdapter; -import io.netty.handler.timeout.ReadTimeoutException; - -import java.util.concurrent.ScheduledFuture; -import java.util.concurrent.TimeUnit; - -import static com.mongodb.assertions.Assertions.isTrue; -import static com.mongodb.assertions.Assertions.isTrueArgument; - -/** - * Passes a {@link ReadTimeoutException} if the time between a {@link #scheduleTimeout} and {@link #removeTimeout} is longer than the set - * timeout. - */ -final class ReadTimeoutHandler extends ChannelInboundHandlerAdapter { - private final long readTimeout; - private volatile ScheduledFuture timeout; - - ReadTimeoutHandler(final long readTimeout) { - isTrueArgument("readTimeout must be greater than zero.", readTimeout > 0); - this.readTimeout = readTimeout; - } - - void scheduleTimeout(final ChannelHandlerContext ctx) { - isTrue("Handler called from the eventLoop", ctx.channel().eventLoop().inEventLoop()); - if (timeout == null) { - timeout = ctx.executor().schedule(new ReadTimeoutTask(ctx), readTimeout, TimeUnit.MILLISECONDS); - } - } - - void removeTimeout(final ChannelHandlerContext ctx) { - isTrue("Handler called from the eventLoop", ctx.channel().eventLoop().inEventLoop()); - if (timeout != null) { - timeout.cancel(false); - timeout = null; - } - } - - private static final class ReadTimeoutTask implements Runnable { - - private final ChannelHandlerContext ctx; - - ReadTimeoutTask(final ChannelHandlerContext ctx) { - this.ctx = ctx; - } - - @Override - public void run() { - if (ctx.channel().isOpen()) { - try { - ctx.fireExceptionCaught(ReadTimeoutException.INSTANCE); - ctx.close(); - } catch (Throwable t) { - ctx.fireExceptionCaught(t); - } - } - } - } -}