diff --git a/driver-core/src/main/com/mongodb/internal/TimeoutSettings.java b/driver-core/src/main/com/mongodb/internal/TimeoutSettings.java index 486a893d74c..e1f0bc0b795 100644 --- a/driver-core/src/main/com/mongodb/internal/TimeoutSettings.java +++ b/driver-core/src/main/com/mongodb/internal/TimeoutSettings.java @@ -165,6 +165,11 @@ public TimeoutSettings withReadTimeoutMS(final long readTimeoutMS) { maxTimeMS, maxCommitTimeMS, wTimeoutMS, maxWaitTimeMS); } + public TimeoutSettings withConnectTimeoutMS(final long connectTimeoutMS) { + return new TimeoutSettings(generationId, timeoutMS, serverSelectionTimeoutMS, connectTimeoutMS, readTimeoutMS, maxAwaitTimeMS, + maxTimeMS, maxCommitTimeMS, wTimeoutMS, maxWaitTimeMS); + } + public TimeoutSettings withServerSelectionTimeoutMS(final long serverSelectionTimeoutMS) { return new TimeoutSettings(timeoutMS, serverSelectionTimeoutMS, connectTimeoutMS, readTimeoutMS, maxAwaitTimeMS, maxTimeMS, maxCommitTimeMS, wTimeoutMS, maxWaitTimeMS); diff --git a/driver-core/src/main/com/mongodb/internal/connection/TlsChannelStreamFactoryFactory.java b/driver-core/src/main/com/mongodb/internal/connection/TlsChannelStreamFactoryFactory.java index daf0d8cecdd..df8b3c2fe42 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/TlsChannelStreamFactoryFactory.java +++ b/driver-core/src/main/com/mongodb/internal/connection/TlsChannelStreamFactoryFactory.java @@ -40,6 +40,7 @@ import java.net.StandardSocketOptions; import java.nio.ByteBuffer; import java.nio.channels.CompletionHandler; +import java.nio.channels.InterruptedByTimeoutException; import java.nio.channels.SelectionKey; import java.nio.channels.Selector; import java.nio.channels.SocketChannel; @@ -49,6 +50,7 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; import static com.mongodb.assertions.Assertions.assertTrue; import static com.mongodb.assertions.Assertions.isTrue; @@ -97,21 +99,40 @@ public void close() { group.shutdown(); } + /** + * Monitors `OP_CONNECT` events for socket connections. + */ private static class SelectorMonitor implements Closeable { - private static final class Pair { + static final class SocketRegistration { private final SocketChannel socketChannel; - private final Runnable attachment; + private final AtomicReference afterConnectAction; - private Pair(final SocketChannel socketChannel, final Runnable attachment) { + SocketRegistration(final SocketChannel socketChannel, final Runnable afterConnectAction) { this.socketChannel = socketChannel; - this.attachment = attachment; + this.afterConnectAction = new AtomicReference<>(afterConnectAction); + } + + boolean tryCancelPendingConnection() { + return tryTakeAction() != null; + } + + void runAfterConnectActionIfNotCanceled() { + Runnable afterConnectActionToExecute = tryTakeAction(); + if (afterConnectActionToExecute != null) { + afterConnectActionToExecute.run(); + } + } + + @Nullable + private Runnable tryTakeAction() { + return afterConnectAction.getAndSet(null); } } private final Selector selector; private volatile boolean isClosed; - private final ConcurrentLinkedDeque pendingRegistrations = new ConcurrentLinkedDeque<>(); + private final ConcurrentLinkedDeque pendingRegistrations = new ConcurrentLinkedDeque<>(); SelectorMonitor() { try { @@ -127,17 +148,14 @@ void start() { while (!isClosed) { try { selector.select(); - for (SelectionKey selectionKey : selector.selectedKeys()) { selectionKey.cancel(); - Runnable runnable = (Runnable) selectionKey.attachment(); - runnable.run(); + ((SocketRegistration) selectionKey.attachment()).runAfterConnectActionIfNotCanceled(); } - for (Iterator iter = pendingRegistrations.iterator(); iter.hasNext();) { - Pair pendingRegistration = iter.next(); - pendingRegistration.socketChannel.register(selector, SelectionKey.OP_CONNECT, - pendingRegistration.attachment); + for (Iterator iter = pendingRegistrations.iterator(); iter.hasNext();) { + SocketRegistration pendingRegistration = iter.next(); + pendingRegistration.socketChannel.register(selector, SelectionKey.OP_CONNECT, pendingRegistration); iter.remove(); } } catch (Exception e) { @@ -156,8 +174,8 @@ void start() { selectorThread.start(); } - void register(final SocketChannel channel, final Runnable attachment) { - pendingRegistrations.add(new Pair(channel, attachment)); + void register(final SocketRegistration registration) { + pendingRegistrations.add(registration); selector.wakeup(); } @@ -200,44 +218,79 @@ public void openAsync(final OperationContext operationContext, final AsyncComple if (getSettings().getSendBufferSize() > 0) { socketChannel.setOption(StandardSocketOptions.SO_SNDBUF, getSettings().getSendBufferSize()); } - + //getConnectTimeoutMs MUST be called before connection attempt, as it might throw MongoOperationTimeout exception. + int connectTimeoutMs = operationContext.getTimeoutContext().getConnectTimeoutMs(); socketChannel.connect(getSocketAddresses(getServerAddress(), inetAddressResolver).get(0)); + SelectorMonitor.SocketRegistration socketRegistration = new SelectorMonitor.SocketRegistration( + socketChannel, () -> initializeTslChannel(handler, socketChannel)); - selectorMonitor.register(socketChannel, () -> { - try { - if (!socketChannel.finishConnect()) { - throw new MongoSocketOpenException("Failed to finish connect", getServerAddress()); - } + if (connectTimeoutMs > 0) { + scheduleTimeoutInterruption(handler, socketRegistration, connectTimeoutMs); + } + selectorMonitor.register(socketRegistration); + } catch (IOException e) { + handler.failed(new MongoSocketOpenException("Exception opening socket", getServerAddress(), e)); + } catch (Throwable t) { + handler.failed(t); + } + } - SSLEngine sslEngine = getSslContext().createSSLEngine(getServerAddress().getHost(), - getServerAddress().getPort()); - sslEngine.setUseClientMode(true); + private void scheduleTimeoutInterruption(final AsyncCompletionHandler handler, + final SelectorMonitor.SocketRegistration socketRegistration, + final int connectTimeoutMs) { + group.getTimeoutExecutor().schedule(() -> { + if (socketRegistration.tryCancelPendingConnection()) { + closeAndTimeout(handler, socketRegistration.socketChannel); + } + }, connectTimeoutMs, TimeUnit.MILLISECONDS); + } - SSLParameters sslParameters = sslEngine.getSSLParameters(); - enableSni(getServerAddress().getHost(), sslParameters); + private void closeAndTimeout(final AsyncCompletionHandler handler, final SocketChannel socketChannel) { + // We check if this stream was closed before timeout exception. + boolean streamClosed = isClosed(); + InterruptedByTimeoutException timeoutException = new InterruptedByTimeoutException(); + try { + socketChannel.close(); + } catch (Exception e) { + timeoutException.addSuppressed(e); + } - if (!sslSettings.isInvalidHostNameAllowed()) { - enableHostNameVerification(sslParameters); - } - sslEngine.setSSLParameters(sslParameters); + if (streamClosed) { + handler.completed(null); + } else { + handler.failed(new MongoSocketOpenException("Exception opening socket", getAddress(), timeoutException)); + } + } - BufferAllocator bufferAllocator = new BufferProviderAllocator(); + private void initializeTslChannel(final AsyncCompletionHandler handler, final SocketChannel socketChannel) { + try { + if (!socketChannel.finishConnect()) { + throw new MongoSocketOpenException("Failed to finish connect", getServerAddress()); + } - TlsChannel tlsChannel = ClientTlsChannel.newBuilder(socketChannel, sslEngine) - .withEncryptedBufferAllocator(bufferAllocator) - .withPlainBufferAllocator(bufferAllocator) - .build(); + SSLEngine sslEngine = getSslContext().createSSLEngine(getServerAddress().getHost(), + getServerAddress().getPort()); + sslEngine.setUseClientMode(true); - // build asynchronous channel, based in the TLS channel and associated with the global group. - setChannel(new AsynchronousTlsChannelAdapter(new AsynchronousTlsChannel(group, tlsChannel, socketChannel))); + SSLParameters sslParameters = sslEngine.getSSLParameters(); + enableSni(getServerAddress().getHost(), sslParameters); - handler.completed(null); - } catch (IOException e) { - handler.failed(new MongoSocketOpenException("Exception opening socket", getServerAddress(), e)); - } catch (Throwable t) { - handler.failed(t); - } - }); + if (!sslSettings.isInvalidHostNameAllowed()) { + enableHostNameVerification(sslParameters); + } + sslEngine.setSSLParameters(sslParameters); + + BufferAllocator bufferAllocator = new BufferProviderAllocator(); + + TlsChannel tlsChannel = ClientTlsChannel.newBuilder(socketChannel, sslEngine) + .withEncryptedBufferAllocator(bufferAllocator) + .withPlainBufferAllocator(bufferAllocator) + .build(); + + // build asynchronous channel, based in the TLS channel and associated with the global group. + setChannel(new AsynchronousTlsChannelAdapter(new AsynchronousTlsChannel(group, tlsChannel, socketChannel))); + + handler.completed(null); } catch (IOException e) { handler.failed(new MongoSocketOpenException("Exception opening socket", getServerAddress(), e)); } catch (Throwable t) { diff --git a/driver-core/src/main/com/mongodb/internal/connection/tlschannel/async/AsynchronousTlsChannelGroup.java b/driver-core/src/main/com/mongodb/internal/connection/tlschannel/async/AsynchronousTlsChannelGroup.java index 57db0df66e8..d9b1420a6e3 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/tlschannel/async/AsynchronousTlsChannelGroup.java +++ b/driver-core/src/main/com/mongodb/internal/connection/tlschannel/async/AsynchronousTlsChannelGroup.java @@ -823,4 +823,13 @@ public long getCurrentWriteCount() { public long getCurrentRegistrationCount() { return registrations.mappingCount(); } + + /** + * Returns the timeout executor used by this channel group. + * + * @return the timeout executor + */ + public ScheduledThreadPoolExecutor getTimeoutExecutor() { + return timeoutExecutor; + } } diff --git a/driver-core/src/test/functional/com/mongodb/internal/connection/TlsChannelStreamFunctionalTest.java b/driver-core/src/test/functional/com/mongodb/internal/connection/TlsChannelStreamFunctionalTest.java new file mode 100644 index 00000000000..3f80fcddfa3 --- /dev/null +++ b/driver-core/src/test/functional/com/mongodb/internal/connection/TlsChannelStreamFunctionalTest.java @@ -0,0 +1,150 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.internal.connection; + +import com.mongodb.MongoSocketOpenException; +import com.mongodb.ServerAddress; +import com.mongodb.connection.SocketSettings; +import com.mongodb.connection.SslSettings; +import com.mongodb.internal.TimeoutContext; +import com.mongodb.internal.TimeoutSettings; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.mockito.MockedStatic; +import org.mockito.Mockito; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; + +import java.io.IOException; +import java.net.ServerSocket; +import java.nio.channels.InterruptedByTimeoutException; +import java.nio.channels.SocketChannel; +import java.util.concurrent.TimeUnit; + +import static com.mongodb.internal.connection.OperationContext.simpleOperationContext; +import static java.lang.String.format; +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; +import static org.mockito.Mockito.atLeast; +import static org.mockito.Mockito.verify; + +class TlsChannelStreamFunctionalTest { + private static final SslSettings SSL_SETTINGS = SslSettings.builder().enabled(true).build(); + private static final String UNREACHABLE_PRIVATE_IP_ADDRESS = "10.255.255.1"; + private static final int UNREACHABLE_PORT = 65333; + + @ParameterizedTest + @ValueSource(ints = {500, 1000, 2000}) + void shouldInterruptConnectionEstablishmentWhenConnectionTimeoutExpires(final int connectTimeoutMs) throws IOException { + //given + try (StreamFactoryFactory streamFactoryFactory = new TlsChannelStreamFactoryFactory(new DefaultInetAddressResolver()); + MockedStatic socketChannelMockedStatic = Mockito.mockStatic(SocketChannel.class)) { + SingleResultSpyCaptor singleResultSpyCaptor = new SingleResultSpyCaptor<>(); + socketChannelMockedStatic.when(SocketChannel::open).thenAnswer(singleResultSpyCaptor); + + StreamFactory streamFactory = streamFactoryFactory.create(SocketSettings.builder() + .connectTimeout(connectTimeoutMs, TimeUnit.MILLISECONDS) + .build(), SSL_SETTINGS); + + Stream stream = streamFactory.create(new ServerAddress(UNREACHABLE_PRIVATE_IP_ADDRESS, UNREACHABLE_PORT)); + long connectOpenStart = System.nanoTime(); + + //when + OperationContext operationContext = createOperationContext(connectTimeoutMs); + MongoSocketOpenException mongoSocketOpenException = assertThrows(MongoSocketOpenException.class, () -> + stream.open(operationContext)); + + //then + long elapsedMs = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - connectOpenStart); + // Allow for some timing imprecision due to test overhead. + int maximumAcceptableTimeoutOvershoot = 300; + + assertInstanceOf(InterruptedByTimeoutException.class, mongoSocketOpenException.getCause()); + assertFalse(connectTimeoutMs > elapsedMs, + format("Connection timed-out sooner than expected. ConnectTimeoutMS: %d, elapsedMs: %d", connectTimeoutMs, elapsedMs)); + assertTrue(elapsedMs - connectTimeoutMs <= maximumAcceptableTimeoutOvershoot, + format("Connection timeout overshoot time %d ms should be within %d ms", elapsedMs - connectTimeoutMs, + maximumAcceptableTimeoutOvershoot)); + + SocketChannel actualSpySocketChannel = singleResultSpyCaptor.getResult(); + assertNotNull(actualSpySocketChannel, "SocketChannel was not opened"); + verify(actualSpySocketChannel, atLeast(1)).close(); + } + } + + @ParameterizedTest + @ValueSource(ints = {0, 500, 1000, 2000}) + void shouldEstablishConnection(final int connectTimeoutMs) throws IOException, InterruptedException { + //given + try (StreamFactoryFactory streamFactoryFactory = new TlsChannelStreamFactoryFactory(new DefaultInetAddressResolver()); + MockedStatic socketChannelMockedStatic = Mockito.mockStatic(SocketChannel.class); + ServerSocket serverSocket = new ServerSocket(0, 1)) { + SingleResultSpyCaptor singleResultSpyCaptor = new SingleResultSpyCaptor<>(); + socketChannelMockedStatic.when(SocketChannel::open).thenAnswer(singleResultSpyCaptor); + + StreamFactory streamFactory = streamFactoryFactory.create(SocketSettings.builder() + .connectTimeout(connectTimeoutMs, TimeUnit.MILLISECONDS) + .build(), SSL_SETTINGS); + + Stream stream = streamFactory.create(new ServerAddress(serverSocket.getInetAddress(), serverSocket.getLocalPort())); + try { + //when + stream.open(createOperationContext(connectTimeoutMs)); + + //then + SocketChannel actualSpySocketChannel = singleResultSpyCaptor.getResult(); + assertNotNull(actualSpySocketChannel, "SocketChannel was not opened"); + assertTrue(actualSpySocketChannel.isConnected()); + + // Wait to verify that socket was not closed by timeout. + MILLISECONDS.sleep(connectTimeoutMs * 2L); + assertTrue(actualSpySocketChannel.isConnected()); + assertFalse(stream.isClosed()); + } finally { + stream.close(); + } + } + } + + private static final class SingleResultSpyCaptor implements Answer { + private volatile T result = null; + + public T getResult() { + return result; + } + + @Override + public T answer(final InvocationOnMock invocationOnMock) throws Throwable { + if (result != null) { + fail(invocationOnMock.getMethod().getName() + " was called more then once"); + } + @SuppressWarnings("unchecked") + T returnedValue = (T) invocationOnMock.callRealMethod(); + result = Mockito.spy(returnedValue); + return result; + } + } + + private static OperationContext createOperationContext(final int connectTimeoutMs) { + return simpleOperationContext(new TimeoutContext(TimeoutSettings.DEFAULT.withConnectTimeoutMS(connectTimeoutMs))); + } +} diff --git a/driver-core/src/test/unit/com/mongodb/internal/TimeoutSettingsTest.java b/driver-core/src/test/unit/com/mongodb/internal/TimeoutSettingsTest.java index 71f63d32e6d..9bffd08542b 100644 --- a/driver-core/src/test/unit/com/mongodb/internal/TimeoutSettingsTest.java +++ b/driver-core/src/test/unit/com/mongodb/internal/TimeoutSettingsTest.java @@ -53,10 +53,11 @@ Collection timeoutSettingsTest() { .withMaxAwaitTimeMS(11) .withMaxCommitMS(999L) .withReadTimeoutMS(11_000) + .withConnectTimeoutMS(500) .withWTimeoutMS(222L); assertAll( () -> assertEquals(30_000, timeoutSettings.getServerSelectionTimeoutMS()), - () -> assertEquals(10_000, timeoutSettings.getConnectTimeoutMS()), + () -> assertEquals(500, timeoutSettings.getConnectTimeoutMS()), () -> assertEquals(11_000, timeoutSettings.getReadTimeoutMS()), () -> assertEquals(100, timeoutSettings.getTimeoutMS()), () -> assertEquals(111, timeoutSettings.getMaxTimeMS()),