Skip to content

Commit 38294a7

Browse files
committed
Add connection timeout handler to TlsChannelImpl.
JAVA-5856
1 parent 144e287 commit 38294a7

File tree

2 files changed

+103
-40
lines changed

2 files changed

+103
-40
lines changed

driver-core/src/main/com/mongodb/internal/connection/TlsChannelStreamFactoryFactory.java

Lines changed: 94 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
import java.net.StandardSocketOptions;
4141
import java.nio.ByteBuffer;
4242
import java.nio.channels.CompletionHandler;
43+
import java.nio.channels.InterruptedByTimeoutException;
4344
import java.nio.channels.SelectionKey;
4445
import java.nio.channels.Selector;
4546
import java.nio.channels.SocketChannel;
@@ -49,7 +50,9 @@
4950
import java.util.concurrent.ExecutorService;
5051
import java.util.concurrent.Future;
5152
import java.util.concurrent.TimeUnit;
53+
import java.util.concurrent.atomic.AtomicReference;
5254

55+
import static com.mongodb.assertions.Assertions.assertFalse;
5356
import static com.mongodb.assertions.Assertions.assertTrue;
5457
import static com.mongodb.assertions.Assertions.isTrue;
5558
import static com.mongodb.internal.connection.ServerAddressHelper.getSocketAddresses;
@@ -99,19 +102,39 @@ public void close() {
99102

100103
private static class SelectorMonitor implements Closeable {
101104

102-
private static final class Pair {
105+
static final class SocketRegistration {
103106
private final SocketChannel socketChannel;
104107
private final Runnable attachment;
108+
private final AtomicReference<ConnectionRegistrationState> connectionRegistrationState;
105109

106-
private Pair(final SocketChannel socketChannel, final Runnable attachment) {
110+
enum ConnectionRegistrationState {
111+
CONNECTING,
112+
CONNECTED,
113+
TIMEOUT_OUT
114+
}
115+
116+
private SocketRegistration(final SocketChannel socketChannel, final Runnable attachment) {
107117
this.socketChannel = socketChannel;
108118
this.attachment = attachment;
119+
this.connectionRegistrationState = new AtomicReference<>(ConnectionRegistrationState.CONNECTING);
120+
}
121+
122+
public boolean markConnectionEstablishmentTimedOut() {
123+
return connectionRegistrationState.compareAndSet(
124+
ConnectionRegistrationState.CONNECTING,
125+
ConnectionRegistrationState.TIMEOUT_OUT);
126+
}
127+
128+
public boolean markConnectionEstablishmentCompleted() {
129+
return connectionRegistrationState.compareAndSet(
130+
ConnectionRegistrationState.CONNECTING,
131+
ConnectionRegistrationState.CONNECTED);
109132
}
110133
}
111134

112135
private final Selector selector;
113136
private volatile boolean isClosed;
114-
private final ConcurrentLinkedDeque<Pair> pendingRegistrations = new ConcurrentLinkedDeque<>();
137+
private final ConcurrentLinkedDeque<SocketRegistration> pendingRegistrations = new ConcurrentLinkedDeque<>();
115138

116139
SelectorMonitor() {
117140
try {
@@ -121,23 +144,29 @@ private Pair(final SocketChannel socketChannel, final Runnable attachment) {
121144
}
122145
}
123146

147+
// Monitors OP_CONNECT events.
124148
void start() {
125149
Thread selectorThread = new Thread(() -> {
126150
try {
127151
while (!isClosed) {
128152
try {
129153
selector.select();
130-
131154
for (SelectionKey selectionKey : selector.selectedKeys()) {
132155
selectionKey.cancel();
133-
Runnable runnable = (Runnable) selectionKey.attachment();
134-
runnable.run();
156+
SocketRegistration socketRegistration = (SocketRegistration) selectionKey.attachment();
157+
158+
boolean markedCompleted = socketRegistration.markConnectionEstablishmentCompleted();
159+
if (markedCompleted) {
160+
Runnable runnable = socketRegistration.attachment;
161+
runnable.run();
162+
} else {
163+
assertFalse(socketRegistration.socketChannel.isOpen());
164+
}
135165
}
136166

137-
for (Iterator<Pair> iter = pendingRegistrations.iterator(); iter.hasNext();) {
138-
Pair pendingRegistration = iter.next();
139-
pendingRegistration.socketChannel.register(selector, SelectionKey.OP_CONNECT,
140-
pendingRegistration.attachment);
167+
for (Iterator<SocketRegistration> iter = pendingRegistrations.iterator(); iter.hasNext();) {
168+
SocketRegistration pendingRegistration = iter.next();
169+
pendingRegistration.socketChannel.register(selector, SelectionKey.OP_CONNECT, pendingRegistration);
141170
iter.remove();
142171
}
143172
} catch (Exception e) {
@@ -156,8 +185,9 @@ void start() {
156185
selectorThread.start();
157186
}
158187

159-
void register(final SocketChannel channel, final Runnable attachment) {
160-
pendingRegistrations.add(new Pair(channel, attachment));
188+
189+
void register(final SocketRegistration registration) {
190+
pendingRegistrations.add(registration);
161191
selector.wakeup();
162192
}
163193

@@ -203,41 +233,65 @@ public void openAsync(final OperationContext operationContext, final AsyncComple
203233

204234
socketChannel.connect(getSocketAddresses(getServerAddress(), inetAddressResolver).get(0));
205235

206-
selectorMonitor.register(socketChannel, () -> {
207-
try {
208-
if (!socketChannel.finishConnect()) {
209-
throw new MongoSocketOpenException("Failed to finish connect", getServerAddress());
210-
}
236+
SelectorMonitor.SocketRegistration socketRegistration = new SelectorMonitor.SocketRegistration(
237+
socketChannel, () -> initializeTslChannel(handler, socketChannel));
211238

212-
SSLEngine sslEngine = getSslContext().createSSLEngine(getServerAddress().getHost(),
213-
getServerAddress().getPort());
214-
sslEngine.setUseClientMode(true);
239+
int connectTimeoutMs = getSettings().getConnectTimeout(TimeUnit.MILLISECONDS);
215240

216-
SSLParameters sslParameters = sslEngine.getSSLParameters();
217-
enableSni(getServerAddress().getHost(), sslParameters);
241+
group.getTimeoutExecutor().schedule(() -> {
242+
boolean markedTimedOut = socketRegistration.markConnectionEstablishmentTimedOut();
243+
if (markedTimedOut) {
244+
closeAndTimeout(handler, socketChannel);
245+
}
246+
}, connectTimeoutMs, TimeUnit.MILLISECONDS);
218247

219-
if (!sslSettings.isInvalidHostNameAllowed()) {
220-
enableHostNameVerification(sslParameters);
221-
}
222-
sslEngine.setSSLParameters(sslParameters);
248+
selectorMonitor.register(socketRegistration);
249+
} catch (IOException e) {
250+
handler.failed(new MongoSocketOpenException("Exception opening socket", getServerAddress(), e));
251+
} catch (Throwable t) {
252+
handler.failed(t);
253+
}
254+
}
223255

224-
BufferAllocator bufferAllocator = new BufferProviderAllocator();
256+
private void closeAndTimeout(final AsyncCompletionHandler<Void> handler, final SocketChannel socketChannel) {
257+
InterruptedByTimeoutException interruptedByTimeoutException = new InterruptedByTimeoutException();
258+
try {
259+
socketChannel.close();
260+
} catch (Exception e) {
261+
interruptedByTimeoutException.addSuppressed(e);
262+
}
263+
handler.failed(new MongoSocketOpenException("Exception opening socket", getAddress(), new InterruptedByTimeoutException()));
264+
}
225265

226-
TlsChannel tlsChannel = ClientTlsChannel.newBuilder(socketChannel, sslEngine)
227-
.withEncryptedBufferAllocator(bufferAllocator)
228-
.withPlainBufferAllocator(bufferAllocator)
229-
.build();
266+
private void initializeTslChannel(final AsyncCompletionHandler<Void> handler, final SocketChannel socketChannel) {
267+
try {
268+
if (!socketChannel.finishConnect()) {
269+
throw new MongoSocketOpenException("Failed to finish connect", getServerAddress());
270+
}
230271

231-
// build asynchronous channel, based in the TLS channel and associated with the global group.
232-
setChannel(new AsynchronousTlsChannelAdapter(new AsynchronousTlsChannel(group, tlsChannel, socketChannel)));
272+
SSLEngine sslEngine = getSslContext().createSSLEngine(getServerAddress().getHost(),
273+
getServerAddress().getPort());
274+
sslEngine.setUseClientMode(true);
233275

234-
handler.completed(null);
235-
} catch (IOException e) {
236-
handler.failed(new MongoSocketOpenException("Exception opening socket", getServerAddress(), e));
237-
} catch (Throwable t) {
238-
handler.failed(t);
239-
}
240-
});
276+
SSLParameters sslParameters = sslEngine.getSSLParameters();
277+
enableSni(getServerAddress().getHost(), sslParameters);
278+
279+
if (!sslSettings.isInvalidHostNameAllowed()) {
280+
enableHostNameVerification(sslParameters);
281+
}
282+
sslEngine.setSSLParameters(sslParameters);
283+
284+
BufferAllocator bufferAllocator = new BufferProviderAllocator();
285+
286+
TlsChannel tlsChannel = ClientTlsChannel.newBuilder(socketChannel, sslEngine)
287+
.withEncryptedBufferAllocator(bufferAllocator)
288+
.withPlainBufferAllocator(bufferAllocator)
289+
.build();
290+
291+
// build asynchronous channel, based in the TLS channel and associated with the global group.
292+
setChannel(new AsynchronousTlsChannelAdapter(new AsynchronousTlsChannel(group, tlsChannel, socketChannel)));
293+
294+
handler.completed(null);
241295
} catch (IOException e) {
242296
handler.failed(new MongoSocketOpenException("Exception opening socket", getServerAddress(), e));
243297
} catch (Throwable t) {

driver-core/src/main/com/mongodb/internal/connection/tlschannel/async/AsynchronousTlsChannelGroup.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -823,4 +823,13 @@ public long getCurrentWriteCount() {
823823
public long getCurrentRegistrationCount() {
824824
return registrations.mappingCount();
825825
}
826+
827+
/**
828+
* Returns the timeout executor used by this channel group.
829+
*
830+
* @return the timeout executor
831+
*/
832+
public ScheduledThreadPoolExecutor getTimeoutExecutor() {
833+
return timeoutExecutor;
834+
}
826835
}

0 commit comments

Comments
 (0)