40
40
import java .net .StandardSocketOptions ;
41
41
import java .nio .ByteBuffer ;
42
42
import java .nio .channels .CompletionHandler ;
43
+ import java .nio .channels .InterruptedByTimeoutException ;
43
44
import java .nio .channels .SelectionKey ;
44
45
import java .nio .channels .Selector ;
45
46
import java .nio .channels .SocketChannel ;
49
50
import java .util .concurrent .ExecutorService ;
50
51
import java .util .concurrent .Future ;
51
52
import java .util .concurrent .TimeUnit ;
53
+ import java .util .concurrent .atomic .AtomicReference ;
52
54
55
+ import static com .mongodb .assertions .Assertions .assertFalse ;
53
56
import static com .mongodb .assertions .Assertions .assertTrue ;
54
57
import static com .mongodb .assertions .Assertions .isTrue ;
55
58
import static com .mongodb .internal .connection .ServerAddressHelper .getSocketAddresses ;
@@ -99,19 +102,39 @@ public void close() {
99
102
100
103
private static class SelectorMonitor implements Closeable {
101
104
102
- private static final class Pair {
105
+ static final class SocketRegistration {
103
106
private final SocketChannel socketChannel ;
104
107
private final Runnable attachment ;
108
+ private final AtomicReference <ConnectionRegistrationState > connectionRegistrationState ;
105
109
106
- private Pair (final SocketChannel socketChannel , final Runnable attachment ) {
110
+ enum ConnectionRegistrationState {
111
+ CONNECTING ,
112
+ CONNECTED ,
113
+ TIMEOUT_OUT
114
+ }
115
+
116
+ private SocketRegistration (final SocketChannel socketChannel , final Runnable attachment ) {
107
117
this .socketChannel = socketChannel ;
108
118
this .attachment = attachment ;
119
+ this .connectionRegistrationState = new AtomicReference <>(ConnectionRegistrationState .CONNECTING );
120
+ }
121
+
122
+ public boolean markConnectionEstablishmentTimedOut () {
123
+ return connectionRegistrationState .compareAndSet (
124
+ ConnectionRegistrationState .CONNECTING ,
125
+ ConnectionRegistrationState .TIMEOUT_OUT );
126
+ }
127
+
128
+ public boolean markConnectionEstablishmentCompleted () {
129
+ return connectionRegistrationState .compareAndSet (
130
+ ConnectionRegistrationState .CONNECTING ,
131
+ ConnectionRegistrationState .CONNECTED );
109
132
}
110
133
}
111
134
112
135
private final Selector selector ;
113
136
private volatile boolean isClosed ;
114
- private final ConcurrentLinkedDeque <Pair > pendingRegistrations = new ConcurrentLinkedDeque <>();
137
+ private final ConcurrentLinkedDeque <SocketRegistration > pendingRegistrations = new ConcurrentLinkedDeque <>();
115
138
116
139
SelectorMonitor () {
117
140
try {
@@ -121,23 +144,29 @@ private Pair(final SocketChannel socketChannel, final Runnable attachment) {
121
144
}
122
145
}
123
146
147
+ // Monitors OP_CONNECT events.
124
148
void start () {
125
149
Thread selectorThread = new Thread (() -> {
126
150
try {
127
151
while (!isClosed ) {
128
152
try {
129
153
selector .select ();
130
-
131
154
for (SelectionKey selectionKey : selector .selectedKeys ()) {
132
155
selectionKey .cancel ();
133
- Runnable runnable = (Runnable ) selectionKey .attachment ();
134
- runnable .run ();
156
+ SocketRegistration socketRegistration = (SocketRegistration ) selectionKey .attachment ();
157
+
158
+ boolean markedCompleted = socketRegistration .markConnectionEstablishmentCompleted ();
159
+ if (markedCompleted ) {
160
+ Runnable runnable = socketRegistration .attachment ;
161
+ runnable .run ();
162
+ } else {
163
+ assertFalse (socketRegistration .socketChannel .isOpen ());
164
+ }
135
165
}
136
166
137
- for (Iterator <Pair > iter = pendingRegistrations .iterator (); iter .hasNext ();) {
138
- Pair pendingRegistration = iter .next ();
139
- pendingRegistration .socketChannel .register (selector , SelectionKey .OP_CONNECT ,
140
- pendingRegistration .attachment );
167
+ for (Iterator <SocketRegistration > iter = pendingRegistrations .iterator (); iter .hasNext ();) {
168
+ SocketRegistration pendingRegistration = iter .next ();
169
+ pendingRegistration .socketChannel .register (selector , SelectionKey .OP_CONNECT , pendingRegistration );
141
170
iter .remove ();
142
171
}
143
172
} catch (Exception e ) {
@@ -156,8 +185,9 @@ void start() {
156
185
selectorThread .start ();
157
186
}
158
187
159
- void register (final SocketChannel channel , final Runnable attachment ) {
160
- pendingRegistrations .add (new Pair (channel , attachment ));
188
+
189
+ void register (final SocketRegistration registration ) {
190
+ pendingRegistrations .add (registration );
161
191
selector .wakeup ();
162
192
}
163
193
@@ -203,41 +233,65 @@ public void openAsync(final OperationContext operationContext, final AsyncComple
203
233
204
234
socketChannel .connect (getSocketAddresses (getServerAddress (), inetAddressResolver ).get (0 ));
205
235
206
- selectorMonitor .register (socketChannel , () -> {
207
- try {
208
- if (!socketChannel .finishConnect ()) {
209
- throw new MongoSocketOpenException ("Failed to finish connect" , getServerAddress ());
210
- }
236
+ SelectorMonitor .SocketRegistration socketRegistration = new SelectorMonitor .SocketRegistration (
237
+ socketChannel , () -> initializeTslChannel (handler , socketChannel ));
211
238
212
- SSLEngine sslEngine = getSslContext ().createSSLEngine (getServerAddress ().getHost (),
213
- getServerAddress ().getPort ());
214
- sslEngine .setUseClientMode (true );
239
+ int connectTimeoutMs = getSettings ().getConnectTimeout (TimeUnit .MILLISECONDS );
215
240
216
- SSLParameters sslParameters = sslEngine .getSSLParameters ();
217
- enableSni (getServerAddress ().getHost (), sslParameters );
241
+ group .getTimeoutExecutor ().schedule (() -> {
242
+ boolean markedTimedOut = socketRegistration .markConnectionEstablishmentTimedOut ();
243
+ if (markedTimedOut ) {
244
+ closeAndTimeout (handler , socketChannel );
245
+ }
246
+ }, connectTimeoutMs , TimeUnit .MILLISECONDS );
218
247
219
- if (!sslSettings .isInvalidHostNameAllowed ()) {
220
- enableHostNameVerification (sslParameters );
221
- }
222
- sslEngine .setSSLParameters (sslParameters );
248
+ selectorMonitor .register (socketRegistration );
249
+ } catch (IOException e ) {
250
+ handler .failed (new MongoSocketOpenException ("Exception opening socket" , getServerAddress (), e ));
251
+ } catch (Throwable t ) {
252
+ handler .failed (t );
253
+ }
254
+ }
223
255
224
- BufferAllocator bufferAllocator = new BufferProviderAllocator ();
256
+ private void closeAndTimeout (final AsyncCompletionHandler <Void > handler , final SocketChannel socketChannel ) {
257
+ InterruptedByTimeoutException interruptedByTimeoutException = new InterruptedByTimeoutException ();
258
+ try {
259
+ socketChannel .close ();
260
+ } catch (Exception e ) {
261
+ interruptedByTimeoutException .addSuppressed (e );
262
+ }
263
+ handler .failed (new MongoSocketOpenException ("Exception opening socket" , getAddress (), new InterruptedByTimeoutException ()));
264
+ }
225
265
226
- TlsChannel tlsChannel = ClientTlsChannel .newBuilder (socketChannel , sslEngine )
227
- .withEncryptedBufferAllocator (bufferAllocator )
228
- .withPlainBufferAllocator (bufferAllocator )
229
- .build ();
266
+ private void initializeTslChannel (final AsyncCompletionHandler <Void > handler , final SocketChannel socketChannel ) {
267
+ try {
268
+ if (!socketChannel .finishConnect ()) {
269
+ throw new MongoSocketOpenException ("Failed to finish connect" , getServerAddress ());
270
+ }
230
271
231
- // build asynchronous channel, based in the TLS channel and associated with the global group.
232
- setChannel (new AsynchronousTlsChannelAdapter (new AsynchronousTlsChannel (group , tlsChannel , socketChannel )));
272
+ SSLEngine sslEngine = getSslContext ().createSSLEngine (getServerAddress ().getHost (),
273
+ getServerAddress ().getPort ());
274
+ sslEngine .setUseClientMode (true );
233
275
234
- handler .completed (null );
235
- } catch (IOException e ) {
236
- handler .failed (new MongoSocketOpenException ("Exception opening socket" , getServerAddress (), e ));
237
- } catch (Throwable t ) {
238
- handler .failed (t );
239
- }
240
- });
276
+ SSLParameters sslParameters = sslEngine .getSSLParameters ();
277
+ enableSni (getServerAddress ().getHost (), sslParameters );
278
+
279
+ if (!sslSettings .isInvalidHostNameAllowed ()) {
280
+ enableHostNameVerification (sslParameters );
281
+ }
282
+ sslEngine .setSSLParameters (sslParameters );
283
+
284
+ BufferAllocator bufferAllocator = new BufferProviderAllocator ();
285
+
286
+ TlsChannel tlsChannel = ClientTlsChannel .newBuilder (socketChannel , sslEngine )
287
+ .withEncryptedBufferAllocator (bufferAllocator )
288
+ .withPlainBufferAllocator (bufferAllocator )
289
+ .build ();
290
+
291
+ // build asynchronous channel, based in the TLS channel and associated with the global group.
292
+ setChannel (new AsynchronousTlsChannelAdapter (new AsynchronousTlsChannel (group , tlsChannel , socketChannel )));
293
+
294
+ handler .completed (null );
241
295
} catch (IOException e ) {
242
296
handler .failed (new MongoSocketOpenException ("Exception opening socket" , getServerAddress (), e ));
243
297
} catch (Throwable t ) {
0 commit comments