diff --git a/src/client-side-encryption/auto_encrypter.ts b/src/client-side-encryption/auto_encrypter.ts index a24f8cd6da6..b8f2e42b13b 100644 --- a/src/client-side-encryption/auto_encrypter.ts +++ b/src/client-side-encryption/auto_encrypter.ts @@ -17,7 +17,12 @@ import { autoSelectSocketOptions } from './client_encryption'; import * as cryptoCallbacks from './crypto_callbacks'; import { MongoCryptInvalidArgumentError } from './errors'; import { MongocryptdManager } from './mongocryptd_manager'; -import { type KMSProviders, refreshKMSCredentials } from './providers'; +import { + type CredentialProviders, + isEmptyCredentials, + type KMSProviders, + refreshKMSCredentials +} from './providers'; import { type CSFLEKMSTlsOptions, StateMachine } from './state_machine'; /** @public */ @@ -30,6 +35,8 @@ export interface AutoEncryptionOptions { keyVaultNamespace?: string; /** Configuration options that are used by specific KMS providers during key generation, encryption, and decryption. */ kmsProviders?: KMSProviders; + /** Configuration options for custom credential providers. */ + credentialProviders?: CredentialProviders; /** * A map of namespaces to a local JSON schema for encryption * @@ -153,6 +160,7 @@ export class AutoEncrypter { _kmsProviders: KMSProviders; _bypassMongocryptdAndCryptShared: boolean; _contextCounter: number; + _credentialProviders?: CredentialProviders; _mongocryptdManager?: MongocryptdManager; _mongocryptdClient?: MongoClient; @@ -237,6 +245,13 @@ export class AutoEncrypter { this._proxyOptions = options.proxyOptions || {}; this._tlsOptions = options.tlsOptions || {}; this._kmsProviders = options.kmsProviders || {}; + this._credentialProviders = options.credentialProviders; + + if (options.credentialProviders?.aws && !isEmptyCredentials('aws', this._kmsProviders)) { + throw new MongoCryptInvalidArgumentError( + 'Can only provide a custom AWS credential provider when the state machine is configured for automatic AWS credential fetching' + ); + } const mongoCryptOptions: MongoCryptOptions = { enableMultipleCollinfo: true, @@ -439,7 +454,7 @@ export class AutoEncrypter { * the original ones. */ async askForKMSCredentials(): Promise { - return await refreshKMSCredentials(this._kmsProviders); + return await refreshKMSCredentials(this._kmsProviders, this._credentialProviders); } /** diff --git a/src/client-side-encryption/client_encryption.ts b/src/client-side-encryption/client_encryption.ts index 487969cf4de..b5968fd0d76 100644 --- a/src/client-side-encryption/client_encryption.ts +++ b/src/client-side-encryption/client_encryption.ts @@ -34,6 +34,8 @@ import { } from './errors'; import { type ClientEncryptionDataKeyProvider, + type CredentialProviders, + isEmptyCredentials, type KMSProviders, refreshKMSCredentials } from './providers/index'; @@ -81,6 +83,9 @@ export class ClientEncryption { /** @internal */ _mongoCrypt: MongoCrypt; + /** @internal */ + _credentialProviders?: CredentialProviders; + /** @internal */ static getMongoCrypt(): MongoCryptConstructor { const encryption = getMongoDBClientEncryption(); @@ -125,6 +130,13 @@ export class ClientEncryption { this._kmsProviders = options.kmsProviders || {}; const { timeoutMS } = resolveTimeoutOptions(client, options); this._timeoutMS = timeoutMS; + this._credentialProviders = options.credentialProviders; + + if (options.credentialProviders?.aws && !isEmptyCredentials('aws', this._kmsProviders)) { + throw new MongoCryptInvalidArgumentError( + 'Can only provide a custom AWS credential provider when the state machine is configured for automatic AWS credential fetching' + ); + } if (options.keyVaultNamespace == null) { throw new MongoCryptInvalidArgumentError('Missing required option `keyVaultNamespace`'); @@ -712,7 +724,7 @@ export class ClientEncryption { * the original ones. */ async askForKMSCredentials(): Promise { - return await refreshKMSCredentials(this._kmsProviders); + return await refreshKMSCredentials(this._kmsProviders, this._credentialProviders); } static get libmongocryptVersion() { @@ -858,6 +870,11 @@ export interface ClientEncryptionOptions { */ kmsProviders?: KMSProviders; + /** + * Options for user provided custom credential providers. + */ + credentialProviders?: CredentialProviders; + /** * Options for specifying a Socks5 proxy to use for connecting to the KMS. */ diff --git a/src/client-side-encryption/providers/aws.ts b/src/client-side-encryption/providers/aws.ts index 240e560bd9a..e50bb4d232f 100644 --- a/src/client-side-encryption/providers/aws.ts +++ b/src/client-side-encryption/providers/aws.ts @@ -1,11 +1,17 @@ -import { AWSSDKCredentialProvider } from '../../cmap/auth/aws_temporary_credentials'; +import { + type AWSCredentialProvider, + AWSSDKCredentialProvider +} from '../../cmap/auth/aws_temporary_credentials'; import { type KMSProviders } from '.'; /** * @internal */ -export async function loadAWSCredentials(kmsProviders: KMSProviders): Promise { - const credentialProvider = new AWSSDKCredentialProvider(); +export async function loadAWSCredentials( + kmsProviders: KMSProviders, + provider?: AWSCredentialProvider +): Promise { + const credentialProvider = new AWSSDKCredentialProvider(provider); // We shouldn't ever receive a response from the AWS SDK that doesn't have a `SecretAccessKey` // or `AccessKeyId`. However, TS says these fields are optional. We provide empty strings diff --git a/src/client-side-encryption/providers/index.ts b/src/client-side-encryption/providers/index.ts index f254cf69f92..63108c7ed93 100644 --- a/src/client-side-encryption/providers/index.ts +++ b/src/client-side-encryption/providers/index.ts @@ -1,4 +1,5 @@ import type { Binary } from '../../bson'; +import { type AWSCredentialProvider } from '../../cmap/auth/aws_temporary_credentials'; import { loadAWSCredentials } from './aws'; import { loadAzureCredentials } from './azure'; import { loadGCPCredentials } from './gcp'; @@ -112,6 +113,15 @@ export type GCPKMSProviderConfiguration = accessToken: string; }; +/** + * @public + * Configuration options for custom credential providers for KMS requests. + */ +export interface CredentialProviders { + /* A custom AWS credential provider */ + aws?: AWSCredentialProvider; +} + /** * @public * Configuration options that are used by specific KMS providers during key generation, encryption, and decryption. @@ -176,11 +186,14 @@ export function isEmptyCredentials( * * @internal */ -export async function refreshKMSCredentials(kmsProviders: KMSProviders): Promise { +export async function refreshKMSCredentials( + kmsProviders: KMSProviders, + credentialProviders?: CredentialProviders +): Promise { let finalKMSProviders = kmsProviders; if (isEmptyCredentials('aws', kmsProviders)) { - finalKMSProviders = await loadAWSCredentials(finalKMSProviders); + finalKMSProviders = await loadAWSCredentials(finalKMSProviders, credentialProviders?.aws); } if (isEmptyCredentials('gcp', kmsProviders)) { diff --git a/src/cmap/auth/aws_temporary_credentials.ts b/src/cmap/auth/aws_temporary_credentials.ts index c93456a5453..baa1a64fc81 100644 --- a/src/cmap/auth/aws_temporary_credentials.ts +++ b/src/cmap/auth/aws_temporary_credentials.ts @@ -21,6 +21,9 @@ export interface AWSTempCredentials { Expiration?: Date; } +/** @public **/ +export type AWSCredentialProvider = () => Promise; + /** * @internal * @@ -41,7 +44,20 @@ export abstract class AWSTemporaryCredentialProvider { /** @internal */ export class AWSSDKCredentialProvider extends AWSTemporaryCredentialProvider { - private _provider?: () => Promise; + private _provider?: AWSCredentialProvider; + + /** + * Create the SDK credentials provider. + * @param credentialsProvider - The credentials provider. + */ + constructor(credentialsProvider?: AWSCredentialProvider) { + super(); + + if (credentialsProvider) { + this._provider = credentialsProvider; + } + } + /** * The AWS SDK caches credentials automatically and handles refresh when the credentials have expired. * To ensure this occurs, we need to cache the `provider` returned by the AWS sdk and re-use it when fetching credentials. diff --git a/src/cmap/auth/mongo_credentials.ts b/src/cmap/auth/mongo_credentials.ts index 97c457945df..259147d5b07 100644 --- a/src/cmap/auth/mongo_credentials.ts +++ b/src/cmap/auth/mongo_credentials.ts @@ -6,6 +6,7 @@ import { MongoInvalidArgumentError, MongoMissingCredentialsError } from '../../error'; +import type { AWSCredentialProvider } from './aws_temporary_credentials'; import { GSSAPICanonicalizationValue } from './gssapi'; import type { OIDCCallbackFunction } from './mongodb_oidc'; import { AUTH_MECHS_AUTH_SRC_EXTERNAL, AuthMechanism } from './providers'; @@ -68,6 +69,33 @@ export interface AuthMechanismProperties extends Document { ALLOWED_HOSTS?: string[]; /** The resource token for OIDC auth in Azure and GCP. */ TOKEN_RESOURCE?: string; + /** + * A custom AWS credential provider to use. An example using the AWS SDK default provider chain: + * + * ```ts + * const client = new MongoClient(process.env.MONGODB_URI, { + * authMechanismProperties: { + * AWS_CREDENTIAL_PROVIDER: fromNodeProviderChain() + * } + * }); + * ``` + * + * Using a custom function that returns AWS credentials: + * + * ```ts + * const client = new MongoClient(process.env.MONGODB_URI, { + * authMechanismProperties: { + * AWS_CREDENTIAL_PROVIDER: async () => { + * return { + * accessKeyId: process.env.ACCESS_KEY_ID, + * secretAccessKey: process.env.SECRET_ACCESS_KEY + * } + * } + * } + * }); + * ``` + */ + AWS_CREDENTIAL_PROVIDER?: AWSCredentialProvider; } /** @public */ diff --git a/src/cmap/auth/mongodb_aws.ts b/src/cmap/auth/mongodb_aws.ts index d9071496b54..9cb22c82caa 100644 --- a/src/cmap/auth/mongodb_aws.ts +++ b/src/cmap/auth/mongodb_aws.ts @@ -9,6 +9,7 @@ import { import { ByteUtils, maxWireVersion, ns, randomBytes } from '../../utils'; import { type AuthContext, AuthProvider } from './auth_provider'; import { + type AWSCredentialProvider, AWSSDKCredentialProvider, type AWSTempCredentials, AWSTemporaryCredentialProvider, @@ -34,11 +35,14 @@ interface AWSSaslContinuePayload { export class MongoDBAWS extends AuthProvider { private credentialFetcher: AWSTemporaryCredentialProvider; - constructor() { + private credentialProvider?: AWSCredentialProvider; + + constructor(credentialProvider?: AWSCredentialProvider) { super(); + this.credentialProvider = credentialProvider; this.credentialFetcher = AWSTemporaryCredentialProvider.isAWSSDKInstalled - ? new AWSSDKCredentialProvider() + ? new AWSSDKCredentialProvider(credentialProvider) : new LegacyAWSTemporaryCredentialProvider(); } diff --git a/src/deps.ts b/src/deps.ts index 0947ca6efc5..e9f4a42e39f 100644 --- a/src/deps.ts +++ b/src/deps.ts @@ -78,14 +78,14 @@ export function getZstdLibrary(): ZStandardLib | { kModuleError: MongoMissingDep } /** - * @internal + * @public * Copy of the AwsCredentialIdentityProvider interface from [`smithy/types`](https://socket.dev/npm/package/\@smithy/types/files/1.1.1/dist-types/identity/awsCredentialIdentity.d.ts), * the return type of the aws-sdk's `fromNodeProviderChain().provider()`. */ export interface AWSCredentials { accessKeyId: string; secretAccessKey: string; - sessionToken: string; + sessionToken?: string; expiration?: Date; } diff --git a/src/index.ts b/src/index.ts index dfc8ac4b8e7..476b5affc3b 100644 --- a/src/index.ts +++ b/src/index.ts @@ -128,10 +128,11 @@ export { ReadPreferenceMode } from './read_preference'; export { ServerType, TopologyType } from './sdam/common'; // Helper classes +export type { AWSCredentialProvider } from './cmap/auth/aws_temporary_credentials'; +export type { AWSCredentials } from './deps'; export { ReadConcern } from './read_concern'; export { ReadPreference } from './read_preference'; export { WriteConcern } from './write_concern'; - // events export { CommandFailedEvent, @@ -255,6 +256,7 @@ export type { AWSKMSProviderConfiguration, AzureKMSProviderConfiguration, ClientEncryptionDataKeyProvider, + CredentialProviders, GCPKMSProviderConfiguration, KMIPKMSProviderConfiguration, KMSProviders, diff --git a/src/mongo_client_auth_providers.ts b/src/mongo_client_auth_providers.ts index c23d515e17a..54aab957a56 100644 --- a/src/mongo_client_auth_providers.ts +++ b/src/mongo_client_auth_providers.ts @@ -13,8 +13,14 @@ import { X509 } from './cmap/auth/x509'; import { MongoInvalidArgumentError } from './error'; /** @internal */ -const AUTH_PROVIDERS = new Map AuthProvider>([ - [AuthMechanism.MONGODB_AWS, () => new MongoDBAWS()], +const AUTH_PROVIDERS = new Map< + AuthMechanism | string, + (authMechanismProperties: AuthMechanismProperties) => AuthProvider +>([ + [ + AuthMechanism.MONGODB_AWS, + ({ AWS_CREDENTIAL_PROVIDER }) => new MongoDBAWS(AWS_CREDENTIAL_PROVIDER) + ], [ AuthMechanism.MONGODB_CR, () => { @@ -24,7 +30,7 @@ const AUTH_PROVIDERS = new Map } ], [AuthMechanism.MONGODB_GSSAPI, () => new GSSAPI()], - [AuthMechanism.MONGODB_OIDC, (workflow?: Workflow) => new MongoDBOIDC(workflow)], + [AuthMechanism.MONGODB_OIDC, properties => new MongoDBOIDC(getWorkflow(properties))], [AuthMechanism.MONGODB_PLAIN, () => new Plain()], [AuthMechanism.MONGODB_SCRAM_SHA1, () => new ScramSHA1()], [AuthMechanism.MONGODB_SCRAM_SHA256, () => new ScramSHA256()], @@ -62,37 +68,28 @@ export class MongoClientAuthProviders { throw new MongoInvalidArgumentError(`authMechanism ${name} not supported`); } - let provider; - if (name === AuthMechanism.MONGODB_OIDC) { - provider = providerFunction(this.getWorkflow(authMechanismProperties)); - } else { - provider = providerFunction(); - } - + const provider = providerFunction(authMechanismProperties); this.existingProviders.set(name, provider); return provider; } +} - /** - * Gets either a device workflow or callback workflow. - */ - getWorkflow(authMechanismProperties: AuthMechanismProperties): Workflow { - if (authMechanismProperties.OIDC_HUMAN_CALLBACK) { - return new HumanCallbackWorkflow( - new TokenCache(), - authMechanismProperties.OIDC_HUMAN_CALLBACK +/** + * Gets either a device workflow or callback workflow. + */ +function getWorkflow(authMechanismProperties: AuthMechanismProperties): Workflow { + if (authMechanismProperties.OIDC_HUMAN_CALLBACK) { + return new HumanCallbackWorkflow(new TokenCache(), authMechanismProperties.OIDC_HUMAN_CALLBACK); + } else if (authMechanismProperties.OIDC_CALLBACK) { + return new AutomatedCallbackWorkflow(new TokenCache(), authMechanismProperties.OIDC_CALLBACK); + } else { + const environment = authMechanismProperties.ENVIRONMENT; + const workflow = OIDC_WORKFLOWS.get(environment)?.(); + if (!workflow) { + throw new MongoInvalidArgumentError( + `Could not load workflow for environment ${authMechanismProperties.ENVIRONMENT}` ); - } else if (authMechanismProperties.OIDC_CALLBACK) { - return new AutomatedCallbackWorkflow(new TokenCache(), authMechanismProperties.OIDC_CALLBACK); - } else { - const environment = authMechanismProperties.ENVIRONMENT; - const workflow = OIDC_WORKFLOWS.get(environment)?.(); - if (!workflow) { - throw new MongoInvalidArgumentError( - `Could not load workflow for environment ${authMechanismProperties.ENVIRONMENT}` - ); - } - return workflow; } + return workflow; } } diff --git a/test/integration/auth/mongodb_aws.test.ts b/test/integration/auth/mongodb_aws.test.ts index 74feeff48fc..e65b9c60bda 100644 --- a/test/integration/auth/mongodb_aws.test.ts +++ b/test/integration/auth/mongodb_aws.test.ts @@ -136,6 +136,47 @@ describe('MONGODB-AWS', function () { }); }); + context('when user supplies a credentials provider', function () { + let providerCount = 0; + + beforeEach(function () { + if (!awsSdkPresent) { + this.skipReason = 'only relevant to AssumeRoleWithWebIdentity with SDK installed'; + return this.skip(); + } + // If we have a username the credentials have been set from the URI, options, or environment + // variables per the auth spec stated order. + if (client.options.credentials.username) { + this.skipReason = 'Credentials in the URI on env variables will not use custom provider.'; + return this.skip(); + } + }); + + it('authenticates with a user provided credentials provider', async function () { + // @ts-expect-error We intentionally access a protected variable. + const credentialProvider = AWSTemporaryCredentialProvider.awsSDK; + const provider = async () => { + providerCount++; + return await credentialProvider.fromNodeProviderChain().apply(); + }; + client = this.configuration.newClient(process.env.MONGODB_URI, { + authMechanismProperties: { + AWS_CREDENTIAL_PROVIDER: provider + } + }); + + const result = await client + .db('aws') + .collection('aws_test') + .estimatedDocumentCount() + .catch(error => error); + + expect(result).to.not.be.instanceOf(MongoServerError); + expect(result).to.be.a('number'); + expect(providerCount).to.be.greaterThan(0); + }); + }); + it('should allow empty string in authMechanismProperties.AWS_SESSION_TOKEN to override AWS_SESSION_TOKEN environment variable', function () { client = this.configuration.newClient(this.configuration.url(), { authMechanismProperties: { AWS_SESSION_TOKEN: '' } @@ -426,11 +467,36 @@ describe('AWS KMS Credential Fetching', function () { : undefined; this.currentTest?.skipReason && this.skip(); }); - it('KMS credentials are successfully fetched.', async function () { - const { aws } = await refreshKMSCredentials({ aws: {} }); - expect(aws).to.have.property('accessKeyId'); - expect(aws).to.have.property('secretAccessKey'); + context('when a credential provider is not provided', function () { + it('KMS credentials are successfully fetched.', async function () { + const { aws } = await refreshKMSCredentials({ aws: {} }); + + expect(aws).to.have.property('accessKeyId'); + expect(aws).to.have.property('secretAccessKey'); + }); + }); + + context('when a credential provider is provided', function () { + let credentialProvider; + let providerCount = 0; + + beforeEach(function () { + // @ts-expect-error We intentionally access a protected variable. + const provider = AWSTemporaryCredentialProvider.awsSDK; + credentialProvider = async () => { + providerCount++; + return await provider.fromNodeProviderChain().apply(); + }; + }); + + it('KMS credentials are successfully fetched.', async function () { + const { aws } = await refreshKMSCredentials({ aws: {} }, { aws: credentialProvider }); + + expect(aws).to.have.property('accessKeyId'); + expect(aws).to.have.property('secretAccessKey'); + expect(providerCount).to.be.greaterThan(0); + }); }); it('does not return any extra keys for the `aws` credential provider', async function () { diff --git a/test/integration/client-side-encryption/client_side_encryption.prose.26.custom_aws_credential_providers.test.ts b/test/integration/client-side-encryption/client_side_encryption.prose.26.custom_aws_credential_providers.test.ts new file mode 100644 index 00000000000..da4a90741e3 --- /dev/null +++ b/test/integration/client-side-encryption/client_side_encryption.prose.26.custom_aws_credential_providers.test.ts @@ -0,0 +1,112 @@ +import { expect } from 'chai'; + +/* eslint-disable @typescript-eslint/no-restricted-imports */ +import { ClientEncryption } from '../../../src/client-side-encryption/client_encryption'; +import { AWSTemporaryCredentialProvider, Binary, MongoClient } from '../../mongodb'; +import { getEncryptExtraOptions } from '../../tools/utils'; + +const metadata: MongoDBMetadataUI = { + requires: { + clientSideEncryption: true, + mongodb: '>=4.2.0', + topology: '!load-balanced' + } +} as const; + +const masterKey = { + region: 'us-east-1', + key: 'arn:aws:kms:us-east-1:579766882180:key/89fcc2c4-08b0-4bd9-9f25-e30687b580d0' +}; + +describe('26. Custom AWS Credential Providers', metadata, () => { + let keyVaultClient; + let credentialProvider; + + beforeEach(async function () { + this.currentTest.skipReason = !AWSTemporaryCredentialProvider.isAWSSDKInstalled + ? 'This test must run in an environment where the AWS SDK is installed.' + : undefined; + this.currentTest?.skipReason && this.skip(); + + keyVaultClient = this.configuration.newClient(process.env.MONGODB_UR); + // @ts-expect-error We intentionally access a protected variable. + credentialProvider = AWSTemporaryCredentialProvider.awsSDK; + }); + + afterEach(async () => { + await keyVaultClient?.close(); + }); + + context( + 'Case 1: ClientEncryption with credentialProviders and incorrect kmsProviders', + metadata, + function () { + it('throws an error', metadata, function () { + expect(() => { + new ClientEncryption(keyVaultClient, { + keyVaultNamespace: 'keyvault.datakeys', + kmsProviders: { + aws: { + accessKeyId: process.env.FLE_AWS_KEY, + secretAccessKey: process.env.FLE_AWS_SECRET + } + }, + credentialProviders: { aws: credentialProvider.fromNodeProviderChain() } + }); + }).to.throw(/Can only provide a custom AWS credential provider/); + }); + } + ); + + context('Case 2: ClientEncryption with credentialProviders works', metadata, function () { + let clientEncryption; + let providerCount = 0; + + beforeEach(function () { + const options = { + keyVaultNamespace: 'keyvault.datakeys', + kmsProviders: { aws: {} }, + credentialProviders: { + aws: async () => { + providerCount++; + return { + accessKeyId: process.env.FLE_AWS_KEY, + secretAccessKey: process.env.FLE_AWS_SECRET + }; + } + }, + extraOptions: getEncryptExtraOptions() + }; + clientEncryption = new ClientEncryption(keyVaultClient, options); + }); + + it('is successful', metadata, async function () { + const dk = await clientEncryption.createDataKey('aws', { masterKey }); + expect(dk).to.be.instanceOf(Binary); + expect(providerCount).to.be.greaterThan(0); + }); + }); + + context( + 'Case 3: AutoEncryptionOpts with credentialProviders and incorrect kmsProviders', + metadata, + function () { + it('throws an error', metadata, function () { + expect(() => { + new MongoClient('mongodb://127.0.0.1:27017', { + autoEncryption: { + keyVaultNamespace: 'keyvault.datakeys', + kmsProviders: { + aws: { + accessKeyId: process.env.FLE_AWS_KEY, + secretAccessKey: process.env.FLE_AWS_SECRET + } + }, + credentialProviders: { aws: credentialProvider.fromNodeProviderChain() } + } + }); + }).to.throw(/Can only provide a custom AWS credential provider/); + }); + } + ); +});