@@ -13,7 +13,7 @@ import {
13
13
import { type ProxyOptions } from '../cmap/connection' ;
14
14
import { getSocks , type SocksLib } from '../deps' ;
15
15
import { type MongoClient , type MongoClientOptions } from '../mongo_client' ;
16
- import { BufferPool , MongoDBCollectionNamespace } from '../utils' ;
16
+ import { BufferPool , MongoDBCollectionNamespace , promiseWithResolvers } from '../utils' ;
17
17
import { type DataKey } from './client_encryption' ;
18
18
import { MongoCryptError } from './errors' ;
19
19
import { type MongocryptdManager } from './mongocryptd_manager' ;
@@ -282,7 +282,7 @@ export class StateMachine {
282
282
* @param kmsContext - A C++ KMS context returned from the bindings
283
283
* @returns A promise that resolves when the KMS reply has be fully parsed
284
284
*/
285
- kmsRequest ( request : MongoCryptKMSRequest ) : Promise < void > {
285
+ async kmsRequest ( request : MongoCryptKMSRequest ) : Promise < void > {
286
286
const parsedUrl = request . endpoint . split ( ':' ) ;
287
287
const port = parsedUrl [ 1 ] != null ? Number . parseInt ( parsedUrl [ 1 ] , 10 ) : HTTPS_PORT ;
288
288
const options : tls . ConnectionOptions & { host : string ; port : number } = {
@@ -291,52 +291,73 @@ export class StateMachine {
291
291
port
292
292
} ;
293
293
const message = request . message ;
294
+ const buffer = new BufferPool ( ) ;
294
295
295
- // TODO(NODE-3959): We can adopt `for-await on(socket, 'data')` with logic to control abort
296
- // eslint-disable-next-line @typescript-eslint/no-misused-promises, no-async-promise-executor
297
- return new Promise ( async ( resolve , reject ) => {
298
- const buffer = new BufferPool ( ) ;
296
+ const netSocket : net . Socket = new net . Socket ( ) ;
297
+ let socket : tls . TLSSocket ;
299
298
300
- // eslint-disable-next-line prefer-const
301
- let socket : net . Socket ;
302
- let rawSocket : net . Socket ;
303
-
304
- function destroySockets ( ) {
305
- for ( const sock of [ socket , rawSocket ] ) {
306
- if ( sock ) {
307
- sock . removeAllListeners ( ) ;
308
- sock . destroy ( ) ;
309
- }
299
+ function destroySockets ( ) {
300
+ for ( const sock of [ socket , netSocket ] ) {
301
+ if ( sock ) {
302
+ sock . removeAllListeners ( ) ;
303
+ sock . destroy ( ) ;
310
304
}
311
305
}
306
+ }
312
307
313
- function ontimeout ( ) {
314
- destroySockets ( ) ;
315
- reject ( new MongoCryptError ( 'KMS request timed out' ) ) ;
316
- }
308
+ function ontimeout ( ) {
309
+ return new MongoCryptError ( 'KMS request timed out' ) ;
310
+ }
311
+
312
+ function onerror ( cause : Error ) {
313
+ return new MongoCryptError ( 'KMS request failed' , { cause } ) ;
314
+ }
317
315
318
- function onerror ( err : Error ) {
319
- destroySockets ( ) ;
320
- const mcError = new MongoCryptError ( 'KMS request failed' , { cause : err } ) ;
321
- reject ( mcError ) ;
316
+ function onclose ( ) {
317
+ return new MongoCryptError ( 'KMS request closed' ) ;
318
+ }
319
+
320
+ const tlsOptions = this . options . tlsOptions ;
321
+ if ( tlsOptions ) {
322
+ const kmsProvider = request . kmsProvider as ClientEncryptionDataKeyProvider ;
323
+ const providerTlsOptions = tlsOptions [ kmsProvider ] ;
324
+ if ( providerTlsOptions ) {
325
+ const error = this . validateTlsOptions ( kmsProvider , providerTlsOptions ) ;
326
+ if ( error ) {
327
+ throw error ;
328
+ }
329
+ try {
330
+ await this . setTlsOptions ( providerTlsOptions , options ) ;
331
+ } catch ( err ) {
332
+ throw onerror ( err ) ;
333
+ }
322
334
}
335
+ }
323
336
337
+ const {
338
+ promise : willConnect ,
339
+ reject : rejectOnNetSocketError ,
340
+ resolve : resolveOnNetSocketConnect
341
+ } = promiseWithResolvers < void > ( ) ;
342
+ netSocket
343
+ . once ( 'timeout' , ( ) => rejectOnNetSocketError ( ontimeout ( ) ) )
344
+ . once ( 'error' , err => rejectOnNetSocketError ( onerror ( err ) ) )
345
+ . once ( 'close' , ( ) => rejectOnNetSocketError ( onclose ( ) ) )
346
+ . once ( 'connect' , ( ) => resolveOnNetSocketConnect ( ) ) ;
347
+
348
+ try {
324
349
if ( this . options . proxyOptions && this . options . proxyOptions . proxyHost ) {
325
- rawSocket = net . connect ( {
350
+ netSocket . connect ( {
326
351
host : this . options . proxyOptions . proxyHost ,
327
352
port : this . options . proxyOptions . proxyPort || 1080
328
353
} ) ;
354
+ await willConnect ;
329
355
330
- rawSocket . on ( 'timeout' , ontimeout ) ;
331
- rawSocket . on ( 'error' , onerror ) ;
332
356
try {
333
- // eslint-disable-next-line @typescript-eslint/no-var-requires
334
- const events = require ( 'events' ) as typeof import ( 'events' ) ;
335
- await events . once ( rawSocket , 'connect' ) ;
336
357
socks ??= loadSocks ( ) ;
337
358
options . socket = (
338
359
await socks . SocksClient . createConnection ( {
339
- existing_socket : rawSocket ,
360
+ existing_socket : netSocket ,
340
361
command : 'connect' ,
341
362
destination : { host : options . host , port : options . port } ,
342
363
proxy : {
@@ -350,45 +371,39 @@ export class StateMachine {
350
371
} )
351
372
) . socket ;
352
373
} catch ( err ) {
353
- return onerror ( err ) ;
374
+ throw onerror ( err ) ;
354
375
}
355
376
}
356
377
357
- const tlsOptions = this . options . tlsOptions ;
358
- if ( tlsOptions ) {
359
- const kmsProvider = request . kmsProvider as ClientEncryptionDataKeyProvider ;
360
- const providerTlsOptions = tlsOptions [ kmsProvider ] ;
361
- if ( providerTlsOptions ) {
362
- const error = this . validateTlsOptions ( kmsProvider , providerTlsOptions ) ;
363
- if ( error ) reject ( error ) ;
364
- try {
365
- await this . setTlsOptions ( providerTlsOptions , options ) ;
366
- } catch ( error ) {
367
- return onerror ( error ) ;
368
- }
369
- }
370
- }
371
378
socket = tls . connect ( options , ( ) => {
372
379
socket . write ( message ) ;
373
380
} ) ;
374
381
375
- socket . once ( 'timeout' , ontimeout ) ;
376
- socket . once ( 'error' , onerror ) ;
377
-
378
- socket . on ( 'data' , data => {
379
- buffer . append ( data ) ;
380
- while ( request . bytesNeeded > 0 && buffer . length ) {
381
- const bytesNeeded = Math . min ( request . bytesNeeded , buffer . length ) ;
382
- request . addResponse ( buffer . read ( bytesNeeded ) ) ;
383
- }
382
+ const {
383
+ promise : willResolveKmsRequest ,
384
+ reject : rejectOnTlsSocketError ,
385
+ resolve
386
+ } = promiseWithResolvers < void > ( ) ;
387
+ socket
388
+ . once ( 'timeout' , ( ) => rejectOnTlsSocketError ( ontimeout ( ) ) )
389
+ . once ( 'error' , err => rejectOnTlsSocketError ( onerror ( err ) ) )
390
+ . once ( 'close' , ( ) => rejectOnTlsSocketError ( onclose ( ) ) )
391
+ . on ( 'data' , data => {
392
+ buffer . append ( data ) ;
393
+ while ( request . bytesNeeded > 0 && buffer . length ) {
394
+ const bytesNeeded = Math . min ( request . bytesNeeded , buffer . length ) ;
395
+ request . addResponse ( buffer . read ( bytesNeeded ) ) ;
396
+ }
384
397
385
- if ( request . bytesNeeded <= 0 ) {
386
- // There's no need for any more activity on this socket at this point.
387
- destroySockets ( ) ;
388
- resolve ( ) ;
389
- }
390
- } ) ;
391
- } ) ;
398
+ if ( request . bytesNeeded <= 0 ) {
399
+ resolve ( ) ;
400
+ }
401
+ } ) ;
402
+ await willResolveKmsRequest ;
403
+ } finally {
404
+ // There's no need for any more activity on this socket at this point.
405
+ destroySockets ( ) ;
406
+ }
392
407
}
393
408
394
409
* requests ( context : MongoCryptContext ) {
0 commit comments