Skip to content

Commit 568e05f

Browse files
fix(NODE-5127): implement reject kmsRequest on server close (#3964)
1 parent 4e56482 commit 568e05f

File tree

2 files changed

+167
-62
lines changed

2 files changed

+167
-62
lines changed

src/client-side-encryption/state_machine.ts

Lines changed: 77 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ import {
1313
import { type ProxyOptions } from '../cmap/connection';
1414
import { getSocks, type SocksLib } from '../deps';
1515
import { type MongoClient, type MongoClientOptions } from '../mongo_client';
16-
import { BufferPool, MongoDBCollectionNamespace } from '../utils';
16+
import { BufferPool, MongoDBCollectionNamespace, promiseWithResolvers } from '../utils';
1717
import { type DataKey } from './client_encryption';
1818
import { MongoCryptError } from './errors';
1919
import { type MongocryptdManager } from './mongocryptd_manager';
@@ -282,7 +282,7 @@ export class StateMachine {
282282
* @param kmsContext - A C++ KMS context returned from the bindings
283283
* @returns A promise that resolves when the KMS reply has be fully parsed
284284
*/
285-
kmsRequest(request: MongoCryptKMSRequest): Promise<void> {
285+
async kmsRequest(request: MongoCryptKMSRequest): Promise<void> {
286286
const parsedUrl = request.endpoint.split(':');
287287
const port = parsedUrl[1] != null ? Number.parseInt(parsedUrl[1], 10) : HTTPS_PORT;
288288
const options: tls.ConnectionOptions & { host: string; port: number } = {
@@ -291,52 +291,73 @@ export class StateMachine {
291291
port
292292
};
293293
const message = request.message;
294+
const buffer = new BufferPool();
294295

295-
// TODO(NODE-3959): We can adopt `for-await on(socket, 'data')` with logic to control abort
296-
// eslint-disable-next-line @typescript-eslint/no-misused-promises, no-async-promise-executor
297-
return new Promise(async (resolve, reject) => {
298-
const buffer = new BufferPool();
296+
const netSocket: net.Socket = new net.Socket();
297+
let socket: tls.TLSSocket;
299298

300-
// eslint-disable-next-line prefer-const
301-
let socket: net.Socket;
302-
let rawSocket: net.Socket;
303-
304-
function destroySockets() {
305-
for (const sock of [socket, rawSocket]) {
306-
if (sock) {
307-
sock.removeAllListeners();
308-
sock.destroy();
309-
}
299+
function destroySockets() {
300+
for (const sock of [socket, netSocket]) {
301+
if (sock) {
302+
sock.removeAllListeners();
303+
sock.destroy();
310304
}
311305
}
306+
}
312307

313-
function ontimeout() {
314-
destroySockets();
315-
reject(new MongoCryptError('KMS request timed out'));
316-
}
308+
function ontimeout() {
309+
return new MongoCryptError('KMS request timed out');
310+
}
311+
312+
function onerror(cause: Error) {
313+
return new MongoCryptError('KMS request failed', { cause });
314+
}
317315

318-
function onerror(err: Error) {
319-
destroySockets();
320-
const mcError = new MongoCryptError('KMS request failed', { cause: err });
321-
reject(mcError);
316+
function onclose() {
317+
return new MongoCryptError('KMS request closed');
318+
}
319+
320+
const tlsOptions = this.options.tlsOptions;
321+
if (tlsOptions) {
322+
const kmsProvider = request.kmsProvider as ClientEncryptionDataKeyProvider;
323+
const providerTlsOptions = tlsOptions[kmsProvider];
324+
if (providerTlsOptions) {
325+
const error = this.validateTlsOptions(kmsProvider, providerTlsOptions);
326+
if (error) {
327+
throw error;
328+
}
329+
try {
330+
await this.setTlsOptions(providerTlsOptions, options);
331+
} catch (err) {
332+
throw onerror(err);
333+
}
322334
}
335+
}
323336

337+
const {
338+
promise: willConnect,
339+
reject: rejectOnNetSocketError,
340+
resolve: resolveOnNetSocketConnect
341+
} = promiseWithResolvers<void>();
342+
netSocket
343+
.once('timeout', () => rejectOnNetSocketError(ontimeout()))
344+
.once('error', err => rejectOnNetSocketError(onerror(err)))
345+
.once('close', () => rejectOnNetSocketError(onclose()))
346+
.once('connect', () => resolveOnNetSocketConnect());
347+
348+
try {
324349
if (this.options.proxyOptions && this.options.proxyOptions.proxyHost) {
325-
rawSocket = net.connect({
350+
netSocket.connect({
326351
host: this.options.proxyOptions.proxyHost,
327352
port: this.options.proxyOptions.proxyPort || 1080
328353
});
354+
await willConnect;
329355

330-
rawSocket.on('timeout', ontimeout);
331-
rawSocket.on('error', onerror);
332356
try {
333-
// eslint-disable-next-line @typescript-eslint/no-var-requires
334-
const events = require('events') as typeof import('events');
335-
await events.once(rawSocket, 'connect');
336357
socks ??= loadSocks();
337358
options.socket = (
338359
await socks.SocksClient.createConnection({
339-
existing_socket: rawSocket,
360+
existing_socket: netSocket,
340361
command: 'connect',
341362
destination: { host: options.host, port: options.port },
342363
proxy: {
@@ -350,45 +371,39 @@ export class StateMachine {
350371
})
351372
).socket;
352373
} catch (err) {
353-
return onerror(err);
374+
throw onerror(err);
354375
}
355376
}
356377

357-
const tlsOptions = this.options.tlsOptions;
358-
if (tlsOptions) {
359-
const kmsProvider = request.kmsProvider as ClientEncryptionDataKeyProvider;
360-
const providerTlsOptions = tlsOptions[kmsProvider];
361-
if (providerTlsOptions) {
362-
const error = this.validateTlsOptions(kmsProvider, providerTlsOptions);
363-
if (error) reject(error);
364-
try {
365-
await this.setTlsOptions(providerTlsOptions, options);
366-
} catch (error) {
367-
return onerror(error);
368-
}
369-
}
370-
}
371378
socket = tls.connect(options, () => {
372379
socket.write(message);
373380
});
374381

375-
socket.once('timeout', ontimeout);
376-
socket.once('error', onerror);
377-
378-
socket.on('data', data => {
379-
buffer.append(data);
380-
while (request.bytesNeeded > 0 && buffer.length) {
381-
const bytesNeeded = Math.min(request.bytesNeeded, buffer.length);
382-
request.addResponse(buffer.read(bytesNeeded));
383-
}
382+
const {
383+
promise: willResolveKmsRequest,
384+
reject: rejectOnTlsSocketError,
385+
resolve
386+
} = promiseWithResolvers<void>();
387+
socket
388+
.once('timeout', () => rejectOnTlsSocketError(ontimeout()))
389+
.once('error', err => rejectOnTlsSocketError(onerror(err)))
390+
.once('close', () => rejectOnTlsSocketError(onclose()))
391+
.on('data', data => {
392+
buffer.append(data);
393+
while (request.bytesNeeded > 0 && buffer.length) {
394+
const bytesNeeded = Math.min(request.bytesNeeded, buffer.length);
395+
request.addResponse(buffer.read(bytesNeeded));
396+
}
384397

385-
if (request.bytesNeeded <= 0) {
386-
// There's no need for any more activity on this socket at this point.
387-
destroySockets();
388-
resolve();
389-
}
390-
});
391-
});
398+
if (request.bytesNeeded <= 0) {
399+
resolve();
400+
}
401+
});
402+
await willResolveKmsRequest;
403+
} finally {
404+
// There's no need for any more activity on this socket at this point.
405+
destroySockets();
406+
}
392407
}
393408

394409
*requests(context: MongoCryptContext) {

test/unit/client-side-encryption/state_machine.test.ts

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,96 @@ describe('StateMachine', function () {
251251
});
252252
});
253253

254+
context('when server closed the socket', function () {
255+
context('Socks5', function () {
256+
let server;
257+
258+
beforeEach(async function () {
259+
server = net.createServer(async socket => {
260+
socket.end();
261+
});
262+
server.listen(0);
263+
await once(server, 'listening');
264+
});
265+
266+
afterEach(function () {
267+
server.close();
268+
});
269+
270+
it('throws a MongoCryptError with SocksClientError cause', async function () {
271+
const stateMachine = new StateMachine({
272+
proxyOptions: {
273+
proxyHost: 'localhost',
274+
proxyPort: server.address().port
275+
}
276+
} as any);
277+
const request = new MockRequest(Buffer.from('foobar'), 500);
278+
279+
try {
280+
await stateMachine.kmsRequest(request);
281+
} catch (err) {
282+
expect(err.name).to.equal('MongoCryptError');
283+
expect(err.message).to.equal('KMS request failed');
284+
expect(err.cause.constructor.name).to.equal('SocksClientError');
285+
return;
286+
}
287+
expect.fail('missed exception');
288+
});
289+
});
290+
291+
context('endpoint with host and port', function () {
292+
let server;
293+
let serverSocket;
294+
295+
beforeEach(async function () {
296+
server = net.createServer(async socket => {
297+
serverSocket = socket;
298+
});
299+
server.listen(0);
300+
await once(server, 'listening');
301+
});
302+
303+
afterEach(function () {
304+
server.close();
305+
});
306+
307+
beforeEach(async function () {
308+
const netSocket = net.connect({
309+
port: server.address().port
310+
});
311+
await once(netSocket, 'connect');
312+
this.sinon.stub(tls, 'connect').returns(netSocket);
313+
});
314+
315+
afterEach(function () {
316+
server.close();
317+
this.sinon.restore();
318+
});
319+
320+
it('throws a MongoCryptError error', async function () {
321+
const stateMachine = new StateMachine({
322+
host: 'localhost',
323+
port: server.address().port
324+
} as any);
325+
const request = new MockRequest(Buffer.from('foobar'), 500);
326+
327+
try {
328+
const kmsRequestPromise = stateMachine.kmsRequest(request);
329+
330+
await promisify(setTimeout)(0);
331+
serverSocket.end();
332+
333+
await kmsRequestPromise;
334+
} catch (err) {
335+
expect(err.name).to.equal('MongoCryptError');
336+
expect(err.message).to.equal('KMS request closed');
337+
return;
338+
}
339+
expect.fail('missed exception');
340+
});
341+
});
342+
});
343+
254344
afterEach(function () {
255345
this.sinon.restore();
256346
});

0 commit comments

Comments
 (0)