From fcc9e0908e424fb9b808a6d5d0d0bb83f9b3e358 Mon Sep 17 00:00:00 2001 From: Bailey Pearson Date: Wed, 22 Mar 2023 09:34:30 -0400 Subject: [PATCH 01/11] refactor: make getWorkflow sync & throw --- .eslintrc.json | 2 ++ src/cmap/auth/mongodb_oidc.ts | 27 +++++++++------------------ 2 files changed, 11 insertions(+), 18 deletions(-) diff --git a/.eslintrc.json b/.eslintrc.json index 9b3becfbdf9..153b36bbd7f 100644 --- a/.eslintrc.json +++ b/.eslintrc.json @@ -87,6 +87,7 @@ "global" ], "@typescript-eslint/no-explicit-any": "off", + "@typescript-eslint/require-await": "off", "no-restricted-imports": [ "error", { @@ -228,6 +229,7 @@ "@typescript-eslint/no-unsafe-call": "off", "@typescript-eslint/restrict-plus-operands": "off", "@typescript-eslint/restrict-template-expressions": "off", + "@typescript-eslint/require-await": "off", "no-return-await": "off", "@typescript-eslint/return-await": [ "error", diff --git a/src/cmap/auth/mongodb_oidc.ts b/src/cmap/auth/mongodb_oidc.ts index d5983ce6c6f..bc89ce3ad2b 100644 --- a/src/cmap/auth/mongodb_oidc.ts +++ b/src/cmap/auth/mongodb_oidc.ts @@ -88,17 +88,8 @@ export class MongoDBOIDC extends AuthProvider { return callback(new MongoMissingCredentialsError('AuthContext must provide credentials.')); } - getWorkflow(credentials, (error, workflow) => { - if (error) { - return callback(error); - } - if (!workflow) { - return callback( - new MongoRuntimeError( - `Could not load workflow for device ${credentials.mechanismProperties.PROVIDER_NAME}` - ) - ); - } + try { + const workflow = getWorkflow(credentials); workflow.execute(connection, credentials, reauthenticating).then( result => { return callback(undefined, result); @@ -107,7 +98,9 @@ export class MongoDBOIDC extends AuthProvider { callback(error); } ); - }); + } catch (error) { + callback(error); + } } /** @@ -150,15 +143,13 @@ export class MongoDBOIDC extends AuthProvider { /** * Gets either a device workflow or callback workflow. */ -function getWorkflow(credentials: MongoCredentials, callback: Callback): void { +function getWorkflow(credentials: MongoCredentials): Workflow { const providerName = credentials.mechanismProperties.PROVIDER_NAME; const workflow = OIDC_WORKFLOWS.get(providerName || 'callback'); if (!workflow) { - return callback( - new MongoInvalidArgumentError( - `Could not load workflow for provider ${credentials.mechanismProperties.PROVIDER_NAME}` - ) + throw new MongoInvalidArgumentError( + `Could not load workflow for provider ${credentials.mechanismProperties.PROVIDER_NAME}` ); } - callback(undefined, workflow); + return workflow; } From 5967761f7abdc6141766c21f4ca196524c60c753 Mon Sep 17 00:00:00 2001 From: Bailey Pearson Date: Wed, 22 Mar 2023 09:47:02 -0400 Subject: [PATCH 02/11] chore: make `prepare()` async --- src/cmap/auth/auth_provider.ts | 10 +- src/cmap/auth/mongodb_oidc.ts | 35 +++---- src/cmap/auth/scram.ts | 33 +++---- src/cmap/auth/x509.ts | 11 +-- src/cmap/connect.ts | 170 ++++++++++++++++----------------- test/tools/uri_spec_runner.ts | 1 - test/unit/cmap/connect.test.ts | 4 +- 7 files changed, 121 insertions(+), 143 deletions(-) diff --git a/src/cmap/auth/auth_provider.ts b/src/cmap/auth/auth_provider.ts index 82942c0a9a4..225bd991001 100644 --- a/src/cmap/auth/auth_provider.ts +++ b/src/cmap/auth/auth_provider.ts @@ -45,12 +45,12 @@ export class AuthProvider { * @param handshakeDoc - The document used for the initial handshake on a connection * @param authContext - Context for authentication flow */ - prepare( + async prepare( handshakeDoc: HandshakeDocument, - authContext: AuthContext, - callback: Callback - ): void { - callback(undefined, handshakeDoc); + /* eslint @typescript-eslint/no-unused-vars : 0 */ + authContext: AuthContext + ): Promise { + return handshakeDoc; } /** diff --git a/src/cmap/auth/mongodb_oidc.ts b/src/cmap/auth/mongodb_oidc.ts index bc89ce3ad2b..64ca94eed9d 100644 --- a/src/cmap/auth/mongodb_oidc.ts +++ b/src/cmap/auth/mongodb_oidc.ts @@ -106,37 +106,24 @@ export class MongoDBOIDC extends AuthProvider { /** * Add the speculative auth for the initial handshake. */ - override prepare( + override async prepare( handshakeDoc: HandshakeDocument, - authContext: AuthContext, - callback: Callback - ): void { + authContext: AuthContext + ): Promise { const { credentials } = authContext; if (!credentials) { - return callback(new MongoMissingCredentialsError('AuthContext must provide credentials.')); + throw new MongoMissingCredentialsError('AuthContext must provide credentials.'); } - getWorkflow(credentials, (error, workflow) => { - if (error) { - return callback(error); - } - if (!workflow) { - return callback( - new MongoRuntimeError( - `Could not load workflow for provider ${credentials.mechanismProperties.PROVIDER_NAME}` - ) - ); - } - workflow.speculativeAuth().then( - result => { - return callback(undefined, { ...handshakeDoc, ...result }); - }, - error => { - callback(error); - } + const workflow = getWorkflow(credentials); + if (!workflow) { + throw new MongoRuntimeError( + `Could not load workflow for provider ${credentials.mechanismProperties.PROVIDER_NAME}` ); - }); + } + const result = await workflow.speculativeAuth(); + return { ...handshakeDoc, ...result }; } } diff --git a/src/cmap/auth/scram.ts b/src/cmap/auth/scram.ts index dbe4ce25a39..896a9aa7d16 100644 --- a/src/cmap/auth/scram.ts +++ b/src/cmap/auth/scram.ts @@ -1,4 +1,5 @@ import * as crypto from 'crypto'; +import { promisify } from 'util'; import { Binary, Document } from '../../bson'; import { saslprep } from '../../deps'; @@ -19,37 +20,37 @@ type CryptoMethod = 'sha1' | 'sha256'; class ScramSHA extends AuthProvider { cryptoMethod: CryptoMethod; + randomBytesAsync: (size: number) => Promise; constructor(cryptoMethod: CryptoMethod) { super(); this.cryptoMethod = cryptoMethod || 'sha1'; + this.randomBytesAsync = promisify(crypto.randomBytes); } - override prepare(handshakeDoc: HandshakeDocument, authContext: AuthContext, callback: Callback) { + override async prepare( + handshakeDoc: HandshakeDocument, + authContext: AuthContext + ): Promise { const cryptoMethod = this.cryptoMethod; const credentials = authContext.credentials; if (!credentials) { - return callback(new MongoMissingCredentialsError('AuthContext must provide credentials.')); + throw new MongoMissingCredentialsError('AuthContext must provide credentials.'); } if (cryptoMethod === 'sha256' && saslprep == null) { emitWarning('Warning: no saslprep library specified. Passwords will not be sanitized'); } - crypto.randomBytes(24, (err, nonce) => { - if (err) { - return callback(err); - } + const nonce = await this.randomBytesAsync(24); + // store the nonce for later use + Object.assign(authContext, { nonce }); - // store the nonce for later use - Object.assign(authContext, { nonce }); - - const request = Object.assign({}, handshakeDoc, { - speculativeAuthenticate: Object.assign(makeFirstMessage(cryptoMethod, credentials, nonce), { - db: credentials.source - }) - }); - - callback(undefined, request); + const request = Object.assign({}, handshakeDoc, { + speculativeAuthenticate: Object.assign(makeFirstMessage(cryptoMethod, credentials, nonce), { + db: credentials.source + }) }); + + return request; } override auth(authContext: AuthContext, callback: Callback) { diff --git a/src/cmap/auth/x509.ts b/src/cmap/auth/x509.ts index a12e6f9d8a8..2f55e27ad90 100644 --- a/src/cmap/auth/x509.ts +++ b/src/cmap/auth/x509.ts @@ -6,20 +6,19 @@ import { AuthContext, AuthProvider } from './auth_provider'; import type { MongoCredentials } from './mongo_credentials'; export class X509 extends AuthProvider { - override prepare( + override async prepare( handshakeDoc: HandshakeDocument, - authContext: AuthContext, - callback: Callback - ): void { + authContext: AuthContext + ): Promise { const { credentials } = authContext; if (!credentials) { - return callback(new MongoMissingCredentialsError('AuthContext must provide credentials.')); + throw new MongoMissingCredentialsError('AuthContext must provide credentials.'); } Object.assign(handshakeDoc, { speculativeAuthenticate: x509AuthenticateCommand(credentials) }); - callback(undefined, handshakeDoc); + return handshakeDoc; } override auth(authContext: AuthContext, callback: Callback): void { diff --git a/src/cmap/connect.ts b/src/cmap/connect.ts index e7e7a87c896..c56447ffb14 100644 --- a/src/cmap/connect.ts +++ b/src/cmap/connect.ts @@ -119,93 +119,92 @@ function performInitialHandshake( const authContext = new AuthContext(conn, credentials, options); conn.authContext = authContext; - prepareHandshakeDocument(authContext, (err, handshakeDoc) => { - if (err || !handshakeDoc) { - return callback(err); - } - - const handshakeOptions: Document = Object.assign({}, options); - if (typeof options.connectTimeoutMS === 'number') { - // The handshake technically is a monitoring check, so its socket timeout should be connectTimeoutMS - handshakeOptions.socketTimeoutMS = options.connectTimeoutMS; - } - - const start = new Date().getTime(); - conn.command(ns('admin.$cmd'), handshakeDoc, handshakeOptions, (err, response) => { - if (err) { - callback(err); - return; + prepareHandshakeDocument(authContext).then( + handshakeDoc => { + const handshakeOptions: Document = Object.assign({}, options); + if (typeof options.connectTimeoutMS === 'number') { + // The handshake technically is a monitoring check, so its socket timeout should be connectTimeoutMS + handshakeOptions.socketTimeoutMS = options.connectTimeoutMS; } - if (response?.ok === 0) { - callback(new MongoServerError(response)); - return; - } + const start = new Date().getTime(); + conn.command(ns('admin.$cmd'), handshakeDoc, handshakeOptions, (err, response) => { + if (err) { + callback(err); + return; + } - if (!('isWritablePrimary' in response)) { - // Provide hello-style response document. - response.isWritablePrimary = response[LEGACY_HELLO_COMMAND]; - } + if (response?.ok === 0) { + callback(new MongoServerError(response)); + return; + } - if (response.helloOk) { - conn.helloOk = true; - } + if (!('isWritablePrimary' in response)) { + // Provide hello-style response document. + response.isWritablePrimary = response[LEGACY_HELLO_COMMAND]; + } - const supportedServerErr = checkSupportedServer(response, options); - if (supportedServerErr) { - callback(supportedServerErr); - return; - } + if (response.helloOk) { + conn.helloOk = true; + } - if (options.loadBalanced) { - if (!response.serviceId) { - return callback( - new MongoCompatibilityError( - 'Driver attempted to initialize in load balancing mode, ' + - 'but the server does not support this mode.' - ) - ); + const supportedServerErr = checkSupportedServer(response, options); + if (supportedServerErr) { + callback(supportedServerErr); + return; } - } - // NOTE: This is metadata attached to the connection while porting away from - // handshake being done in the `Server` class. Likely, it should be - // relocated, or at very least restructured. - conn.hello = response; - conn.lastHelloMS = new Date().getTime() - start; - - if (!response.arbiterOnly && credentials) { - // store the response on auth context - authContext.response = response; - - const resolvedCredentials = credentials.resolveAuthMechanism(response); - const provider = AUTH_PROVIDERS.get(resolvedCredentials.mechanism); - if (!provider) { - return callback( - new MongoInvalidArgumentError( - `No AuthProvider for ${resolvedCredentials.mechanism} defined.` - ) - ); + if (options.loadBalanced) { + if (!response.serviceId) { + return callback( + new MongoCompatibilityError( + 'Driver attempted to initialize in load balancing mode, ' + + 'but the server does not support this mode.' + ) + ); + } } - provider.auth(authContext, err => { - if (err) { - if (err instanceof MongoError) { - err.addErrorLabel(MongoErrorLabel.HandshakeError); - if (needsRetryableWriteLabel(err, response.maxWireVersion)) { - err.addErrorLabel(MongoErrorLabel.RetryableWriteError); + + // NOTE: This is metadata attached to the connection while porting away from + // handshake being done in the `Server` class. Likely, it should be + // relocated, or at very least restructured. + conn.hello = response; + conn.lastHelloMS = new Date().getTime() - start; + + if (!response.arbiterOnly && credentials) { + // store the response on auth context + authContext.response = response; + + const resolvedCredentials = credentials.resolveAuthMechanism(response); + const provider = AUTH_PROVIDERS.get(resolvedCredentials.mechanism); + if (!provider) { + return callback( + new MongoInvalidArgumentError( + `No AuthProvider for ${resolvedCredentials.mechanism} defined.` + ) + ); + } + provider.auth(authContext, err => { + if (err) { + if (err instanceof MongoError) { + err.addErrorLabel(MongoErrorLabel.HandshakeError); + if (needsRetryableWriteLabel(err, response.maxWireVersion)) { + err.addErrorLabel(MongoErrorLabel.RetryableWriteError); + } } + return callback(err); } - return callback(err); - } - callback(undefined, conn); - }); + callback(undefined, conn); + }); - return; - } + return; + } - callback(undefined, conn); - }); - }); + callback(undefined, conn); + }); + }, + error => callback(error) + ); } export interface HandshakeDocument extends Document { @@ -226,10 +225,9 @@ export interface HandshakeDocument extends Document { * * This function is only exposed for testing purposes. */ -export function prepareHandshakeDocument( - authContext: AuthContext, - callback: Callback -) { +export async function prepareHandshakeDocument( + authContext: AuthContext +): Promise { const options = authContext.options; const compressors = options.compressors ? options.compressors : []; const { serverApi } = authContext.connection; @@ -253,23 +251,19 @@ export function prepareHandshakeDocument( const provider = AUTH_PROVIDERS.get(AuthMechanism.MONGODB_SCRAM_SHA256); if (!provider) { // This auth mechanism is always present. - return callback( - new MongoInvalidArgumentError( - `No AuthProvider for ${AuthMechanism.MONGODB_SCRAM_SHA256} defined.` - ) + throw new MongoInvalidArgumentError( + `No AuthProvider for ${AuthMechanism.MONGODB_SCRAM_SHA256} defined.` ); } - return provider.prepare(handshakeDoc, authContext, callback); + return provider.prepare(handshakeDoc, authContext); } const provider = AUTH_PROVIDERS.get(credentials.mechanism); if (!provider) { - return callback( - new MongoInvalidArgumentError(`No AuthProvider for ${credentials.mechanism} defined.`) - ); + throw new MongoInvalidArgumentError(`No AuthProvider for ${credentials.mechanism} defined.`); } - return provider.prepare(handshakeDoc, authContext, callback); + return provider.prepare(handshakeDoc, authContext); } - callback(undefined, handshakeDoc); + return handshakeDoc; } /** @public */ diff --git a/test/tools/uri_spec_runner.ts b/test/tools/uri_spec_runner.ts index 492043aa11b..a31a25fa2dc 100644 --- a/test/tools/uri_spec_runner.ts +++ b/test/tools/uri_spec_runner.ts @@ -24,7 +24,6 @@ interface UriTest extends UriTestBase { }; options: Record; } - interface AuthTest extends UriTestBase { credential: { username: string; diff --git a/test/unit/cmap/connect.test.ts b/test/unit/cmap/connect.test.ts index 9a038951a3a..e361138d059 100644 --- a/test/unit/cmap/connect.test.ts +++ b/test/unit/cmap/connect.test.ts @@ -12,7 +12,7 @@ import { LEGACY_HELLO_COMMAND, MongoCredentials, MongoNetworkError, - prepareHandshakeDocument as prepareHandshakeDocumentCb + prepareHandshakeDocument } from '../../mongodb'; import { genClusterTime } from '../../tools/common'; import * as mock from '../../tools/mongodb-mock/index'; @@ -206,8 +206,6 @@ describe('Connect Tests', function () { }); context('prepareHandshakeDocument', () => { - const prepareHandshakeDocument = promisify(prepareHandshakeDocumentCb); - context('when serverApi.version is present', () => { const options = {}; const authContext = { From 913809414e5a0415f9c9309466ff964ccab4f6cf Mon Sep 17 00:00:00 2001 From: Bailey Pearson Date: Wed, 22 Mar 2023 13:07:41 -0400 Subject: [PATCH 03/11] prevent double callback call --- src/cmap/auth/mongodb_oidc.ts | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/src/cmap/auth/mongodb_oidc.ts b/src/cmap/auth/mongodb_oidc.ts index 64ca94eed9d..4911e05e7ac 100644 --- a/src/cmap/auth/mongodb_oidc.ts +++ b/src/cmap/auth/mongodb_oidc.ts @@ -88,19 +88,23 @@ export class MongoDBOIDC extends AuthProvider { return callback(new MongoMissingCredentialsError('AuthContext must provide credentials.')); } + let workflow; + try { - const workflow = getWorkflow(credentials); - workflow.execute(connection, credentials, reauthenticating).then( - result => { - return callback(undefined, result); - }, - error => { - callback(error); - } - ); + workflow = getWorkflow(credentials); } catch (error) { callback(error); + return; } + + workflow.execute(connection, credentials, reauthenticating).then( + result => { + return callback(undefined, result); + }, + error => { + callback(error); + } + ); } /** From f9eaab02e1b18085460eda12a9a55eae5077535c Mon Sep 17 00:00:00 2001 From: Neal Beeken Date: Wed, 22 Mar 2023 15:19:42 -0400 Subject: [PATCH 04/11] chore: small cleanups --- .eslintrc.json | 6 ++++++ src/cmap/auth/auth_provider.ts | 5 ++--- src/cmap/auth/scram.ts | 12 +++++++----- src/cmap/connect.ts | 11 +++-------- 4 files changed, 18 insertions(+), 16 deletions(-) diff --git a/.eslintrc.json b/.eslintrc.json index 153b36bbd7f..09be58ec580 100644 --- a/.eslintrc.json +++ b/.eslintrc.json @@ -121,6 +121,12 @@ "selector": "BinaryExpression[operator=/[=!]==?/] Literal[value='undefined']", "message": "Do not strictly check typeof undefined (NOTE: currently this rule only detects the usage of 'undefined' string literal so this could be a misfire)" } + ], + "@typescript-eslint/no-unused-vars": [ + "error", + { + "argsIgnorePattern": "^_" + } ] }, "overrides": [ diff --git a/src/cmap/auth/auth_provider.ts b/src/cmap/auth/auth_provider.ts index 225bd991001..4fa7fb46b0f 100644 --- a/src/cmap/auth/auth_provider.ts +++ b/src/cmap/auth/auth_provider.ts @@ -47,8 +47,7 @@ export class AuthProvider { */ async prepare( handshakeDoc: HandshakeDocument, - /* eslint @typescript-eslint/no-unused-vars : 0 */ - authContext: AuthContext + _authContext: AuthContext ): Promise { return handshakeDoc; } @@ -59,7 +58,7 @@ export class AuthProvider { * @param context - A shared context for authentication flow * @param callback - The callback to return the result from the authentication */ - auth(context: AuthContext, callback: Callback): void { + auth(_context: AuthContext, callback: Callback): void { // TODO(NODE-3483): Replace this with MongoMethodOverrideError callback(new MongoRuntimeError('`auth` method must be overridden by subclass')); } diff --git a/src/cmap/auth/scram.ts b/src/cmap/auth/scram.ts index 896a9aa7d16..3ce2c7ab91a 100644 --- a/src/cmap/auth/scram.ts +++ b/src/cmap/auth/scram.ts @@ -42,13 +42,15 @@ class ScramSHA extends AuthProvider { const nonce = await this.randomBytesAsync(24); // store the nonce for later use - Object.assign(authContext, { nonce }); + authContext.nonce = nonce; - const request = Object.assign({}, handshakeDoc, { - speculativeAuthenticate: Object.assign(makeFirstMessage(cryptoMethod, credentials, nonce), { + const request = { + ...handshakeDoc, + speculativeAuthenticate: { + ...makeFirstMessage(cryptoMethod, credentials, nonce), db: credentials.source - }) - }); + } + }; return request; } diff --git a/src/cmap/connect.ts b/src/cmap/connect.ts index c56447ffb14..887c9f29c54 100644 --- a/src/cmap/connect.ts +++ b/src/cmap/connect.ts @@ -15,7 +15,6 @@ import { MongoNetworkError, MongoNetworkTimeoutError, MongoRuntimeError, - MongoServerError, needsRetryableWriteLabel } from '../error'; import { Callback, ClientMetadata, HostAddress, makeClientMetadata, ns } from '../utils'; @@ -28,7 +27,7 @@ import { Plain } from './auth/plain'; import { AuthMechanism } from './auth/providers'; import { ScramSHA1, ScramSHA256 } from './auth/scram'; import { X509 } from './auth/x509'; -import { Connection, ConnectionOptions, CryptoConnection } from './connection'; +import { CommandOptions, Connection, ConnectionOptions, CryptoConnection } from './connection'; import { MAX_SUPPORTED_SERVER_VERSION, MAX_SUPPORTED_WIRE_VERSION, @@ -121,7 +120,8 @@ function performInitialHandshake( conn.authContext = authContext; prepareHandshakeDocument(authContext).then( handshakeDoc => { - const handshakeOptions: Document = Object.assign({}, options); + // @ts-expect-error: TODO(NODE-XXXX): The options need to be filtered properly, Connection options differ from Command options + const handshakeOptions: CommandOptions = { ...options }; if (typeof options.connectTimeoutMS === 'number') { // The handshake technically is a monitoring check, so its socket timeout should be connectTimeoutMS handshakeOptions.socketTimeoutMS = options.connectTimeoutMS; @@ -134,11 +134,6 @@ function performInitialHandshake( return; } - if (response?.ok === 0) { - callback(new MongoServerError(response)); - return; - } - if (!('isWritablePrimary' in response)) { // Provide hello-style response document. response.isWritablePrimary = response[LEGACY_HELLO_COMMAND]; From 2e9687f88f49b3861212ae7a159e3dc5b35a2604 Mon Sep 17 00:00:00 2001 From: Neal Beeken Date: Wed, 22 Mar 2023 19:18:20 -0400 Subject: [PATCH 05/11] chore: mv test files to ts --- package.json | 2 +- test/manual/{kerberos.test.js => kerberos.test.ts} | 11 ++++++----- .../cmap/auth/{gssapi.test.js => gssapi.test.ts} | 14 +++++--------- 3 files changed, 12 insertions(+), 15 deletions(-) rename test/manual/{kerberos.test.js => kerberos.test.ts} (98%) rename test/unit/cmap/auth/{gssapi.test.js => gssapi.test.ts} (97%) diff --git a/package.json b/package.json index 3328fba5550..a564ad07a72 100644 --- a/package.json +++ b/package.json @@ -129,7 +129,7 @@ "check:aws": "mocha --config test/mocha_mongodb.json test/integration/auth/mongodb_aws.test.ts", "check:oidc": "mocha --config test/manual/mocharc.json test/manual/mongodb_oidc.prose.test.ts", "check:ocsp": "mocha --config test/manual/mocharc.json test/manual/ocsp_support.test.js", - "check:kerberos": "mocha --config test/manual/mocharc.json test/manual/kerberos.test.js", + "check:kerberos": "mocha --config test/manual/mocharc.json test/manual/kerberos.test.ts", "check:tls": "mocha --config test/manual/mocharc.json test/manual/tls_support.test.js", "check:ldap": "mocha --config test/manual/mocharc.json test/manual/ldap.test.js", "check:socks5": "mocha --config test/manual/mocharc.json test/manual/socks5.test.ts", diff --git a/test/manual/kerberos.test.js b/test/manual/kerberos.test.ts similarity index 98% rename from test/manual/kerberos.test.js rename to test/manual/kerberos.test.ts index 20f7f6a8778..bee06172324 100644 --- a/test/manual/kerberos.test.js +++ b/test/manual/kerberos.test.ts @@ -1,10 +1,11 @@ -'use strict'; -const { MongoClient } = require('../mongodb'); -const chai = require('chai'); -const sinon = require('sinon'); -const dns = require('dns'); +import * as chai from 'chai'; +import { promises as dns } from 'dns'; +import * as sinon from 'sinon'; + +import { MongoClient } from '../mongodb'; const expect = chai.expect; +// eslint-disable-next-line @typescript-eslint/no-var-requires chai.use(require('sinon-chai')); function verifyKerberosAuthentication(client, done) { diff --git a/test/unit/cmap/auth/gssapi.test.js b/test/unit/cmap/auth/gssapi.test.ts similarity index 97% rename from test/unit/cmap/auth/gssapi.test.js rename to test/unit/cmap/auth/gssapi.test.ts index f879ad83149..3288afaa2fd 100644 --- a/test/unit/cmap/auth/gssapi.test.js +++ b/test/unit/cmap/auth/gssapi.test.ts @@ -1,16 +1,12 @@ -const chai = require('chai'); -const dns = require('dns'); -const sinon = require('sinon'); -const sinonChai = require('sinon-chai'); +import { expect } from 'chai'; +import { promises as dns } from 'dns'; +import * as sinon from 'sinon'; -const { +import { GSSAPICanonicalizationValue, performGSSAPICanonicalizeHostName, resolveCname -} = require('../../../mongodb'); - -const expect = chai.expect; -chai.use(sinonChai); +} from '../../../mongodb'; describe('GSSAPI', () => { const sandbox = sinon.createSandbox(); From 1bbdf058e538fe84594de62d8d365dab502d4d56 Mon Sep 17 00:00:00 2001 From: Neal Beeken Date: Wed, 22 Mar 2023 19:19:36 -0400 Subject: [PATCH 06/11] refactor: make auth async --- src/cmap/auth/auth_provider.ts | 23 +- src/cmap/auth/gssapi.ts | 249 +++++++++------------ src/cmap/auth/mongocr.ts | 66 +++--- src/cmap/auth/mongodb_aws.ts | 348 ++++++++++++++--------------- src/cmap/auth/mongodb_oidc.ts | 25 +-- src/cmap/auth/plain.ts | 12 +- src/cmap/auth/scram.ts | 107 +++------ src/cmap/auth/x509.ts | 21 +- src/cmap/connect.ts | 165 ++++++-------- src/cmap/connection_pool.ts | 22 +- test/manual/kerberos.test.ts | 268 +++++++++------------- test/unit/cmap/auth/gssapi.test.ts | 298 ++++++++++-------------- 12 files changed, 663 insertions(+), 941 deletions(-) diff --git a/src/cmap/auth/auth_provider.ts b/src/cmap/auth/auth_provider.ts index 4fa7fb46b0f..1cd5e67b122 100644 --- a/src/cmap/auth/auth_provider.ts +++ b/src/cmap/auth/auth_provider.ts @@ -1,6 +1,6 @@ import type { Document } from '../../bson'; import { MongoRuntimeError } from '../../error'; -import type { Callback, ClientMetadataOptions } from '../../utils'; +import type { ClientMetadataOptions } from '../../utils'; import type { HandshakeDocument } from '../connect'; import type { Connection, ConnectionOptions } from '../connection'; import type { MongoCredentials } from './mongo_credentials'; @@ -38,7 +38,7 @@ export class AuthContext { } } -export class AuthProvider { +export abstract class AuthProvider { /** * Prepare the handshake document before the initial handshake. * @@ -58,26 +58,23 @@ export class AuthProvider { * @param context - A shared context for authentication flow * @param callback - The callback to return the result from the authentication */ - auth(_context: AuthContext, callback: Callback): void { - // TODO(NODE-3483): Replace this with MongoMethodOverrideError - callback(new MongoRuntimeError('`auth` method must be overridden by subclass')); - } + abstract auth(_context: AuthContext): Promise; /** * Reauthenticate. * @param context - The shared auth context. * @param callback - The callback. */ - reauth(context: AuthContext, callback: Callback): void { + async reauth(context: AuthContext): Promise { // If we are already reauthenticating this is a no-op. if (context.reauthenticating) { - return callback(new MongoRuntimeError('Reauthentication already in progress.')); + throw new MongoRuntimeError('Reauthentication already in progress.'); } - context.reauthenticating = true; - const cb: Callback = (error, result) => { + try { + context.reauthenticating = true; + await this.auth(context); + } finally { context.reauthenticating = false; - callback(error, result); - }; - this.auth(context, cb); + } } } diff --git a/src/cmap/auth/gssapi.ts b/src/cmap/auth/gssapi.ts index 61eaf60e53f..4a4ae3e1bbb 100644 --- a/src/cmap/auth/gssapi.ts +++ b/src/cmap/auth/gssapi.ts @@ -1,15 +1,9 @@ import * as dns from 'dns'; -import type { Document } from '../../bson'; import { Kerberos, KerberosClient } from '../../deps'; -import { - MongoError, - MongoInvalidArgumentError, - MongoMissingCredentialsError, - MongoMissingDependencyError, - MongoRuntimeError -} from '../../error'; -import { Callback, ns } from '../../utils'; +import { MongoInvalidArgumentError, MongoMissingCredentialsError } from '../../error'; +import { ns } from '../../utils'; +import type { Connection } from '../connection'; import { AuthContext, AuthProvider } from './auth_provider'; /** @public */ @@ -32,70 +26,59 @@ type MechanismProperties = { SERVICE_REALM?: string; }; +async function externalCommand( + connection: Connection, + command: ReturnType | ReturnType +): Promise<{ payload: string; conversationId: any }> { + return connection.commandAsync(ns('$external.$cmd'), command, undefined) as Promise<{ + payload: string; + conversationId: any; + }>; +} + export class GSSAPI extends AuthProvider { - override auth(authContext: AuthContext, callback: Callback): void { + override async auth(authContext: AuthContext): Promise { const { connection, credentials } = authContext; - if (credentials == null) - return callback( - new MongoMissingCredentialsError('Credentials required for GSSAPI authentication') - ); - const { username } = credentials; - function externalCommand( - command: Document, - cb: Callback<{ payload: string; conversationId: any }> - ) { - return connection.command(ns('$external.$cmd'), command, undefined, cb); + if (credentials == null) { + throw new MongoMissingCredentialsError('Credentials required for GSSAPI authentication'); } - makeKerberosClient(authContext, (err, client) => { - if (err) return callback(err); - if (client == null) return callback(new MongoMissingDependencyError('GSSAPI client missing')); - client.step('', (err, payload) => { - if (err) return callback(err); - - externalCommand(saslStart(payload), (err, result) => { - if (err) return callback(err); - if (result == null) return callback(); - negotiate(client, 10, result.payload, (err, payload) => { - if (err) return callback(err); - - externalCommand(saslContinue(payload, result.conversationId), (err, result) => { - if (err) return callback(err); - if (result == null) return callback(); - finalize(client, username, result.payload, (err, payload) => { - if (err) return callback(err); - - externalCommand( - { - saslContinue: 1, - conversationId: result.conversationId, - payload - }, - (err, result) => { - if (err) return callback(err); - - callback(undefined, result); - } - ); - }); - }); - }); - }); - }); + + const { username } = credentials; + + const client = await makeKerberosClient(authContext); + + const payload = await client.step(''); + + const saslStartResponse = await externalCommand(connection, saslStart(payload)); + + const negotiatedPayload = await negotiate(client, 10, saslStartResponse.payload); + + const saslContinueResponse = await externalCommand( + connection, + saslContinue(negotiatedPayload, saslStartResponse.conversationId) + ); + + const finalizePayload = await finalize(client, username, saslContinueResponse.payload); + + await externalCommand(connection, { + saslContinue: 1, + conversationId: saslContinueResponse.conversationId, + payload: finalizePayload }); } } -function makeKerberosClient(authContext: AuthContext, callback: Callback): void { +async function makeKerberosClient(authContext: AuthContext): Promise { const { hostAddress } = authContext.options; const { credentials } = authContext; if (!hostAddress || typeof hostAddress.host !== 'string' || !credentials) { - return callback( - new MongoInvalidArgumentError('Connection must have host and port and credentials defined.') + throw new MongoInvalidArgumentError( + 'Connection must have host and port and credentials defined.' ); } if ('kModuleError' in Kerberos) { - return callback(Kerberos['kModuleError']); + throw Kerberos['kModuleError']; } const { initializeClient } = Kerberos; @@ -104,95 +87,71 @@ function makeKerberosClient(authContext: AuthContext, callback: Callback { - if (err) return callback(err); - - const initOptions = {}; - if (password != null) { - Object.assign(initOptions, { user: username, password: password }); - } - - const spnHost = mechanismProperties.SERVICE_HOST ?? host; - let spn = `${serviceName}${process.platform === 'win32' ? '/' : '@'}${spnHost}`; - if ('SERVICE_REALM' in mechanismProperties) { - spn = `${spn}@${mechanismProperties.SERVICE_REALM}`; - } - - initializeClient(spn, initOptions, (err: string, client: KerberosClient): void => { - // TODO(NODE-3483) - if (err) return callback(new MongoRuntimeError(err)); - callback(undefined, client); - }); - } - ); + const host = await performGSSAPICanonicalizeHostName(hostAddress.host, mechanismProperties); + + const initOptions = {}; + if (password != null) { + // TODO(NODE-XXXX): These do not match the typescript options in initializeClient + Object.assign(initOptions, { user: username, password: password }); + } + + const spnHost = mechanismProperties.SERVICE_HOST ?? host; + let spn = `${serviceName}${process.platform === 'win32' ? '/' : '@'}${spnHost}`; + if ('SERVICE_REALM' in mechanismProperties) { + spn = `${spn}@${mechanismProperties.SERVICE_REALM}`; + } + + return initializeClient(spn, initOptions); } -function saslStart(payload?: string): Document { +function saslStart(payload: string) { return { saslStart: 1, mechanism: 'GSSAPI', payload, autoAuthorize: 1 - }; + } as const; } -function saslContinue(payload?: string, conversationId?: number): Document { +function saslContinue(payload: string, conversationId: number) { return { saslContinue: 1, conversationId, payload - }; + } as const; } -function negotiate( +async function negotiate( client: KerberosClient, retries: number, - payload: string, - callback: Callback -): void { - client.step(payload, (err, response) => { - // Retries exhausted, raise error - if (err && retries === 0) return callback(err); - + payload: string +): Promise { + try { + const response = await client.step(payload); + return response || ''; + } catch (error) { + if (retries === 0) { + // Retries exhausted, raise error + throw error; + } // Adjust number of retries and call step again - if (err) return negotiate(client, retries - 1, payload, callback); - - // Return the payload - callback(undefined, response || ''); - }); + return negotiate(client, retries - 1, payload); + } } -function finalize( - client: KerberosClient, - user: string, - payload: string, - callback: Callback -): void { +async function finalize(client: KerberosClient, user: string, payload: string): Promise { // GSS Client Unwrap - client.unwrap(payload, (err, response) => { - if (err) return callback(err); - - // Wrap the response - client.wrap(response || '', { user }, (err, wrapped) => { - if (err) return callback(err); - - // Return the payload - callback(undefined, wrapped); - }); - }); + const response = await client.unwrap(payload); + return client.wrap(response || '', { user }); } -export function performGSSAPICanonicalizeHostName( +export async function performGSSAPICanonicalizeHostName( host: string, - mechanismProperties: MechanismProperties, - callback: Callback -): void { + mechanismProperties: MechanismProperties +): Promise { const mode = mechanismProperties.CANONICALIZE_HOST_NAME; if (!mode || mode === GSSAPICanonicalizationValue.none) { - return callback(undefined, host); + return host; } // If forward and reverse or true @@ -201,39 +160,33 @@ export function performGSSAPICanonicalizeHostName( mode === GSSAPICanonicalizationValue.forwardAndReverse ) { // Perform the lookup of the ip address. - dns.lookup(host, (error, address) => { - // No ip found, return the error. - if (error) return callback(error); + const { address } = await dns.promises.lookup(host); + try { // Perform a reverse ptr lookup on the ip address. - dns.resolvePtr(address, (err, results) => { - // This can error as ptr records may not exist for all ips. In this case - // fallback to a cname lookup as dns.lookup() does not return the - // cname. - if (err) { - return resolveCname(host, callback); - } - // If the ptr did not error but had no results, return the host. - callback(undefined, results.length > 0 ? results[0] : host); - }); - }); + const results = await dns.promises.resolvePtr(address); + // If the ptr did not error but had no results, return the host. + return results.length > 0 ? results[0] : host; + } catch (error) { + // This can error as ptr records may not exist for all ips. In this case + // fallback to a cname lookup as dns.lookup() does not return the + // cname. + return resolveCname(host); + } } else { // The case for forward is just to resolve the cname as dns.lookup() // will not return it. - resolveCname(host, callback); + return resolveCname(host); } } -export function resolveCname(host: string, callback: Callback): void { +export async function resolveCname(host: string): Promise { // Attempt to resolve the host name - dns.resolveCname(host, (err, r) => { - if (err) return callback(undefined, host); - - // Get the first resolve host id - if (r.length > 0) { - return callback(undefined, r[0]); - } - - callback(undefined, host); - }); + try { + const results = await dns.promises.resolveCname(host); + // Get the first resolved host id + return results.length > 0 ? results[0] : host; + } catch { + return host; + } } diff --git a/src/cmap/auth/mongocr.ts b/src/cmap/auth/mongocr.ts index 232378f0d49..579069e9b61 100644 --- a/src/cmap/auth/mongocr.ts +++ b/src/cmap/auth/mongocr.ts @@ -1,47 +1,41 @@ import * as crypto from 'crypto'; import { MongoMissingCredentialsError } from '../../error'; -import { Callback, ns } from '../../utils'; +import { ns } from '../../utils'; import { AuthContext, AuthProvider } from './auth_provider'; export class MongoCR extends AuthProvider { - override auth(authContext: AuthContext, callback: Callback): void { + override async auth(authContext: AuthContext): Promise { const { connection, credentials } = authContext; if (!credentials) { - return callback(new MongoMissingCredentialsError('AuthContext must provide credentials.')); + throw new MongoMissingCredentialsError('AuthContext must provide credentials.'); } - const username = credentials.username; - const password = credentials.password; - const source = credentials.source; - connection.command(ns(`${source}.$cmd`), { getnonce: 1 }, undefined, (err, r) => { - let nonce = null; - let key = null; - - // Get nonce - if (err == null) { - nonce = r.nonce; - - // Use node md5 generator - let md5 = crypto.createHash('md5'); - - // Generate keys used for authentication - md5.update(`${username}:mongo:${password}`, 'utf8'); - const hash_password = md5.digest('hex'); - - // Final key - md5 = crypto.createHash('md5'); - md5.update(nonce + username + hash_password, 'utf8'); - key = md5.digest('hex'); - } - - const authenticateCommand = { - authenticate: 1, - user: username, - nonce, - key - }; - - connection.command(ns(`${source}.$cmd`), authenticateCommand, undefined, callback); - }); + + const { username, password, source } = credentials; + + const r = await connection.commandAsync(ns(`${source}.$cmd`), { getnonce: 1 }, undefined); + + // Get nonce + const nonce = r.nonce; + + const hashPassword = crypto + .createHash('md5') + .update(`${username}:mongo:${password}`, 'utf8') + .digest('hex'); + + // Final key + const key = crypto + .createHash('md5') + .update(nonce + username + hashPassword, 'utf8') + .digest('hex'); + + const authenticateCommand = { + authenticate: 1, + user: username, + nonce, + key + }; + + await connection.commandAsync(ns(`${source}.$cmd`), authenticateCommand, undefined); } } diff --git a/src/cmap/auth/mongodb_aws.ts b/src/cmap/auth/mongodb_aws.ts index 5d9007dcb54..031a73df84e 100644 --- a/src/cmap/auth/mongodb_aws.ts +++ b/src/cmap/auth/mongodb_aws.ts @@ -1,6 +1,7 @@ import * as crypto from 'crypto'; import * as http from 'http'; import * as url from 'url'; +import { promisify } from 'util'; import type { Binary, BSONSerializeOptions } from '../../bson'; import * as BSON from '../../bson'; @@ -11,7 +12,7 @@ import { MongoMissingCredentialsError, MongoRuntimeError } from '../../error'; -import { ByteUtils, Callback, maxWireVersion, ns } from '../../utils'; +import { ByteUtils, maxWireVersion, ns } from '../../utils'; import { AuthContext, AuthProvider } from './auth_provider'; import { MongoCredentials } from './mongo_credentials'; import { AuthMechanism } from './providers'; @@ -35,35 +36,33 @@ interface AWSSaslContinuePayload { } export class MongoDBAWS extends AuthProvider { - override auth(authContext: AuthContext, callback: Callback): void { + randomBytesAsync: (size: number) => Promise; + + constructor() { + super(); + this.randomBytesAsync = promisify(crypto.randomBytes); + } + + override async auth(authContext: AuthContext): Promise { const { connection, credentials } = authContext; if (!credentials) { - return callback(new MongoMissingCredentialsError('AuthContext must provide credentials.')); + throw new MongoMissingCredentialsError('AuthContext must provide credentials.'); } if ('kModuleError' in aws4) { - return callback(aws4['kModuleError']); + throw aws4['kModuleError']; } const { sign } = aws4; if (maxWireVersion(connection) < 9) { - callback( - new MongoCompatibilityError( - 'MONGODB-AWS authentication requires MongoDB version 4.4 or later' - ) + throw new MongoCompatibilityError( + 'MONGODB-AWS authentication requires MongoDB version 4.4 or later' ); - return; } if (!credentials.username) { - makeTempCredentials(credentials, (err, tempCredentials) => { - if (err || !tempCredentials) return callback(err); - - authContext.credentials = tempCredentials; - this.auth(authContext, callback); - }); - - return; + authContext.credentials = await makeTempCredentials(credentials); + return this.auth(authContext); } const accessKeyId = credentials.username; @@ -79,87 +78,75 @@ export class MongoDBAWS extends AuthProvider { : undefined; const db = credentials.source; - crypto.randomBytes(32, (err, nonce) => { - if (err) { - callback(err); - return; - } + const nonce = await this.randomBytesAsync(32); + + const saslStart = { + saslStart: 1, + mechanism: 'MONGODB-AWS', + payload: BSON.serialize({ r: nonce, p: ASCII_N }, bsonOptions) + }; + + const saslStartResponse = await connection.commandAsync(ns(`${db}.$cmd`), saslStart, undefined); + + const serverResponse = BSON.deserialize(saslStartResponse.payload.buffer, bsonOptions) as { + s: Binary; + h: string; + }; + const host = serverResponse.h; + const serverNonce = serverResponse.s.buffer; + if (serverNonce.length !== 64) { + // TODO(NODE-3483) + throw new MongoRuntimeError(`Invalid server nonce length ${serverNonce.length}, expected 64`); + } - const saslStart = { - saslStart: 1, - mechanism: 'MONGODB-AWS', - payload: BSON.serialize({ r: nonce, p: ASCII_N }, bsonOptions) - }; - - connection.command(ns(`${db}.$cmd`), saslStart, undefined, (err, res) => { - if (err) return callback(err); - - const serverResponse = BSON.deserialize(res.payload.buffer, bsonOptions) as { - s: Binary; - h: string; - }; - const host = serverResponse.h; - const serverNonce = serverResponse.s.buffer; - if (serverNonce.length !== 64) { - callback( - // TODO(NODE-3483) - new MongoRuntimeError(`Invalid server nonce length ${serverNonce.length}, expected 64`) - ); + if (!ByteUtils.equals(serverNonce.subarray(0, nonce.byteLength), nonce)) { + // throw because the serverNonce's leading 32 bytes must equal the client nonce's 32 bytes + // https://github.com/mongodb/specifications/blob/875446db44aade414011731840831f38a6c668df/source/auth/auth.rst#id11 - return; - } + // TODO(NODE-3483) + throw new MongoRuntimeError('Server nonce does not begin with client nonce'); + } - if (!ByteUtils.equals(serverNonce.subarray(0, nonce.byteLength), nonce)) { - // throw because the serverNonce's leading 32 bytes must equal the client nonce's 32 bytes - // https://github.com/mongodb/specifications/blob/875446db44aade414011731840831f38a6c668df/source/auth/auth.rst#id11 + if (host.length < 1 || host.length > 255 || host.indexOf('..') !== -1) { + // TODO(NODE-3483) + throw new MongoRuntimeError(`Server returned an invalid host: "${host}"`); + } - // TODO(NODE-3483) - callback(new MongoRuntimeError('Server nonce does not begin with client nonce')); - return; - } + const body = 'Action=GetCallerIdentity&Version=2011-06-15'; + const options = sign( + { + method: 'POST', + host, + region: deriveRegion(serverResponse.h), + service: 'sts', + headers: { + 'Content-Type': 'application/x-www-form-urlencoded', + 'Content-Length': body.length, + 'X-MongoDB-Server-Nonce': ByteUtils.toBase64(serverNonce), + 'X-MongoDB-GS2-CB-Flag': 'n' + }, + path: '/', + body + }, + awsCredentials + ); - if (host.length < 1 || host.length > 255 || host.indexOf('..') !== -1) { - // TODO(NODE-3483) - callback(new MongoRuntimeError(`Server returned an invalid host: "${host}"`)); - return; - } + const payload: AWSSaslContinuePayload = { + a: options.headers.Authorization, + d: options.headers['X-Amz-Date'] + }; - const body = 'Action=GetCallerIdentity&Version=2011-06-15'; - const options = sign( - { - method: 'POST', - host, - region: deriveRegion(serverResponse.h), - service: 'sts', - headers: { - 'Content-Type': 'application/x-www-form-urlencoded', - 'Content-Length': body.length, - 'X-MongoDB-Server-Nonce': ByteUtils.toBase64(serverNonce), - 'X-MongoDB-GS2-CB-Flag': 'n' - }, - path: '/', - body - }, - awsCredentials - ); - - const payload: AWSSaslContinuePayload = { - a: options.headers.Authorization, - d: options.headers['X-Amz-Date'] - }; - if (sessionToken) { - payload.t = sessionToken; - } + if (sessionToken) { + payload.t = sessionToken; + } - const saslContinue = { - saslContinue: 1, - conversationId: 1, - payload: BSON.serialize(payload, bsonOptions) - }; + const saslContinue = { + saslContinue: 1, + conversationId: 1, + payload: BSON.serialize(payload, bsonOptions) + }; - connection.command(ns(`${db}.$cmd`), saslContinue, undefined, callback); - }); - }); + await connection.commandAsync(ns(`${db}.$cmd`), saslContinue, undefined); } } @@ -179,27 +166,21 @@ export interface AWSCredentials { expiration?: Date; } -function makeTempCredentials(credentials: MongoCredentials, callback: Callback) { - function done(creds: AWSTempCredentials) { +async function makeTempCredentials(credentials: MongoCredentials): Promise { + function makeMongoCredentialsFromAWSTemp(creds: AWSTempCredentials) { if (!creds.AccessKeyId || !creds.SecretAccessKey || !creds.Token) { - callback( - new MongoMissingCredentialsError('Could not obtain temporary MONGODB-AWS credentials') - ); - return; + throw new MongoMissingCredentialsError('Could not obtain temporary MONGODB-AWS credentials'); } - callback( - undefined, - new MongoCredentials({ - username: creds.AccessKeyId, - password: creds.SecretAccessKey, - source: credentials.source, - mechanism: AuthMechanism.MONGODB_AWS, - mechanismProperties: { - AWS_SESSION_TOKEN: creds.Token - } - }) - ); + return new MongoCredentials({ + username: creds.AccessKeyId, + password: creds.SecretAccessKey, + source: credentials.source, + mechanism: AuthMechanism.MONGODB_AWS, + mechanismProperties: { + AWS_SESSION_TOKEN: creds.Token + } + }); } const credentialProvider = getAwsCredentialProvider(); @@ -210,47 +191,32 @@ function makeTempCredentials(credentials: MongoCredentials, callback: Callback { - if (err) return callback(err); - done(res); - } + return makeMongoCredentialsFromAWSTemp( + await request(`${AWS_RELATIVE_URI}${process.env.AWS_CONTAINER_CREDENTIALS_RELATIVE_URI}`) ); - - return; } // Otherwise assume we are on an EC2 instance // get a token - request( - `${AWS_EC2_URI}/latest/api/token`, - { method: 'PUT', json: false, headers: { 'X-aws-ec2-metadata-token-ttl-seconds': 30 } }, - (err, token) => { - if (err) return callback(err); - - // get role name - request( - `${AWS_EC2_URI}/${AWS_EC2_PATH}`, - { json: false, headers: { 'X-aws-ec2-metadata-token': token } }, - (err, roleName) => { - if (err) return callback(err); - - // get temp credentials - request( - `${AWS_EC2_URI}/${AWS_EC2_PATH}/${roleName}`, - { headers: { 'X-aws-ec2-metadata-token': token } }, - (err, creds) => { - if (err) return callback(err); - done(creds); - } - ); - } - ); - } - ); + const token = await request(`${AWS_EC2_URI}/latest/api/token`, { + method: 'PUT', + json: false, + headers: { 'X-aws-ec2-metadata-token-ttl-seconds': 30 } + }); + + // get role name + const roleName = await request(`${AWS_EC2_URI}/${AWS_EC2_PATH}`, { + json: false, + headers: { 'X-aws-ec2-metadata-token': token } + }); + + // get temp credentials + const creds = await request(`${AWS_EC2_URI}/${AWS_EC2_PATH}/${roleName}`, { + headers: { 'X-aws-ec2-metadata-token': token } + }); + + return makeMongoCredentialsFromAWSTemp(creds); } else { /* * Creates a credential provider that will attempt to find credentials from the @@ -264,18 +230,17 @@ function makeTempCredentials(credentials: MongoCredentials, callback: Callback { - done({ - AccessKeyId: creds.accessKeyId, - SecretAccessKey: creds.secretAccessKey, - Token: creds.sessionToken, - Expiration: creds.expiration - }); - }) - .catch((error: Error) => { - callback(new MongoAWSError(error.message)); + try { + const creds = await provider(); + return makeMongoCredentialsFromAWSTemp({ + AccessKeyId: creds.accessKeyId, + SecretAccessKey: creds.secretAccessKey, + Token: creds.sessionToken, + Expiration: creds.expiration }); + } catch (error) { + throw new MongoAWSError(error.message); + } } } @@ -295,42 +260,53 @@ interface RequestOptions { headers?: http.OutgoingHttpHeaders; } -function request(uri: string, _options: RequestOptions | undefined, callback: Callback) { - const options = Object.assign( - { +async function request(uri: string): Promise>; +async function request( + uri: string, + options?: { json?: true } & RequestOptions +): Promise>; +async function request(uri: string, options?: { json: false } & RequestOptions): Promise; +async function request( + uri: string, + options: RequestOptions = {} +): Promise> { + return new Promise>((resolve, reject) => { + const requestOptions = { method: 'GET', timeout: 10000, - json: true - }, - url.parse(uri), - _options - ); - - const req = http.request(options, res => { - res.setEncoding('utf8'); - - let data = ''; - res.on('data', d => (data += d)); - res.on('end', () => { - if (options.json === false) { - callback(undefined, data); - return; - } + json: true, + ...url.parse(uri), + ...options + }; - try { - const parsed = JSON.parse(data); - callback(undefined, parsed); - } catch (err) { - // TODO(NODE-3483) - callback(new MongoRuntimeError(`Invalid JSON response: "${data}"`)); - } + const req = http.request(requestOptions, res => { + res.setEncoding('utf8'); + + let data = ''; + res.on('data', d => { + data += d; + }); + + res.once('end', () => { + if (options.json === false) { + resolve(data); + return; + } + + try { + const parsed = JSON.parse(data); + resolve(parsed); + } catch { + // TODO(NODE-3483) + reject(new MongoRuntimeError(`Invalid JSON response: "${data}"`)); + } + }); }); - }); - req.on('timeout', () => { - req.destroy(new MongoAWSError(`AWS request to ${uri} timed out after ${options.timeout} ms`)); + req.once('timeout', () => + req.destroy(new MongoAWSError(`AWS request to ${uri} timed out after ${options.timeout} ms`)) + ); + req.once('error', error => reject(error)); + req.end(); }); - - req.on('error', err => callback(err)); - req.end(); } diff --git a/src/cmap/auth/mongodb_oidc.ts b/src/cmap/auth/mongodb_oidc.ts index 4911e05e7ac..41e6a3b02cc 100644 --- a/src/cmap/auth/mongodb_oidc.ts +++ b/src/cmap/auth/mongodb_oidc.ts @@ -3,7 +3,6 @@ import { MongoMissingCredentialsError, MongoRuntimeError } from '../../error'; -import type { Callback } from '../../utils'; import type { HandshakeDocument } from '../connect'; import { type AuthContext, AuthProvider } from './auth_provider'; import type { MongoCredentials } from './mongo_credentials'; @@ -77,34 +76,20 @@ export class MongoDBOIDC extends AuthProvider { /** * Authenticate using OIDC */ - override auth(authContext: AuthContext, callback: Callback): void { + override async auth(authContext: AuthContext): Promise { const { connection, credentials, response, reauthenticating } = authContext; if (response?.speculativeAuthenticate) { - return callback(); + return; } if (!credentials) { - return callback(new MongoMissingCredentialsError('AuthContext must provide credentials.')); + throw new MongoMissingCredentialsError('AuthContext must provide credentials.'); } - let workflow; - - try { - workflow = getWorkflow(credentials); - } catch (error) { - callback(error); - return; - } + const workflow = getWorkflow(credentials); - workflow.execute(connection, credentials, reauthenticating).then( - result => { - return callback(undefined, result); - }, - error => { - callback(error); - } - ); + await workflow.execute(connection, credentials, reauthenticating); } /** diff --git a/src/cmap/auth/plain.ts b/src/cmap/auth/plain.ts index 94b19a52b0c..b1fd060ef2f 100644 --- a/src/cmap/auth/plain.ts +++ b/src/cmap/auth/plain.ts @@ -1,16 +1,16 @@ import { Binary } from '../../bson'; import { MongoMissingCredentialsError } from '../../error'; -import { Callback, ns } from '../../utils'; +import { ns } from '../../utils'; import { AuthContext, AuthProvider } from './auth_provider'; export class Plain extends AuthProvider { - override auth(authContext: AuthContext, callback: Callback): void { + override async auth(authContext: AuthContext): Promise { const { connection, credentials } = authContext; if (!credentials) { - return callback(new MongoMissingCredentialsError('AuthContext must provide credentials.')); + throw new MongoMissingCredentialsError('AuthContext must provide credentials.'); } - const username = credentials.username; - const password = credentials.password; + + const { username, password } = credentials; const payload = new Binary(Buffer.from(`\x00${username}\x00${password}`)); const command = { @@ -20,6 +20,6 @@ export class Plain extends AuthProvider { autoAuthorize: 1 }; - connection.command(ns('$external.$cmd'), command, undefined, callback); + await connection.commandAsync(ns('$external.$cmd'), command, undefined); } } diff --git a/src/cmap/auth/scram.ts b/src/cmap/auth/scram.ts index 3ce2c7ab91a..4321e5548e5 100644 --- a/src/cmap/auth/scram.ts +++ b/src/cmap/auth/scram.ts @@ -4,13 +4,11 @@ import { promisify } from 'util'; import { Binary, Document } from '../../bson'; import { saslprep } from '../../deps'; import { - AnyError, MongoInvalidArgumentError, MongoMissingCredentialsError, - MongoRuntimeError, - MongoServerError + MongoRuntimeError } from '../../error'; -import { Callback, emitWarning, ns } from '../../utils'; +import { emitWarning, ns } from '../../utils'; import type { HandshakeDocument } from '../connect'; import { AuthContext, AuthProvider } from './auth_provider'; import type { MongoCredentials } from './mongo_credentials'; @@ -55,20 +53,16 @@ class ScramSHA extends AuthProvider { return request; } - override auth(authContext: AuthContext, callback: Callback) { + override async auth(authContext: AuthContext) { const { reauthenticating, response } = authContext; if (response?.speculativeAuthenticate && !reauthenticating) { - continueScramConversation( + return continueScramConversation( this.cryptoMethod, response.speculativeAuthenticate, - authContext, - callback + authContext ); - - return; } - - executeScram(this.cryptoMethod, authContext, callback); + return executeScram(this.cryptoMethod, authContext); } } @@ -109,43 +103,34 @@ function makeFirstMessage( }; } -function executeScram(cryptoMethod: CryptoMethod, authContext: AuthContext, callback: Callback) { +async function executeScram(cryptoMethod: CryptoMethod, authContext: AuthContext): Promise { const { connection, credentials } = authContext; if (!credentials) { - return callback(new MongoMissingCredentialsError('AuthContext must provide credentials.')); + throw new MongoMissingCredentialsError('AuthContext must provide credentials.'); } if (!authContext.nonce) { - return callback( - new MongoInvalidArgumentError('AuthContext must contain a valid nonce property') - ); + throw new MongoInvalidArgumentError('AuthContext must contain a valid nonce property'); } const nonce = authContext.nonce; const db = credentials.source; const saslStartCmd = makeFirstMessage(cryptoMethod, credentials, nonce); - connection.command(ns(`${db}.$cmd`), saslStartCmd, undefined, (_err, result) => { - const err = resolveError(_err, result); - if (err) { - return callback(err); - } - - continueScramConversation(cryptoMethod, result, authContext, callback); - }); + const response = await connection.commandAsync(ns(`${db}.$cmd`), saslStartCmd, undefined); + await continueScramConversation(cryptoMethod, response, authContext); } -function continueScramConversation( +async function continueScramConversation( cryptoMethod: CryptoMethod, response: Document, - authContext: AuthContext, - callback: Callback -) { + authContext: AuthContext +): Promise { const connection = authContext.connection; const credentials = authContext.credentials; if (!credentials) { - return callback(new MongoMissingCredentialsError('AuthContext must provide credentials.')); + throw new MongoMissingCredentialsError('AuthContext must provide credentials.'); } if (!authContext.nonce) { - return callback(new MongoInvalidArgumentError('Unable to continue SCRAM without valid nonce')); + throw new MongoInvalidArgumentError('Unable to continue SCRAM without valid nonce'); } const nonce = authContext.nonce; @@ -157,11 +142,7 @@ function continueScramConversation( if (cryptoMethod === 'sha256') { processedPassword = 'kModuleError' in saslprep ? password : saslprep(password); } else { - try { - processedPassword = passwordDigest(username, password); - } catch (e) { - return callback(e); - } + processedPassword = passwordDigest(username, password); } const payload = Buffer.isBuffer(response.payload) @@ -171,20 +152,15 @@ function continueScramConversation( const iterations = parseInt(dict.i, 10); if (iterations && iterations < 4096) { - callback( - // TODO(NODE-3483) - new MongoRuntimeError(`Server returned an invalid iteration count ${iterations}`), - false - ); - return; + // TODO(NODE-3483) + throw new MongoRuntimeError(`Server returned an invalid iteration count ${iterations}`); } const salt = dict.s; const rnonce = dict.r; if (rnonce.startsWith('nonce')) { // TODO(NODE-3483) - callback(new MongoRuntimeError(`Server returned an invalid nonce: ${rnonce}`), false); - return; + throw new MongoRuntimeError(`Server returned an invalid nonce: ${rnonce}`); } // Set up start of proof @@ -214,30 +190,25 @@ function continueScramConversation( payload: new Binary(Buffer.from(clientFinal)) }; - connection.command(ns(`${db}.$cmd`), saslContinueCmd, undefined, (_err, r) => { - const err = resolveError(_err, r); - if (err) { - return callback(err); - } + const r = await connection.commandAsync(ns(`${db}.$cmd`), saslContinueCmd, undefined); + const parsedResponse = parsePayload(r.payload.value()); - const parsedResponse = parsePayload(r.payload.value()); - if (!compareDigest(Buffer.from(parsedResponse.v, 'base64'), serverSignature)) { - callback(new MongoRuntimeError('Server returned an invalid signature')); - return; - } + if (!compareDigest(Buffer.from(parsedResponse.v, 'base64'), serverSignature)) { + throw new MongoRuntimeError('Server returned an invalid signature'); + } - if (!r || r.done !== false) { - return callback(err, r); - } + if (r.done !== false) { + // If the server sends r.done === true we can save one RTT + return; + } - const retrySaslContinueCmd = { - saslContinue: 1, - conversationId: r.conversationId, - payload: Buffer.alloc(0) - }; + const retrySaslContinueCmd = { + saslContinue: 1, + conversationId: r.conversationId, + payload: Buffer.alloc(0) + }; - connection.command(ns(`${db}.$cmd`), retrySaslContinueCmd, undefined, callback); - }); + await connection.commandAsync(ns(`${db}.$cmd`), retrySaslContinueCmd, undefined); } function parsePayload(payload: string) { @@ -366,14 +337,6 @@ function compareDigest(lhs: Buffer, rhs: Uint8Array) { return result === 0; } -function resolveError(err?: AnyError, result?: Document) { - if (err) return err; - if (result) { - if (result.$err || result.errmsg) return new MongoServerError(result); - } - return; -} - export class ScramSHA1 extends ScramSHA { constructor() { super('sha1'); diff --git a/src/cmap/auth/x509.ts b/src/cmap/auth/x509.ts index 2f55e27ad90..05f4b034f8d 100644 --- a/src/cmap/auth/x509.ts +++ b/src/cmap/auth/x509.ts @@ -1,6 +1,6 @@ import type { Document } from '../../bson'; import { MongoMissingCredentialsError } from '../../error'; -import { Callback, ns } from '../../utils'; +import { ns } from '../../utils'; import type { HandshakeDocument } from '../connect'; import { AuthContext, AuthProvider } from './auth_provider'; import type { MongoCredentials } from './mongo_credentials'; @@ -14,30 +14,25 @@ export class X509 extends AuthProvider { if (!credentials) { throw new MongoMissingCredentialsError('AuthContext must provide credentials.'); } - Object.assign(handshakeDoc, { - speculativeAuthenticate: x509AuthenticateCommand(credentials) - }); - - return handshakeDoc; + return { ...handshakeDoc, speculativeAuthenticate: x509AuthenticateCommand(credentials) }; } - override auth(authContext: AuthContext, callback: Callback): void { + override async auth(authContext: AuthContext) { const connection = authContext.connection; const credentials = authContext.credentials; if (!credentials) { - return callback(new MongoMissingCredentialsError('AuthContext must provide credentials.')); + throw new MongoMissingCredentialsError('AuthContext must provide credentials.'); } const response = authContext.response; - if (response && response.speculativeAuthenticate) { - return callback(); + if (response?.speculativeAuthenticate) { + return; } - connection.command( + await connection.commandAsync( ns('$external.$cmd'), x509AuthenticateCommand(credentials), - undefined, - callback + undefined ); } } diff --git a/src/cmap/connect.ts b/src/cmap/connect.ts index 887c9f29c54..39849cbd0b7 100644 --- a/src/cmap/connect.ts +++ b/src/cmap/connect.ts @@ -60,7 +60,16 @@ export function connect(options: ConnectionOptions, callback: Callback callback(undefined, connection), + error => { + connection.destroy({ force: false }); + callback(error); + } + ); }); } @@ -91,115 +100,89 @@ function checkSupportedServer(hello: Document, options: ConnectionOptions) { return new MongoCompatibilityError(message); } -function performInitialHandshake( +async function performInitialHandshake( conn: Connection, - options: ConnectionOptions, - _callback: Callback -) { - const callback: Callback = function (err, ret) { - if (err && conn) { - conn.destroy({ force: false }); - } - _callback(err, ret); - }; - + options: ConnectionOptions +): Promise { const credentials = options.credentials; + if (credentials) { if ( !(credentials.mechanism === AuthMechanism.MONGODB_DEFAULT) && !AUTH_PROVIDERS.get(credentials.mechanism) ) { - callback( - new MongoInvalidArgumentError(`AuthMechanism '${credentials.mechanism}' not supported`) - ); - return; + throw new MongoInvalidArgumentError(`AuthMechanism '${credentials.mechanism}' not supported`); } } const authContext = new AuthContext(conn, credentials, options); conn.authContext = authContext; - prepareHandshakeDocument(authContext).then( - handshakeDoc => { - // @ts-expect-error: TODO(NODE-XXXX): The options need to be filtered properly, Connection options differ from Command options - const handshakeOptions: CommandOptions = { ...options }; - if (typeof options.connectTimeoutMS === 'number') { - // The handshake technically is a monitoring check, so its socket timeout should be connectTimeoutMS - handshakeOptions.socketTimeoutMS = options.connectTimeoutMS; - } - const start = new Date().getTime(); - conn.command(ns('admin.$cmd'), handshakeDoc, handshakeOptions, (err, response) => { - if (err) { - callback(err); - return; - } + const handshakeDoc = await prepareHandshakeDocument(authContext); - if (!('isWritablePrimary' in response)) { - // Provide hello-style response document. - response.isWritablePrimary = response[LEGACY_HELLO_COMMAND]; - } + // @ts-expect-error: TODO(NODE-XXXX): The options need to be filtered properly, Connection options differ from Command options + const handshakeOptions: CommandOptions = { ...options }; + if (typeof options.connectTimeoutMS === 'number') { + // The handshake technically is a monitoring check, so its socket timeout should be connectTimeoutMS + handshakeOptions.socketTimeoutMS = options.connectTimeoutMS; + } - if (response.helloOk) { - conn.helloOk = true; - } + const start = new Date().getTime(); + const response = await conn.commandAsync(ns('admin.$cmd'), handshakeDoc, handshakeOptions); - const supportedServerErr = checkSupportedServer(response, options); - if (supportedServerErr) { - callback(supportedServerErr); - return; - } + if (!('isWritablePrimary' in response)) { + // Provide hello-style response document. + response.isWritablePrimary = response[LEGACY_HELLO_COMMAND]; + } - if (options.loadBalanced) { - if (!response.serviceId) { - return callback( - new MongoCompatibilityError( - 'Driver attempted to initialize in load balancing mode, ' + - 'but the server does not support this mode.' - ) - ); - } - } + if (response.helloOk) { + conn.helloOk = true; + } - // NOTE: This is metadata attached to the connection while porting away from - // handshake being done in the `Server` class. Likely, it should be - // relocated, or at very least restructured. - conn.hello = response; - conn.lastHelloMS = new Date().getTime() - start; - - if (!response.arbiterOnly && credentials) { - // store the response on auth context - authContext.response = response; - - const resolvedCredentials = credentials.resolveAuthMechanism(response); - const provider = AUTH_PROVIDERS.get(resolvedCredentials.mechanism); - if (!provider) { - return callback( - new MongoInvalidArgumentError( - `No AuthProvider for ${resolvedCredentials.mechanism} defined.` - ) - ); - } - provider.auth(authContext, err => { - if (err) { - if (err instanceof MongoError) { - err.addErrorLabel(MongoErrorLabel.HandshakeError); - if (needsRetryableWriteLabel(err, response.maxWireVersion)) { - err.addErrorLabel(MongoErrorLabel.RetryableWriteError); - } - } - return callback(err); - } - callback(undefined, conn); - }); - - return; - } + const supportedServerErr = checkSupportedServer(response, options); + if (supportedServerErr) { + throw supportedServerErr; + } - callback(undefined, conn); - }); - }, - error => callback(error) - ); + if (options.loadBalanced) { + if (!response.serviceId) { + throw new MongoCompatibilityError( + 'Driver attempted to initialize in load balancing mode, ' + + 'but the server does not support this mode.' + ); + } + } + + // NOTE: This is metadata attached to the connection while porting away from + // handshake being done in the `Server` class. Likely, it should be + // relocated, or at very least restructured. + conn.hello = response; + conn.lastHelloMS = new Date().getTime() - start; + + if (!response.arbiterOnly && credentials) { + // store the response on auth context + authContext.response = response; + + const resolvedCredentials = credentials.resolveAuthMechanism(response); + const provider = AUTH_PROVIDERS.get(resolvedCredentials.mechanism); + if (!provider) { + throw new MongoInvalidArgumentError( + `No AuthProvider for ${resolvedCredentials.mechanism} defined.` + ); + } + + try { + await provider.auth(authContext); + } catch (error) { + if (error instanceof MongoError) { + error.addErrorLabel(MongoErrorLabel.HandshakeError); + if (needsRetryableWriteLabel(error, response.maxWireVersion)) { + error.addErrorLabel(MongoErrorLabel.RetryableWriteError); + } + } + throw error; + } + } } export interface HandshakeDocument extends Document { diff --git a/src/cmap/connection_pool.ts b/src/cmap/connection_pool.ts index 5365d19d07c..e3d42281355 100644 --- a/src/cmap/connection_pool.ts +++ b/src/cmap/connection_pool.ts @@ -620,17 +620,17 @@ export class ConnectionPool extends TypedEventEmitter { ) ); } - provider.reauth(authContext, error => { - if (error) { - return callback(error); - } - return fn(undefined, connection, (fnErr, fnResult) => { - if (fnErr) { - return callback(fnErr); - } - callback(undefined, fnResult); - }); - }); + provider.reauth(authContext).then( + () => { + fn(undefined, connection, (fnErr, fnResult) => { + if (fnErr) { + return callback(fnErr); + } + callback(undefined, fnResult); + }); + }, + error => callback(error) + ); } /** Clear the min pool size timer */ diff --git a/test/manual/kerberos.test.ts b/test/manual/kerberos.test.ts index bee06172324..63a60158747 100644 --- a/test/manual/kerberos.test.ts +++ b/test/manual/kerberos.test.ts @@ -8,33 +8,23 @@ const expect = chai.expect; // eslint-disable-next-line @typescript-eslint/no-var-requires chai.use(require('sinon-chai')); -function verifyKerberosAuthentication(client, done) { - client - .db('kerberos') - .collection('test') - .find() - .toArray(function (err, docs) { - let expectError; - try { - expect(err).to.not.exist; - expect(docs).to.have.length(1); - expect(docs[0].kerberos).to.be.true; - } catch (e) { - expectError = e; - } - client.close(e => done(expectError || e)); - }); +async function verifyKerberosAuthentication(client) { + const docs = await client.db('kerberos').collection('test').find().toArray(); + expect(docs).to.have.nested.property('[0].kerberos', true); } describe('Kerberos', function () { - const sandbox = sinon.createSandbox(); + let resolvePtrSpy; + let resolveCnameSpy; - beforeEach(function () { - sandbox.spy(dns); + beforeEach(() => { + sinon.spy(dns, 'lookup'); + resolvePtrSpy = sinon.spy(dns, 'resolvePtr'); + resolveCnameSpy = sinon.spy(dns, 'resolveCname'); }); afterEach(function () { - sandbox.restore(); + sinon.restore(); }); if (process.env.MONGODB_URI == null) { @@ -58,12 +48,10 @@ describe('Kerberos', function () { krb5Uri = `${parts[0]}:${process.env.LDAPTEST_PASSWORD}@${parts[1]}`; } - it('should authenticate with original uri', function (done) { + it('should authenticate with original uri', async function () { const client = new MongoClient(krb5Uri); - client.connect(function (err, client) { - expect(err).to.not.exist; - verifyKerberosAuthentication(client, done); - }); + await client.connect(); + await verifyKerberosAuthentication(client); }); context('when passing in CANONICALIZE_HOST_NAME', function () { @@ -76,31 +64,27 @@ describe('Kerberos', function () { }); context('when the value is forward', function () { - it('authenticates with a forward cname lookup', function (done) { + it('authenticates with a forward cname lookup', async function () { const client = new MongoClient( `${krb5Uri}&authMechanismProperties=SERVICE_NAME:mongodb,CANONICALIZE_HOST_NAME:forward&maxPoolSize=1` ); - client.connect(function (err, client) { - if (err) return done(err); - expect(dns.resolveCname).to.be.calledOnceWith(host); - verifyKerberosAuthentication(client, done); - }); + await client.connect(); + expect(dns.resolveCname).to.be.calledOnceWith(host); + await verifyKerberosAuthentication(client); }); }); for (const option of [false, 'none']) { context(`when the value is ${option}`, function () { - it('authenticates with no dns lookups', function (done) { + it('authenticates with no dns lookups', async function () { const client = new MongoClient( `${krb5Uri}&authMechanismProperties=SERVICE_NAME:mongodb,CANONICALIZE_HOST_NAME:${option}&maxPoolSize=1` ); - client.connect(function (err, client) { - if (err) return done(err); - expect(dns.resolveCname).to.not.be.called; - // 2 calls when establishing connection - expect no third call. - expect(dns.lookup).to.be.calledTwice; - verifyKerberosAuthentication(client, done); - }); + await client.connect(); + expect(dns.resolveCname).to.not.be.called; + // There are 2 calls to establish connection, however they use the callback form of dns.lookup + expect(dns.lookup).to.not.be.called; + await verifyKerberosAuthentication(client); }); }); } @@ -108,150 +92,124 @@ describe('Kerberos', function () { for (const option of [true, 'forwardAndReverse']) { context(`when the value is ${option}`, function () { context('when the reverse lookup succeeds', function () { - const resolveStub = (address, callback) => { - callback(null, [host]); - }; - beforeEach(function () { - dns.resolvePtr.restore(); - sinon.stub(dns, 'resolvePtr').callsFake(resolveStub); + resolvePtrSpy.restore(); + sinon.stub(dns, 'resolvePtr').resolves([host]); }); - it('authenticates with a forward dns lookup and a reverse ptr lookup', function (done) { + it('authenticates with a forward dns lookup and a reverse ptr lookup', async function () { const client = new MongoClient( `${krb5Uri}&authMechanismProperties=SERVICE_NAME:mongodb,CANONICALIZE_HOST_NAME:${option}&maxPoolSize=1` ); - client.connect(function (err, client) { - if (err) return done(err); - // 2 calls to establish connection, 1 call in canonicalization. - expect(dns.lookup).to.be.calledThrice; - expect(dns.resolvePtr).to.be.calledOnce; - verifyKerberosAuthentication(client, done); - }); + await client.connect(); + // There are 2 calls to establish connection, however they use the callback form of dns.lookup + // 1 dns.promises.lookup call in canonicalization. + expect(dns.lookup).to.be.calledOnce; + expect(dns.resolvePtr).to.be.calledOnce; + await verifyKerberosAuthentication(client); }); }); context('when the reverse lookup is empty', function () { - const resolveStub = (address, callback) => { - callback(null, []); - }; - beforeEach(function () { - dns.resolvePtr.restore(); - sinon.stub(dns, 'resolvePtr').callsFake(resolveStub); + resolvePtrSpy.restore(); + sinon.stub(dns, 'resolvePtr').resolves([]); }); - it('authenticates with a fallback cname lookup', function (done) { + it('authenticates with a fallback cname lookup', async function () { const client = new MongoClient( `${krb5Uri}&authMechanismProperties=SERVICE_NAME:mongodb,CANONICALIZE_HOST_NAME:${option}&maxPoolSize=1` ); - client.connect(function (err, client) { - if (err) return done(err); - // 2 calls to establish connection, 1 call in canonicalization. - expect(dns.lookup).to.be.calledThrice; - // This fails. - expect(dns.resolvePtr).to.be.calledOnce; - // Expect the fallback to the host name. - expect(dns.resolveCname).to.not.be.called; - verifyKerberosAuthentication(client, done); - }); + + await client.connect(); + // There are 2 calls to establish connection, however they use the callback form of dns.lookup + // 1 dns.promises.lookup call in canonicalization. + expect(dns.lookup).to.be.calledOnce; + // This fails. + expect(dns.resolvePtr).to.be.calledOnce; + // Expect the fallback to the host name. + expect(dns.resolveCname).to.not.be.called; + await verifyKerberosAuthentication(client); }); }); context('when the reverse lookup fails', function () { - const resolveStub = (address, callback) => { - callback(new Error('not found'), null); - }; - beforeEach(function () { - dns.resolvePtr.restore(); - sinon.stub(dns, 'resolvePtr').callsFake(resolveStub); + resolvePtrSpy.restore(); + sinon.stub(dns, 'resolvePtr').rejects(new Error('not found')); }); - it('authenticates with a fallback cname lookup', function (done) { + it('authenticates with a fallback cname lookup', async function () { const client = new MongoClient( `${krb5Uri}&authMechanismProperties=SERVICE_NAME:mongodb,CANONICALIZE_HOST_NAME:${option}&maxPoolSize=1` ); - client.connect(function (err, client) { - if (err) return done(err); - // 2 calls to establish connection, 1 call in canonicalization. - expect(dns.lookup).to.be.calledThrice; - // This fails. - expect(dns.resolvePtr).to.be.calledOnce; - // Expect the fallback to be called. - expect(dns.resolveCname).to.be.calledOnceWith(host); - verifyKerberosAuthentication(client, done); - }); + + await client.connect(); + // There are 2 calls to establish connection, however they use the callback form of dns.lookup + // 1 dns.promises.lookup call in canonicalization. + expect(dns.lookup).to.be.calledOnce; + // This fails. + expect(dns.resolvePtr).to.be.calledOnce; + // Expect the fallback to be called. + expect(dns.resolveCname).to.be.calledOnceWith(host); + await verifyKerberosAuthentication(client); }); }); context('when the cname lookup fails', function () { - const resolveStub = (address, callback) => { - callback(new Error('not found'), null); - }; - beforeEach(function () { - dns.resolveCname.restore(); - sinon.stub(dns, 'resolveCname').callsFake(resolveStub); + resolveCnameSpy.restore(); + sinon.stub(dns, 'resolveCname').rejects(new Error('not found')); }); - it('authenticates with a fallback host name', function (done) { + it('authenticates with a fallback host name', async function () { const client = new MongoClient( `${krb5Uri}&authMechanismProperties=SERVICE_NAME:mongodb,CANONICALIZE_HOST_NAME:${option}&maxPoolSize=1` ); - client.connect(function (err, client) { - if (err) return done(err); - // 2 calls to establish connection, 1 call in canonicalization. - expect(dns.lookup).to.be.calledThrice; - // This fails. - expect(dns.resolvePtr).to.be.calledOnce; - // Expect the fallback to be called. - expect(dns.resolveCname).to.be.calledOnceWith(host); - verifyKerberosAuthentication(client, done); - }); + await client.connect(); + // There are 2 calls to establish connection, however they use the callback form of dns.lookup + // 1 dns.promises.lookup call in canonicalization. + expect(dns.lookup).to.be.calledOnce; + // This fails. + expect(dns.resolvePtr).to.be.calledOnce; + // Expect the fallback to be called. + expect(dns.resolveCname).to.be.calledOnceWith(host); + await verifyKerberosAuthentication(client); }); }); context('when the cname lookup is empty', function () { - const resolveStub = (address, callback) => { - callback(null, []); - }; - beforeEach(function () { - dns.resolveCname.restore(); - sinon.stub(dns, 'resolveCname').callsFake(resolveStub); + resolveCnameSpy.restore(); + sinon.stub(dns, 'resolveCname').resolves([]); }); - it('authenticates with a fallback host name', function (done) { + it('authenticates with a fallback host name', async function () { const client = new MongoClient( `${krb5Uri}&authMechanismProperties=SERVICE_NAME:mongodb,CANONICALIZE_HOST_NAME:${option}&maxPoolSize=1` ); - client.connect(function (err, client) { - if (err) return done(err); - // 2 calls to establish connection, 1 call in canonicalization. - expect(dns.lookup).to.be.calledThrice; - // This fails. - expect(dns.resolvePtr).to.be.calledOnce; - // Expect the fallback to be called. - expect(dns.resolveCname).to.be.calledOnceWith(host); - verifyKerberosAuthentication(client, done); - }); + await client.connect(); + // There are 2 calls to establish connection, however they use the callback form of dns.lookup + // 1 dns.promises.lookup call in canonicalization. + expect(dns.lookup).to.be.calledOnce; + // This fails. + expect(dns.resolvePtr).to.be.calledOnce; + // Expect the fallback to be called. + expect(dns.resolveCname).to.be.calledOnceWith(host); + await verifyKerberosAuthentication(client); }); }); }); } }); - // Unskip this test when a proper setup is available - see NODE-3060 - it.skip('validate that SERVICE_REALM and CANONICALIZE_HOST_NAME can be passed in', function (done) { + it.skip('validate that SERVICE_REALM and CANONICALIZE_HOST_NAME can be passed in', async function () { const client = new MongoClient( `${krb5Uri}&authMechanismProperties=SERVICE_NAME:mongodb,CANONICALIZE_HOST_NAME:false,SERVICE_REALM:windows&maxPoolSize=1` ); - client.connect(function (err, client) { - expect(err).to.not.exist; - verifyKerberosAuthentication(client, done); - }); - }); + await client.connect(); + await verifyKerberosAuthentication(client); + }).skipReason = 'TODO(NODE-3060): Unskip this test when a proper setup is available'; context('when passing SERVICE_HOST as an auth mech option', function () { context('when the SERVICE_HOST is invalid', function () { @@ -262,10 +220,7 @@ describe('Kerberos', function () { }); it('fails to authenticate', async function () { - let expectedError; - await client.connect().catch(e => { - expectedError = e; - }); + const expectedError = await client.connect().catch(e => e); if (!expectedError) { expect.fail('Expected connect with invalid SERVICE_HOST to fail'); } @@ -280,53 +235,48 @@ describe('Kerberos', function () { } }); - it('authenticates', function (done) { - client.connect(function (err, client) { - expect(err).to.not.exist; - verifyKerberosAuthentication(client, done); - }); + afterEach(async () => { + await client.close(); + }); + + it('authenticates', async function () { + await client.connect(); + await verifyKerberosAuthentication(client); }); }); }); describe('should use the SERVICE_NAME property', function () { - it('as an option handed to the MongoClient', function (done) { + it('as an option handed to the MongoClient', async function () { const client = new MongoClient(`${krb5Uri}&maxPoolSize=1`, { authMechanismProperties: { SERVICE_NAME: 'alternate' } }); - client.connect(function (err) { - expect(err).to.exist; - expect(err.message).to.match( - /(Error from KDC: LOOKING_UP_SERVER)|(not found in Kerberos database)|(UNKNOWN_SERVER)/ - ); - done(); - }); + + const err = await client.connect().catch(e => e); + expect(err.message).to.match( + /(Error from KDC: LOOKING_UP_SERVER)|(not found in Kerberos database)|(UNKNOWN_SERVER)/ + ); }); - it('as part of the query string parameters', function (done) { + it('as part of the query string parameters', async function () { const client = new MongoClient( `${krb5Uri}&authMechanismProperties=SERVICE_NAME:alternate&maxPoolSize=1` ); - client.connect(function (err) { - expect(err).to.exist; - expect(err.message).to.match( - /(Error from KDC: LOOKING_UP_SERVER)|(not found in Kerberos database)|(UNKNOWN_SERVER)/ - ); - done(); - }); + + const err = await client.connect().catch(e => e); + expect(err.message).to.match( + /(Error from KDC: LOOKING_UP_SERVER)|(not found in Kerberos database)|(UNKNOWN_SERVER)/ + ); }); }); - it('should fail to authenticate with bad credentials', function (done) { + it('should fail to authenticate with bad credentials', async function () { const client = new MongoClient( krb5Uri.replace(encodeURIComponent(process.env.KRB5_PRINCIPAL), 'bad%40creds.cc') ); - client.connect(function (err) { - expect(err).to.exist; - expect(err.message).to.match(/Authentication failed/); - done(); - }); + const err = await client.connect().catch(e => e); + expect(err.message).to.match(/Authentication failed/); }); }); diff --git a/test/unit/cmap/auth/gssapi.test.ts b/test/unit/cmap/auth/gssapi.test.ts index 3288afaa2fd..1d803a05528 100644 --- a/test/unit/cmap/auth/gssapi.test.ts +++ b/test/unit/cmap/auth/gssapi.test.ts @@ -9,14 +9,18 @@ import { } from '../../../mongodb'; describe('GSSAPI', () => { - const sandbox = sinon.createSandbox(); + let lookupSpy; + let resolvePtrSpy; + let resolveCnameSpy; beforeEach(() => { - sandbox.spy(dns); + lookupSpy = sinon.spy(dns, 'lookup'); + resolvePtrSpy = sinon.spy(dns, 'resolvePtr'); + resolveCnameSpy = sinon.spy(dns, 'resolveCname'); }); afterEach(() => { - sandbox.restore(); + sinon.restore(); }); describe('.performGSSAPICanonicalizeHostName', () => { @@ -24,47 +28,34 @@ describe('GSSAPI', () => { for (const mode of [GSSAPICanonicalizationValue.off, GSSAPICanonicalizationValue.none]) { context(`when the mode is ${mode}`, () => { - it('performs no dns lookups', done => { - performGSSAPICanonicalizeHostName( - hostName, - { CANONICALIZE_HOST_NAME: mode }, - (error, host) => { - if (error) return done(error); - expect(host).to.equal(hostName); - expect(dns.lookup).to.not.be.called; - expect(dns.resolvePtr).to.not.be.called; - expect(dns.resolveCname).to.not.be.called; - done(); - } - ); + it('performs no dns lookups', async () => { + const host = await performGSSAPICanonicalizeHostName(hostName, { + CANONICALIZE_HOST_NAME: mode + }); + expect(host).to.equal(hostName); + expect(dns.lookup).to.not.be.called; + expect(dns.resolvePtr).to.not.be.called; + expect(dns.resolveCname).to.not.be.called; }); }); } context(`when the mode is forward`, () => { const resolved = '10gen.cc'; - const resolveStub = (host, callback) => { - callback(undefined, [resolved]); - }; beforeEach(() => { - dns.resolveCname.restore(); - sinon.stub(dns, 'resolveCname').callsFake(resolveStub); + resolveCnameSpy.restore(); + sinon.stub(dns, 'resolveCname').resolves([resolved]); }); - it('performs a cname lookup', done => { - performGSSAPICanonicalizeHostName( - hostName, - { CANONICALIZE_HOST_NAME: GSSAPICanonicalizationValue.forward }, - (error, host) => { - if (error) return done(error); - expect(host).to.equal(resolved); - expect(dns.lookup).to.not.be.called; - expect(dns.resolvePtr).to.not.be.called; - expect(dns.resolveCname).to.be.calledOnceWith(hostName); - done(); - } - ); + it('performs a cname lookup', async () => { + const host = await performGSSAPICanonicalizeHostName(hostName, { + CANONICALIZE_HOST_NAME: GSSAPICanonicalizationValue.forward + }); + expect(host).to.equal(resolved); + expect(dns.lookup).to.not.be.called; + expect(dns.resolvePtr).to.not.be.called; + expect(dns.resolveCname).to.be.calledOnceWith(hostName); }); }); @@ -74,152 +65,111 @@ describe('GSSAPI', () => { ]) { context(`when the mode is ${mode}`, () => { context('when the forward lookup succeeds', () => { - const lookedUp = '1.1.1.1'; - const lookupStub = (host, callback) => { - callback(undefined, lookedUp); - }; + const lookedUp = { address: '1.1.1.1', family: 4 }; context('when the reverse lookup succeeds', () => { context('when there is 1 result', () => { const resolved = '10gen.cc'; - const resolveStub = (host, callback) => { - callback(undefined, [resolved]); - }; beforeEach(() => { - dns.lookup.restore(); - dns.resolvePtr.restore(); - sinon.stub(dns, 'lookup').callsFake(lookupStub); - sinon.stub(dns, 'resolvePtr').callsFake(resolveStub); + lookupSpy.restore(); + resolvePtrSpy.restore(); + sinon.stub(dns, 'lookup').resolves(lookedUp); + sinon.stub(dns, 'resolvePtr').resolves([resolved]); }); - it('uses the reverse lookup host', done => { - performGSSAPICanonicalizeHostName( - hostName, - { CANONICALIZE_HOST_NAME: mode }, - (error, host) => { - if (error) return done(error); - expect(host).to.equal(resolved); - expect(dns.lookup).to.be.calledOnceWith(hostName); - expect(dns.resolvePtr).to.be.calledOnceWith(lookedUp); - expect(dns.resolveCname).to.not.be.called; - done(); - } - ); + it('uses the reverse lookup host', async () => { + const host = await performGSSAPICanonicalizeHostName(hostName, { + CANONICALIZE_HOST_NAME: mode + }); + expect(host).to.equal(resolved); + expect(dns.lookup).to.be.calledOnceWith(hostName); + expect(dns.resolvePtr).to.be.calledOnceWith(lookedUp.address); + expect(dns.resolveCname).to.not.be.called; }); }); context('when there is more than 1 result', () => { const resolved = '10gen.cc'; - const resolveStub = (host, callback) => { - callback(undefined, [resolved, 'example.com']); - }; beforeEach(() => { - dns.lookup.restore(); - dns.resolvePtr.restore(); - sinon.stub(dns, 'lookup').callsFake(lookupStub); - sinon.stub(dns, 'resolvePtr').callsFake(resolveStub); + lookupSpy.restore(); + resolvePtrSpy.restore(); + sinon.stub(dns, 'lookup').resolves(lookedUp); + sinon.stub(dns, 'resolvePtr').resolves([resolved, 'example.com']); }); - it('uses the first found reverse lookup host', done => { - performGSSAPICanonicalizeHostName( - hostName, - { CANONICALIZE_HOST_NAME: mode }, - (error, host) => { - if (error) return done(error); - expect(host).to.equal(resolved); - expect(dns.lookup).to.be.calledOnceWith(hostName); - expect(dns.resolvePtr).to.be.calledOnceWith(lookedUp); - expect(dns.resolveCname).to.not.be.called; - done(); - } - ); + it('uses the first found reverse lookup host', async () => { + const host = await performGSSAPICanonicalizeHostName(hostName, { + CANONICALIZE_HOST_NAME: mode + }); + expect(host).to.equal(resolved); + expect(dns.lookup).to.be.calledOnceWith(hostName); + expect(dns.resolvePtr).to.be.calledOnceWith(lookedUp.address); + expect(dns.resolveCname).to.not.be.called; }); }); }); context('when the reverse lookup fails', () => { const cname = 'test.com'; - const resolveStub = (host, callback) => { - callback(new Error('failed'), undefined); - }; - const cnameStub = (host, callback) => { - callback(undefined, [cname]); - }; beforeEach(() => { - dns.lookup.restore(); - dns.resolvePtr.restore(); - dns.resolveCname.restore(); - sinon.stub(dns, 'lookup').callsFake(lookupStub); - sinon.stub(dns, 'resolvePtr').callsFake(resolveStub); - sinon.stub(dns, 'resolveCname').callsFake(cnameStub); + lookupSpy.restore(); + resolvePtrSpy.restore(); + resolveCnameSpy.restore(); + sinon.stub(dns, 'lookup').resolves(lookedUp); + sinon.stub(dns, 'resolvePtr').rejects(new Error('failed')); + sinon.stub(dns, 'resolveCname').resolves([cname]); }); - it('falls back to a cname lookup', done => { - performGSSAPICanonicalizeHostName( - hostName, - { CANONICALIZE_HOST_NAME: mode }, - (error, host) => { - if (error) return done(error); - expect(host).to.equal(cname); - expect(dns.lookup).to.be.calledOnceWith(hostName); - expect(dns.resolvePtr).to.be.calledOnceWith(lookedUp); - expect(dns.resolveCname).to.be.calledWith(hostName); - done(); - } - ); + it('falls back to a cname lookup', async () => { + const host = await performGSSAPICanonicalizeHostName(hostName, { + CANONICALIZE_HOST_NAME: mode + }); + + expect(host).to.equal(cname); + expect(dns.lookup).to.be.calledOnceWith(hostName); + expect(dns.resolvePtr).to.be.calledOnceWith(lookedUp.address); + expect(dns.resolveCname).to.be.calledWith(hostName); }); }); context('when the reverse lookup is empty', () => { - const resolveStub = (host, callback) => { - callback(undefined, []); - }; - beforeEach(() => { - dns.lookup.restore(); - dns.resolvePtr.restore(); - sinon.stub(dns, 'lookup').callsFake(lookupStub); - sinon.stub(dns, 'resolvePtr').callsFake(resolveStub); + lookupSpy.restore(); + resolvePtrSpy.restore(); + sinon.stub(dns, 'lookup').resolves(lookedUp); + sinon.stub(dns, 'resolvePtr').resolves([]); }); - it('uses the provided host', done => { - performGSSAPICanonicalizeHostName( - hostName, - { CANONICALIZE_HOST_NAME: mode }, - (error, host) => { - if (error) return done(error); - expect(host).to.equal(hostName); - expect(dns.lookup).to.be.calledOnceWith(hostName); - expect(dns.resolvePtr).to.be.calledOnceWith(lookedUp); - expect(dns.resolveCname).to.not.be.called; - done(); - } - ); + it('uses the provided host', async () => { + const host = await performGSSAPICanonicalizeHostName(hostName, { + CANONICALIZE_HOST_NAME: mode + }); + expect(host).to.equal(hostName); + expect(dns.lookup).to.be.calledOnceWith(hostName); + expect(dns.resolvePtr).to.be.calledOnceWith(lookedUp.address); + expect(dns.resolveCname).to.not.be.called; }); }); }); context('when the forward lookup fails', () => { - const lookupStub = (host, callback) => { - callback(new Error('failed'), undefined); - }; - beforeEach(() => { - dns.lookup.restore(); - sinon.stub(dns, 'lookup').callsFake(lookupStub); + lookupSpy.restore(); + sinon.stub(dns, 'lookup').rejects(new Error('failed')); }); - it('fails with the error', done => { - performGSSAPICanonicalizeHostName(hostName, { CANONICALIZE_HOST_NAME: mode }, error => { - expect(error.message).to.equal('failed'); - expect(dns.lookup).to.be.calledOnceWith(hostName); - expect(dns.resolvePtr).to.not.be.called; - expect(dns.resolveCname).to.not.be.called; - done(); - }); + it('fails with the error', async () => { + const error = await performGSSAPICanonicalizeHostName(hostName, { + CANONICALIZE_HOST_NAME: mode + }).catch(error => error); + + expect(error.message).to.equal('failed'); + expect(dns.lookup).to.be.calledOnceWith(hostName); + expect(dns.resolvePtr).to.not.be.called; + expect(dns.resolveCname).to.not.be.called; }); }); }); @@ -229,22 +179,16 @@ describe('GSSAPI', () => { describe('.resolveCname', () => { context('when the cname call errors', () => { const hostName = 'example.com'; - const resolveStub = (host, callback) => { - callback(new Error('failed')); - }; beforeEach(() => { - dns.resolveCname.restore(); - sinon.stub(dns, 'resolveCname').callsFake(resolveStub); + resolveCnameSpy.restore(); + sinon.stub(dns, 'resolveCname').rejects(new Error('failed')); }); - it('falls back to the provided host name', done => { - resolveCname(hostName, (error, host) => { - if (error) return done(error); - expect(host).to.equal(hostName); - expect(dns.resolveCname).to.be.calledOnceWith(hostName); - done(); - }); + it('falls back to the provided host name', async () => { + const host = await resolveCname(hostName); + expect(host).to.equal(hostName); + expect(dns.resolveCname).to.be.calledOnceWith(hostName); }); }); @@ -252,66 +196,48 @@ describe('GSSAPI', () => { context('when there is one result', () => { const hostName = 'example.com'; const resolved = '10gen.cc'; - const resolveStub = (host, callback) => { - callback(undefined, [resolved]); - }; beforeEach(() => { - dns.resolveCname.restore(); - sinon.stub(dns, 'resolveCname').callsFake(resolveStub); + resolveCnameSpy.restore(); + sinon.stub(dns, 'resolveCname').resolves([resolved]); }); - it('uses the result', done => { - resolveCname(hostName, (error, host) => { - if (error) return done(error); - expect(host).to.equal(resolved); - expect(dns.resolveCname).to.be.calledOnceWith(hostName); - done(); - }); + it('uses the result', async () => { + const host = await resolveCname(hostName); + expect(host).to.equal(resolved); + expect(dns.resolveCname).to.be.calledOnceWith(hostName); }); }); context('when there is more than one result', () => { const hostName = 'example.com'; const resolved = '10gen.cc'; - const resolveStub = (host, callback) => { - callback(undefined, [resolved, hostName]); - }; beforeEach(() => { - dns.resolveCname.restore(); - sinon.stub(dns, 'resolveCname').callsFake(resolveStub); + resolveCnameSpy.restore(); + sinon.stub(dns, 'resolveCname').resolves([resolved, hostName]); }); - it('uses the first result', done => { - resolveCname(hostName, (error, host) => { - if (error) return done(error); - expect(host).to.equal(resolved); - expect(dns.resolveCname).to.be.calledOnceWith(hostName); - done(); - }); + it('uses the first result', async () => { + const host = await resolveCname(hostName); + expect(host).to.equal(resolved); + expect(dns.resolveCname).to.be.calledOnceWith(hostName); }); }); }); context('when the cname call returns no results', () => { const hostName = 'example.com'; - const resolveStub = (host, callback) => { - callback(undefined, []); - }; beforeEach(() => { - dns.resolveCname.restore(); - sinon.stub(dns, 'resolveCname').callsFake(resolveStub); + resolveCnameSpy.restore(); + sinon.stub(dns, 'resolveCname').resolves([]); }); - it('falls back to using the provided host', done => { - resolveCname(hostName, (error, host) => { - if (error) return done(error); - expect(host).to.equal(hostName); - expect(dns.resolveCname).to.be.calledOnceWith(hostName); - done(); - }); + it('falls back to using the provided host', async () => { + const host = await resolveCname(hostName); + expect(host).to.equal(hostName); + expect(dns.resolveCname).to.be.calledOnceWith(hostName); }); }); }); From 40845446aa71c6e5fcec0925cd2df949ec696ddf Mon Sep 17 00:00:00 2001 From: Neal Beeken Date: Thu, 23 Mar 2023 10:36:37 -0400 Subject: [PATCH 07/11] chore: clean up client and get more cov --- package.json | 6 +++--- test/manual/kerberos.test.ts | 38 +++++++++++++++++++----------------- 2 files changed, 23 insertions(+), 21 deletions(-) diff --git a/package.json b/package.json index a564ad07a72..2e685db951c 100644 --- a/package.json +++ b/package.json @@ -126,12 +126,12 @@ "check:ts": "node ./node_modules/typescript/bin/tsc -v && node ./node_modules/typescript/bin/tsc --noEmit", "check:atlas": "mocha --config test/manual/mocharc.json test/manual/atlas_connectivity.test.js", "check:adl": "mocha --config test/mocha_mongodb.json test/manual/atlas-data-lake-testing", - "check:aws": "mocha --config test/mocha_mongodb.json test/integration/auth/mongodb_aws.test.ts", + "check:aws": "nyc mocha --config test/mocha_mongodb.json test/integration/auth/mongodb_aws.test.ts", "check:oidc": "mocha --config test/manual/mocharc.json test/manual/mongodb_oidc.prose.test.ts", "check:ocsp": "mocha --config test/manual/mocharc.json test/manual/ocsp_support.test.js", - "check:kerberos": "mocha --config test/manual/mocharc.json test/manual/kerberos.test.ts", + "check:kerberos": "nyc mocha --config test/manual/mocharc.json test/manual/kerberos.test.ts", "check:tls": "mocha --config test/manual/mocharc.json test/manual/tls_support.test.js", - "check:ldap": "mocha --config test/manual/mocharc.json test/manual/ldap.test.js", + "check:ldap": "nyc mocha --config test/manual/mocharc.json test/manual/ldap.test.js", "check:socks5": "mocha --config test/manual/mocharc.json test/manual/socks5.test.ts", "check:csfle": "mocha --config test/mocha_mongodb.json test/integration/client-side-encryption", "check:snappy": "mocha test/unit/assorted/snappy.test.js", diff --git a/test/manual/kerberos.test.ts b/test/manual/kerberos.test.ts index 63a60158747..11179be3e50 100644 --- a/test/manual/kerberos.test.ts +++ b/test/manual/kerberos.test.ts @@ -16,6 +16,7 @@ async function verifyKerberosAuthentication(client) { describe('Kerberos', function () { let resolvePtrSpy; let resolveCnameSpy; + let client; beforeEach(() => { sinon.spy(dns, 'lookup'); @@ -27,6 +28,11 @@ describe('Kerberos', function () { sinon.restore(); }); + afterEach(async () => { + await client.close(); + client = null; + }); + if (process.env.MONGODB_URI == null) { console.error('skipping Kerberos tests, MONGODB_URI environment variable is not defined'); return; @@ -49,7 +55,7 @@ describe('Kerberos', function () { } it('should authenticate with original uri', async function () { - const client = new MongoClient(krb5Uri); + client = new MongoClient(krb5Uri); await client.connect(); await verifyKerberosAuthentication(client); }); @@ -65,7 +71,7 @@ describe('Kerberos', function () { context('when the value is forward', function () { it('authenticates with a forward cname lookup', async function () { - const client = new MongoClient( + client = new MongoClient( `${krb5Uri}&authMechanismProperties=SERVICE_NAME:mongodb,CANONICALIZE_HOST_NAME:forward&maxPoolSize=1` ); await client.connect(); @@ -77,7 +83,7 @@ describe('Kerberos', function () { for (const option of [false, 'none']) { context(`when the value is ${option}`, function () { it('authenticates with no dns lookups', async function () { - const client = new MongoClient( + client = new MongoClient( `${krb5Uri}&authMechanismProperties=SERVICE_NAME:mongodb,CANONICALIZE_HOST_NAME:${option}&maxPoolSize=1` ); await client.connect(); @@ -98,7 +104,7 @@ describe('Kerberos', function () { }); it('authenticates with a forward dns lookup and a reverse ptr lookup', async function () { - const client = new MongoClient( + client = new MongoClient( `${krb5Uri}&authMechanismProperties=SERVICE_NAME:mongodb,CANONICALIZE_HOST_NAME:${option}&maxPoolSize=1` ); await client.connect(); @@ -117,7 +123,7 @@ describe('Kerberos', function () { }); it('authenticates with a fallback cname lookup', async function () { - const client = new MongoClient( + client = new MongoClient( `${krb5Uri}&authMechanismProperties=SERVICE_NAME:mongodb,CANONICALIZE_HOST_NAME:${option}&maxPoolSize=1` ); @@ -140,7 +146,7 @@ describe('Kerberos', function () { }); it('authenticates with a fallback cname lookup', async function () { - const client = new MongoClient( + client = new MongoClient( `${krb5Uri}&authMechanismProperties=SERVICE_NAME:mongodb,CANONICALIZE_HOST_NAME:${option}&maxPoolSize=1` ); @@ -163,7 +169,7 @@ describe('Kerberos', function () { }); it('authenticates with a fallback host name', async function () { - const client = new MongoClient( + client = new MongoClient( `${krb5Uri}&authMechanismProperties=SERVICE_NAME:mongodb,CANONICALIZE_HOST_NAME:${option}&maxPoolSize=1` ); await client.connect(); @@ -185,7 +191,7 @@ describe('Kerberos', function () { }); it('authenticates with a fallback host name', async function () { - const client = new MongoClient( + client = new MongoClient( `${krb5Uri}&authMechanismProperties=SERVICE_NAME:mongodb,CANONICALIZE_HOST_NAME:${option}&maxPoolSize=1` ); await client.connect(); @@ -204,7 +210,7 @@ describe('Kerberos', function () { }); it.skip('validate that SERVICE_REALM and CANONICALIZE_HOST_NAME can be passed in', async function () { - const client = new MongoClient( + client = new MongoClient( `${krb5Uri}&authMechanismProperties=SERVICE_NAME:mongodb,CANONICALIZE_HOST_NAME:false,SERVICE_REALM:windows&maxPoolSize=1` ); await client.connect(); @@ -213,7 +219,7 @@ describe('Kerberos', function () { context('when passing SERVICE_HOST as an auth mech option', function () { context('when the SERVICE_HOST is invalid', function () { - const client = new MongoClient(`${krb5Uri}&maxPoolSize=1`, { + client = new MongoClient(`${krb5Uri}&maxPoolSize=1`, { authMechanismProperties: { SERVICE_HOST: 'example.com' } @@ -229,16 +235,12 @@ describe('Kerberos', function () { }); context('when the SERVICE_HOST is valid', function () { - const client = new MongoClient(`${krb5Uri}&maxPoolSize=1`, { + client = new MongoClient(`${krb5Uri}&maxPoolSize=1`, { authMechanismProperties: { SERVICE_HOST: 'ldaptest.10gen.cc' } }); - afterEach(async () => { - await client.close(); - }); - it('authenticates', async function () { await client.connect(); await verifyKerberosAuthentication(client); @@ -248,7 +250,7 @@ describe('Kerberos', function () { describe('should use the SERVICE_NAME property', function () { it('as an option handed to the MongoClient', async function () { - const client = new MongoClient(`${krb5Uri}&maxPoolSize=1`, { + client = new MongoClient(`${krb5Uri}&maxPoolSize=1`, { authMechanismProperties: { SERVICE_NAME: 'alternate' } @@ -261,7 +263,7 @@ describe('Kerberos', function () { }); it('as part of the query string parameters', async function () { - const client = new MongoClient( + client = new MongoClient( `${krb5Uri}&authMechanismProperties=SERVICE_NAME:alternate&maxPoolSize=1` ); @@ -273,7 +275,7 @@ describe('Kerberos', function () { }); it('should fail to authenticate with bad credentials', async function () { - const client = new MongoClient( + client = new MongoClient( krb5Uri.replace(encodeURIComponent(process.env.KRB5_PRINCIPAL), 'bad%40creds.cc') ); const err = await client.connect().catch(e => e); From 5d7d0f79802373ef87d358d31d6f154c3afb3698 Mon Sep 17 00:00:00 2001 From: Neal Beeken Date: Thu, 23 Mar 2023 10:50:31 -0400 Subject: [PATCH 08/11] fix: all clients created inside it blocks --- test/manual/kerberos.test.ts | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/test/manual/kerberos.test.ts b/test/manual/kerberos.test.ts index 11179be3e50..4dc02c00216 100644 --- a/test/manual/kerberos.test.ts +++ b/test/manual/kerberos.test.ts @@ -29,7 +29,7 @@ describe('Kerberos', function () { }); afterEach(async () => { - await client.close(); + await client?.close(); client = null; }); @@ -219,13 +219,13 @@ describe('Kerberos', function () { context('when passing SERVICE_HOST as an auth mech option', function () { context('when the SERVICE_HOST is invalid', function () { - client = new MongoClient(`${krb5Uri}&maxPoolSize=1`, { - authMechanismProperties: { - SERVICE_HOST: 'example.com' - } - }); - it('fails to authenticate', async function () { + client = new MongoClient(`${krb5Uri}&maxPoolSize=1`, { + authMechanismProperties: { + SERVICE_HOST: 'example.com' + } + }); + const expectedError = await client.connect().catch(e => e); if (!expectedError) { expect.fail('Expected connect with invalid SERVICE_HOST to fail'); @@ -235,13 +235,13 @@ describe('Kerberos', function () { }); context('when the SERVICE_HOST is valid', function () { - client = new MongoClient(`${krb5Uri}&maxPoolSize=1`, { - authMechanismProperties: { - SERVICE_HOST: 'ldaptest.10gen.cc' - } - }); - it('authenticates', async function () { + client = new MongoClient(`${krb5Uri}&maxPoolSize=1`, { + authMechanismProperties: { + SERVICE_HOST: 'ldaptest.10gen.cc' + } + }); + await client.connect(); await verifyKerberosAuthentication(client); }); From d4a395beb0ef6b4d4ea5514b009e2bec442f2cab Mon Sep 17 00:00:00 2001 From: Neal Beeken Date: Fri, 24 Mar 2023 10:57:25 -0400 Subject: [PATCH 09/11] comments p1 --- src/cmap/auth/auth_provider.ts | 5 +-- src/cmap/auth/mongocr.ts | 11 ++--- src/cmap/auth/mongodb_aws.ts | 1 - src/cmap/auth/mongodb_oidc.ts | 12 +----- test/tools/uri_spec_runner.ts | 1 + test/unit/cmap/auth/mongodb_oidc.test.ts | 51 ++++++++++++++++++++++++ 6 files changed, 61 insertions(+), 20 deletions(-) create mode 100644 test/unit/cmap/auth/mongodb_oidc.test.ts diff --git a/src/cmap/auth/auth_provider.ts b/src/cmap/auth/auth_provider.ts index 1cd5e67b122..98c93669f8b 100644 --- a/src/cmap/auth/auth_provider.ts +++ b/src/cmap/auth/auth_provider.ts @@ -56,17 +56,14 @@ export abstract class AuthProvider { * Authenticate * * @param context - A shared context for authentication flow - * @param callback - The callback to return the result from the authentication */ - abstract auth(_context: AuthContext): Promise; + abstract auth(context: AuthContext): Promise; /** * Reauthenticate. * @param context - The shared auth context. - * @param callback - The callback. */ async reauth(context: AuthContext): Promise { - // If we are already reauthenticating this is a no-op. if (context.reauthenticating) { throw new MongoRuntimeError('Reauthentication already in progress.'); } diff --git a/src/cmap/auth/mongocr.ts b/src/cmap/auth/mongocr.ts index 579069e9b61..4bc0003461e 100644 --- a/src/cmap/auth/mongocr.ts +++ b/src/cmap/auth/mongocr.ts @@ -13,10 +13,11 @@ export class MongoCR extends AuthProvider { const { username, password, source } = credentials; - const r = await connection.commandAsync(ns(`${source}.$cmd`), { getnonce: 1 }, undefined); - - // Get nonce - const nonce = r.nonce; + const { nonce } = await connection.commandAsync( + ns(`${source}.$cmd`), + { getnonce: 1 }, + undefined + ); const hashPassword = crypto .createHash('md5') @@ -26,7 +27,7 @@ export class MongoCR extends AuthProvider { // Final key const key = crypto .createHash('md5') - .update(nonce + username + hashPassword, 'utf8') + .update(`${nonce}${username}${hashPassword}`, 'utf8') .digest('hex'); const authenticateCommand = { diff --git a/src/cmap/auth/mongodb_aws.ts b/src/cmap/auth/mongodb_aws.ts index 031a73df84e..91eb98ba0ce 100644 --- a/src/cmap/auth/mongodb_aws.ts +++ b/src/cmap/auth/mongodb_aws.ts @@ -62,7 +62,6 @@ export class MongoDBAWS extends AuthProvider { if (!credentials.username) { authContext.credentials = await makeTempCredentials(credentials); - return this.auth(authContext); } const accessKeyId = credentials.username; diff --git a/src/cmap/auth/mongodb_oidc.ts b/src/cmap/auth/mongodb_oidc.ts index 41e6a3b02cc..69ae3e0d3b3 100644 --- a/src/cmap/auth/mongodb_oidc.ts +++ b/src/cmap/auth/mongodb_oidc.ts @@ -1,8 +1,4 @@ -import { - MongoInvalidArgumentError, - MongoMissingCredentialsError, - MongoRuntimeError -} from '../../error'; +import { MongoInvalidArgumentError, MongoMissingCredentialsError } from '../../error'; import type { HandshakeDocument } from '../connect'; import { type AuthContext, AuthProvider } from './auth_provider'; import type { MongoCredentials } from './mongo_credentials'; @@ -106,11 +102,7 @@ export class MongoDBOIDC extends AuthProvider { } const workflow = getWorkflow(credentials); - if (!workflow) { - throw new MongoRuntimeError( - `Could not load workflow for provider ${credentials.mechanismProperties.PROVIDER_NAME}` - ); - } + const result = await workflow.speculativeAuth(); return { ...handshakeDoc, ...result }; } diff --git a/test/tools/uri_spec_runner.ts b/test/tools/uri_spec_runner.ts index a31a25fa2dc..492043aa11b 100644 --- a/test/tools/uri_spec_runner.ts +++ b/test/tools/uri_spec_runner.ts @@ -24,6 +24,7 @@ interface UriTest extends UriTestBase { }; options: Record; } + interface AuthTest extends UriTestBase { credential: { username: string; diff --git a/test/unit/cmap/auth/mongodb_oidc.test.ts b/test/unit/cmap/auth/mongodb_oidc.test.ts new file mode 100644 index 00000000000..121244688e9 --- /dev/null +++ b/test/unit/cmap/auth/mongodb_oidc.test.ts @@ -0,0 +1,51 @@ +import { expect } from 'chai'; + +import { + AuthContext, + MongoCredentials, + MongoDBOIDC, + MongoInvalidArgumentError +} from '../../../mongodb'; + +describe('class MongoDBOIDC', () => { + context('when an unknown OIDC provider name is set', () => { + it('prepare rejects with MongoInvalidArgumentError', async () => { + const oidc = new MongoDBOIDC(); + const error = await oidc + .prepare( + {}, + new AuthContext( + {}, + new MongoCredentials({ + mechanism: 'MONGODB-OIDC', + mechanismProperties: { PROVIDER_NAME: 'iLoveJavaScript' } + }), + {} + ) + ) + .catch(error => error); + + expect(error).to.be.instanceOf(MongoInvalidArgumentError); + expect(error).to.match(/workflow for provider/); + }); + + it('auth rejects with MongoInvalidArgumentError', async () => { + const oidc = new MongoDBOIDC(); + const error = await oidc + .auth( + new AuthContext( + {}, + new MongoCredentials({ + mechanism: 'MONGODB-OIDC', + mechanismProperties: { PROVIDER_NAME: 'iLoveJavaScript' } + }), + {} + ) + ) + .catch(error => error); + + expect(error).to.be.instanceOf(MongoInvalidArgumentError); + expect(error).to.match(/workflow for provider/); + }); + }); +}); From d4636f9161a8ec4f911eee9701102799e23ac512 Mon Sep 17 00:00:00 2001 From: Neal Beeken Date: Fri, 24 Mar 2023 12:12:04 -0400 Subject: [PATCH 10/11] fix aws --- src/cmap/auth/mongodb_aws.ts | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/cmap/auth/mongodb_aws.ts b/src/cmap/auth/mongodb_aws.ts index 91eb98ba0ce..fa94b93c67f 100644 --- a/src/cmap/auth/mongodb_aws.ts +++ b/src/cmap/auth/mongodb_aws.ts @@ -44,8 +44,8 @@ export class MongoDBAWS extends AuthProvider { } override async auth(authContext: AuthContext): Promise { - const { connection, credentials } = authContext; - if (!credentials) { + const { connection } = authContext; + if (!authContext.credentials) { throw new MongoMissingCredentialsError('AuthContext must provide credentials.'); } @@ -60,10 +60,12 @@ export class MongoDBAWS extends AuthProvider { ); } - if (!credentials.username) { - authContext.credentials = await makeTempCredentials(credentials); + if (!authContext.credentials.username) { + authContext.credentials = await makeTempCredentials(authContext.credentials); } + const { credentials } = authContext; + const accessKeyId = credentials.username; const secretAccessKey = credentials.password; const sessionToken = credentials.mechanismProperties.AWS_SESSION_TOKEN; From 4ffa648374fcdb7eb7458541ad5d11160102bb18 Mon Sep 17 00:00:00 2001 From: Neal Beeken Date: Fri, 24 Mar 2023 12:25:39 -0400 Subject: [PATCH 11/11] add tickets --- src/cmap/auth/gssapi.ts | 2 +- src/cmap/connect.ts | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/cmap/auth/gssapi.ts b/src/cmap/auth/gssapi.ts index 4a4ae3e1bbb..8b5d6613e76 100644 --- a/src/cmap/auth/gssapi.ts +++ b/src/cmap/auth/gssapi.ts @@ -91,7 +91,7 @@ async function makeKerberosClient(authContext: AuthContext): Promise