diff --git a/packages/bolt-connection/src/bolt/bolt-protocol-v1.js b/packages/bolt-connection/src/bolt/bolt-protocol-v1.js index 35c8b5dfa..87b736ce7 100644 --- a/packages/bolt-connection/src/bolt/bolt-protocol-v1.js +++ b/packages/bolt-connection/src/bolt/bolt-protocol-v1.js @@ -205,6 +205,7 @@ export default class BoltProtocol { onError: onError }) + // TODO: Verify the Neo4j version in the message const error = newError( 'Driver is connected to a database that does not support logoff. ' + 'Please upgrade to Neo4j 5.5.0 or later in order to use this functionality.' @@ -233,6 +234,7 @@ export default class BoltProtocol { onError: (error) => this._onLoginError(error, onError) }) + // TODO: Verify the Neo4j version in the message const error = newError( 'Driver is connected to a database that does not support logon. ' + 'Please upgrade to Neo4j 5.5.0 or later in order to use this functionality.' diff --git a/packages/bolt-connection/src/connection-provider/authentication-provider.js b/packages/bolt-connection/src/connection-provider/authentication-provider.js new file mode 100644 index 000000000..7f406d127 --- /dev/null +++ b/packages/bolt-connection/src/connection-provider/authentication-provider.js @@ -0,0 +1,66 @@ +/** + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [http://neo4j.com] + * + * This file is part of Neo4j. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { expirationBasedAuthTokenManager } from 'neo4j-driver-core' +import { object } from '../lang' + +/** + * Class which provides Authorization for {@link Connection} + */ +export default class AuthenticationProvider { + constructor ({ authTokenManager, userAgent }) { + this._authTokenManager = authTokenManager || expirationBasedAuthTokenManager({ + tokenProvider: () => {} + }) + this._userAgent = userAgent + } + + async authenticate ({ connection, auth, skipReAuth, waitReAuth, forceReAuth }) { + if (auth != null) { + const shouldReAuth = connection.supportsReAuth === true && ( + (!object.equals(connection.authToken, auth) && skipReAuth !== true) || + forceReAuth === true + ) + if (connection.authToken == null || shouldReAuth) { + return await connection.connect(this._userAgent, auth, waitReAuth || false) + } + return connection + } + + const authToken = await this._authTokenManager.getToken() + + if (!object.equals(authToken, connection.authToken)) { + return await connection.connect(this._userAgent, authToken) + } + + return connection + } + + handleError ({ connection, code }) { + if ( + connection && + [ + 'Neo.ClientError.Security.Unauthorized', + 'Neo.ClientError.Security.TokenExpired' + ].includes(code) + ) { + this._authTokenManager.onTokenExpired(connection.authToken) + } + } +} diff --git a/packages/bolt-connection/src/connection-provider/connection-provider-direct.js b/packages/bolt-connection/src/connection-provider/connection-provider-direct.js index b4cb1ce61..fd5ed21bd 100644 --- a/packages/bolt-connection/src/connection-provider/connection-provider-direct.js +++ b/packages/bolt-connection/src/connection-provider/connection-provider-direct.js @@ -26,14 +26,19 @@ import { import { internal, error } from 'neo4j-driver-core' const { - constants: { BOLT_PROTOCOL_V3, BOLT_PROTOCOL_V4_0, BOLT_PROTOCOL_V4_4 } + constants: { + BOLT_PROTOCOL_V3, + BOLT_PROTOCOL_V4_0, + BOLT_PROTOCOL_V4_4, + BOLT_PROTOCOL_V5_1 + } } = internal const { SERVICE_UNAVAILABLE } = error export default class DirectConnectionProvider extends PooledConnectionProvider { - constructor ({ id, config, log, address, userAgent, authToken }) { - super({ id, config, log, userAgent, authToken }) + constructor ({ id, config, log, address, userAgent, authTokenManager, newPool }) { + super({ id, config, log, userAgent, authTokenManager, newPool }) this._address = address } @@ -42,27 +47,33 @@ export default class DirectConnectionProvider extends PooledConnectionProvider { * See {@link ConnectionProvider} for more information about this method and * its arguments. */ - acquireConnection ({ accessMode, database, bookmarks } = {}) { + async acquireConnection ({ accessMode, database, bookmarks, auth, forceReAuth } = {}) { const databaseSpecificErrorHandler = ConnectionErrorHandler.create({ errorCode: SERVICE_UNAVAILABLE, - handleAuthorizationExpired: (error, address) => - this._handleAuthorizationExpired(error, address, database) + handleAuthorizationExpired: (error, address, conn) => + this._handleAuthorizationExpired(error, address, conn, database) }) - return this._connectionPool - .acquire(this._address) - .then( - connection => - new DelegateConnection(connection, databaseSpecificErrorHandler) - ) + const connection = await this._connectionPool.acquire({ auth, forceReAuth }, this._address) + + if (auth) { + await this._verifyStickyConnection({ + auth, + connection, + address: this._address + }) + return connection + } + + return new DelegateConnection(connection, databaseSpecificErrorHandler) } - _handleAuthorizationExpired (error, address, database) { + _handleAuthorizationExpired (error, address, connection, database) { this._log.warn( `Direct driver ${this._id} will close connection to ${address} for database '${database}' because of an error ${error.code} '${error.message}'` ) - this._connectionPool.purge(address).catch(() => {}) - return error + + return super._handleAuthorizationExpired(error, address, connection) } async _hasProtocolVersion (versionPredicate) { @@ -111,6 +122,19 @@ export default class DirectConnectionProvider extends PooledConnectionProvider { ) } + async supportsSessionAuth () { + return await this._hasProtocolVersion( + version => version >= BOLT_PROTOCOL_V5_1 + ) + } + + async verifyAuthentication ({ auth }) { + return this._verifyAuthentication({ + auth, + getAddress: () => this._address + }) + } + async verifyConnectivityAndGetServerInfo () { return await this._verifyConnectivityAndGetServerVersion({ address: this._address }) } diff --git a/packages/bolt-connection/src/connection-provider/connection-provider-pooled.js b/packages/bolt-connection/src/connection-provider/connection-provider-pooled.js index f6424d5ba..2c297daa9 100644 --- a/packages/bolt-connection/src/connection-provider/connection-provider-pooled.js +++ b/packages/bolt-connection/src/connection-provider/connection-provider-pooled.js @@ -19,12 +19,21 @@ import { createChannelConnection, ConnectionErrorHandler } from '../connection' import Pool, { PoolConfig } from '../pool' -import { error, ConnectionProvider, ServerInfo } from 'neo4j-driver-core' +import { error, ConnectionProvider, ServerInfo, newError, isStaticAuthTokenManger } from 'neo4j-driver-core' +import AuthenticationProvider from './authentication-provider' +import { object } from '../lang' const { SERVICE_UNAVAILABLE } = error +const AUTHENTICATION_ERRORS = [ + 'Neo.ClientError.Security.CredentialsExpired', + 'Neo.ClientError.Security.Forbidden', + 'Neo.ClientError.Security.TokenExpired', + 'Neo.ClientError.Security.Unauthorized' +] + export default class PooledConnectionProvider extends ConnectionProvider { constructor ( - { id, config, log, userAgent, authToken }, + { id, config, log, userAgent, authTokenManager, newPool = (...args) => new Pool(...args) }, createChannelConnectionHook = null ) { super() @@ -32,8 +41,8 @@ export default class PooledConnectionProvider extends ConnectionProvider { this._id = id this._config = config this._log = log - this._userAgent = userAgent - this._authToken = authToken + this._authTokenManager = authTokenManager + this._authenticationProvider = new AuthenticationProvider({ authTokenManager, userAgent }) this._createChannelConnection = createChannelConnectionHook || (address => { @@ -44,10 +53,11 @@ export default class PooledConnectionProvider extends ConnectionProvider { this._log ) }) - this._connectionPool = new Pool({ + this._connectionPool = newPool({ create: this._createConnection.bind(this), destroy: this._destroyConnection.bind(this), - validate: this._validateConnection.bind(this), + validateOnAcquire: this._validateConnectionOnAcquire.bind(this), + validateOnRelease: this._validateConnectionOnRelease.bind(this), installIdleObserver: PooledConnectionProvider._installIdleObserverOnConnection.bind( this ), @@ -57,6 +67,7 @@ export default class PooledConnectionProvider extends ConnectionProvider { config: PoolConfig.fromDriverConfig(config), log: this._log }) + this._userAgent = userAgent this._openConnections = {} } @@ -69,14 +80,13 @@ export default class PooledConnectionProvider extends ConnectionProvider { * @return {Promise} promise resolved with a new connection or rejected when failed to connect. * @access private */ - _createConnection (address, release) { + _createConnection ({ auth }, address, release) { return this._createChannelConnection(address).then(connection => { connection._release = () => { return release(address, connection) } this._openConnections[connection.id] = connection - return connection - .connect(this._userAgent, this._authToken) + return this._authenticationProvider.authenticate({ connection, auth }) .catch(error => { // let's destroy this connection this._destroyConnection(connection) @@ -86,6 +96,26 @@ export default class PooledConnectionProvider extends ConnectionProvider { }) } + async _validateConnectionOnAcquire ({ auth, skipReAuth }, conn) { + if (!this._validateConnection(conn)) { + return false + } + + try { + await this._authenticationProvider.authenticate({ connection: conn, auth, skipReAuth }) + return true + } catch (error) { + this._log.debug( + `The connection ${conn.id} is not valid because of an error ${error.code} '${error.message}'` + ) + return false + } + } + + _validateConnectionOnRelease (conn) { + return conn._sticky !== true && this._validateConnection(conn) + } + /** * Check that a connection is usable * @return {boolean} true if the connection is open @@ -98,7 +128,11 @@ export default class PooledConnectionProvider extends ConnectionProvider { const maxConnectionLifetime = this._config.maxConnectionLifetime const lifetime = Date.now() - conn.creationTimestamp - return lifetime <= maxConnectionLifetime + if (lifetime > maxConnectionLifetime) { + return false + } + + return true } /** @@ -118,7 +152,7 @@ export default class PooledConnectionProvider extends ConnectionProvider { * @return {Promise} the server info */ async _verifyConnectivityAndGetServerVersion ({ address }) { - const connection = await this._connectionPool.acquire(address) + const connection = await this._connectionPool.acquire({}, address) const serverInfo = new ServerInfo(connection.server, connection.protocol().version) try { if (!connection.protocol().isLastMessageLogon()) { @@ -130,6 +164,47 @@ export default class PooledConnectionProvider extends ConnectionProvider { return serverInfo } + async _verifyAuthentication ({ getAddress, auth }) { + const connectionsToRelease = [] + try { + const address = await getAddress() + const connection = await this._connectionPool.acquire({ auth, skipReAuth: true }, address) + connectionsToRelease.push(connection) + + const lastMessageIsNotLogin = !connection.protocol().isLastMessageLogon() + + if (!connection.supportsReAuth) { + throw newError('Driver is connected to a database that does not support user switch.') + } + if (lastMessageIsNotLogin && connection.supportsReAuth) { + await this._authenticationProvider.authenticate({ connection, auth, waitReAuth: true, forceReAuth: true }) + } else if (lastMessageIsNotLogin && !connection.supportsReAuth) { + const stickyConnection = await this._connectionPool.acquire({ auth }, address, { requireNew: true }) + stickyConnection._sticky = true + connectionsToRelease.push(stickyConnection) + } + return true + } catch (error) { + if (AUTHENTICATION_ERRORS.includes(error.code)) { + return false + } + throw error + } finally { + await Promise.all(connectionsToRelease.map(conn => conn._release())) + } + } + + async _verifyStickyConnection ({ auth, connection, address }) { + const connectionWithSameCredentials = object.equals(auth, connection.authToken) + const shouldCreateStickyConnection = !connectionWithSameCredentials + connection._sticky = connectionWithSameCredentials && !connection.supportsReAuth + + if (shouldCreateStickyConnection || connection._sticky) { + await connection._release() + throw newError('Driver is connected to a database that does not support user switch.') + } + } + async close () { // purge all idle connections in the connection pool await this._connectionPool.close() @@ -146,4 +221,22 @@ export default class PooledConnectionProvider extends ConnectionProvider { static _removeIdleObserverOnConnection (conn) { conn._updateCurrentObserver() } + + _handleAuthorizationExpired (error, address, connection) { + this._authenticationProvider.handleError({ connection, code: error.code }) + + if (error.code === 'Neo.ClientError.Security.AuthorizationExpired') { + this._connectionPool.apply(address, (conn) => { conn.authToken = null }) + } + + if (connection) { + connection.close().catch(() => undefined) + } + + if (error.code === 'Neo.ClientError.Security.TokenExpired' && !isStaticAuthTokenManger(this._authTokenManager)) { + error.retriable = true + } + + return error + } } diff --git a/packages/bolt-connection/src/connection-provider/connection-provider-routing.js b/packages/bolt-connection/src/connection-provider/connection-provider-routing.js index f0442ab23..65e077100 100644 --- a/packages/bolt-connection/src/connection-provider/connection-provider-routing.js +++ b/packages/bolt-connection/src/connection-provider/connection-provider-routing.js @@ -37,7 +37,8 @@ const { ACCESS_MODE_WRITE: WRITE, BOLT_PROTOCOL_V3, BOLT_PROTOCOL_V4_0, - BOLT_PROTOCOL_V4_4 + BOLT_PROTOCOL_V4_4, + BOLT_PROTOCOL_V5_1 } } = internal @@ -51,6 +52,7 @@ const AUTHORIZATION_EXPIRED_CODE = const INVALID_ARGUMENT_ERROR = 'Neo.ClientError.Statement.ArgumentError' const INVALID_REQUEST_ERROR = 'Neo.ClientError.Request.Invalid' const STATEMENT_TYPE_ERROR = 'Neo.ClientError.Statement.TypeError' +const NOT_AVAILABLE = 'N/A' const SYSTEM_DB_NAME = 'system' const DEFAULT_DB_NAME = null @@ -65,10 +67,11 @@ export default class RoutingConnectionProvider extends PooledConnectionProvider config, log, userAgent, - authToken, - routingTablePurgeDelay + authTokenManager, + routingTablePurgeDelay, + newPool }) { - super({ id, config, log, userAgent, authToken }, address => { + super({ id, config, log, userAgent, authTokenManager, newPool }, address => { return createChannelConnection( address, this._config, @@ -109,12 +112,12 @@ export default class RoutingConnectionProvider extends PooledConnectionProvider return error } - _handleAuthorizationExpired (error, address, database) { + _handleAuthorizationExpired (error, address, connection, database) { this._log.warn( `Routing driver ${this._id} will close connections to ${address} for database '${database}' because of an error ${error.code} '${error.message}'` ) - this._connectionPool.purge(address).catch(() => {}) - return error + + return super._handleAuthorizationExpired(error, address, connection, database) } _handleWriteFailure (error, address, database) { @@ -133,7 +136,7 @@ export default class RoutingConnectionProvider extends PooledConnectionProvider * See {@link ConnectionProvider} for more information about this method and * its arguments. */ - async acquireConnection ({ accessMode, database, bookmarks, impersonatedUser, onDatabaseNameResolved } = {}) { + async acquireConnection ({ accessMode, database, bookmarks, impersonatedUser, onDatabaseNameResolved, auth } = {}) { let name let address const context = { database: database || DEFAULT_DB_NAME } @@ -142,8 +145,8 @@ export default class RoutingConnectionProvider extends PooledConnectionProvider SESSION_EXPIRED, (error, address) => this._handleUnavailability(error, address, context.database), (error, address) => this._handleWriteFailure(error, address, context.database), - (error, address) => - this._handleAuthorizationExpired(error, address, context.database) + (error, address, conn) => + this._handleAuthorizationExpired(error, address, conn, context.database) ) const routingTable = await this._freshRoutingTable({ @@ -151,6 +154,7 @@ export default class RoutingConnectionProvider extends PooledConnectionProvider database: context.database, bookmarks, impersonatedUser, + auth, onDatabaseNameResolved: (databaseName) => { context.database = context.database || databaseName if (onDatabaseNameResolved) { @@ -179,11 +183,16 @@ export default class RoutingConnectionProvider extends PooledConnectionProvider } try { - const connection = await this._acquireConnectionToServer( - address, - name, - routingTable - ) + const connection = await this._connectionPool.acquire({ auth }, address) + + if (auth) { + await this._verifyStickyConnection({ + auth, + connection, + address + }) + return connection + } return new DelegateConnection(connection, databaseSpecificErrorHandler) } catch (error) { @@ -248,6 +257,12 @@ export default class RoutingConnectionProvider extends PooledConnectionProvider ) } + async supportsSessionAuth () { + return await this._hasProtocolVersion( + version => version >= BOLT_PROTOCOL_V5_1 + ) + } + getNegotiatedProtocolVersion () { return new Promise((resolve, reject) => { this._hasProtocolVersion(resolve) @@ -255,6 +270,35 @@ export default class RoutingConnectionProvider extends PooledConnectionProvider }) } + async verifyAuthentication ({ database, accessMode, auth }) { + return this._verifyAuthentication({ + auth, + getAddress: async () => { + const context = { database: database || DEFAULT_DB_NAME } + + const routingTable = await this._freshRoutingTable({ + accessMode, + database: context.database, + auth, + onDatabaseNameResolved: (databaseName) => { + context.database = context.database || databaseName + } + }) + + const servers = accessMode === WRITE ? routingTable.writers : routingTable.readers + + if (servers.length === 0) { + throw newError( + `No servers available for database '${context.database}' with access mode '${accessMode}'`, + SERVICE_UNAVAILABLE + ) + } + + return servers[0] + } + }) + } + async verifyConnectivityAndGetServerInfo ({ database, accessMode }) { const context = { database: database || DEFAULT_DB_NAME } @@ -300,11 +344,7 @@ export default class RoutingConnectionProvider extends PooledConnectionProvider }) } - _acquireConnectionToServer (address, serverName, routingTable) { - return this._connectionPool.acquire(address) - } - - _freshRoutingTable ({ accessMode, database, bookmarks, impersonatedUser, onDatabaseNameResolved } = {}) { + _freshRoutingTable ({ accessMode, database, bookmarks, impersonatedUser, onDatabaseNameResolved, auth } = {}) { const currentRoutingTable = this._routingTableRegistry.get( database, () => new RoutingTable({ database }) @@ -316,10 +356,10 @@ export default class RoutingConnectionProvider extends PooledConnectionProvider this._log.info( `Routing table is stale for database: "${database}" and access mode: "${accessMode}": ${currentRoutingTable}` ) - return this._refreshRoutingTable(currentRoutingTable, bookmarks, impersonatedUser, onDatabaseNameResolved) + return this._refreshRoutingTable(currentRoutingTable, bookmarks, impersonatedUser, onDatabaseNameResolved, auth) } - _refreshRoutingTable (currentRoutingTable, bookmarks, impersonatedUser, onDatabaseNameResolved) { + _refreshRoutingTable (currentRoutingTable, bookmarks, impersonatedUser, onDatabaseNameResolved, auth) { const knownRouters = currentRoutingTable.routers if (this._useSeedRouter) { @@ -328,7 +368,8 @@ export default class RoutingConnectionProvider extends PooledConnectionProvider currentRoutingTable, bookmarks, impersonatedUser, - onDatabaseNameResolved + onDatabaseNameResolved, + auth ) } return this._fetchRoutingTableFromKnownRoutersFallbackToSeedRouter( @@ -336,7 +377,8 @@ export default class RoutingConnectionProvider extends PooledConnectionProvider currentRoutingTable, bookmarks, impersonatedUser, - onDatabaseNameResolved + onDatabaseNameResolved, + auth ) } @@ -345,7 +387,8 @@ export default class RoutingConnectionProvider extends PooledConnectionProvider currentRoutingTable, bookmarks, impersonatedUser, - onDatabaseNameResolved + onDatabaseNameResolved, + auth ) { // we start with seed router, no routers were probed before const seenRouters = [] @@ -354,7 +397,8 @@ export default class RoutingConnectionProvider extends PooledConnectionProvider this._seedRouter, currentRoutingTable, bookmarks, - impersonatedUser + impersonatedUser, + auth ) if (newRoutingTable) { @@ -365,7 +409,8 @@ export default class RoutingConnectionProvider extends PooledConnectionProvider knownRouters, currentRoutingTable, bookmarks, - impersonatedUser + impersonatedUser, + auth ) newRoutingTable = newRoutingTable2 error = error2 || error @@ -384,13 +429,15 @@ export default class RoutingConnectionProvider extends PooledConnectionProvider currentRoutingTable, bookmarks, impersonatedUser, - onDatabaseNameResolved + onDatabaseNameResolved, + auth ) { let [newRoutingTable, error] = await this._fetchRoutingTableUsingKnownRouters( knownRouters, currentRoutingTable, bookmarks, - impersonatedUser + impersonatedUser, + auth ) if (!newRoutingTable) { @@ -400,7 +447,8 @@ export default class RoutingConnectionProvider extends PooledConnectionProvider this._seedRouter, currentRoutingTable, bookmarks, - impersonatedUser + impersonatedUser, + auth ) } @@ -416,13 +464,15 @@ export default class RoutingConnectionProvider extends PooledConnectionProvider knownRouters, currentRoutingTable, bookmarks, - impersonatedUser + impersonatedUser, + auth ) { const [newRoutingTable, error] = await this._fetchRoutingTable( knownRouters, currentRoutingTable, bookmarks, - impersonatedUser + impersonatedUser, + auth ) if (newRoutingTable) { @@ -447,7 +497,8 @@ export default class RoutingConnectionProvider extends PooledConnectionProvider seedRouter, routingTable, bookmarks, - impersonatedUser + impersonatedUser, + auth ) { const resolvedAddresses = await this._resolveSeedRouter(seedRouter) @@ -456,7 +507,7 @@ export default class RoutingConnectionProvider extends PooledConnectionProvider address => seenRouters.indexOf(address) < 0 ) - return await this._fetchRoutingTable(newAddresses, routingTable, bookmarks, impersonatedUser) + return await this._fetchRoutingTable(newAddresses, routingTable, bookmarks, impersonatedUser, auth) } async _resolveSeedRouter (seedRouter) { @@ -468,7 +519,7 @@ export default class RoutingConnectionProvider extends PooledConnectionProvider return [].concat.apply([], dnsResolvedAddresses) } - async _fetchRoutingTable (routerAddresses, routingTable, bookmarks, impersonatedUser) { + async _fetchRoutingTable (routerAddresses, routingTable, bookmarks, impersonatedUser, auth) { return routerAddresses.reduce( async (refreshedTablePromise, currentRouter, currentIndex) => { const [newRoutingTable] = await refreshedTablePromise @@ -491,7 +542,8 @@ export default class RoutingConnectionProvider extends PooledConnectionProvider const [session, error] = await this._createSessionForRediscovery( currentRouter, bookmarks, - impersonatedUser + impersonatedUser, + auth ) if (session) { try { @@ -516,17 +568,28 @@ export default class RoutingConnectionProvider extends PooledConnectionProvider ) } - async _createSessionForRediscovery (routerAddress, bookmarks, impersonatedUser) { + async _createSessionForRediscovery (routerAddress, bookmarks, impersonatedUser, auth) { try { - const connection = await this._connectionPool.acquire(routerAddress) + const connection = await this._connectionPool.acquire({ auth }, routerAddress) + + if (auth) { + await this._verifyStickyConnection({ + auth, + connection, + address: routerAddress + }) + } const databaseSpecificErrorHandler = ConnectionErrorHandler.create({ errorCode: SESSION_EXPIRED, - handleAuthorizationExpired: (error, address) => this._handleAuthorizationExpired(error, address) + handleAuthorizationExpired: (error, address, conn) => this._handleAuthorizationExpired(error, address, conn) }) - const connectionProvider = new SingleConnectionProvider( - new DelegateConnection(connection, databaseSpecificErrorHandler)) + const delegateConnection = !connection._sticky + ? new DelegateConnection(connection, databaseSpecificErrorHandler) + : new DelegateConnection(connection) + + const connectionProvider = new SingleConnectionProvider(delegateConnection) const protocolVersion = connection.protocol().version if (protocolVersion < 4.0) { @@ -709,7 +772,8 @@ function _isFailFastError (error) { INVALID_BOOKMARK_MIXTURE_CODE, INVALID_ARGUMENT_ERROR, INVALID_REQUEST_ERROR, - STATEMENT_TYPE_ERROR + STATEMENT_TYPE_ERROR, + NOT_AVAILABLE ].includes(error.code) } diff --git a/packages/bolt-connection/src/connection/connection-channel.js b/packages/bolt-connection/src/connection/connection-channel.js index 8b270519f..38bcd2fa5 100644 --- a/packages/bolt-connection/src/connection/connection-channel.js +++ b/packages/bolt-connection/src/connection/connection-channel.js @@ -124,7 +124,7 @@ export default class ChannelConnection extends Connection { protocolSupplier ) { super(errorHandler) - + this._authToken = null this._reseting = false this._resetObservers = [] this._id = idGenerator++ @@ -156,6 +156,18 @@ export default class ChannelConnection extends Connection { } } + get authToken () { + return this._authToken + } + + set authToken (value) { + this._authToken = value + } + + get supportsReAuth () { + return this._protocol.supportsReAuth + } + get id () { return this._id } @@ -174,8 +186,36 @@ export default class ChannelConnection extends Connection { * @param {Object} authToken the object containing auth information. * @return {Promise} promise resolved with the current connection if connection is successful. Rejected promise otherwise. */ - connect (userAgent, authToken) { - return this._initialize(userAgent, authToken) + async connect (userAgent, authToken, waitReAuth) { + if (this._protocol.initialized && !this._protocol.supportsReAuth) { + throw newError('Connection does not support re-auth') + } + + this._authToken = authToken + + if (!this._protocol.initialized) { + return await this._initialize(userAgent, authToken) + } + + if (waitReAuth) { + return await new Promise((resolve, reject) => { + this._protocol.logoff({ + onError: reject + }) + + this._protocol.logon({ + authToken, + onError: reject, + onComplete: () => resolve(this), + flush: true + }) + }) + } + + this._protocol.logoff() + this._protocol.logon({ authToken, flush: true }) + + return this } /** diff --git a/packages/bolt-connection/src/connection/connection-delegate.js b/packages/bolt-connection/src/connection/connection-delegate.js index 16d2da48a..1f77823f7 100644 --- a/packages/bolt-connection/src/connection/connection-delegate.js +++ b/packages/bolt-connection/src/connection/connection-delegate.js @@ -51,6 +51,18 @@ export default class DelegateConnection extends Connection { return this._delegate.server } + get authToken () { + return this._delegate.authToken + } + + get supportsReAuth () { + return this._delegate.supportsReAuth + } + + set authToken (value) { + this._delegate.authToken = value + } + get address () { return this._delegate.address } @@ -71,8 +83,8 @@ export default class DelegateConnection extends Connection { return this._delegate.protocol() } - connect (userAgent, authToken) { - return this._delegate.connect(userAgent, authToken) + connect (userAgent, authToken, waitReAuth) { + return this._delegate.connect(userAgent, authToken, waitReAuth) } write (message, observer, flush) { diff --git a/packages/bolt-connection/src/connection/connection-error-handler.js b/packages/bolt-connection/src/connection/connection-error-handler.js index de844e96f..91f855c11 100644 --- a/packages/bolt-connection/src/connection/connection-error-handler.js +++ b/packages/bolt-connection/src/connection/connection-error-handler.js @@ -62,15 +62,15 @@ export default class ConnectionErrorHandler { * @param {ServerAddress} address the address of the connection where the error happened. * @return {Neo4jError} new error that should be propagated to the user. */ - handleAndTransformError (error, address) { + handleAndTransformError (error, address, connection) { if (isAutorizationExpiredError(error)) { - return this._handleAuthorizationExpired(error, address) + return this._handleAuthorizationExpired(error, address, connection) } if (isAvailabilityError(error)) { - return this._handleUnavailability(error, address) + return this._handleUnavailability(error, address, connection) } if (isFailureToWrite(error)) { - return this._handleWriteFailure(error, address) + return this._handleWriteFailure(error, address, connection) } return error } diff --git a/packages/bolt-connection/src/connection/connection.js b/packages/bolt-connection/src/connection/connection.js index d3c692712..9d1107023 100644 --- a/packages/bolt-connection/src/connection/connection.js +++ b/packages/bolt-connection/src/connection/connection.js @@ -39,6 +39,18 @@ export default class Connection { throw new Error('not implemented') } + get authToken () { + throw new Error('not implemented') + } + + set authToken (value) { + throw new Error('not implemented') + } + + get supportsReAuth () { + throw new Error('not implemented') + } + /** * @returns {boolean} whether this connection is in a working condition */ @@ -124,7 +136,7 @@ export default class Connection { */ handleAndTransformError (error, address) { if (this._errorHandler) { - return this._errorHandler.handleAndTransformError(error, address) + return this._errorHandler.handleAndTransformError(error, address, this) } return error diff --git a/packages/bolt-connection/src/lang/index.js b/packages/bolt-connection/src/lang/index.js index 9d565fe8f..ab67c87c5 100644 --- a/packages/bolt-connection/src/lang/index.js +++ b/packages/bolt-connection/src/lang/index.js @@ -18,3 +18,4 @@ */ export * as functional from './functional' +export * as object from './object' diff --git a/packages/bolt-connection/src/lang/object.js b/packages/bolt-connection/src/lang/object.js new file mode 100644 index 000000000..e2a862c04 --- /dev/null +++ b/packages/bolt-connection/src/lang/object.js @@ -0,0 +1,47 @@ +/** + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [http://neo4j.com] + * + * This file is part of Neo4j. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +export function equals (a, b) { + if (a === b) { + return true + } + + if (a === null || b === null) { + return false + } + + if (typeof a === 'object' && typeof b === 'object') { + const keysA = Object.keys(a) + const keysB = Object.keys(b) + + if (keysA.length !== keysB.length) { + return false + } + + for (const key of keysA) { + if (a[key] !== b[key]) { + return false + } + } + + return true + } + + return false +} diff --git a/packages/bolt-connection/src/pool/pool.js b/packages/bolt-connection/src/pool/pool.js index d23f8f8cf..187393f5f 100644 --- a/packages/bolt-connection/src/pool/pool.js +++ b/packages/bolt-connection/src/pool/pool.js @@ -26,15 +26,18 @@ const { class Pool { /** - * @param {function(address: ServerAddress, function(address: ServerAddress, resource: object): Promise): Promise} create + * @param {function(acquisitionContext: object, address: ServerAddress, function(address: ServerAddress, resource: object): Promise): Promise} create * an allocation function that creates a promise with a new resource. It's given an address for which to * allocate the connection and a function that will return the resource to the pool if invoked, which is * meant to be called on .dispose or .close or whatever mechanism the resource uses to finalize. + * @param {function(acquisitionContext: object, resource: object): boolean} validateOnAcquire + * called at various times when an instance is acquired + * If this returns false, the resource will be evicted + * @param {function(resource: object): boolean} validateOnRelease + * called at various times when an instance is released + * If this returns false, the resource will be evicted * @param {function(resource: object): Promise} destroy * called with the resource when it is evicted from this pool - * @param {function(resource: object): boolean} validate - * called at various times (like when an instance is acquired and when it is returned. - * If this returns false, the resource will be evicted * @param {function(resource: object, observer: { onError }): void} installIdleObserver * called when the resource is released back to pool * @param {function(resource: object): void} removeIdleObserver @@ -43,9 +46,10 @@ class Pool { * @param {Logger} log the driver logger. */ constructor ({ - create = (address, release) => Promise.resolve(), + create = (acquisitionContext, address, release) => Promise.resolve(), destroy = conn => Promise.resolve(), - validate = conn => true, + validateOnAcquire = (acquisitionContext, conn) => true, + validateOnRelease = (conn) => true, installIdleObserver = (conn, observer) => {}, removeIdleObserver = conn => {}, config = PoolConfig.defaultConfig(), @@ -53,7 +57,8 @@ class Pool { } = {}) { this._create = create this._destroy = destroy - this._validate = validate + this._validateOnAcquire = validateOnAcquire + this._validateOnRelease = validateOnRelease this._installIdleObserver = installIdleObserver this._removeIdleObserver = removeIdleObserver this._maxSize = config.maxSize @@ -69,10 +74,13 @@ class Pool { /** * Acquire and idle resource fom the pool or create a new one. + * @param {object} acquisitionContext the acquisition context used for create and validateOnAcquire connection * @param {ServerAddress} address the address for which we're acquiring. + * @param {object} config the config + * @param {boolean} config.requireNew Indicate it requires a new resource * @return {Promise} resource that is ready to use. */ - acquire (address) { + acquire (acquisitionContext, address, config) { const key = address.asKey() // We're out of resources and will try to acquire later on when an existing resource is released. @@ -108,7 +116,7 @@ class Pool { } }, this._acquisitionTimeout) - request = new PendingRequest(key, resolve, reject, timeoutId, this._log) + request = new PendingRequest(key, acquisitionContext, config, resolve, reject, timeoutId, this._log) allRequests[key].push(request) this._processPendingAcquireRequests(address) }) @@ -123,6 +131,14 @@ class Pool { return this._purgeKey(address.asKey()) } + apply (address, resourceConsumer) { + const key = address.asKey() + + if (key in this._pools) { + this._pools[key].apply(resourceConsumer) + } + } + /** * Destroy all idle resources in this pool. * @returns {Promise} A promise that is resolved when the resources are purged @@ -185,29 +201,32 @@ class Pool { return pool } - async _acquire (address) { + async _acquire (acquisitionContext, address, requireNew) { if (this._closed) { throw newError('Pool is closed, it is no more able to serve requests.') } const key = address.asKey() const pool = this._getOrInitializePoolFor(key) - while (pool.length) { - const resource = pool.pop() + if (!requireNew) { + while (pool.length) { + const resource = pool.pop() - if (this._validate(resource)) { if (this._removeIdleObserver) { this._removeIdleObserver(resource) } - // idle resource is valid and can be acquired - resourceAcquired(key, this._activeResourceCounts) - if (this._log.isDebugEnabled()) { - this._log.debug(`${resource} acquired from the pool ${key}`) + if (await this._validateOnAcquire(acquisitionContext, resource)) { + // idle resource is valid and can be acquired + resourceAcquired(key, this._activeResourceCounts) + if (this._log.isDebugEnabled()) { + this._log.debug(`${resource} acquired from the pool ${key}`) + } + return { resource, pool } + } else { + pool.removeInUse(resource) + await this._destroy(resource) } - return { resource, pool } - } else { - await this._destroy(resource) } } @@ -228,9 +247,19 @@ class Pool { this._pendingCreates[key] = this._pendingCreates[key] + 1 let resource try { - // Invoke callback that creates actual connection - resource = await this._create(address, (address, resource) => this._release(address, resource, pool)) + const numConnections = this.activeResourceCount(address) + pool.length + if (numConnections >= this._maxSize && requireNew) { + const resource = pool.pop() + if (this._removeIdleObserver) { + this._removeIdleObserver(resource) + } + pool.removeInUse(resource) + await this._destroy(resource) + } + // Invoke callback that creates actual connection + resource = await this._create(acquisitionContext, address, (address, resource) => this._release(address, resource, pool)) + pool.pushInUse(resource) resourceAcquired(key, this._activeResourceCounts) if (this._log.isDebugEnabled()) { this._log.debug(`${resource} created for the pool ${key}`) @@ -246,12 +275,13 @@ class Pool { if (pool.isActive()) { // there exist idle connections for the given key - if (!this._validate(resource)) { + if (!await this._validateOnRelease(resource)) { if (this._log.isDebugEnabled()) { this._log.debug( `${resource} destroyed and can't be released to the pool ${key} because it is not functional` ) } + pool.removeInUse(resource) await this._destroy(resource) } else { if (this._installIdleObserver) { @@ -263,6 +293,7 @@ class Pool { const pool = this._pools[key] if (pool) { this._pools[key] = pool.filter(r => r !== resource) + pool.removeInUse(resource) } // let's not care about background clean-ups due to errors but just trigger the destroy // process for the resource, we especially catch any errors and ignore them to avoid @@ -283,6 +314,7 @@ class Pool { `${resource} destroyed and can't be released to the pool ${key} because pool has been purged` ) } + pool.removeInUse(resource) await this._destroy(resource) } resourceReleased(key, this._activeResourceCounts) @@ -314,7 +346,7 @@ class Pool { const pendingRequest = requests.shift() // pop a pending acquire request if (pendingRequest) { - this._acquire(address) + this._acquire(pendingRequest.context, address, pendingRequest.requireNew) .catch(error => { // failed to acquire/create a new connection to resolve the pending acquire request // propagate the error by failing the pending request @@ -378,13 +410,23 @@ function resourceReleased (key, activeResourceCounts) { } class PendingRequest { - constructor (key, resolve, reject, timeoutId, log) { + constructor (key, context, config, resolve, reject, timeoutId, log) { this._key = key + this._context = context this._resolve = resolve this._reject = reject this._timeoutId = timeoutId this._log = log this._completed = false + this._config = config || {} + } + + get context () { + return this._context + } + + get requireNew () { + return this._config.requireNew || false } isCompleted () { @@ -419,6 +461,7 @@ class SingleAddressPool { constructor () { this._active = true this._elements = [] + this._elementsInUse = new Set() } isActive () { @@ -427,6 +470,8 @@ class SingleAddressPool { close () { this._active = false + this._elements = [] + this._elementsInUse = new Set() } filter (predicate) { @@ -434,17 +479,33 @@ class SingleAddressPool { return this } + apply (resourceConsumer) { + this._elements.forEach(resourceConsumer) + this._elementsInUse.forEach(resourceConsumer) + } + get length () { return this._elements.length } pop () { - return this._elements.pop() + const element = this._elements.pop() + this._elementsInUse.add(element) + return element } push (element) { + this._elementsInUse.delete(element) return this._elements.push(element) } + + pushInUse (element) { + this._elementsInUse.add(element) + } + + removeInUse (element) { + this._elementsInUse.delete(element) + } } export default Pool diff --git a/packages/bolt-connection/test/connection-provider/authentication-provider.test.js b/packages/bolt-connection/test/connection-provider/authentication-provider.test.js new file mode 100644 index 000000000..99dc57600 --- /dev/null +++ b/packages/bolt-connection/test/connection-provider/authentication-provider.test.js @@ -0,0 +1,847 @@ +/** + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [http://neo4j.com] + * + * This file is part of Neo4j. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import { expirationBasedAuthTokenManager } from 'neo4j-driver-core' +import AuthenticationProvider from '../../src/connection-provider/authentication-provider' + +describe('AuthenticationProvider', () => { + const USER_AGENT = 'javascript-driver/5.5.0' + + describe('.authenticate()', () => { + describe('when called without an auth', () => { + describe('and first call', () => { + describe('and connection.authToken is different of new AuthToken', () => { + it('should refresh the auth token', async () => { + const authTokenProvider = jest.fn(() => toRenewableToken({ scheme: 'none' })) + const authenticationProvider = createAuthenticationProvider(authTokenProvider) + const connection = mockConnection() + + await authenticationProvider.authenticate({ connection }) + + expect(authTokenProvider).toHaveBeenCalledTimes(1) + }) + + it('should refresh authToken only once', async () => { + const authTokenProvider = jest.fn(() => new Promise((resolve) => { + setTimeout(() => { + resolve(toRenewableToken({ scheme: 'none' })) + }, 100) + })) + + const authenticationProvider = createAuthenticationProvider(authTokenProvider) + const connections = [mockConnection(), mockConnection()] + + await Promise.all(connections.map(connection => authenticationProvider.authenticate({ connection }))) + + expect(authTokenProvider).toHaveBeenCalledTimes(1) + }) + + it('should return the connection', async () => { + const authTokenProvider = jest.fn(() => toRenewableToken({ scheme: 'none' })) + const authenticationProvider = createAuthenticationProvider(authTokenProvider) + const connection = mockConnection() + + const resultedConnection = await authenticationProvider.authenticate({ connection }) + + expect(resultedConnection).toBe(connection) + }) + + it('should call connection.connect', async () => { + const authToken = { scheme: 'none' } + const authTokenProvider = jest.fn(() => toRenewableToken(authToken)) + const authenticationProvider = createAuthenticationProvider(authTokenProvider) + const connection = mockConnection() + + await authenticationProvider.authenticate({ connection }) + + expect(connection.connect).toHaveBeenCalledWith(USER_AGENT, authToken) + }) + + it('should throw errors happened during token refresh', async () => { + const error = new Error('ops') + const authTokenProvider = jest.fn(() => Promise.reject(error)) + const authenticationProvider = createAuthenticationProvider(authTokenProvider) + const connection = mockConnection() + + await expect(authenticationProvider.authenticate({ connection })).rejects.toThrow(error) + }) + + it('should throw errors happened during connection.connect', async () => { + const error = new Error('ops') + const authTokenProvider = jest.fn(() => toRenewableToken({ scheme: 'none' })) + const authenticationProvider = createAuthenticationProvider(authTokenProvider) + const connection = mockConnection({ + connect: () => Promise.reject(error) + }) + + await expect(authenticationProvider.authenticate({ connection })).rejects.toThrow(error) + }) + }) + + describe('when connection.authToken is equal to new AuthToken', () => { + it('should refresh the auth token', async () => { + const authTokenProvider = jest.fn(() => toRenewableToken({ scheme: 'none' })) + const authenticationProvider = createAuthenticationProvider(authTokenProvider) + const connection = mockConnection() + + await authenticationProvider.authenticate({ connection }) + + expect(authTokenProvider).toHaveBeenCalledTimes(1) + }) + + it('should refresh authToken only once', async () => { + const authTokenProvider = jest.fn(() => new Promise((resolve) => { + setTimeout(() => { + resolve(toRenewableToken({ scheme: 'none' })) + }, 100) + })) + + const authenticationProvider = createAuthenticationProvider(authTokenProvider) + const connections = [mockConnection(), mockConnection()] + + await Promise.all(connections.map(connection => authenticationProvider.authenticate({ connection }))) + + expect(authTokenProvider).toHaveBeenCalledTimes(1) + }) + + it('should return the connection', async () => { + const authTokenProvider = jest.fn(() => toRenewableToken({ scheme: 'none' })) + const authenticationProvider = createAuthenticationProvider(authTokenProvider) + const connection = mockConnection() + + const resultedConnection = await authenticationProvider.authenticate({ connection }) + + expect(resultedConnection).toBe(connection) + }) + + it('should not call connection.connect', async () => { + const authTokenProvider = jest.fn(() => toRenewableToken({ scheme: 'none' })) + const authenticationProvider = createAuthenticationProvider(authTokenProvider) + const connection = mockConnection({ authToken: { scheme: 'none' } }) + + await authenticationProvider.authenticate({ connection }) + + expect(connection.connect).toHaveBeenCalledTimes(0) + }) + + it('should throw errors happened during token refresh', async () => { + const error = new Error('ops') + const authTokenProvider = jest.fn(() => Promise.reject(error)) + const authenticationProvider = createAuthenticationProvider(authTokenProvider) + const connection = mockConnection() + + await expect(authenticationProvider.authenticate({ connection })).rejects.toThrow(error) + }) + }) + }) + + describe('and token has expired', () => { + describe('and connection.authToken is different of new AuthToken', () => { + it('should refresh the auth token', async () => { + const authTokenProvider = jest.fn(() => toRenewableToken({ scheme: 'none' })) + const authenticationProvider = createAuthenticationProvider(authTokenProvider, { + renewableAuthToken: toExpiredRenewableToken({ scheme: 'none' }) + }) + const connection = mockConnection() + + await authenticationProvider.authenticate({ connection }) + + expect(authTokenProvider).toHaveBeenCalledTimes(1) + }) + + it('should refresh authToken only once', async () => { + const authTokenProvider = jest.fn(() => new Promise((resolve) => { + setTimeout(() => { + resolve(toRenewableToken({ scheme: 'none' })) + }, 100) + })) + + const authenticationProvider = createAuthenticationProvider(authTokenProvider) + const connections = [mockConnection(), mockConnection()] + + await Promise.all(connections.map(connection => authenticationProvider.authenticate({ connection }))) + + expect(authTokenProvider).toHaveBeenCalledTimes(1) + }) + + it('should return the connection', async () => { + const authTokenProvider = jest.fn(() => toRenewableToken({ scheme: 'none' })) + const authenticationProvider = createAuthenticationProvider(authTokenProvider, { + renewableAuthToken: toExpiredRenewableToken({ scheme: 'none' }) + }) + const connection = mockConnection() + + const resultedConnection = await authenticationProvider.authenticate({ connection }) + + expect(resultedConnection).toBe(connection) + }) + + it('should call connection.connect', async () => { + const authToken = { scheme: 'none' } + const authTokenProvider = jest.fn(() => toRenewableToken(authToken)) + const authenticationProvider = createAuthenticationProvider(authTokenProvider, { + renewableAuthToken: toExpiredRenewableToken({ scheme: 'none' }) + }) + const connection = mockConnection() + + await authenticationProvider.authenticate({ connection }) + + expect(connection.connect).toHaveBeenCalledWith(USER_AGENT, authToken) + }) + + it('should throw errors happened during token refresh', async () => { + const error = new Error('ops') + const authTokenProvider = jest.fn(() => Promise.reject(error)) + const authenticationProvider = createAuthenticationProvider(authTokenProvider, { + renewableAuthToken: toExpiredRenewableToken({ scheme: 'none' }) + }) + const connection = mockConnection() + + await expect(authenticationProvider.authenticate({ connection })).rejects.toThrow(error) + }) + + it('should throw errors happened during connection.connect', async () => { + const error = new Error('ops') + const authTokenProvider = jest.fn(() => toRenewableToken({ scheme: 'none' })) + const authenticationProvider = createAuthenticationProvider(authTokenProvider, { + renewableAuthToken: toExpiredRenewableToken({ scheme: 'none' }) + }) + const connection = mockConnection({ + connect: () => Promise.reject(error) + }) + + await expect(authenticationProvider.authenticate({ connection })).rejects.toThrow(error) + }) + }) + + describe('when connection.authToken is equal to new AuthToken', () => { + it('should refresh the auth token', async () => { + const authTokenProvider = jest.fn(() => toRenewableToken({ scheme: 'none' })) + const authenticationProvider = createAuthenticationProvider(authTokenProvider, { + renewableAuthToken: toExpiredRenewableToken({ scheme: 'none' }) + }) + const connection = mockConnection() + + await authenticationProvider.authenticate({ connection }) + + expect(authTokenProvider).toHaveBeenCalledTimes(1) + }) + + it('should refresh authToken only once', async () => { + const authTokenProvider = jest.fn(() => new Promise((resolve) => { + setTimeout(() => { + resolve(toRenewableToken({ scheme: 'none' })) + }, 100) + })) + + const authenticationProvider = createAuthenticationProvider(authTokenProvider) + const connections = [mockConnection(), mockConnection()] + + await Promise.all(connections.map(connection => authenticationProvider.authenticate({ connection }))) + + expect(authTokenProvider).toHaveBeenCalledTimes(1) + }) + + it('should return the connection', async () => { + const authTokenProvider = jest.fn(() => toRenewableToken({ scheme: 'none' })) + const authenticationProvider = createAuthenticationProvider(authTokenProvider, { + renewableAuthToken: toExpiredRenewableToken({ scheme: 'none' }) + }) + const connection = mockConnection() + + const resultedConnection = await authenticationProvider.authenticate({ connection }) + + expect(resultedConnection).toBe(connection) + }) + + it('should not call connection.connect', async () => { + const authTokenProvider = jest.fn(() => toRenewableToken({ scheme: 'none' })) + const authenticationProvider = createAuthenticationProvider(authTokenProvider, { + renewableAuthToken: toExpiredRenewableToken({ scheme: 'none' }) + }) + const connection = mockConnection({ authToken: { scheme: 'none' } }) + + await authenticationProvider.authenticate({ connection }) + + expect(connection.connect).toHaveBeenCalledTimes(0) + }) + + it('should throw errors happened during token refresh', async () => { + const error = new Error('ops') + const authTokenProvider = jest.fn(() => Promise.reject(error)) + const authenticationProvider = createAuthenticationProvider(authTokenProvider, { + renewableAuthToken: toExpiredRenewableToken({ scheme: 'none' }) + }) + const connection = mockConnection() + + await expect(authenticationProvider.authenticate({ connection })).rejects.toThrow(error) + }) + }) + }) + + describe('and token is not expired', () => { + describe('and connection.authToken is different of provider.authToken', () => { + it('should not refresh the auth token', async () => { + const authTokenProvider = jest.fn(() => toRenewableToken({})) + const authenticationProvider = createAuthenticationProvider(authTokenProvider, { + renewableAuthToken: toRenewableToken({ scheme: 'none' }) + }) + const connection = mockConnection() + + await authenticationProvider.authenticate({ connection }) + + expect(authTokenProvider).toHaveBeenCalledTimes(0) + }) + + it('should return the connection', async () => { + const authTokenProvider = jest.fn(() => toRenewableToken({})) + const authenticationProvider = createAuthenticationProvider(authTokenProvider, { + renewableAuthToken: toRenewableToken({ scheme: 'none' }) + }) + const connection = mockConnection() + + const resultedConnection = await authenticationProvider.authenticate({ connection }) + + expect(resultedConnection).toBe(connection) + }) + + it('should call connection.connect', async () => { + const authToken = { scheme: 'none' } + const authTokenProvider = jest.fn(() => toRenewableToken({})) + const authenticationProvider = createAuthenticationProvider(authTokenProvider, { + renewableAuthToken: toRenewableToken({ scheme: 'none' }) + }) + const connection = mockConnection() + + await authenticationProvider.authenticate({ connection }) + + expect(connection.connect).toHaveBeenCalledWith(USER_AGENT, authToken) + }) + + it('should throw errors happened during connection.connect', async () => { + const error = new Error('ops') + const authTokenProvider = jest.fn(() => toRenewableToken({})) + const authenticationProvider = createAuthenticationProvider(authTokenProvider, { + renewableAuthToken: toRenewableToken({ scheme: 'none' }) + }) + const connection = mockConnection({ + connect: () => Promise.reject(error) + }) + + await expect(authenticationProvider.authenticate({ connection })).rejects.toThrow(error) + }) + }) + + describe('when connection.authToken is equal to provider.authToken', () => { + it('should not refresh the auth token', async () => { + const authTokenProvider = jest.fn(() => toRenewableToken({})) + const authenticationProvider = createAuthenticationProvider(authTokenProvider, { + renewableAuthToken: toRenewableToken({ scheme: 'none' }) + }) + const connection = mockConnection() + + await authenticationProvider.authenticate({ connection }) + + expect(authTokenProvider).toHaveBeenCalledTimes(0) + }) + + it('should return the connection', async () => { + const authTokenProvider = jest.fn(() => toRenewableToken({})) + const authenticationProvider = createAuthenticationProvider(authTokenProvider, { + renewableAuthToken: toRenewableToken({ scheme: 'none' }) + }) + const connection = mockConnection() + + const resultedConnection = await authenticationProvider.authenticate({ connection }) + + expect(resultedConnection).toBe(connection) + }) + + it('should not call connection.connect', async () => { + const authTokenProvider = jest.fn(() => toRenewableToken({})) + const authenticationProvider = createAuthenticationProvider(authTokenProvider, { + renewableAuthToken: toRenewableToken({ scheme: 'none' }) + }) + const connection = mockConnection({ authToken: { scheme: 'none' } }) + + await authenticationProvider.authenticate({ connection }) + + expect(connection.connect).toHaveBeenCalledTimes(0) + }) + }) + }) + }) + + describe.each([ + ['and first call', createAuthenticationProvider], + ['and token has expired', (authTokenProvider) => createAuthenticationProvider(authTokenProvider, { + renewableAuthToken: toExpiredRenewableToken({ scheme: 'none', credentials: 'token expired' }) + })], + ['and toke is not expired', (authTokenProvider) => createAuthenticationProvider(authTokenProvider, { + renewableAuthToken: toExpiredRenewableToken({ scheme: 'none' }) + })] + ])('when called with an auth and %s', (_, createAuthenticationProvider) => { + describe.each([false, true])('and connection is not authenticated (supportsReAuth=%s)', (supportsReAuth) => { + it('should call connection connect with the supplied auth', async () => { + const auth = { scheme: 'bearer', credentials: 'my token' } + const authTokenProvider = jest.fn(() => toRenewableToken({})) + const authenticationProvider = createAuthenticationProvider(authTokenProvider) + const connection = mockConnection({ supportsReAuth }) + + await authenticationProvider.authenticate({ connection, auth }) + + expect(connection.connect).toHaveBeenCalledWith(USER_AGENT, auth, false) + }) + + it('should return the connection', async () => { + const auth = { scheme: 'bearer', credentials: 'my token' } + const authTokenProvider = jest.fn(() => toRenewableToken({})) + const authenticationProvider = createAuthenticationProvider(authTokenProvider) + const connection = mockConnection({ supportsReAuth }) + + await expect(authenticationProvider.authenticate({ connection, auth })).resolves.toBe(connection) + }) + + it('should not refresh the token', async () => { + const auth = { scheme: 'bearer', credentials: 'my token' } + const authTokenProvider = jest.fn(() => toRenewableToken({})) + const authenticationProvider = createAuthenticationProvider(authTokenProvider) + const connection = mockConnection({ supportsReAuth }) + + await authenticationProvider.authenticate({ connection, auth }) + + expect(authTokenProvider).not.toHaveBeenCalled() + }) + + it('should throws if connection fails', async () => { + const error = new Error('nope') + const auth = { scheme: 'bearer', credentials: 'my token' } + const authTokenProvider = jest.fn(() => toRenewableToken({})) + const authenticationProvider = createAuthenticationProvider(authTokenProvider) + const connection = mockConnection({ + supportsReAuth, + connect: jest.fn(() => Promise.reject(error)) + }) + + await expect(authenticationProvider.authenticate({ connection, auth })).rejects.toThrow(error) + }) + }) + + describe.each([false, true])('and connection is authenticated with same token (supportsReAuth=%s)', (supportsReAuth) => { + it('should not call connection connect with the supplied auth', async () => { + const auth = { scheme: 'bearer', credentials: 'my token' } + const authTokenProvider = jest.fn(() => toRenewableToken({})) + const authenticationProvider = createAuthenticationProvider(authTokenProvider) + const connection = mockConnection({ supportsReAuth, authToken: { ...auth } }) + + await authenticationProvider.authenticate({ connection, auth }) + + expect(connection.connect).not.toHaveBeenCalledWith(USER_AGENT, auth) + }) + + it('should not call connection connect with the supplied auth and skipReAuth=true', async () => { + const auth = { scheme: 'bearer', credentials: 'my token' } + const authTokenProvider = jest.fn(() => toRenewableToken({})) + const authenticationProvider = createAuthenticationProvider(authTokenProvider) + const connection = mockConnection({ supportsReAuth, authToken: { ...auth } }) + + await authenticationProvider.authenticate({ connection, auth, skipReAuth: true }) + + expect(connection.connect).not.toHaveBeenCalledWith(USER_AGENT, auth) + }) + + if (supportsReAuth) { + it('should call connection connect with the supplied auth if forceReAuth=true', async () => { + const auth = { scheme: 'bearer', credentials: 'my token' } + const authTokenProvider = jest.fn(() => toRenewableToken({})) + const authenticationProvider = createAuthenticationProvider(authTokenProvider) + const connection = mockConnection({ supportsReAuth, authToken: { scheme: 'bearer', credentials: 'other' } }) + + await authenticationProvider.authenticate({ connection, auth, forceReAuth: true }) + + expect(connection.connect).toHaveBeenCalledWith(USER_AGENT, auth, false) + }) + } else { + it('should not call connection connect with the supplied auth if forceReAuth=true', async () => { + const auth = { scheme: 'bearer', credentials: 'my token' } + const authTokenProvider = jest.fn(() => toRenewableToken({})) + const authenticationProvider = createAuthenticationProvider(authTokenProvider) + const connection = mockConnection({ supportsReAuth, authToken: { ...auth } }) + + await authenticationProvider.authenticate({ connection, auth, forceReAuth: true }) + + expect(connection.connect).not.toHaveBeenCalledWith(USER_AGENT, auth) + }) + } + + it('should return the connection', async () => { + const auth = { scheme: 'bearer', credentials: 'my token' } + const authTokenProvider = jest.fn(() => toRenewableToken({})) + const authenticationProvider = createAuthenticationProvider(authTokenProvider) + const connection = mockConnection({ supportsReAuth, authToken: { ...auth } }) + + await expect(authenticationProvider.authenticate({ connection, auth })).resolves.toBe(connection) + }) + + it('should not refresh the token', async () => { + const auth = { scheme: 'bearer', credentials: 'my token' } + const authTokenProvider = jest.fn(() => toRenewableToken({})) + const authenticationProvider = createAuthenticationProvider(authTokenProvider) + const connection = mockConnection({ supportsReAuth, authToken: { ...auth } }) + + await authenticationProvider.authenticate({ connection, auth }) + + expect(authTokenProvider).not.toHaveBeenCalled() + }) + }) + + describe.each([true])('and connection is authenticated with different token (supportsReAuth=%s)', (supportsReAuth) => { + it('should call connection connect with the supplied auth', async () => { + const auth = { scheme: 'bearer', credentials: 'my token' } + const authTokenProvider = jest.fn(() => toRenewableToken({})) + const authenticationProvider = createAuthenticationProvider(authTokenProvider) + const connection = mockConnection({ supportsReAuth, authToken: { scheme: 'bearer', credentials: 'other' } }) + + await authenticationProvider.authenticate({ connection, auth }) + + expect(connection.connect).toHaveBeenCalledWith(USER_AGENT, auth, false) + }) + + it('should return the connection', async () => { + const auth = { scheme: 'bearer', credentials: 'my token' } + const authTokenProvider = jest.fn(() => toRenewableToken({})) + const authenticationProvider = createAuthenticationProvider(authTokenProvider) + const connection = mockConnection({ supportsReAuth, authToken: { scheme: 'bearer', credentials: 'other' } }) + + await expect(authenticationProvider.authenticate({ connection, auth })).resolves.toBe(connection) + }) + + it('should not refresh the token', async () => { + const auth = { scheme: 'bearer', credentials: 'my token' } + const authTokenProvider = jest.fn(() => toRenewableToken({})) + const authenticationProvider = createAuthenticationProvider(authTokenProvider) + const connection = mockConnection({ supportsReAuth, authToken: { scheme: 'bearer', credentials: 'other' } }) + + await authenticationProvider.authenticate({ connection, auth }) + + expect(authTokenProvider).not.toHaveBeenCalled() + }) + + it('should throws if connection fails', async () => { + const error = new Error('nope') + const auth = { scheme: 'bearer', credentials: 'my token' } + const authTokenProvider = jest.fn(() => toRenewableToken({})) + const authenticationProvider = createAuthenticationProvider(authTokenProvider) + const connection = mockConnection({ + supportsReAuth, + connect: jest.fn(() => Promise.reject(error)), + authToken: { scheme: 'bearer', credentials: 'other' } + }) + + await expect(authenticationProvider.authenticate({ connection, auth })).rejects.toThrow(error) + }) + + it.each([ + [true, true], + [false, false], + [undefined, false], + [null, false] + ])('should redirect `waitReAuth=%s` as `%s` to the connection.connect()', async (waitReAuth, expectedWaitForReAuth) => { + const auth = { scheme: 'bearer', credentials: 'my token' } + const authTokenProvider = jest.fn(() => toRenewableToken({})) + const authenticationProvider = createAuthenticationProvider(authTokenProvider) + const connection = mockConnection({ supportsReAuth, authToken: { scheme: 'bearer', credentials: 'other' } }) + + await authenticationProvider.authenticate({ connection, auth, waitReAuth }) + + expect(connection.connect).toHaveBeenCalledWith(USER_AGENT, auth, expectedWaitForReAuth) + }) + + it('should not call connect when skipReAuth=true', async () => { + const auth = { scheme: 'bearer', credentials: 'my token' } + const authTokenProvider = jest.fn(() => toRenewableToken({})) + const authenticationProvider = createAuthenticationProvider(authTokenProvider) + const connection = mockConnection({ supportsReAuth, authToken: { scheme: 'bearer', credentials: 'other' } }) + + await authenticationProvider.authenticate({ connection, auth, skipReAuth: true }) + + expect(connection.connect).not.toBeCalled() + }) + }) + + describe.each([false])('and connection is authenticated with different token (supportsReAuth=%s)', (supportsReAuth) => { + it('should not call connection connect with the supplied auth', async () => { + const auth = { scheme: 'bearer', credentials: 'my token' } + const authTokenProvider = jest.fn(() => toRenewableToken({})) + const authenticationProvider = createAuthenticationProvider(authTokenProvider) + const connection = mockConnection({ supportsReAuth, authToken: { ...auth, credentials: 'other' } }) + + await authenticationProvider.authenticate({ connection, auth }) + + expect(connection.connect).not.toHaveBeenCalledWith(USER_AGENT, auth) + }) + + it('should not call connection connect with the supplied auth and forceReAuth=true', async () => { + const auth = { scheme: 'bearer', credentials: 'my token' } + const authTokenProvider = jest.fn(() => toRenewableToken({})) + const authenticationProvider = createAuthenticationProvider(authTokenProvider) + const connection = mockConnection({ supportsReAuth, authToken: { ...auth, credentials: 'other' } }) + + await authenticationProvider.authenticate({ connection, auth, forceReAuth: true }) + + expect(connection.connect).not.toHaveBeenCalledWith(USER_AGENT, auth) + }) + + it('should return the connection', async () => { + const auth = { scheme: 'bearer', credentials: 'my token' } + const authTokenProvider = jest.fn(() => toRenewableToken({})) + const authenticationProvider = createAuthenticationProvider(authTokenProvider) + const connection = mockConnection({ supportsReAuth, authToken: { ...auth, credentials: 'other' } }) + + await expect(authenticationProvider.authenticate({ connection, auth })).resolves.toBe(connection) + }) + + it('should not refresh the token', async () => { + const auth = { scheme: 'bearer', credentials: 'my token' } + const authTokenProvider = jest.fn(() => toRenewableToken({})) + const authenticationProvider = createAuthenticationProvider(authTokenProvider) + const connection = mockConnection({ supportsReAuth, authToken: { ...auth, credentials: 'other' } }) + + await authenticationProvider.authenticate({ connection, auth }) + + expect(authTokenProvider).not.toHaveBeenCalled() + }) + }) + }) + }) + + describe('.handleError()', () => { + it.each( + shouldNotScheduleRefreshScenarios() + )('should not schedule a refresh when %s', (_, createScenario) => { + const { + connection, + code, + authTokenProvider, + authenticationProvider + } = createScenario() + + authenticationProvider.handleError({ code, connection }) + + expect(authTokenProvider).not.toHaveBeenCalled() + }) + + it.each( + errorCodeTriggerRefreshAuth() + )('should schedule refresh when auth are the same, valid error code (%s) and no refresh schedule', async (code) => { + const authToken = { scheme: 'bearer', credentials: 'token' } + const newTokenPromiseState = {} + const newTokenPromise = new Promise((resolve) => { newTokenPromiseState.resolve = resolve }) + const authTokenProvider = jest.fn(() => newTokenPromise) + const renewableAuthToken = toRenewableToken(authToken) + const connection = mockConnection({ + authToken: { ...authToken } + }) + const authenticationProvider = createAuthenticationProvider(authTokenProvider, { + renewableAuthToken + }) + + authenticationProvider.handleError({ code, connection }) + + expect(authTokenProvider).toHaveBeenCalled() + + // Test implementation details + expect(authenticationProvider._authTokenManager._currentAuthData).toEqual(undefined) + + const newRenewableToken = toRenewableToken({ scheme: 'bearer', credentials: 'token2' }) + newTokenPromiseState.resolve(newRenewableToken) + + await newTokenPromise + + expect(authenticationProvider._authTokenManager._currentAuthData).toBe(newRenewableToken) + }) + + function shouldNotScheduleRefreshScenarios () { + return [ + ...nonValidCodesScenarios(), + ...validCodesWithDifferentAuthScenarios(), + ...nonValidCodesWithDifferentAuthScenarios(), + ...validCodesWithSameAuthButWithRescheduleInPlaceScenarios() + ] + + function nonValidCodesScenarios () { + return polyfillFlatMap([ + 'Neo.ClientError.Security.AuthorizationExpired', + 'Neo.ClientError.General.ForbiddenOnReadOnlyDatabase', + 'Neo.Made.Up.Error' + ]).flatMap(code => [ + [ + `connection and provider has same auth token and error code does not trigger re-fresh (code=${code})`, () => { + const authToken = { scheme: 'bearer', credentials: 'token' } + const authTokenProvider = jest.fn(() => {}) + return { + connection: mockConnection({ + authToken: { ...authToken } + }), + code, + authTokenProvider, + authenticationProvider: createAuthenticationProvider(authTokenProvider, { + renewableAuthToken: toRenewableToken(authToken) + }) + } + } + ] + ]) + } + + function validCodesWithDifferentAuthScenarios () { + return errorCodeTriggerRefreshAuth().flatMap(code => [ + [ + `connection and provider has different auth token and error code does trigger re-fresh (code=${code})`, + () => { + const authToken = { scheme: 'bearer', credentials: 'token' } + const authTokenProvider = jest.fn(() => {}) + return { + connection: mockConnection({ + authToken: { ...authToken, credentials: 'token2' } + }), + code, + authTokenProvider, + authenticationProvider: createAuthenticationProvider(authTokenProvider, { + renewableAuthToken: toRenewableToken(authToken) + }) + } + } + ] + + ]) + } + + function nonValidCodesWithDifferentAuthScenarios () { + return polyfillFlatMap([ + 'Neo.ClientError.Security.AuthorizationExpired', + 'Neo.ClientError.General.ForbiddenOnReadOnlyDatabase', + 'Neo.Made.Up.Error' + ]).flatMap(code => [ + [ + `connection and provider has different auth token and error code does not trigger re-fresh (code=${code})`, + () => { + const authToken = { scheme: 'bearer', credentials: 'token' } + const authTokenProvider = jest.fn(() => {}) + return { + connection: mockConnection({ + authToken: { ...authToken, credentials: 'token2' } + }), + code, + authTokenProvider, + authenticationProvider: createAuthenticationProvider(authTokenProvider, { + renewableAuthToken: toRenewableToken(authToken) + }) + } + } + ] + + ]) + } + + function validCodesWithSameAuthButWithRescheduleInPlaceScenarios () { + return errorCodeTriggerRefreshAuth().flatMap(code => [ + [ + `connection and provider has same auth token and error code does trigger re-fresh (code=${code}), but refresh already schedule`, () => { + const authToken = { scheme: 'bearer', credentials: 'token' } + const authTokenProvider = jest.fn(() => {}) + return { + connection: mockConnection({ + authToken: { ...authToken } + }), + code, + authTokenProvider, + authenticationProvider: createAuthenticationProvider(authTokenProvider, { + renewableAuthToken: toRenewableToken(authToken), + refreshObserver: refreshObserverMock() + }) + } + } + ] + ]) + } + } + }) + + function createAuthenticationProvider (authTokenProvider, mocks) { + const authTokenManager = expirationBasedAuthTokenManager({ tokenProvider: authTokenProvider }) + const provider = new AuthenticationProvider({ + authTokenManager, + userAgent: USER_AGENT + }) + + if (mocks) { + authTokenManager._currentAuthData = mocks.renewableAuthToken + authTokenManager._refreshObservable = mocks.refreshObserver + } + + return provider + } + + function mockConnection ({ connect, authToken, supportsReAuth } = {}) { + const connection = { + connect: connect || jest.fn(() => Promise.resolve(connection)), + authToken, + supportsReAuth + } + return connection + } + + function toRenewableToken (token, expiration) { + return { + token, + expiration + } + } + + function toExpiredRenewableToken (token) { + return toRenewableToken(token, new Date(new Date().getTime() - 1)) + } + + function errorCodeTriggerRefreshAuth () { + return polyfillFlatMap([ + 'Neo.ClientError.Security.Unauthorized', + 'Neo.ClientError.Security.TokenExpired' + ]) + } + + function polyfillFlatMap (arr) { + /** Polyfill flatMap for Node10 tests */ + if (!arr.flatMap) { + arr.flatMap = function (callback, thisArg) { + return arr.concat.apply([], arr.map(callback, thisArg)) + } + } + return arr + } + + function refreshObserverMock () { + const subscribers = [] + + return { + subscribe: (sub) => subscribers.push(sub), + onCompleted: (data) => subscribers.forEach(sub => sub.onCompleted(data)), + onError: (e) => subscribers.forEach(sub => sub.onError(e)) + } + } +}) diff --git a/packages/bolt-connection/test/connection-provider/connection-provider-direct.test.js b/packages/bolt-connection/test/connection-provider/connection-provider-direct.test.js index 2b659de63..83ff25abf 100644 --- a/packages/bolt-connection/test/connection-provider/connection-provider-direct.test.js +++ b/packages/bolt-connection/test/connection-provider/connection-provider-direct.test.js @@ -20,7 +20,9 @@ import DirectConnectionProvider from '../../src/connection-provider/connection-provider-direct' import { Pool } from '../../src/pool' import { Connection, DelegateConnection } from '../../src/connection' -import { internal, newError, ServerInfo } from 'neo4j-driver-core' +import { internal, newError, ServerInfo, staticAuthTokenManager, expirationBasedAuthTokenManager } from 'neo4j-driver-core' +import AuthenticationProvider from '../../src/connection-provider/authentication-provider' +import { functional } from '../../src/lang' const { serverAddress: { ServerAddress }, @@ -56,10 +58,11 @@ describe('#unit DirectConnectionProvider', () => { expect(conn instanceof DelegateConnection).toBeTruthy() }) - it('should purge connections for address when AuthorizationExpired happens', async () => { + it('should close connection and remove authToken for address when AuthorizationExpired happens', async () => { const address = ServerAddress.fromUrl('localhost:123') const pool = newPool() jest.spyOn(pool, 'purge') + jest.spyOn(pool, 'apply') const connectionProvider = newDirectConnectionProvider(address, pool) const conn = await connectionProvider.acquireConnection({ @@ -72,12 +75,48 @@ describe('#unit DirectConnectionProvider', () => { 'Neo.ClientError.Security.AuthorizationExpired' ) + jest.spyOn(conn, 'close') + + conn.handleAndTransformError(error, address) + + expect(conn.close).toHaveBeenCalled() + expect(pool.purge).not.toHaveBeenCalledWith(address) + expect(pool.apply).toHaveBeenCalledTimes(1) + + const [[calledAddress, appliedFunction]] = pool.apply.mock.calls + + expect(calledAddress).toBe(address) + + const fakeConn = { authToken: 'some token' } + + appliedFunction(fakeConn) + expect(fakeConn.authToken).toBe(null) + pool.apply(address, conn => expect(conn.authToken).toBe(null)) + }) + + it('should call authenticationAuthProvider.handleError when AuthorizationExpired happens', async () => { + const address = ServerAddress.fromUrl('localhost:123') + const pool = newPool() + const connectionProvider = newDirectConnectionProvider(address, pool) + + const handleError = jest.spyOn(connectionProvider._authenticationProvider, 'handleError') + + const conn = await connectionProvider.acquireConnection({ + accessMode: 'READ', + database: '' + }) + + const error = newError( + 'Message', + 'Neo.ClientError.Security.AuthorizationExpired' + ) + conn.handleAndTransformError(error, address) - expect(pool.purge).toHaveBeenCalledWith(address) + expect(handleError).toBeCalledWith({ connection: conn, code: 'Neo.ClientError.Security.AuthorizationExpired' }) }) - it('should purge not change error when AuthorizationExpired happens', async () => { + it('should not change error when AuthorizationExpired happens', async () => { const address = ServerAddress.fromUrl('localhost:123') const pool = newPool() const connectionProvider = newDirectConnectionProvider(address, pool) @@ -98,10 +137,11 @@ describe('#unit DirectConnectionProvider', () => { }) }) -it('should purge connections for address when TokenExpired happens', async () => { +it('should close the connection when TokenExpired happens', async () => { const address = ServerAddress.fromUrl('localhost:123') const pool = newPool() jest.spyOn(pool, 'purge') + jest.spyOn(pool, 'apply') const connectionProvider = newDirectConnectionProvider(address, pool) const conn = await connectionProvider.acquireConnection({ @@ -114,9 +154,14 @@ it('should purge connections for address when TokenExpired happens', async () => 'Neo.ClientError.Security.TokenExpired' ) + jest.spyOn(conn, 'close') + conn.handleAndTransformError(error, address) - expect(pool.purge).toHaveBeenCalledWith(address) + expect(conn.close).toHaveBeenCalled() + expect(pool.purge).not.toHaveBeenCalledWith(address) + expect(pool.apply).toHaveBeenCalledTimes(0) + pool.apply(address, conn => expect(conn.authToken).toBeDefined()) }) it('should not change error when TokenExpired happens', async () => { @@ -139,6 +184,378 @@ it('should not change error when TokenExpired happens', async () => { expect(error).toBe(expectedError) }) +it('should call authenticationAuthProvider.handleError when TokenExpired happens', async () => { + const address = ServerAddress.fromUrl('localhost:123') + const pool = newPool() + const connectionProvider = newDirectConnectionProvider(address, pool) + + const handleError = jest.spyOn(connectionProvider._authenticationProvider, 'handleError') + + const conn = await connectionProvider.acquireConnection({ + accessMode: 'READ', + database: '' + }) + + const error = newError( + 'Message', + 'Neo.ClientError.Security.TokenExpired' + ) + + conn.handleAndTransformError(error, address) + + expect(handleError).toBeCalledWith({ connection: conn, code: 'Neo.ClientError.Security.TokenExpired' }) +}) + +it('should change error to retriable when error when TokenExpired happens and staticAuthTokenManager is not being used', async () => { + const address = ServerAddress.fromUrl('localhost:123') + const pool = newPool() + const connectionProvider = newDirectConnectionProvider(address, pool, expirationBasedAuthTokenManager({ tokenProvider: () => null })) + + const conn = await connectionProvider.acquireConnection({ + accessMode: 'READ', + database: '' + }) + + const expectedError = newError( + 'Message', + 'Neo.ClientError.Security.TokenExpired' + ) + + const error = conn.handleAndTransformError(expectedError, address) + + expect(error.retriable).toBe(true) +}) + +it('should not change error to retriable when error when TokenExpired happens and staticAuthTokenManager is being used', async () => { + const address = ServerAddress.fromUrl('localhost:123') + const pool = newPool() + const connectionProvider = newDirectConnectionProvider(address, pool, staticAuthTokenManager({ authToken: null })) + + const conn = await connectionProvider.acquireConnection({ + accessMode: 'READ', + database: '' + }) + + const expectedError = newError( + 'Message', + 'Neo.ClientError.Security.TokenExpired' + ) + + const error = conn.handleAndTransformError(expectedError, address) + + expect(error.retriable).toBe(false) +}) + +describe('constructor', () => { + describe('newPool', () => { + const server0 = ServerAddress.fromUrl('localhost:123') + const server01 = ServerAddress.fromUrl('localhost:1235') + + describe('param.create', () => { + it('should create connection', async () => { + const { create, createChannelConnectionHook, provider } = setup() + + const connection = await create({}, server0, undefined) + + expect(createChannelConnectionHook).toHaveBeenCalledWith(server0) + expect(provider._openConnections[connection.id]).toBe(connection) + await expect(createChannelConnectionHook.mock.results[0].value).resolves.toBe(connection) + }) + + it('should register the release function into the connection', async () => { + const { create } = setup() + const releaseResult = { property: 'some property' } + const release = jest.fn(() => releaseResult) + + const connection = await create({}, server0, release) + + const released = connection._release() + + expect(released).toBe(releaseResult) + expect(release).toHaveBeenCalledWith(server0, connection) + }) + + it.each([ + null, + undefined, + { scheme: 'bearer', credentials: 'token01' } + ])('should authenticate connection (auth = %o)', async (auth) => { + const { create, authenticationProviderHook } = setup() + + const connection = await create({ auth }, server0) + + expect(authenticationProviderHook.authenticate).toHaveBeenCalledWith({ + connection, + auth + }) + }) + + it('should handle create connection failures', async () => { + const error = newError('some error') + const createConnection = jest.fn(() => Promise.reject(error)) + const { create, authenticationProviderHook, provider } = setup({ createConnection }) + const openConnections = { ...provider._openConnections } + + await expect(create({}, server0)).rejects.toThrow(error) + + expect(authenticationProviderHook.authenticate).not.toHaveBeenCalled() + expect(provider._openConnections).toEqual(openConnections) + }) + + it.each([ + null, + undefined, + { scheme: 'bearer', credentials: 'token01' } + ])('should handle authentication failures (auth = %o)', async (auth) => { + const error = newError('some error') + const authenticationProvider = jest.fn(() => Promise.reject(error)) + const { create, authenticationProviderHook, createChannelConnectionHook, provider } = setup({ authenticationProvider }) + const openConnections = { ...provider._openConnections } + + await expect(create({ auth }, server0)).rejects.toThrow(error) + + const connection = await createChannelConnectionHook.mock.results[0].value + expect(authenticationProviderHook.authenticate).toHaveBeenCalledWith({ auth, connection }) + expect(provider._openConnections).toEqual(openConnections) + expect(connection._closed).toBe(true) + }) + }) + + describe('param.destroy', () => { + it('should close connection and unregister it', async () => { + const { create, destroy, provider } = setup() + const openConnections = { ...provider._openConnections } + const connection = await create({}, server0, undefined) + + await destroy(connection) + + expect(connection._closed).toBe(true) + expect(provider._openConnections).toEqual(openConnections) + }) + }) + + describe('param.validateOnAcquire', () => { + it.each([ + null, + undefined, + { scheme: 'bearer', credentials: 'token01' } + ])('should return true when connection is open and within the lifetime and authentication succeed (auth=%o)', async (auth) => { + const connection = new FakeConnection(server0) + connection.creationTimestamp = Date.now() + + const { validateOnAcquire, authenticationProviderHook } = setup() + + await expect(validateOnAcquire({ auth }, connection)).resolves.toBe(true) + + expect(authenticationProviderHook.authenticate).toHaveBeenCalledWith({ + connection, auth + }) + }) + + it.each([ + null, + undefined, + { scheme: 'bearer', credentials: 'token01' } + ])('should return true when connection is open and within the lifetime and authentication fails (auth=%o)', async (auth) => { + const connection = new FakeConnection(server0) + const error = newError('failed') + const authenticationProvider = jest.fn(() => Promise.reject(error)) + connection.creationTimestamp = Date.now() + + const { validateOnAcquire, authenticationProviderHook, log } = setup({ authenticationProvider }) + + await expect(validateOnAcquire({ auth }, connection)).resolves.toBe(false) + + expect(authenticationProviderHook.authenticate).toHaveBeenCalledWith({ + connection, auth + }) + + expect(log.debug).toHaveBeenCalledWith( + `The connection ${connection.id} is not valid because of an error ${error.code} '${error.message}'` + ) + }) + + it.each([ + true, + false + ])('should call authenticationProvider.authenticate with skipReAuth=%s', async (skipReAuth) => { + const connection = new FakeConnection(server0) + const auth = {} + connection.creationTimestamp = Date.now() + + const { validateOnAcquire, authenticationProviderHook } = setup() + + await expect(validateOnAcquire({ auth, skipReAuth }, connection)).resolves.toBe(true) + + expect(authenticationProviderHook.authenticate).toHaveBeenCalledWith({ + connection, auth, skipReAuth + }) + }) + + it('should return false when connection is closed and within the lifetime', async () => { + const connection = new FakeConnection(server0) + connection.creationTimestamp = Date.now() + await connection.close() + + const { validateOnAcquire, authenticationProviderHook } = setup() + + await expect(validateOnAcquire({}, connection)).resolves.toBe(false) + expect(authenticationProviderHook.authenticate).not.toHaveBeenCalled() + }) + + it('should return false when connection is open and out of the lifetime', async () => { + const connection = new FakeConnection(server0) + connection.creationTimestamp = Date.now() - 4000 + + const { validateOnAcquire, authenticationProviderHook } = setup({ maxConnectionLifetime: 3000 }) + + await expect(validateOnAcquire({}, connection)).resolves.toBe(false) + expect(authenticationProviderHook.authenticate).not.toHaveBeenCalled() + }) + + it('should return false when connection is closed and out of the lifetime', async () => { + const connection = new FakeConnection(server0) + await connection.close() + connection.creationTimestamp = Date.now() - 4000 + + const { validateOnAcquire, authenticationProviderHook } = setup({ maxConnectionLifetime: 3000 }) + + await expect(validateOnAcquire({}, connection)).resolves.toBe(false) + expect(authenticationProviderHook.authenticate).not.toHaveBeenCalled() + }) + }) + + describe('param.validateOnRelease', () => { + it('should return true when connection is open and within the lifetime', () => { + const connection = new FakeConnection(server0) + connection.creationTimestamp = Date.now() + + const { validateOnRelease } = setup() + + expect(validateOnRelease(connection)).toBe(true) + }) + + it('should return false when connection is closed and within the lifetime', async () => { + const connection = new FakeConnection(server0) + connection.creationTimestamp = Date.now() + await connection.close() + + const { validateOnRelease } = setup() + + expect(validateOnRelease(connection)).toBe(false) + }) + + it('should return false when connection is open and out of the lifetime', () => { + const connection = new FakeConnection(server0) + connection.creationTimestamp = Date.now() - 4000 + + const { validateOnRelease } = setup({ maxConnectionLifetime: 3000 }) + + expect(validateOnRelease(connection)).toBe(false) + }) + + it('should return false when connection is closed and out of the lifetime', async () => { + const connection = new FakeConnection(server0) + await connection.close() + connection.creationTimestamp = Date.now() - 4000 + + const { validateOnRelease } = setup({ maxConnectionLifetime: 3000 }) + + expect(validateOnRelease(connection)).toBe(false) + }) + + it('should return false when connection is sticky', async () => { + const connection = new FakeConnection(server0) + connection._sticky = true + + const { validateOnRelease } = setup() + + expect(validateOnRelease(connection)).toBe(false) + }) + }) + + function setup ({ createConnection, authenticationProvider, maxConnectionLifetime } = {}) { + const newPool = jest.fn((...args) => new Pool(...args)) + const log = new Logger('debug', () => undefined) + jest.spyOn(log, 'debug') + const createChannelConnectionHook = createConnection || jest.fn(async (address) => new FakeConnection(address)) + const authenticationProviderHook = new AuthenticationProvider({ }) + jest.spyOn(authenticationProviderHook, 'authenticate') + .mockImplementation(authenticationProvider || jest.fn(({ connection }) => Promise.resolve(connection))) + const provider = new DirectConnectionProvider({ + newPool, + config: { + maxConnectionLifetime: maxConnectionLifetime || 1000 + }, + address: server01, + log + }) + provider._createChannelConnection = createChannelConnectionHook + provider._authenticationProvider = authenticationProviderHook + return { + provider, + ...newPool.mock.calls[0][0], + createChannelConnectionHook, + authenticationProviderHook, + log + } + } + }) +}) + +describe('user-switching', () => { + describe('should not allow sticky connections', () => { + describe('when does not supports re-auth', () => { + it.each([ + ['new connection', { other: 'auth' }, { other: 'auth' }, true], + ['old connection', { some: 'auth' }, { other: 'token' }, false] + ])('should raise and error when try switch user on acquire [%s]', async (_, connAuth, acquireAuth, isStickyConn) => { + const address = ServerAddress.fromUrl('localhost:123') + const pool = newPool() + const connection = new FakeConnection(address, () => {}, undefined, connAuth) + const poolAcquire = jest.spyOn(pool, 'acquire').mockResolvedValue(connection) + const connectionProvider = newDirectConnectionProvider(address, pool) + + const error = await connectionProvider + .acquireConnection({ + accessMode: 'READ', + database: '', + auth: acquireAuth + }) + .catch(functional.identity) + + expect(error).toEqual(newError('Driver is connected to a database that does not support user switch.')) + expect(poolAcquire).toHaveBeenCalledWith({ auth: acquireAuth }, address) + expect(connection._release).toHaveBeenCalled() + expect(connection._sticky).toEqual(isStickyConn) + }) + }) + + describe('when supports re-auth', () => { + const connAuth = { some: 'auth' } + const acquireAuth = connAuth + + it('should return connection when try switch user on acquire', async () => { + const address = ServerAddress.fromUrl('localhost:123') + const pool = newPool() + const connection = new FakeConnection(address, () => {}, undefined, connAuth, { supportsReAuth: true }) + jest.spyOn(pool, 'acquire').mockResolvedValue(connection) + const connectionProvider = newDirectConnectionProvider(address, pool) + + const acquiredConnection = await connectionProvider + .acquireConnection({ + accessMode: 'READ', + database: '', + auth: acquireAuth + }) + + expect(acquiredConnection).toBe(connection) + expect(acquiredConnection._sticky).toEqual(false) + }) + }) + }) +}) + describe('.verifyConnectivityAndGetServerInfo()', () => { describe('when connection is available in the pool', () => { it('should return the server info', async () => { @@ -313,38 +730,56 @@ describe('.verifyConnectivityAndGetServerInfo()', () => { }) }) -function newDirectConnectionProvider (address, pool) { +function newDirectConnectionProvider (address, pool, authTokenManager) { const connectionProvider = new DirectConnectionProvider({ id: 0, config: {}, log: Logger.noOp(), - address: address + address: address, + authTokenManager }) connectionProvider._connectionPool = pool return connectionProvider } function newPool ({ create, config } = {}) { + const auth = { scheme: 'bearer', credentials: 'my token' } const _create = (address, release) => { if (create) { return create(address, release) } - return new FakeConnection(address, release) + return new FakeConnection(address, release, undefined, auth) } return new Pool({ config, - create: (address, release) => + create: (_, address, release) => Promise.resolve(_create(address, release)) }) } class FakeConnection extends Connection { - constructor (address, release, server) { + constructor (address, release, server, auth, { supportsReAuth } = {}) { super(null) this._address = address this._release = jest.fn(() => release(address, this)) this._server = server + this._authToken = auth + this._closed = false + this._id = 1 + this._supportsReAuth = supportsReAuth || false + } + + get id () { + return this._id + } + + get authToken () { + return this._authToken + } + + set authToken (authToken) { + this._authToken = authToken } get address () { @@ -354,4 +789,16 @@ class FakeConnection extends Connection { get server () { return this._server } + + get supportsReAuth () { + return this._supportsReAuth + } + + async close () { + this._closed = true + } + + isOpen () { + return !this._closed + } } diff --git a/packages/bolt-connection/test/connection-provider/connection-provider-routing.test.js b/packages/bolt-connection/test/connection-provider/connection-provider-routing.test.js index 8e742cf3c..2ec3512a1 100644 --- a/packages/bolt-connection/test/connection-provider/connection-provider-routing.test.js +++ b/packages/bolt-connection/test/connection-provider/connection-provider-routing.test.js @@ -24,13 +24,17 @@ import { Integer, int, internal, - ServerInfo + ServerInfo, + staticAuthTokenManager, + expirationBasedAuthTokenManager } from 'neo4j-driver-core' import { RoutingTable } from '../../src/rediscovery/' import { Pool } from '../../src/pool' import SimpleHostNameResolver from '../../src/channel/browser/browser-host-name-resolver' import RoutingConnectionProvider from '../../src/connection-provider/connection-provider-routing' import { DelegateConnection, Connection } from '../../src/connection' +import AuthenticationProvider from '../../src/connection-provider/authentication-provider' +import { functional } from '../../src/lang' const { serverAddress: { ServerAddress }, @@ -48,7 +52,8 @@ describe.each([ 4.2, 4.3, 4.4, - 5.0 + 5.0, + 5.1 ])('#unit RoutingConnectionProvider (PROTOCOL_VERSION=%d)', (PROTOCOL_VERSION) => { const server0 = ServerAddress.fromUrl('server0') const server1 = ServerAddress.fromUrl('server1') @@ -129,9 +134,9 @@ describe.each([ it('purges connections when address is forgotten', () => { const pool = newPool() - pool.acquire(server1) - pool.acquire(server3) - pool.acquire(server5) + pool.acquire({}, server1) + pool.acquire({}, server3) + pool.acquire({}, server5) expectPoolToContain(pool, [server1, server3, server5]) const connectionProvider = newRoutingConnectionProvider( @@ -1434,10 +1439,11 @@ describe.each([ }) }, 10000) - it.each(usersDataSet)('should purge connections for address when AuthorizationExpired happens [user=%s]', async (user) => { + it.each(usersDataSet)('should close connection and erase authToken for connection with address when AuthorizationExpired happens [user=%s]', async (user) => { const pool = newPool() jest.spyOn(pool, 'purge') + jest.spyOn(pool, 'apply') const connectionProvider = newRoutingConnectionProvider( [ @@ -1468,11 +1474,78 @@ describe.each([ impersonatedUser: user }) + jest.spyOn(server2Connection, 'close') + jest.spyOn(server3Connection, 'close') + server3Connection.handleAndTransformError(error, server3) server2Connection.handleAndTransformError(error, server2) - expect(pool.purge).toHaveBeenCalledWith(server3) - expect(pool.purge).toHaveBeenCalledWith(server2) + expect(server2Connection.close).toHaveBeenCalled() + expect(server3Connection.close).toHaveBeenCalled() + + expect(pool.purge).not.toHaveBeenCalledWith(server3) + expect(pool.purge).not.toHaveBeenCalledWith(server2) + + expect(pool.apply).toHaveBeenCalledTimes(2) + + const [[ + calledAddress1, appliedFunction1 + ], + [ + calledAddress2, appliedFunction2 + ]] = pool.apply.mock.calls + + expect(calledAddress1).toBe(server3) + let fakeConn = { authToken: 'some token' } + appliedFunction1(fakeConn) + expect(fakeConn.authToken).toBe(null) + pool.apply(server3, conn => expect(conn.authToken).toBe(null)) + + expect(calledAddress2).toBe(server2) + fakeConn = { authToken: 'some token' } + appliedFunction2(fakeConn) + expect(fakeConn.authToken).toBe(null) + pool.apply(server2, conn => expect(conn.authToken).toBe(null)) + }) + + it.each(usersDataSet)('should call authenticationAuthProvider.handleError when AuthorizationExpired happens [user=%s]', async (user) => { + const pool = newPool() + const connectionProvider = newRoutingConnectionProvider( + [ + newRoutingTable( + null, + [server1, server2], + [server3, server2], + [server2, server4] + ) + ], + pool + ) + + const handleError = jest.spyOn(connectionProvider._authenticationProvider, 'handleError') + + const error = newError( + 'Message', + 'Neo.ClientError.Security.AuthorizationExpired' + ) + + const server2Connection = await connectionProvider.acquireConnection({ + accessMode: 'WRITE', + database: null, + impersonatedUser: user + }) + + const server3Connection = await connectionProvider.acquireConnection({ + accessMode: 'READ', + database: null, + impersonatedUser: user + }) + + server3Connection.handleAndTransformError(error, server3) + server2Connection.handleAndTransformError(error, server2) + + expect(handleError).toBeCalledWith({ connection: server3Connection, code: 'Neo.ClientError.Security.AuthorizationExpired' }) + expect(handleError).toBeCalledWith({ connection: server2Connection, code: 'Neo.ClientError.Security.AuthorizationExpired' }) }) it.each(usersDataSet)('should purge not change error when AuthorizationExpired happens [user=%s]', async (user) => { @@ -1511,11 +1584,53 @@ describe.each([ expect(error).toBe(expectedError) }) - it.each(usersDataSet)('should purge connections for address when TokenExpired happens [user=%s]', async (user) => { + it.each(usersDataSet)('should not purge connections for address when TokenExpired happens [user=%s]', async (user) => { const pool = newPool() jest.spyOn(pool, 'purge') + jest.spyOn(pool, 'apply') + + const connectionProvider = newRoutingConnectionProvider( + [ + newRoutingTable( + null, + [server1, server2], + [server3, server2], + [server2, server4] + ) + ], + pool + ) + const error = newError( + 'Message', + 'Neo.ClientError.Security.TokenExpired' + ) + + const server2Connection = await connectionProvider.acquireConnection({ + accessMode: 'WRITE', + database: null, + impersonatedUser: user + }) + + const server3Connection = await connectionProvider.acquireConnection({ + accessMode: 'READ', + database: null, + impersonatedUser: user + }) + + server3Connection.handleAndTransformError(error, server3) + server2Connection.handleAndTransformError(error, server2) + + expect(pool.purge).not.toHaveBeenCalledWith(server3) + expect(pool.purge).not.toHaveBeenCalledWith(server2) + expect(pool.apply).toHaveBeenCalledTimes(0) + pool.apply(server2, conn => expect(conn.authToken).toBeDefined()) + pool.apply(server3, conn => expect(conn.authToken).toBeDefined()) + }) + + it.each(usersDataSet)('should call authenticationAuthProvider.handleError when TokenExpired happens [user=%s]', async (user) => { + const pool = newPool() const connectionProvider = newRoutingConnectionProvider( [ newRoutingTable( @@ -1528,6 +1643,8 @@ describe.each([ pool ) + const handleError = jest.spyOn(connectionProvider._authenticationProvider, 'handleError') + const error = newError( 'Message', 'Neo.ClientError.Security.TokenExpired' @@ -1548,8 +1665,8 @@ describe.each([ server3Connection.handleAndTransformError(error, server3) server2Connection.handleAndTransformError(error, server2) - expect(pool.purge).toHaveBeenCalledWith(server3) - expect(pool.purge).toHaveBeenCalledWith(server2) + expect(handleError).toBeCalledWith({ connection: server3Connection, code: 'Neo.ClientError.Security.TokenExpired' }) + expect(handleError).toBeCalledWith({ connection: server2Connection, code: 'Neo.ClientError.Security.TokenExpired' }) }) it.each(usersDataSet)('should not change error when TokenExpired happens [user=%s]', async (user) => { @@ -1588,6 +1705,84 @@ describe.each([ expect(error).toBe(expectedError) }) + it.each(usersDataSet)('should change error to retriable when error when TokenExpired happens and staticAuthTokenManager is not being used [user=%s]', async (user) => { + const pool = newPool() + const connectionProvider = newRoutingConnectionProvider( + [ + newRoutingTable( + null, + [server1, server2], + [server3, server2], + [server2, server4] + ) + ], + pool + ) + connectionProvider._authTokenManager = expirationBasedAuthTokenManager({ tokenProvider: () => null }) + + const error = newError( + 'Message', + 'Neo.ClientError.Security.TokenExpired' + ) + + const server2Connection = await connectionProvider.acquireConnection({ + accessMode: 'WRITE', + database: null, + impersonatedUser: user + }) + + const server3Connection = await connectionProvider.acquireConnection({ + accessMode: 'READ', + database: null, + impersonatedUser: user + }) + + const error1 = server3Connection.handleAndTransformError(error, server3) + const error2 = server2Connection.handleAndTransformError(error, server2) + + expect(error1.retriable).toBe(true) + expect(error2.retriable).toBe(true) + }) + + it.each(usersDataSet)('should not change error to retriable when error when TokenExpired happens and staticAuthTokenManager is being used [user=%s]', async (user) => { + const pool = newPool() + const connectionProvider = newRoutingConnectionProvider( + [ + newRoutingTable( + null, + [server1, server2], + [server3, server2], + [server2, server4] + ) + ], + pool + ) + connectionProvider._authTokenManager = staticAuthTokenManager({ authToken: null }) + + const error = newError( + 'Message', + 'Neo.ClientError.Security.TokenExpired' + ) + + const server2Connection = await connectionProvider.acquireConnection({ + accessMode: 'WRITE', + database: null, + impersonatedUser: user + }) + + const server3Connection = await connectionProvider.acquireConnection({ + accessMode: 'READ', + database: null, + impersonatedUser: user + }) + + const error1 = server3Connection.handleAndTransformError(error, server3) + const error2 = server2Connection.handleAndTransformError(error, server2) + + expect(error1.retriable).toBe(false) + expect(error2.retriable).toBe(false) + }) + it.each(usersDataSet)('should use resolved seed router after accepting table with no writers [user=%s]', (user, done) => { const routingTable1 = newRoutingTable( null, @@ -1710,7 +1905,7 @@ describe.each([ expect(capturedError.code).toBe(SERVICE_UNAVAILABLE) expect(capturedError.message).toBe( 'Server at server-non-existing-seed-router:7687 can\'t ' + - 'perform routing. Make sure you are connecting to a causal cluster' + 'perform routing. Make sure you are connecting to a causal cluster' ) // Error should be the cause of the given capturedError expect(capturedError).toEqual(newError(capturedError.message, capturedError.code, error)) @@ -1784,7 +1979,7 @@ describe.each([ expect(conn2.address).toBe(serverA) }, 10000) - it.each(usersDataSet)('should purge connections for address when AuthorizationExpired happens [user=%s]', async (user) => { + it.each(usersDataSet)('should not purge connections for address when AuthorizationExpired happens [user=%s]', async (user) => { const pool = newPool() jest.spyOn(pool, 'purge') @@ -1822,11 +2017,11 @@ describe.each([ serverAConnection.handleAndTransformError(error, serverA) server2Connection.handleAndTransformError(error, server2) - expect(pool.purge).toHaveBeenCalledWith(serverA) - expect(pool.purge).toHaveBeenCalledWith(server2) + expect(pool.purge).not.toHaveBeenCalledWith(serverA) + expect(pool.purge).not.toHaveBeenCalledWith(server2) }) - it.each(usersDataSet)('should purge not change error when AuthorizationExpired happens [user=%s]', async (user) => { + it.each(usersDataSet)('should not purge change error when AuthorizationExpired happens [user=%s]', async (user) => { const pool = newPool() const connectionProvider = newRoutingConnectionProvider( @@ -1861,7 +2056,7 @@ describe.each([ expect(error).toBe(expectedError) }) - it.each(usersDataSet)('should purge connections for address when TokenExpired happens [user=%s]', async (user) => { + it.each(usersDataSet)('should not purge connections for address when TokenExpired happens [user=%s]', async (user) => { const pool = newPool() jest.spyOn(pool, 'purge') @@ -1899,8 +2094,8 @@ describe.each([ serverAConnection.handleAndTransformError(error, serverA) server2Connection.handleAndTransformError(error, server2) - expect(pool.purge).toHaveBeenCalledWith(serverA) - expect(pool.purge).toHaveBeenCalledWith(server2) + expect(pool.purge).not.toHaveBeenCalledWith(serverA) + expect(pool.purge).not.toHaveBeenCalledWith(server2) }) it.each(usersDataSet)('should not change error when TokenExpired happens [user=%s]', async (user) => { @@ -2589,7 +2784,7 @@ describe.each([ const targetServers = accessMode === WRITE ? routingTable.writers : routingTable.readers const address = targetServers[0] - expect(acquireSpy).toHaveBeenCalledWith(address) + expect(acquireSpy).toHaveBeenCalledWith({}, address) const connections = seenConnectionsPerAddress.get(address) @@ -2608,7 +2803,7 @@ describe.each([ const targetServers = accessMode === WRITE ? routingTable.writers : routingTable.readers const address = targetServers[0] - expect(acquireSpy).toHaveBeenCalledWith(address) + expect(acquireSpy).toHaveBeenCalledWith({}, address) const connections = seenConnectionsPerAddress.get(address) @@ -2628,7 +2823,7 @@ describe.each([ const targetServers = accessMode === WRITE ? routingTable.readers : routingTable.writers for (const address of targetServers) { - expect(acquireSpy).not.toHaveBeenCalledWith(address) + expect(acquireSpy).not.toHaveBeenCalledWith({}, address) expect(seenConnectionsPerAddress.get(address)).toBeUndefined() } }) @@ -2711,7 +2906,7 @@ describe.each([ } finally { const targetServers = accessMode === WRITE ? routingTable.writers : routingTable.readers for (const address of targetServers) { - expect(acquireSpy).toHaveBeenCalledWith(address) + expect(acquireSpy).toHaveBeenCalledWith({}, address) const connections = seenConnectionsPerAddress.get(address) @@ -2792,7 +2987,7 @@ describe.each([ }) }) - function setup ({ resetAndFlush, releaseMock, newConnection } = { }) { + function setup ({ resetAndFlush, releaseMock, newConnection } = {}) { const routingTable = newRoutingTable( database || null, [server1, server2], @@ -2934,6 +3129,567 @@ describe.each([ }) }) + describe('constructor', () => { + describe('newPool', () => { + describe('param.create', () => { + it('should create connection', async () => { + const { create, createChannelConnectionHook, provider } = setup() + + const connection = await create({}, server0, undefined) + + expect(createChannelConnectionHook).toHaveBeenCalledWith(server0) + expect(provider._openConnections[connection.id]).toBe(connection) + await expect(createChannelConnectionHook.mock.results[0].value).resolves.toBe(connection) + }) + + it('should register the release function into the connection', async () => { + const { create } = setup() + const releaseResult = { property: 'some property' } + const release = jest.fn(() => releaseResult) + + const connection = await create({}, server0, release) + + const released = connection._release() + + expect(released).toBe(releaseResult) + expect(release).toHaveBeenCalledWith(server0, connection) + }) + + it.each([ + null, + undefined, + { scheme: 'bearer', credentials: 'token01' } + ])('should authenticate connection (auth = %o)', async (auth) => { + const { create, authenticationProviderHook } = setup() + + const connection = await create({ auth }, server0) + + expect(authenticationProviderHook.authenticate).toHaveBeenCalledWith({ + connection, + auth + }) + }) + + it('should handle create connection failures', async () => { + const error = newError('some error') + const createConnection = jest.fn(() => Promise.reject(error)) + const { create, authenticationProviderHook, provider } = setup({ createConnection }) + const openConnections = { ...provider._openConnections } + + await expect(create({}, server0)).rejects.toThrow(error) + + expect(authenticationProviderHook.authenticate).not.toHaveBeenCalled() + expect(provider._openConnections).toEqual(openConnections) + }) + + it.each([ + null, + undefined, + { scheme: 'bearer', credentials: 'token01' } + ])('should handle authentication failures (auth = %o)', async (auth) => { + const error = newError('some error') + const authenticationProvider = jest.fn(() => Promise.reject(error)) + const { create, authenticationProviderHook, createChannelConnectionHook, provider } = setup({ authenticationProvider }) + const openConnections = { ...provider._openConnections } + + await expect(create({ auth }, server0)).rejects.toThrow(error) + + const connection = await createChannelConnectionHook.mock.results[0].value + expect(authenticationProviderHook.authenticate).toHaveBeenCalledWith({ auth, connection }) + expect(provider._openConnections).toEqual(openConnections) + expect(connection._closed).toBe(true) + }) + }) + + describe('param.destroy', () => { + it('should close connection and unregister it', async () => { + const { create, destroy, provider } = setup() + const openConnections = { ...provider._openConnections } + const connection = await create({}, server0, undefined) + + await destroy(connection) + + expect(connection._closed).toBe(true) + expect(provider._openConnections).toEqual(openConnections) + }) + }) + + describe('param.validateOnAcquire', () => { + it.each([ + null, + undefined, + { scheme: 'bearer', credentials: 'token01' } + ])('should return true when connection is open and within the lifetime and authentication succeed (auth=%o)', async (auth) => { + const connection = new FakeConnection(server0) + connection.creationTimestamp = Date.now() + + const { validateOnAcquire, authenticationProviderHook } = setup() + + await expect(validateOnAcquire({ auth }, connection)).resolves.toBe(true) + + expect(authenticationProviderHook.authenticate).toHaveBeenCalledWith({ + connection, auth + }) + }) + + it.each([ + null, + undefined, + { scheme: 'bearer', credentials: 'token01' } + ])('should return true when connection is open and within the lifetime and authentication fails (auth=%o)', async (auth) => { + const connection = new FakeConnection(server0) + const error = newError('failed') + const authenticationProvider = jest.fn(() => Promise.reject(error)) + connection.creationTimestamp = Date.now() + + const { validateOnAcquire, authenticationProviderHook, log } = setup({ authenticationProvider }) + + await expect(validateOnAcquire({ auth }, connection)).resolves.toBe(false) + + expect(authenticationProviderHook.authenticate).toHaveBeenCalledWith({ + connection, auth + }) + + expect(log.debug).toHaveBeenCalledWith( + `The connection ${connection.id} is not valid because of an error ${error.code} '${error.message}'` + ) + }) + + it.each([ + true, + false + ])('should call authenticationProvider.authenticate with skipReAuth=%s', async (skipReAuth) => { + const connection = new FakeConnection(server0) + const auth = {} + connection.creationTimestamp = Date.now() + + const { validateOnAcquire, authenticationProviderHook } = setup() + + await expect(validateOnAcquire({ auth, skipReAuth }, connection)).resolves.toBe(true) + + expect(authenticationProviderHook.authenticate).toHaveBeenCalledWith({ + connection, auth, skipReAuth + }) + }) + + it('should return false when connection is closed and within the lifetime', async () => { + const connection = new FakeConnection(server0) + connection.creationTimestamp = Date.now() + await connection.close() + + const { validateOnAcquire, authenticationProviderHook } = setup() + + await expect(validateOnAcquire({}, connection)).resolves.toBe(false) + expect(authenticationProviderHook.authenticate).not.toHaveBeenCalled() + }) + + it('should return false when connection is open and out of the lifetime', async () => { + const connection = new FakeConnection(server0) + connection.creationTimestamp = Date.now() - 4000 + + const { validateOnAcquire, authenticationProviderHook } = setup({ maxConnectionLifetime: 3000 }) + + await expect(validateOnAcquire({}, connection)).resolves.toBe(false) + expect(authenticationProviderHook.authenticate).not.toHaveBeenCalled() + }) + + it('should return false when connection is closed and out of the lifetime', async () => { + const connection = new FakeConnection(server0) + await connection.close() + connection.creationTimestamp = Date.now() - 4000 + + const { validateOnAcquire, authenticationProviderHook } = setup({ maxConnectionLifetime: 3000 }) + + await expect(validateOnAcquire({}, connection)).resolves.toBe(false) + expect(authenticationProviderHook.authenticate).not.toHaveBeenCalled() + }) + }) + + describe('param.validateOnRelease', () => { + it('should return true when connection is open and within the lifetime', () => { + const connection = new FakeConnection(server0) + connection.creationTimestamp = Date.now() + + const { validateOnRelease } = setup() + + expect(validateOnRelease(connection)).toBe(true) + }) + + it('should return false when connection is closed and within the lifetime', async () => { + const connection = new FakeConnection(server0) + connection.creationTimestamp = Date.now() + await connection.close() + + const { validateOnRelease } = setup() + + expect(validateOnRelease(connection)).toBe(false) + }) + + it('should return false when connection is open and out of the lifetime', () => { + const connection = new FakeConnection(server0) + connection.creationTimestamp = Date.now() - 4000 + + const { validateOnRelease } = setup({ maxConnectionLifetime: 3000 }) + + expect(validateOnRelease(connection)).toBe(false) + }) + + it('should return false when connection is closed and out of the lifetime', async () => { + const connection = new FakeConnection(server0) + await connection.close() + connection.creationTimestamp = Date.now() - 4000 + + const { validateOnRelease } = setup({ maxConnectionLifetime: 3000 }) + + expect(validateOnRelease(connection)).toBe(false) + }) + + it('should return false when connection is sticky', async () => { + const connection = new FakeConnection(server0) + connection._sticky = true + + const { validateOnRelease } = setup() + + expect(validateOnRelease(connection)).toBe(false) + }) + }) + + function setup ({ createConnection, authenticationProvider, maxConnectionLifetime } = {}) { + const newPool = jest.fn((...args) => new Pool(...args)) + const log = new Logger('debug', () => undefined) + jest.spyOn(log, 'debug') + const createChannelConnectionHook = createConnection || jest.fn(async (address) => new FakeConnection(address)) + const authenticationProviderHook = new AuthenticationProvider({}) + jest.spyOn(authenticationProviderHook, 'authenticate') + .mockImplementation(authenticationProvider || jest.fn(({ connection }) => Promise.resolve(connection))) + const provider = new RoutingConnectionProvider({ + newPool, + config: { + maxConnectionLifetime: maxConnectionLifetime || 1000 + }, + address: server01, + log + }) + provider._createChannelConnection = createChannelConnectionHook + provider._authenticationProvider = authenticationProviderHook + return { + provider, + ...newPool.mock.calls[0][0], + createChannelConnectionHook, + authenticationProviderHook, + log + } + } + }) + }) + + describe('user-switching', () => { + describe('when does not support re-auth', () => { + describe.each([ + ['new connection', { other: 'auth' }, { other: 'auth' }, true], + ['old connection', { some: 'auth' }, { other: 'token' }, false] + ])('%s', (_, connAuth, acquireAuth, isStickyConn) => { + it('should raise and error when try switch user on acquire', async () => { + const address = ServerAddress.fromUrl('localhost:123') + const pool = newPool() + const connection = new FakeConnection(address, () => { }, undefined, PROTOCOL_VERSION, null, connAuth) + const poolAcquire = jest.spyOn(pool, 'acquire').mockResolvedValue(connection) + const connectionProvider = newRoutingConnectionProvider([ + newRoutingTable( + null, + [server1, server2], + [server3, server2], + [server2, server4] + ) + ], + pool + ) + const auth = acquireAuth + + const error = await connectionProvider + .acquireConnection({ + accessMode: 'READ', + database: '', + auth + }) + .catch(functional.identity) + + expect(error).toEqual(newError('Driver is connected to a database that does not support user switch.')) + expect(poolAcquire).toHaveBeenCalledWith({ auth }, server3) + expect(connection._release).toHaveBeenCalled() + expect(connection._sticky).toEqual(isStickyConn) + }) + + it('should raise and error when try switch user on acquire [expired rt]', async () => { + const address = ServerAddress.fromUrl('localhost:123') + const pool = newPool() + const connection = new FakeConnection(address, () => { }, undefined, PROTOCOL_VERSION, null, connAuth) + const poolAcquire = jest.spyOn(pool, 'acquire').mockResolvedValue(connection) + const connectionProvider = newRoutingConnectionProvider([ + newRoutingTable( + 'dba', + [server1, server2], + [server3, server2], + [server2, server4], + int(0) // expired + ) + ], + pool, + { + dba: newRoutingTable( + 'dba', + [server1, server2], + [server3, server2], + [server2, server4] + ) + } + ) + connectionProvider._useSeedRouter = false + + const auth = acquireAuth + + const error = await connectionProvider + .acquireConnection({ + accessMode: 'READ', + database: 'dba', + auth + }) + .catch(functional.identity) + + expect(error).toEqual(newError('Driver is connected to a database that does not support user switch.')) + expect(poolAcquire).toHaveBeenCalledWith({ auth }, server1) + expect(connection._release).toHaveBeenCalled() + expect(connection._sticky).toEqual(isStickyConn) + }) + + it('should raise and error when try switch user on acquire [expired rt and userSeedRouter]', async () => { + const address = ServerAddress.fromUrl('localhost:123') + const pool = newPool() + const connection = new FakeConnection(address, () => { }, undefined, PROTOCOL_VERSION, null, connAuth) + const poolAcquire = jest.spyOn(pool, 'acquire').mockResolvedValue(connection) + const connectionProvider = newRoutingConnectionProviderWithSeedRouter( + server0, + [server0], + [ + newRoutingTable( + 'dba', + [server1, server2], + [server3, server2], + [server2, server4], + int(0) // expired + ) + ], + { + dba: newRoutingTable( + 'dba', + [server1, server2], + [server3, server2], + [server2, server4] + ) + }, + pool + ) + + const auth = acquireAuth + + const error = await connectionProvider + .acquireConnection({ + accessMode: 'READ', + database: 'dba', + auth + }) + .catch(functional.identity) + + expect(error).toEqual(newError('Driver is connected to a database that does not support user switch.')) + expect(poolAcquire).toHaveBeenCalledWith({ auth }, server0) + expect(connection._release).toHaveBeenCalled() + expect(connection._sticky).toEqual(isStickyConn) + }) + + it('should raise and error when try switch user on acquire [firstCall and userSeedRouter]', async () => { + const address = ServerAddress.fromUrl('localhost:123') + const pool = newPool() + const connection = new FakeConnection(address, () => { }, undefined, PROTOCOL_VERSION, null, connAuth) + const poolAcquire = jest.spyOn(pool, 'acquire').mockResolvedValue(connection) + const connectionProvider = newRoutingConnectionProviderWithSeedRouter( + server0, + [server0], + [], + {}, + pool + ) + + const auth = acquireAuth + + const error = await connectionProvider + .acquireConnection({ + accessMode: 'READ', + database: 'dba', + auth + }) + .catch(functional.identity) + + expect(error).toEqual(newError('Driver is connected to a database that does not support user switch.')) + expect(poolAcquire).toHaveBeenCalledWith({ auth }, server0) + expect(connection._release).toHaveBeenCalled() + expect(connection._sticky).toEqual(isStickyConn) + }) + }) + }) + + describe('when does it support re-auth', () => { + const connAuth = { myAuth: 'auth' } + const acquireAuth = connAuth + + it('should return connection when try switch user on acquire', async () => { + const address = ServerAddress.fromUrl('localhost:123') + const pool = newPool() + const connection = new FakeConnection(address, () => { }, undefined, PROTOCOL_VERSION, null, connAuth, { supportsReAuth: true }) + jest.spyOn(pool, 'acquire').mockResolvedValue(connection) + const connectionProvider = newRoutingConnectionProvider([ + newRoutingTable( + null, + [server1, server2], + [server3, server2], + [server2, server4] + ) + ], + pool + ) + const auth = acquireAuth + + const acquiredConnection = await connectionProvider + .acquireConnection({ + accessMode: 'READ', + database: '', + auth + }) + + expect(acquiredConnection).toBe(connection) + expect(acquiredConnection._sticky).toEqual(false) + }) + + it('should return connection when try switch user on acquire [expired rt]', async () => { + const address = ServerAddress.fromUrl('localhost:123') + const pool = newPool() + const connection = new FakeConnection(address, () => { }, undefined, PROTOCOL_VERSION, null, connAuth, { supportsReAuth: true }) + jest.spyOn(pool, 'acquire').mockResolvedValue(connection) + const connectionProvider = newRoutingConnectionProvider([ + newRoutingTable( + 'dba', + [server1, server2], + [server3, server2], + [server2, server4], + int(0) // expired + ) + ], + pool, + { + dba: { + [server1.asHostPort()]: newRoutingTable( + 'dba', + [server1, server2], + [server3, server2], + [server2, server4] + ) + } + } + ) + connectionProvider._useSeedRouter = false + + const auth = acquireAuth + + const acquiredConnection = await connectionProvider + .acquireConnection({ + accessMode: 'READ', + database: 'dba', + auth + }) + + expect(acquiredConnection).toBe(connection) + expect(acquiredConnection._sticky).toEqual(false) + }) + + it('should return connection when try switch user on acquire [expired rt and userSeedRouter]', async () => { + const address = ServerAddress.fromUrl('localhost:123') + const pool = newPool() + const connection = new FakeConnection(address, () => { }, undefined, PROTOCOL_VERSION, null, connAuth, { supportsReAuth: true }) + jest.spyOn(pool, 'acquire').mockResolvedValue(connection) + const connectionProvider = newRoutingConnectionProviderWithSeedRouter( + server0, + [server0], + [ + newRoutingTable( + 'dba', + [server1, server2], + [server3, server2], + [server2, server4], + int(0) // expired + ) + ], + { + dba: { + [server0.asHostPort()]: newRoutingTable( + 'dba', + [server1, server2], + [server3, server2], + [server2, server4] + ) + } + }, + pool + ) + + const auth = acquireAuth + + const acquiredConnection = await connectionProvider + .acquireConnection({ + accessMode: 'READ', + database: 'dba', + auth + }) + + expect(acquiredConnection).toBe(connection) + expect(acquiredConnection._sticky).toEqual(false) + }) + + it('should not delegated connection when try switch user on acquire [firstCall and userSeedRouter]', async () => { + const address = ServerAddress.fromUrl('localhost:123') + const pool = newPool() + const connection = new FakeConnection(address, () => { }, undefined, PROTOCOL_VERSION, null, connAuth, { supportsReAuth: true }) + jest.spyOn(pool, 'acquire').mockResolvedValue(connection) + const connectionProvider = newRoutingConnectionProviderWithSeedRouter( + server0, + [server0], + [], + { + dba: { + [server0.asHostPort()]: newRoutingTable( + 'dba', + [server1, server2], + [server3, server2], + [server2, server4] + ) + } + }, + pool + ) + + const auth = acquireAuth + + const acquiredConnection = await connectionProvider + .acquireConnection({ + accessMode: 'READ', + database: 'dba', + auth + }) + + expect(acquiredConnection).toBe(connection) + expect(acquiredConnection._sticky).toEqual(false) + }) + }) + }) + function newPool ({ create, config } = {}) { const _create = (address, release) => { if (create) { @@ -2943,11 +3699,14 @@ describe.each([ return Promise.reject(e) } } - return Promise.resolve(new FakeConnection(address, release, 'version', PROTOCOL_VERSION)) + return Promise.resolve(new FakeConnection(address, release, 'version', PROTOCOL_VERSION, undefined, { + scheme: 'bearer', + credentials: 'token' + })) } return new Pool({ config, - create: (address, release) => _create(address, release) + create: (_, address, release) => _create(address, release) }) } @@ -3093,7 +3852,7 @@ function expectPoolToNotContain (pool, addresses) { } class FakeConnection extends Connection { - constructor (address, release, version, protocolVersion, server) { + constructor (address, release, version, protocolVersion, server, authToken, { supportsReAuth } = {}) { super(null) this._address = address @@ -3103,6 +3862,22 @@ class FakeConnection extends Connection { this._release = jest.fn(() => release(address, this)) this.resetAndFlush = jest.fn(() => Promise.resolve()) this._server = server + this._authToken = authToken + this._id = 1 + this._closed = false + this._supportsReAuth = supportsReAuth || false + } + + get id () { + return this._id + } + + get authToken () { + return this._authToken + } + + set authToken (authToken) { + this._authToken = authToken } get address () { @@ -3117,6 +3892,18 @@ class FakeConnection extends Connection { return this._server } + get supportsReAuth () { + return this._supportsReAuth + } + + async close () { + this._closed = true + } + + isOpen () { + return !this._closed + } + protocol () { return { version: this._protocolVersion, @@ -3127,7 +3914,7 @@ class FakeConnection extends Connection { class FakeRediscovery { constructor (routerToRoutingTable, error) { - this._routerToRoutingTable = routerToRoutingTable + this._routerToRoutingTable = routerToRoutingTable || {} this._error = error } diff --git a/packages/bolt-connection/test/connection/connection-channel.test.js b/packages/bolt-connection/test/connection/connection-channel.test.js index 18200eea5..446747647 100644 --- a/packages/bolt-connection/test/connection/connection-channel.test.js +++ b/packages/bolt-connection/test/connection/connection-channel.test.js @@ -149,6 +149,182 @@ describe('ChannelConnection', () => { expect(call.notificationFilter).toBe(notificationFilter) } ) + it('should set the AuthToken in the context', async () => { + const authToken = { + scheme: 'none' + } + const protocol = { + initialize: jest.fn(observer => observer.onComplete({})) + } + const protocolSupplier = () => protocol + const connection = spyOnConnectionChannel({ protocolSupplier }) + + await connection.connect('userAgent', authToken) + + expect(connection.authToken).toEqual(authToken) + }) + + describe('re-auth', () => { + describe('when protocol support re-auth', () => { + it('should call logoff and login', async () => { + const authToken = { + scheme: 'none' + } + const protocol = { + initialize: jest.fn(observer => observer.onComplete({})), + logoff: jest.fn(() => undefined), + logon: jest.fn(() => undefined), + initialized: true, + supportsReAuth: true + } + + const protocolSupplier = () => protocol + const connection = spyOnConnectionChannel({ protocolSupplier }) + + await connection.connect('userAgent', authToken) + + expect(protocol.initialize).not.toHaveBeenCalled() + expect(protocol.logoff).toHaveBeenCalledWith() + expect(protocol.logon).toHaveBeenCalledWith({ authToken, flush: true }) + expect(connection.authToken).toEqual(authToken) + }) + + describe('when waitReAuth=true', () => { + it('should wait for login complete', async () => { + const authToken = { + scheme: 'none' + } + + const onCompleteObservers = [] + const protocol = { + initialize: jest.fn(observer => observer.onComplete({})), + logoff: jest.fn(() => undefined), + logon: jest.fn(({ onComplete }) => onCompleteObservers.push(onComplete)), + initialized: true, + supportsReAuth: true + } + + const protocolSupplier = () => protocol + const connection = spyOnConnectionChannel({ protocolSupplier }) + + const connectionPromise = connection.connect('userAgent', authToken, true) + + const isPending = await Promise.race([connectionPromise, Promise.resolve(true)]) + expect(isPending).toEqual(true) + expect(onCompleteObservers.length).toEqual(1) + + expect(protocol.initialize).not.toHaveBeenCalled() + expect(protocol.logoff).toHaveBeenCalled() + expect(protocol.logon).toHaveBeenCalledWith(expect.objectContaining({ + authToken, + flush: true + })) + + expect(connection.authToken).toEqual(authToken) + + onCompleteObservers.forEach(onComplete => onComplete({})) + await expect(connectionPromise).resolves.toBe(connection) + }) + + it('should notify logoff errors', async () => { + const authToken = { + scheme: 'none' + } + + const onLogoffErrors = [] + const protocol = { + initialize: jest.fn(observer => observer.onComplete({})), + logoff: jest.fn(({ onError }) => onLogoffErrors.push(onError)), + logon: jest.fn(() => undefined), + initialized: true, + supportsReAuth: true + } + + const protocolSupplier = () => protocol + const connection = spyOnConnectionChannel({ protocolSupplier }) + + const connectionPromise = connection.connect('userAgent', authToken, true) + + const isPending = await Promise.race([connectionPromise, Promise.resolve(true)]) + expect(isPending).toEqual(true) + expect(onLogoffErrors.length).toEqual(1) + + expect(protocol.initialize).not.toHaveBeenCalled() + expect(protocol.logoff).toHaveBeenCalled() + expect(protocol.logon).toHaveBeenCalledWith(expect.objectContaining({ + authToken, + flush: true + })) + + const expectedError = newError('something wrong is not right.') + onLogoffErrors.forEach(onError => onError(expectedError)) + await expect(connectionPromise).rejects.toBe(expectedError) + }) + + it('should notify logon errors', async () => { + const authToken = { + scheme: 'none' + } + + const onLoginErrors = [] + const protocol = { + initialize: jest.fn(observer => observer.onComplete({})), + logoff: jest.fn(() => undefined), + logon: jest.fn(({ onError }) => onLoginErrors.push(onError)), + initialized: true, + supportsReAuth: true + } + + const protocolSupplier = () => protocol + const connection = spyOnConnectionChannel({ protocolSupplier }) + + const connectionPromise = connection.connect('userAgent', authToken, true) + + const isPending = await Promise.race([connectionPromise, Promise.resolve(true)]) + expect(isPending).toEqual(true) + expect(onLoginErrors.length).toEqual(1) + + expect(protocol.initialize).not.toHaveBeenCalled() + expect(protocol.logoff).toHaveBeenCalled() + expect(protocol.logon).toHaveBeenCalledWith(expect.objectContaining({ + authToken, + flush: true + })) + + const expectedError = newError('something wrong is not right.') + onLoginErrors.forEach(onError => onError(expectedError)) + await expect(connectionPromise).rejects.toBe(expectedError) + }) + }) + }) + + describe('when protocol does not support re-auth', () => { + it('should throw connection does not support re-auth', async () => { + const authToken = { + scheme: 'none' + } + const protocol = { + initialize: jest.fn(observer => observer.onComplete({})), + logoff: jest.fn(() => undefined), + logon: jest.fn(() => undefined), + initialized: true, + supportsReAuth: false + } + + const protocolSupplier = () => protocol + const connection = spyOnConnectionChannel({ protocolSupplier }) + + await expect(connection.connect('userAgent', authToken)).rejects.toThrow( + newError('Connection does not support re-auth') + ) + + expect(protocol.initialize).not.toHaveBeenCalled() + expect(protocol.logoff).not.toHaveBeenCalled() + expect(protocol.logon).not.toHaveBeenCalled() + expect(connection.authToken).toEqual(null) + }) + }) + }) }) describe('._handleFatalError()', () => { diff --git a/packages/bolt-connection/test/pool/pool.test.js b/packages/bolt-connection/test/pool/pool.test.js index 2037cd8b5..b224f9512 100644 --- a/packages/bolt-connection/test/pool/pool.test.js +++ b/packages/bolt-connection/test/pool/pool.test.js @@ -33,13 +33,13 @@ describe('#unit Pool', () => { let counter = 0 const address = ServerAddress.fromUrl('bolt://localhost:7687') const pool = new Pool({ - create: (server, release) => + create: (_, server, release) => Promise.resolve(new Resource(server, counter++, release)) }) // When - const r0 = await pool.acquire(address) - const r1 = await pool.acquire(address) + const r0 = await pool.acquire({}, address) + const r1 = await pool.acquire({}, address) // Then expect(r0.id).toBe(0) @@ -52,15 +52,15 @@ describe('#unit Pool', () => { let counter = 0 const address = ServerAddress.fromUrl('bolt://localhost:7687') const pool = new Pool({ - create: (server, release) => + create: (_, server, release) => Promise.resolve(new Resource(server, counter++, release)) }) // When - const r0 = await pool.acquire(address) + const r0 = await pool.acquire({}, address) await r0.close() - const r1 = await pool.acquire(address) + const r1 = await pool.acquire({}, address) // Then expect(r0.id).toBe(0) @@ -74,17 +74,17 @@ describe('#unit Pool', () => { const address1 = ServerAddress.fromUrl('bolt://localhost:7687') const address2 = ServerAddress.fromUrl('bolt://localhost:7688') const pool = new Pool({ - create: (server, release) => + create: (_, server, release) => Promise.resolve(new Resource(server, counter++, release)) }) // When - const r0 = await pool.acquire(address1) - const r1 = await pool.acquire(address2) + const r0 = await pool.acquire({}, address1) + const r1 = await pool.acquire({}, address2) await r0.close() - const r2 = await pool.acquire(address1) - const r3 = await pool.acquire(address2) + const r2 = await pool.acquire({}, address1) + const r3 = await pool.acquire({}, address2) // Then expect(r0.id).toBe(0) @@ -102,19 +102,19 @@ describe('#unit Pool', () => { const destroyed = [] const address = ServerAddress.fromUrl('bolt://localhost:7687') const pool = new Pool({ - create: (server, release) => + create: (_acquisitionContext, server, release) => Promise.resolve(new Resource(server, counter++, release)), destroy: res => { destroyed.push(res) return Promise.resolve() }, - validate: res => false, + validateOnRelease: res => false, config: new PoolConfig(1000, 60000) }) // When - const r0 = await pool.acquire(address) - const r1 = await pool.acquire(address) + const r0 = await pool.acquire({}, address) + const r1 = await pool.acquire({}, address) // Then await r0.close() @@ -125,13 +125,133 @@ describe('#unit Pool', () => { expect(destroyed[1].id).toBe(r1.id) }) + it('frees if validateOnRelease returns Promise.resolve(false)', async () => { + // Given a pool that allocates + let counter = 0 + const destroyed = [] + const address = ServerAddress.fromUrl('bolt://localhost:7687') + const pool = new Pool({ + create: (_acquisitionContext, server, release) => + Promise.resolve(new Resource(server, counter++, release)), + destroy: res => { + destroyed.push(res) + return Promise.resolve() + }, + validateOnRelease: res => Promise.resolve(false), + config: new PoolConfig(1000, 60000) + }) + + // When + const r0 = await pool.acquire({}, address) + const r1 = await pool.acquire({}, address) + + // Then + await r0.close() + await r1.close() + + expect(destroyed.length).toBe(2) + expect(destroyed[0].id).toBe(r0.id) + expect(destroyed[1].id).toBe(r1.id) + }) + + it('does not free if validateOnRelease returns Promise.resolve(true)', async () => { + // Given a pool that allocates + let counter = 0 + const destroyed = [] + const address = ServerAddress.fromUrl('bolt://localhost:7687') + const pool = new Pool({ + create: (_acquisitionContext, server, release) => + Promise.resolve(new Resource(server, counter++, release)), + destroy: res => { + destroyed.push(res) + return Promise.resolve() + }, + validateOnRelease: res => Promise.resolve(true), + config: new PoolConfig(1000, 60000) + }) + + // When + const r0 = await pool.acquire({}, address) + const r1 = await pool.acquire({}, address) + + // Then + await r0.close() + await r1.close() + + expect(destroyed.length).toBe(0) + }) + + it('frees if validateOnAcquire returns Promise.resolve(false)', async () => { + // Given a pool that allocates + let counter = 0 + const destroyed = [] + const address = ServerAddress.fromUrl('bolt://localhost:7687') + const pool = new Pool({ + create: (_acquisitionContext, server, release) => + Promise.resolve(new Resource(server, counter++, release)), + destroy: res => { + destroyed.push(res) + return Promise.resolve() + }, + validateOnAcquire: res => Promise.resolve(false), + config: new PoolConfig(1000, 60000) + }) + + // When + const r0 = await pool.acquire({}, address) + const r1 = await pool.acquire({}, address) + await r1.close() + await r0.close() + + // Then + const r2 = await pool.acquire({}, address) + + // Closing + await r2.close() + + expect(destroyed.length).toBe(2) + expect(destroyed[0].id).toBe(r0.id) + expect(destroyed[1].id).toBe(r1.id) + }) + + it('does not free if validateOnAcquire returns Promise.resolve(true)', async () => { + // Given a pool that allocates + let counter = 0 + const destroyed = [] + const address = ServerAddress.fromUrl('bolt://localhost:7687') + const pool = new Pool({ + create: (_acquisitionContext, server, release) => + Promise.resolve(new Resource(server, counter++, release)), + destroy: res => { + destroyed.push(res) + return Promise.resolve() + }, + validateOnAcquire: res => Promise.resolve(true), + config: new PoolConfig(1000, 60000) + }) + + // When + const r0 = await pool.acquire({}, address) + const r1 = await pool.acquire({}, address) + await r0.close() + await r1.close() + + // Then + const r2 = await pool.acquire({}, address) + + // Closing + await r2.close() + + expect(destroyed.length).toBe(0) + }) + it('purges keys', async () => { // Given a pool that allocates let counter = 0 const address1 = ServerAddress.fromUrl('bolt://localhost:7687') const address2 = ServerAddress.fromUrl('bolt://localhost:7688') const pool = new Pool({ - create: (server, release) => + create: (_, server, release) => Promise.resolve(new Resource(server, counter++, release)), destroy: res => { res.destroyed = true @@ -140,8 +260,8 @@ describe('#unit Pool', () => { }) // When - const r0 = await pool.acquire(address1) - const r1 = await pool.acquire(address2) + const r0 = await pool.acquire({}, address1) + const r1 = await pool.acquire({}, address2) await r0.close() await r1.close() @@ -154,8 +274,8 @@ describe('#unit Pool', () => { expect(pool.has(address1)).toBeFalsy() expect(pool.has(address2)).toBeTruthy() - const r2 = await pool.acquire(address1) - const r3 = await pool.acquire(address2) + const r2 = await pool.acquire({}, address1) + const r3 = await pool.acquire({}, address2) // Then expect(r0.id).toBe(0) @@ -171,7 +291,7 @@ describe('#unit Pool', () => { const address1 = ServerAddress.fromUrl('bolt://localhost:7687') const address2 = ServerAddress.fromUrl('bolt://localhost:7688') const pool = new Pool({ - create: (server, release) => + create: (_, server, release) => Promise.resolve(new Resource(server, counter++, release)), destroy: res => { res.destroyed = true @@ -180,11 +300,11 @@ describe('#unit Pool', () => { }) // When - const r00 = await pool.acquire(address1) - const r01 = await pool.acquire(address1) - await pool.acquire(address2) - await pool.acquire(address2) - await pool.acquire(address2) + const r00 = await pool.acquire({}, address1) + const r01 = await pool.acquire({}, address1) + await pool.acquire({}, address2) + await pool.acquire({}, address2) + await pool.acquire({}, address2) expect(pool.activeResourceCount(address1)).toEqual(2) expect(pool.activeResourceCount(address2)).toEqual(3) @@ -216,7 +336,7 @@ describe('#unit Pool', () => { let counter = 0 const address = ServerAddress.fromUrl('bolt://localhost:7687') const pool = new Pool({ - create: (server, release) => + create: (_, server, release) => Promise.resolve(new Resource(server, counter++, release)), destroy: res => { res.destroyed = true @@ -224,7 +344,7 @@ describe('#unit Pool', () => { } }) - const r0 = await pool.acquire(address) + const r0 = await pool.acquire({}, address) expect(pool.has(address)).toBeTruthy() expect(r0.id).toEqual(0) @@ -241,7 +361,7 @@ describe('#unit Pool', () => { let counter = 0 const address = ServerAddress.fromUrl('bolt://localhost:7687') const pool = new Pool({ - create: (server, release) => + create: (_, server, release) => Promise.resolve(new Resource(server, counter++, release)), destroy: res => { res.destroyed = true @@ -250,7 +370,7 @@ describe('#unit Pool', () => { }) // Acquire resource - const r0 = await pool.acquire(address) + const r0 = await pool.acquire({}, address) expect(pool.has(address)).toBeTruthy() expect(r0.id).toEqual(0) @@ -260,7 +380,7 @@ describe('#unit Pool', () => { expect(r0.destroyed).toBeFalsy() // Acquiring second resource should recreate the pool - const r1 = await pool.acquire(address) + const r1 = await pool.acquire({}, address) expect(pool.has(address)).toBeTruthy() expect(r1.id).toEqual(1) @@ -283,7 +403,7 @@ describe('#unit Pool', () => { const address3 = ServerAddress.fromUrl('bolt://localhost:7689') const pool = new Pool({ - create: (server, release) => + create: (_, server, release) => Promise.resolve(new Resource(server, counter++, release)), destroy: res => { res.destroyed = true @@ -292,12 +412,12 @@ describe('#unit Pool', () => { }) const acquiredResources = [ - pool.acquire(address1), - pool.acquire(address2), - pool.acquire(address3), - pool.acquire(address1), - pool.acquire(address2), - pool.acquire(address3) + pool.acquire({}, address2), + pool.acquire({}, address3), + pool.acquire({}, address1), + pool.acquire({}, address1), + pool.acquire({}, address2), + pool.acquire({}, address3) ] const values = await Promise.all(acquiredResources) await Promise.all(values.map(resource => resource.close())) @@ -310,7 +430,7 @@ describe('#unit Pool', () => { it('should fail to acquire when closed', async () => { const address = ServerAddress.fromUrl('bolt://localhost:7687') const pool = new Pool({ - create: (server, release) => + create: (_, server, release) => Promise.resolve(new Resource(server, 0, release)), destroy: res => { return Promise.resolve() @@ -320,7 +440,7 @@ describe('#unit Pool', () => { // Close the pool await pool.close() - await expect(pool.acquire(address)).rejects.toMatchObject({ + await expect(pool.acquire({}, address)).rejects.toMatchObject({ message: expect.stringMatching('Pool is closed') }) }) @@ -329,7 +449,7 @@ describe('#unit Pool', () => { const address = ServerAddress.fromUrl('bolt://localhost:7687') const pool = new Pool({ - create: (server, release) => + create: (_, server, release) => Promise.resolve(new Resource(server, 0, release)), destroy: res => { return Promise.resolve() @@ -337,16 +457,17 @@ describe('#unit Pool', () => { }) // Acquire and release a resource - const resource = await pool.acquire(address) + const resource = await pool.acquire({}, address) await resource.close() // Close the pool await pool.close() - await expect(pool.acquire(address)).rejects.toMatchObject({ + await expect(pool.acquire({}, address)).rejects.toMatchObject({ message: expect.stringMatching('Pool is closed') }) }) + it('purges keys other than the ones to keep', async () => { let counter = 0 @@ -355,7 +476,7 @@ describe('#unit Pool', () => { const address3 = ServerAddress.fromUrl('bolt://localhost:7689') const pool = new Pool({ - create: (server, release) => + create: (_, server, release) => Promise.resolve(new Resource(server, counter++, release)), destroy: res => { res.destroyed = true @@ -364,12 +485,12 @@ describe('#unit Pool', () => { }) const acquiredResources = [ - pool.acquire(address1), - pool.acquire(address2), - pool.acquire(address3), - pool.acquire(address1), - pool.acquire(address2), - pool.acquire(address3) + pool.acquire({}, address1), + pool.acquire({}, address2), + pool.acquire({}, address3), + pool.acquire({}, address1), + pool.acquire({}, address2), + pool.acquire({}, address3) ] await Promise.all(acquiredResources) @@ -392,7 +513,7 @@ describe('#unit Pool', () => { const address3 = ServerAddress.fromUrl('bolt://localhost:7689') const pool = new Pool({ - create: (server, release) => + create: (_, server, release) => Promise.resolve(new Resource(server, counter++, release)), destroy: res => { res.destroyed = true @@ -401,12 +522,12 @@ describe('#unit Pool', () => { }) const acquiredResources = [ - pool.acquire(address1), - pool.acquire(address2), - pool.acquire(address3), - pool.acquire(address1), - pool.acquire(address2), - pool.acquire(address3) + pool.acquire({}, address1), + pool.acquire({}, address2), + pool.acquire({}, address3), + pool.acquire({}, address1), + pool.acquire({}, address2), + pool.acquire({}, address3) ] await Promise.all(acquiredResources) @@ -422,28 +543,28 @@ describe('#unit Pool', () => { }) it('skips broken connections during acquire', async () => { - let validated = false + let validated = true let counter = 0 const address = ServerAddress.fromUrl('bolt://localhost:7687') const pool = new Pool({ - create: (server, release) => + create: (_, server, release) => Promise.resolve(new Resource(server, counter++, release)), destroy: res => { res.destroyed = true return Promise.resolve() }, - validate: () => { - if (validated) { - return false + validateOnAcquire: (context, _res) => { + if (context.triggerValidation) { + validated = !validated + return validated } - validated = true return true } }) - const r0 = await pool.acquire(address) + const r0 = await pool.acquire({ triggerValidation: false }, address) await r0.close() - const r1 = await pool.acquire(address) + const r1 = await pool.acquire({ triggerValidation: true }, address) expect(r1).not.toBe(r0) }) @@ -453,12 +574,12 @@ describe('#unit Pool', () => { const absentAddress = ServerAddress.fromUrl('bolt://localhost:7688') const pool = new Pool({ - create: (server, release) => + create: (_, server, release) => Promise.resolve(new Resource(server, 42, release)) }) - await pool.acquire(existingAddress) - await pool.acquire(existingAddress) + await pool.acquire({}, existingAddress) + await pool.acquire({}, existingAddress) expect(pool.has(existingAddress)).toBeTruthy() expect(pool.has(absentAddress)).toBeFalsy() @@ -466,7 +587,7 @@ describe('#unit Pool', () => { it('reports zero active resources when empty', () => { const pool = new Pool({ - create: (server, release) => + create: (_, server, release) => Promise.resolve(new Resource(server, 42, release)) }) @@ -484,14 +605,14 @@ describe('#unit Pool', () => { it('reports active resources', async () => { const address = ServerAddress.fromUrl('bolt://localhost:7687') const pool = new Pool({ - create: (server, release) => + create: (_, server, release) => Promise.resolve(new Resource(server, 42, release)) }) const acquiredResources = [ - pool.acquire(address), - pool.acquire(address), - pool.acquire(address) + pool.acquire({}, address), + pool.acquire({}, address), + pool.acquire({}, address) ] const values = await Promise.all(acquiredResources) @@ -503,21 +624,21 @@ describe('#unit Pool', () => { it('reports active resources when they are acquired', async () => { const address = ServerAddress.fromUrl('bolt://localhost:7687') const pool = new Pool({ - create: (server, release) => + create: (_, server, release) => Promise.resolve(new Resource(server, 42, release)) }) // three new resources are created and returned to the pool - const r0 = await pool.acquire(address) - const r1 = await pool.acquire(address) - const r2 = await pool.acquire(address) + const r0 = await pool.acquire({}, address) + const r1 = await pool.acquire({}, address) + const r2 = await pool.acquire({}, address) await [r0, r1, r2].map(v => v.close()) // three idle resources are acquired from the pool const acquiredResources = [ - pool.acquire(address), - pool.acquire(address), - pool.acquire(address) + pool.acquire({}, address), + pool.acquire({}, address), + pool.acquire({}, address) ] const resources = await Promise.all(acquiredResources) @@ -531,13 +652,13 @@ describe('#unit Pool', () => { it('does not report resources that are returned to the pool', async () => { const address = ServerAddress.fromUrl('bolt://localhost:7687') const pool = new Pool({ - create: (server, release) => + create: (_, server, release) => Promise.resolve(new Resource(server, 42, release)) }) - const r0 = await pool.acquire(address) - const r1 = await pool.acquire(address) - const r2 = await pool.acquire(address) + const r0 = await pool.acquire({}, address) + const r1 = await pool.acquire({}, address) + const r2 = await pool.acquire({}, address) expect(pool.activeResourceCount(address)).toEqual(3) await r0.close() @@ -549,7 +670,7 @@ describe('#unit Pool', () => { await r2.close() expect(pool.activeResourceCount(address)).toEqual(0) - const r3 = await pool.acquire(address) + const r3 = await pool.acquire({}, address) expect(pool.activeResourceCount(address)).toEqual(1) await r3.close() @@ -561,22 +682,21 @@ describe('#unit Pool', () => { const address = ServerAddress.fromUrl('bolt://localhost:7687') const pool = new Pool({ - create: (server, release) => + create: (_, server, release) => Promise.resolve(new Resource(server, counter++, release)), destroy: res => Promise.resolve(), - validate: res => true, config: new PoolConfig(2, 5000) }) - await pool.acquire(address) - const r1 = await pool.acquire(address) + await pool.acquire({}, address) + const r1 = await pool.acquire({}, address) setTimeout(() => { expectNumberOfAcquisitionRequests(pool, address, 1) r1.close() }, 1000) - const r2 = await pool.acquire(address) + const r2 = await pool.acquire({}, address) expect(r2).toBe(r1) }) @@ -585,17 +705,16 @@ describe('#unit Pool', () => { const address = ServerAddress.fromUrl('bolt://localhost:7687') const pool = new Pool({ - create: (server, release) => + create: (_, server, release) => Promise.resolve(new Resource(server, counter++, release)), destroy: res => Promise.resolve(), - validate: res => true, config: new PoolConfig(2, 1000) }) - await pool.acquire(address) - await pool.acquire(address) + await pool.acquire({}, address) + await pool.acquire({}, address) - await expect(pool.acquire(address)).rejects.toMatchObject({ + await expect(pool.acquire({}, address)).rejects.toMatchObject({ message: expect.stringMatching('acquisition timed out') }) expectNumberOfAcquisitionRequests(pool, address, 0) @@ -608,7 +727,7 @@ describe('#unit Pool', () => { const pool = new Pool({ // Hook into connection creation to track when and what connections that are // created. - create: (server, release) => { + create: (_, server, release) => { // Create a fake connection that makes it possible control when it's connected // and released from the outer scope. const conn = { @@ -630,13 +749,13 @@ describe('#unit Pool', () => { // Make the first request for a connection, this will be hanging waiting for the // connect promise to be resolved. - const req1 = pool.acquire(address) + const req1 = pool.acquire({}, address) expect(conns.length).toEqual(1) // Make another request to the same server, this should not try to acquire another // connection since the pool will be full when the connection for the first request // is resolved. - const req2 = pool.acquire(address) + const req2 = pool.acquire({}, address) expect(conns.length).toEqual(1) // Let's fulfill the connect promise belonging to the first request. @@ -655,16 +774,15 @@ describe('#unit Pool', () => { let counter = 0 const pool = new Pool({ - create: (server, release) => + create: (_, server, release) => Promise.resolve(new Resource(server, counter++, release)), - destroy: res => Promise.resolve(), - validate: res => true + destroy: res => Promise.resolve() }) - await pool.acquire(address) - await pool.acquire(address) + await pool.acquire({}, address) + await pool.acquire({}, address) - const r2 = await pool.acquire(address) + const r2 = await pool.acquire({}, address) expect(r2.id).toEqual(2) expectNoPendingAcquisitionRequests(pool) }) @@ -675,17 +793,16 @@ describe('#unit Pool', () => { const address = ServerAddress.fromUrl('bolt://localhost:7687') const pool = new Pool({ - create: (server, release) => + create: (_, server, release) => Promise.resolve(new Resource(server, counter++, release)), destroy: res => Promise.resolve(), - validate: res => true, config: new PoolConfig(2, acquisitionTimeout) }) - const resource1 = await pool.acquire(address) + const resource1 = await pool.acquire({}, address) expect(resource1.id).toEqual(0) - const resource2 = await pool.acquire(address) + const resource2 = await pool.acquire({}, address) expect(resource2.id).toEqual(1) // try to release both resources around the time acquisition fails with timeout @@ -699,7 +816,7 @@ describe('#unit Pool', () => { // Remember that both code paths are ok with this test, either a success with a valid resource // or a time out error due to acquisition timeout being kicked in. await pool - .acquire(address) + .acquire({}, address) .then(someResource => { expect(someResource).toBeDefined() expect(someResource).not.toBeNull() @@ -718,14 +835,15 @@ describe('#unit Pool', () => { let counter = 0 const pool = new Pool({ - create: (server, release) => + create: (_, server, release) => Promise.resolve(new Resource(server, counter++, release)), destroy: res => Promise.resolve(), - validate: resourceValidOnlyOnceValidationFunction, + validateOnAcquire: (_, res) => resourceValidOnlyOnceValidationFunction(res), + validateOnRelease: resourceValidOnlyOnceValidationFunction, config: new PoolConfig(1, acquisitionTimeout) }) - const resource1 = await pool.acquire(address) + const resource1 = await pool.acquire({}, address) expect(resource1.id).toEqual(0) expect(pool.activeResourceCount(address)).toEqual(1) @@ -735,7 +853,7 @@ describe('#unit Pool', () => { resource1.close() }, acquisitionTimeout / 2) - const resource2 = await pool.acquire(address) + const resource2 = await pool.acquire({}, address) expect(resource2.id).toEqual(1) expectNoPendingAcquisitionRequests(pool) expect(pool.activeResourceCount(address)).toEqual(1) @@ -747,18 +865,19 @@ describe('#unit Pool', () => { let counter = 0 const pool = new Pool({ - create: (server, release) => + create: (_, server, release) => Promise.resolve(new Resource(server, counter++, release)), destroy: res => Promise.resolve(), - validate: resourceValidOnlyOnceValidationFunction, + validateOnAcquire: (_, res) => resourceValidOnlyOnceValidationFunction(res), + validateOnRelease: resourceValidOnlyOnceValidationFunction, config: new PoolConfig(2, acquisitionTimeout) }) - const resource1 = await pool.acquire(address) + const resource1 = await pool.acquire({}, address) expect(resource1.id).toEqual(0) expect(pool.activeResourceCount(address)).toEqual(1) - const resource2 = await pool.acquire(address) + const resource2 = await pool.acquire({}, address) expect(resource2.id).toEqual(1) expect(pool.activeResourceCount(address)).toEqual(2) @@ -769,7 +888,7 @@ describe('#unit Pool', () => { resource2.close() }, acquisitionTimeout / 2) - const resource3 = await pool.acquire(address) + const resource3 = await pool.acquire({}, address) expect(resource3.id).toEqual(2) expectNoPendingAcquisitionRequests(pool) expect(pool.activeResourceCount(address)).toEqual(1) @@ -782,10 +901,9 @@ describe('#unit Pool', () => { let removeIdleObserverCount = 0 const pool = new Pool({ - create: (server, release) => + create: (_, server, release) => Promise.resolve(new Resource(server, resourceCount++, release)), destroy: res => Promise.resolve(), - validate: res => true, installIdleObserver: (resource, observer) => { installIdleObserverCount++ }, @@ -794,17 +912,17 @@ describe('#unit Pool', () => { } }) - const r1 = await pool.acquire(address) - const r2 = await pool.acquire(address) - const r3 = await pool.acquire(address) + const r1 = await pool.acquire({}, address) + const r2 = await pool.acquire({}, address) + const r3 = await pool.acquire({}, address) await [r1, r2, r3].map(r => r.close()) expect(installIdleObserverCount).toEqual(3) expect(removeIdleObserverCount).toEqual(0) - await pool.acquire(address) - await pool.acquire(address) - await pool.acquire(address) + await pool.acquire({}, address) + await pool.acquire({}, address) + await pool.acquire({}, address) expect(installIdleObserverCount).toEqual(3) expect(removeIdleObserverCount).toEqual(3) @@ -815,10 +933,9 @@ describe('#unit Pool', () => { let resourceCount = 0 const pool = new Pool({ - create: (server, release) => + create: (_, server, release) => Promise.resolve(new Resource(server, resourceCount++, release)), destroy: res => Promise.resolve(), - validate: res => true, installIdleObserver: (resource, observer) => { resource.observer = observer }, @@ -827,8 +944,8 @@ describe('#unit Pool', () => { } }) - const resource1 = await pool.acquire(address) - const resource2 = await pool.acquire(address) + const resource1 = await pool.acquire({}, address) + const resource2 = await pool.acquire({}, address) expect(pool.activeResourceCount(address)).toBe(2) await resource1.close() @@ -855,10 +972,9 @@ describe('#unit Pool', () => { let resourceCount = 0 const pool = new Pool({ - create: (server, release) => + create: (_, server, release) => Promise.resolve(new Resource(server, resourceCount++, release)), destroy: res => Promise.resolve(), - validate: res => true, installIdleObserver: (resource, observer) => { resource.observer = observer }, @@ -867,8 +983,8 @@ describe('#unit Pool', () => { } }) - const resource1 = await pool.acquire(address) - const resource2 = await pool.acquire(address) + const resource1 = await pool.acquire({}, address) + const resource2 = await pool.acquire({}, address) await resource1.close() await resource2.close() @@ -884,17 +1000,18 @@ describe('#unit Pool', () => { let counter = 0 const pool = new Pool({ - create: (server, release) => + create: (_, server, release) => new Promise(resolve => setTimeout( () => resolve(new Resource(server, counter++, release)) , acquisitionTimeout + 10)), destroy: res => Promise.resolve(), - validate: resourceValidOnlyOnceValidationFunction, + validateOnAcquire: (_, res) => resourceValidOnlyOnceValidationFunction(res), + validateOnRelease: resourceValidOnlyOnceValidationFunction, config: new PoolConfig(1, acquisitionTimeout) }) try { - await pool.acquire(address) + await pool.acquire({}, address) fail('should have thrown') } catch (e) { expect(e).toEqual( @@ -922,7 +1039,7 @@ describe('#unit Pool', () => { }) const pool = new Pool({ - create: (server, release) => + create: (_, server, release) => Promise.resolve(new Resource(server, resourceCount++, release)), destroy: res => { resourcesReleased.push(res) @@ -933,12 +1050,11 @@ describe('#unit Pool', () => { resolveRelease() } return releasePromise - }, - validate: res => true + } }) - const resource1 = await pool.acquire(address) - const resource2 = await pool.acquire(address) + const resource1 = await pool.acquire({}, address) + const resource2 = await pool.acquire({}, address) await resource1.close() await resource2.close() @@ -948,6 +1064,196 @@ describe('#unit Pool', () => { resource2, resource1 ]) }) + + describe('when acquire force new', () => { + it('allocates if pool is empty', async () => { + // Given + let counter = 0 + const address = ServerAddress.fromUrl('bolt://localhost:7687') + const pool = new Pool({ + create: (_, server, release) => + Promise.resolve(new Resource(server, counter++, release)) + }) + + // When + const r0 = await pool.acquire({}, address) + const r1 = await pool.acquire({}, address, { requireNew: true }) + + // Then + expect(r0.id).toBe(0) + expect(r1.id).toBe(1) + expect(r0).not.toBe(r1) + }) + + it('not pools if resources are returned', async () => { + // Given a pool that allocates + let counter = 0 + const address = ServerAddress.fromUrl('bolt://localhost:7687') + const pool = new Pool({ + create: (_, server, release) => + Promise.resolve(new Resource(server, counter++, release)) + }) + + // When + const r0 = await pool.acquire({}, address) + await r0.close() + + const r1 = await pool.acquire({}, address, { requireNew: true }) + + // Then + expect(r0.id).toBe(0) + expect(r1.id).toBe(1) + expect(r0).not.toBe(r1) + }) + + it('should fail to acquire when closed', async () => { + const address = ServerAddress.fromUrl('bolt://localhost:7687') + const pool = new Pool({ + create: (_, server, release) => + Promise.resolve(new Resource(server, 0, release)), + destroy: res => { + return Promise.resolve() + } + }) + + // Close the pool + await pool.close() + + await expect(pool.acquire({}, address, { requireNew: true })).rejects.toMatchObject({ + message: expect.stringMatching('Pool is closed') + }) + }) + + it('should fail to acquire when closed with idle connections', async () => { + const address = ServerAddress.fromUrl('bolt://localhost:7687') + + const pool = new Pool({ + create: (_, server, release) => + Promise.resolve(new Resource(server, 0, release)), + destroy: res => { + return Promise.resolve() + } + }) + + // Acquire and release a resource + const resource = await pool.acquire({}, address) + await resource.close() + + // Close the pool + await pool.close() + + await expect(pool.acquire({}, address, { requireNew: true })).rejects.toMatchObject({ + message: expect.stringMatching('Pool is closed') + }) + }) + + it('should wait for a returned connection when max pool size is reached', async () => { + let counter = 0 + + const address = ServerAddress.fromUrl('bolt://localhost:7687') + const pool = new Pool({ + create: (_, server, release) => + Promise.resolve(new Resource(server, counter++, release)), + destroy: res => Promise.resolve(), + config: new PoolConfig(2, 5000) + }) + + const r0 = await pool.acquire({}, address) + const r1 = await pool.acquire({}, address, { requireNew: true }) + + setTimeout(() => { + expectNumberOfAcquisitionRequests(pool, address, 1) + r1.close() + }, 1000) + + expect(r1).not.toBe(r0) + const r2 = await pool.acquire({}, address) + expect(r2).toBe(r1) + }) + + it('should wait for a returned connection when max pool size is reached and return new', async () => { + let counter = 0 + + const address = ServerAddress.fromUrl('bolt://localhost:7687') + const pool = new Pool({ + create: (_, server, release) => + Promise.resolve(new Resource(server, counter++, release)), + destroy: res => Promise.resolve(), + config: new PoolConfig(2, 5000) + }) + + const r0 = await pool.acquire({}, address) + const r1 = await pool.acquire({}, address, { requireNew: true }) + + setTimeout(() => { + expectNumberOfAcquisitionRequests(pool, address, 1) + r1.close() + }, 1000) + + expect(r1).not.toBe(r0) + const r2 = await pool.acquire({}, address, { requireNew: true }) + expect(r2).not.toBe(r1) + }) + + it('should handle a sequence of request new and the regular request', async () => { + let counter = 0 + + const destroy = jest.fn(res => Promise.resolve()) + const removeIdleObserver = jest.fn(res => undefined) + const address = ServerAddress.fromUrl('bolt://localhost:7687') + const pool = new Pool({ + create: (_, server, release) => + Promise.resolve(new Resource(server, counter++, release)), + destroy: destroy, + removeIdleObserver: removeIdleObserver, + config: new PoolConfig(1, 5000) + }) + + const r0 = await pool.acquire({}, address, { requireNew: true }) + expect(pool.activeResourceCount(address)).toEqual(1) + expect(idleResources(pool, address)).toBe(0) + expect(resourceInUse(pool, address)).toBe(1) + + setTimeout(() => { + expectNumberOfAcquisitionRequests(pool, address, 1) + r0.close() + }, 1000) + + const r1 = await pool.acquire({}, address, { requireNew: true }) + expect(destroy).toHaveBeenCalledWith(r0) + expect(removeIdleObserver).toHaveBeenCalledWith(r0) + expect(pool.activeResourceCount(address)).toEqual(1) + expect(idleResources(pool, address)).toBe(0) + expect(resourceInUse(pool, address)).toBe(1) + + setTimeout(() => { + expectNumberOfAcquisitionRequests(pool, address, 1) + r1.close() + }, 1000) + + expect(r1).not.toBe(r0) + const r2 = await pool.acquire({}, address, { requireNew: true }) + expect(removeIdleObserver).toHaveBeenCalledWith(r1) + expect(destroy).toHaveBeenCalledWith(r1) + expect(r2).not.toBe(r1) + expect(pool.activeResourceCount(address)).toEqual(1) + expect(idleResources(pool, address)).toBe(0) + expect(resourceInUse(pool, address)).toBe(1) + + setTimeout(() => { + expectNumberOfAcquisitionRequests(pool, address, 1) + r2.close() + }, 1000) + + const r3 = await pool.acquire({}, address) + expect(r3).toBe(r2) + expect(removeIdleObserver).toHaveBeenCalledWith(r2) + expect(destroy).not.toHaveBeenCalledWith(r2) + expect(pool.activeResourceCount(address)).toEqual(1) + expect(idleResources(pool, address)).toBe(0) + expect(resourceInUse(pool, address)).toBe(1) + }) + }) }) function expectNoPendingAcquisitionRequests (pool) { @@ -973,6 +1279,13 @@ function idleResources (pool, address) { return undefined } +function resourceInUse (pool, address) { + if (pool.has(address)) { + return pool._pools[address.asKey()]._elementsInUse.size + } + return undefined +} + function expectNumberOfAcquisitionRequests (pool, address, expectedNumber) { expect(pool._acquireRequests[address.asKey()].length).toEqual(expectedNumber) } diff --git a/packages/core/src/auth-token-manager.ts b/packages/core/src/auth-token-manager.ts new file mode 100644 index 000000000..77a2961cc --- /dev/null +++ b/packages/core/src/auth-token-manager.ts @@ -0,0 +1,227 @@ +/** + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [http://neo4j.com] + * + * This file is part of Neo4j. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import auth from './auth' +import { AuthToken } from './types' +import { util } from './internal' + +/** + * Interface for the piece of software responsible for keeping track of current active {@link AuthToken} across the driver. + * @interface + * @experimental Exposed as preview feature. + * @since 5.8 + */ +export default class AuthTokenManager { + /** + * Returns a valid token. + * + * **Warning**: This method must only ever return auth information belonging to the same identity. + * Switching identities using the `AuthTokenManager` is undefined behavior. + * + * @returns {Promise|AuthToken} The valid auth token or a promise for a valid auth token + */ + getToken (): Promise | AuthToken { + throw new Error('Not Implemented') + } + + /** + * Called to notify a token expiration. + * + * @param {AuthToken} token The expired token. + * @return {void} + */ + onTokenExpired (token: AuthToken): void { + throw new Error('Not implemented') + } +} + +/** + * Interface which defines an {@link AuthToken} with an expiration data time associated + * @interface + * @experimental Exposed as preview feature. + * @since 5.8 + */ +export class AuthTokenAndExpiration { + public readonly token: AuthToken + public readonly expiration?: Date + + private constructor () { + /** + * The {@link AuthToken} used for authenticate connections. + * + * @type {AuthToken} + * @see {auth} + */ + this.token = auth.none() as AuthToken + + /** + * The expected expiration date of the auth token. + * + * This information will be used for triggering the auth token refresh + * in managers created with {@link expirationBasedAuthTokenManager}. + * + * If this value is not defined, the {@link AuthToken} will be considered valid + * until a `Neo.ClientError.Security.TokenExpired` error happens. + * + * @type {Date|undefined} + */ + this.expiration = undefined + } +} + +/** + * Creates a {@link AuthTokenManager} for handle {@link AuthToken} which is expires. + * + * **Warning**: `tokenProvider` must only ever return auth information belonging to the same identity. + * Switching identities using the `AuthTokenManager` is undefined behavior. + * + * @param {object} param0 - The params + * @param {function(): Promise} param0.tokenProvider - Retrieves a new valid auth token. + * Must only ever return auth information belonging to the same identity. + * @returns {AuthTokenManager} The temporal auth data manager. + * @experimental Exposed as preview feature. + */ +export function expirationBasedAuthTokenManager ({ tokenProvider }: { tokenProvider: () => Promise }): AuthTokenManager { + if (typeof tokenProvider !== 'function') { + throw new TypeError(`tokenProvider should be function, but got: ${typeof tokenProvider}`) + } + return new ExpirationBasedAuthTokenManager(tokenProvider) +} + +/** + * Create a {@link AuthTokenManager} for handle static {@link AuthToken} + * + * @private + * @param {param} args - The args + * @param {AuthToken} args.authToken - The static auth token which will always used in the driver. + * @returns {AuthTokenManager} The temporal auth data manager. + */ +export function staticAuthTokenManager ({ authToken }: { authToken: AuthToken }): AuthTokenManager { + return new StaticAuthTokenManager(authToken) +} + +/** + * Checks if the manager is a StaticAuthTokenManager + * + * @private + * @experimental + * @param {AuthTokenManager} manager The auth token manager to be checked. + * @returns {boolean} Manager is StaticAuthTokenManager + */ +export function isStaticAuthTokenManger (manager: AuthTokenManager): manager is StaticAuthTokenManager { + return manager instanceof StaticAuthTokenManager +} + +interface TokenRefreshObserver { + onCompleted: (data: AuthTokenAndExpiration) => void + onError: (error: Error) => void +} + +class TokenRefreshObservable implements TokenRefreshObserver { + constructor (private readonly _subscribers: TokenRefreshObserver[] = []) { + + } + + subscribe (sub: TokenRefreshObserver): void { + this._subscribers.push(sub) + } + + onCompleted (data: AuthTokenAndExpiration): void { + this._subscribers.forEach(sub => sub.onCompleted(data)) + } + + onError (error: Error): void { + this._subscribers.forEach(sub => sub.onError(error)) + } +} + +class ExpirationBasedAuthTokenManager implements AuthTokenManager { + constructor ( + private readonly _tokenProvider: () => Promise, + private _currentAuthData?: AuthTokenAndExpiration, + private _refreshObservable?: TokenRefreshObservable) { + + } + + async getToken (): Promise { + if (this._currentAuthData === undefined || + ( + this._currentAuthData.expiration !== undefined && + this._currentAuthData.expiration < new Date() + )) { + await this._refreshAuthToken() + } + + return this._currentAuthData?.token as AuthToken + } + + onTokenExpired (token: AuthToken): void { + if (util.equals(token, this._currentAuthData?.token)) { + this._scheduleRefreshAuthToken() + } + } + + private _scheduleRefreshAuthToken (observer?: TokenRefreshObserver): void { + if (this._refreshObservable === undefined) { + this._currentAuthData = undefined + this._refreshObservable = new TokenRefreshObservable() + + Promise.resolve(this._tokenProvider()) + .then(data => { + this._currentAuthData = data + this._refreshObservable?.onCompleted(data) + }) + .catch(error => { + this._refreshObservable?.onError(error) + }) + .finally(() => { + this._refreshObservable = undefined + }) + } + + if (observer !== undefined) { + this._refreshObservable.subscribe(observer) + } + } + + private async _refreshAuthToken (): Promise { + return await new Promise((resolve, reject) => { + this._scheduleRefreshAuthToken({ + onCompleted: resolve, + onError: reject + }) + }) + } +} + +class StaticAuthTokenManager implements AuthTokenManager { + constructor ( + private readonly _authToken: AuthToken + ) { + + } + + getToken (): AuthToken { + return this._authToken + } + + onTokenExpired (_: AuthToken): void { + // nothing to do here + } +} diff --git a/packages/core/src/auth.ts b/packages/core/src/auth.ts index 502166a0e..f15db8dee 100644 --- a/packages/core/src/auth.ts +++ b/packages/core/src/auth.ts @@ -53,6 +53,11 @@ const auth = { credentials: base64EncodedToken } }, + none: () => { + return { + scheme: 'none' + } + }, custom: ( principal: string, credentials: string, diff --git a/packages/core/src/connection-provider.ts b/packages/core/src/connection-provider.ts index 8a900a3a9..bd0c96f03 100644 --- a/packages/core/src/connection-provider.ts +++ b/packages/core/src/connection-provider.ts @@ -21,9 +21,10 @@ import Connection from './connection' import { bookmarks } from './internal' import { ServerInfo } from './result-summary' +import { AuthToken } from './types' /** - * Inteface define a common way to acquire a connection + * Interface define a common way to acquire a connection * * @private */ @@ -51,6 +52,7 @@ class ConnectionProvider { bookmarks: bookmarks.Bookmarks impersonatedUser?: string onDatabaseNameResolved?: (databaseName?: string) => void + auth?: AuthToken }): Promise { throw Error('Not implemented') } @@ -85,6 +87,16 @@ class ConnectionProvider { throw Error('Not implemented') } + /** + * This method checks whether the driver session re-auth functionality + * by checking protocol handshake result + * + * @returns {Promise} + */ + supportsSessionAuth (): Promise { + throw Error('Not implemented') + } + /** * This method verifies the connectivity of the database by trying to acquire a connection * for each server available in the cluster. @@ -99,6 +111,22 @@ class ConnectionProvider { throw Error('Not implemented') } + /** + * This method verifies the authorization credentials work by trying to acquire a connection + * to one of the servers with the given credentials. + * + * @param {object} param - object parameter + * @property {AuthToken} param.auth - the target auth for the to-be-acquired connection + * @property {string} param.database - the target database for the to-be-acquired connection + * @property {string} param.accessMode - the access mode for the to-be-acquired connection + * + * @returns {Promise} promise resolved with true if succeed, false if failed with + * authentication issue and rejected with error if non-authentication error happens. + */ + verifyAuthentication (param?: { auth?: AuthToken, database?: string, accessMode?: string }): Promise { + throw Error('Not implemented') + } + /** * Returns the protocol version negotiated via handshake. * diff --git a/packages/core/src/connection.ts b/packages/core/src/connection.ts index 6761adc58..80749db12 100644 --- a/packages/core/src/connection.ts +++ b/packages/core/src/connection.ts @@ -37,6 +37,13 @@ class Connection { return {} } + /** + * @property {object} authToken The auth registered in the connection + */ + get authToken (): any { + return {} + } + /** * @property {ServerAddress} the server address this connection is opened against */ @@ -51,6 +58,13 @@ class Connection { return undefined } + /** + * @property {boolean} supportsReAuth Indicates the connection supports re-auth + */ + get supportsReAuth (): boolean { + return false + } + /** * @returns {boolean} whether this connection is in a working condition */ diff --git a/packages/core/src/driver.ts b/packages/core/src/driver.ts index a3fcb2dc4..1232c77eb 100644 --- a/packages/core/src/driver.ts +++ b/packages/core/src/driver.ts @@ -38,7 +38,8 @@ import { LoggingConfig, TrustStrategy, SessionMode, - Query + Query, + AuthToken } from './types' import { ServerAddress } from './internal/server-address' import BookmarkManager, { bookmarkManager } from './bookmark-manager' @@ -96,6 +97,7 @@ type CreateSession = (args: { impersonatedUser?: string bookmarkManager?: BookmarkManager notificationFilter?: NotificationFilter + auth?: AuthToken }) => Session type CreateQueryExecutor = (createSession: (config: { database?: string, bookmarkManager?: BookmarkManager }) => Session) => QueryExecutor @@ -121,6 +123,7 @@ class SessionConfig { fetchSize?: number bookmarkManager?: BookmarkManager notificationFilter?: NotificationFilter + auth?: AuthToken /** * @constructor @@ -192,6 +195,22 @@ class SessionConfig { */ this.impersonatedUser = undefined + /** + * The {@link AuthToken} which will be used for the duration of the session. + * + * By default, the session will use connections authenticated with {@link AuthToken} configured in the + * driver creation. This configuration allows switch user and/or authorization information for the + * session lifetime. + * + * **Warning**: This option is only enable when the driver is connected with Neo4j Database servers + * which supports Bolt 5.1 and onwards. + * + * @type {AuthToken|undefined} + * @experimental Exposed as preview feature. + * @see {@link driver} + */ + this.auth = undefined + /** * The record fetch size of each batch of this session. * @@ -573,6 +592,26 @@ class Driver { return connectionProvider.verifyConnectivityAndGetServerInfo({ database, accessMode: READ }) } + /** + * This method verifies the authorization credentials work by trying to acquire a connection + * to one of the servers with the given credentials. + * + * @param {object} param - object parameter + * @property {AuthToken} param.auth - the target auth for the to-be-acquired connection + * @property {string} param.database - the target database for the to-be-acquired connection + * + * @returns {Promise} promise resolved with true if succeed, false if failed with + * authentication issue and rejected with error if non-authentication error happens. + */ + async verifyAuthentication ({ database, auth }: { auth?: AuthToken, database?: string } = {}): Promise { + const connectionProvider = this._getOrCreateConnectionProvider() + return await connectionProvider.verifyAuthentication({ + database: database ?? 'system', + auth, + accessMode: READ + }) + } + /** * Get ServerInfo for the giver database. * @@ -624,6 +663,19 @@ class Driver { return connectionProvider.supportsUserImpersonation() } + /** + * Returns whether the driver session re-auth functionality capabilities based on the protocol + * version negotiated via handshake. + * + * Note that this function call _always_ causes a round-trip to the server. + * + * @returns {Promise} promise resolved with a boolean or rejected with error. + */ + supportsSessionAuth (): Promise { + const connectionProvider = this._getOrCreateConnectionProvider() + return connectionProvider.supportsSessionAuth() + } + /** * Returns the protocol version negotiated via handshake. * @@ -696,7 +748,8 @@ class Driver { impersonatedUser, fetchSize, bookmarkManager, - notificationFilter + notificationFilter, + auth }: SessionConfig = {}): Session { return this._newSession({ defaultAccessMode, @@ -707,7 +760,8 @@ class Driver { // eslint-disable-next-line @typescript-eslint/no-non-null-assertion fetchSize: validateFetchSizeValue(fetchSize, this._config.fetchSize!), bookmarkManager, - notificationFilter + notificationFilter, + auth }) } @@ -746,7 +800,8 @@ class Driver { impersonatedUser, fetchSize, bookmarkManager, - notificationFilter + notificationFilter, + auth }: { defaultAccessMode: SessionMode bookmarkOrBookmarks?: string | string[] @@ -756,6 +811,7 @@ class Driver { fetchSize: number bookmarkManager?: BookmarkManager notificationFilter?: NotificationFilter + auth?: AuthToken }): Session { const sessionMode = Session._validateSessionMode(defaultAccessMode) const connectionProvider = this._getOrCreateConnectionProvider() @@ -773,7 +829,8 @@ class Driver { impersonatedUser, fetchSize, bookmarkManager, - notificationFilter + notificationFilter, + auth }) } diff --git a/packages/core/src/index.ts b/packages/core/src/index.ts index 5532f83f0..c329b04e3 100644 --- a/packages/core/src/index.ts +++ b/packages/core/src/index.ts @@ -87,6 +87,7 @@ import Session, { TransactionConfig } from './session' import Driver, * as driver from './driver' import auth from './auth' import BookmarkManager, { BookmarkManagerConfig, bookmarkManager } from './bookmark-manager' +import AuthTokenManager, { expirationBasedAuthTokenManager, staticAuthTokenManager, isStaticAuthTokenManger, AuthTokenAndExpiration } from './auth-token-manager' import { SessionConfig, QueryConfig, RoutingControl, routing } from './driver' import * as types from './types' import * as json from './json' @@ -163,6 +164,7 @@ const forExport = { json, auth, bookmarkManager, + expirationBasedAuthTokenManager, routing, resultTransformers, notificationCategory, @@ -230,6 +232,9 @@ export { json, auth, bookmarkManager, + expirationBasedAuthTokenManager, + staticAuthTokenManager, + isStaticAuthTokenManger, routing, resultTransformers, notificationCategory, @@ -247,6 +252,8 @@ export type { TransactionConfig, BookmarkManager, BookmarkManagerConfig, + AuthTokenManager, + AuthTokenAndExpiration, SessionConfig, QueryConfig, RoutingControl, diff --git a/packages/core/src/internal/connection-holder.ts b/packages/core/src/internal/connection-holder.ts index 17b5690d1..76dda3abf 100644 --- a/packages/core/src/internal/connection-holder.ts +++ b/packages/core/src/internal/connection-holder.ts @@ -24,6 +24,7 @@ import Connection from '../connection' import { ACCESS_MODE_WRITE } from './constants' import { Bookmarks } from './bookmarks' import ConnectionProvider from '../connection-provider' +import { AuthToken } from '../types' /** * @private @@ -85,6 +86,8 @@ class ConnectionHolder implements ConnectionHolderInterface { private readonly _impersonatedUser?: string private readonly _getConnectionAcquistionBookmarks: () => Promise private readonly _onDatabaseNameResolved?: (databaseName?: string) => void + private readonly _auth?: AuthToken + private _closed: boolean /** * @constructor @@ -96,6 +99,7 @@ class ConnectionHolder implements ConnectionHolderInterface { * @property {string?} params.impersonatedUser - the user which will be impersonated * @property {function(databaseName:string)} params.onDatabaseNameResolved - callback called when the database name is resolved * @property {function():Promise} params.getConnectionAcquistionBookmarks - called for getting Bookmarks for acquiring connections + * @property {AuthToken} params.auth - the target auth for the to-be-acquired connection */ constructor ({ mode = ACCESS_MODE_WRITE, @@ -104,7 +108,8 @@ class ConnectionHolder implements ConnectionHolderInterface { connectionProvider, impersonatedUser, onDatabaseNameResolved, - getConnectionAcquistionBookmarks + getConnectionAcquistionBookmarks, + auth }: { mode?: string database?: string @@ -113,8 +118,10 @@ class ConnectionHolder implements ConnectionHolderInterface { impersonatedUser?: string onDatabaseNameResolved?: (databaseName?: string) => void getConnectionAcquistionBookmarks?: () => Promise + auth?: AuthToken } = {}) { this._mode = mode + this._closed = false this._database = database != null ? assertString(database, 'database') : '' this._bookmarks = bookmarks ?? Bookmarks.empty() this._connectionProvider = connectionProvider @@ -122,6 +129,7 @@ class ConnectionHolder implements ConnectionHolderInterface { this._referenceCount = 0 this._connectionPromise = Promise.resolve(null) this._onDatabaseNameResolved = onDatabaseNameResolved + this._auth = auth this._getConnectionAcquistionBookmarks = getConnectionAcquistionBookmarks ?? (() => Promise.resolve(Bookmarks.empty())) } @@ -166,7 +174,8 @@ class ConnectionHolder implements ConnectionHolderInterface { database: this._database, bookmarks: await this._getBookmarks(), impersonatedUser: this._impersonatedUser, - onDatabaseNameResolved: this._onDatabaseNameResolved + onDatabaseNameResolved: this._onDatabaseNameResolved, + auth: this._auth }) } @@ -192,6 +201,7 @@ class ConnectionHolder implements ConnectionHolderInterface { } close (hasTx?: boolean): Promise { + this._closed = true if (this._referenceCount === 0) { return this._connectionPromise } diff --git a/packages/core/src/internal/util.ts b/packages/core/src/internal/util.ts index 1d7e152d7..903881f19 100644 --- a/packages/core/src/internal/util.ts +++ b/packages/core/src/internal/util.ts @@ -223,6 +223,44 @@ function isString (str: any): str is string { return Object.prototype.toString.call(str) === '[object String]' } +/** + * Verifies if object are the equals + * @param {unknown} a + * @param {unknown} b + * @returns {boolean} + */ +function equals (a: unknown, b: unknown): boolean { + if (a === b) { + return true + } + + if (a === null || b === null) { + return false + } + + if (typeof a === 'object' && typeof b === 'object') { + const keysA = Object.keys(a) + const keysB = Object.keys(b) + + if (keysA.length !== keysB.length) { + return false + } + + type AObjectKey = keyof typeof a + type BObjectKey = keyof typeof b + + for (const key of keysA) { + if (!equals(a[key as AObjectKey], b[key as BObjectKey])) { + return false + } + } + + return true + } + + return false +} + export { isEmptyObjectOrNull, isObject, @@ -233,6 +271,7 @@ export { assertNumberOrInteger, assertValidDate, validateQueryAndParameters, + equals, ENCRYPTION_ON, ENCRYPTION_OFF } diff --git a/packages/core/src/session.ts b/packages/core/src/session.ts index d29509bf9..80ee7b24a 100644 --- a/packages/core/src/session.ts +++ b/packages/core/src/session.ts @@ -30,7 +30,7 @@ import { TransactionExecutor } from './internal/transaction-executor' import { Bookmarks } from './internal/bookmarks' import { TxConfig } from './internal/tx-config' import ConnectionProvider from './connection-provider' -import { Query, SessionMode } from './types' +import { AuthToken, Query, SessionMode } from './types' import Connection from './connection' import { NumberOrInteger } from './graph-types' import TransactionPromise from './transaction-promise' @@ -86,6 +86,7 @@ class Session { * @param {boolean} args.reactive - Whether this session should create reactive streams * @param {number} args.fetchSize - Defines how many records is pulled in each pulling batch * @param {string} args.impersonatedUser - The username which the user wants to impersonate for the duration of the session. + * @param {AuthToken} args.auth - the target auth for the to-be-acquired connection * @param {NotificationFilter} args.notificationFilter - The notification filter used for this session. */ constructor ({ @@ -98,7 +99,8 @@ class Session { fetchSize, impersonatedUser, bookmarkManager, - notificationFilter + notificationFilter, + auth }: { mode: SessionMode connectionProvider: ConnectionProvider @@ -110,6 +112,7 @@ class Session { impersonatedUser?: string bookmarkManager?: BookmarkManager notificationFilter?: NotificationFilter + auth?: AuthToken }) { this._mode = mode this._database = database @@ -119,6 +122,7 @@ class Session { this._getConnectionAcquistionBookmarks = this._getConnectionAcquistionBookmarks.bind(this) this._readConnectionHolder = new ConnectionHolder({ mode: ACCESS_MODE_READ, + auth, database, bookmarks, connectionProvider, @@ -128,6 +132,7 @@ class Session { }) this._writeConnectionHolder = new ConnectionHolder({ mode: ACCESS_MODE_WRITE, + auth, database, bookmarks, connectionProvider, diff --git a/packages/core/src/types.ts b/packages/core/src/types.ts index 63140cc16..463dd4713 100644 --- a/packages/core/src/types.ts +++ b/packages/core/src/types.ts @@ -43,11 +43,12 @@ export type TrustStrategy = export interface Parameters { [key: string]: any } export interface AuthToken { scheme: string - principal: string + principal?: string credentials: string realm?: string parameters?: Parameters } + export interface Config { encrypted?: boolean | EncryptionLevel trust?: TrustStrategy diff --git a/packages/core/test/driver.test.ts b/packages/core/test/driver.test.ts index 44de0c37c..e13a77142 100644 --- a/packages/core/test/driver.test.ts +++ b/packages/core/test/driver.test.ts @@ -82,6 +82,26 @@ describe('Driver', () => { expect(createSession).toHaveBeenCalledWith(expectedSessionParams()) }) + it('should create the session with auth', () => { + const auth = { + scheme: 'basic', + principal: 'the imposter', + credentials: 'super safe password' + } + + const session = driver?.session({ auth }) + + expect(session).not.toBeUndefined() + expect(createSession).toHaveBeenCalledWith(expectedSessionParams({ auth })) + }) + + it('should create the session without auth', () => { + const session = driver?.session() + + expect(session).not.toBeUndefined() + expect(createSession).toHaveBeenCalledWith(expectedSessionParams()) + }) + it.each([ [undefined, Bookmarks.empty()], [null, Bookmarks.empty()], @@ -247,6 +267,23 @@ describe('Driver', () => { promise?.catch(_ => 'Do nothing').finally(() => {}) }) + it.each([ + ['Promise.resolve(true)', Promise.resolve(true)], + ['Promise.resolve(false)', Promise.resolve(false)], + [ + "Promise.reject(newError('something went wrong'))", + Promise.reject(newError('something went wrong')) + ] + ])('.supportsSessionAuth() => %s', (_, expectedPromise) => { + connectionProvider.supportsSessionAuth = jest.fn(() => expectedPromise) + + const promise: Promise | undefined = driver?.supportsSessionAuth() + + expect(promise).toBe(expectedPromise) + + promise?.catch(_ => 'Do nothing').finally(() => {}) + }) + it.each([ [{ encrypted: true }, true], [{ encrypted: false }, false], diff --git a/packages/core/test/session.test.ts b/packages/core/test/session.test.ts index d554468ec..803d9c86d 100644 --- a/packages/core/test/session.test.ts +++ b/packages/core/test/session.test.ts @@ -20,6 +20,7 @@ import { ConnectionProvider, Session, Connection, TransactionPromise, Transactio import { bookmarks } from '../src/internal' import { ACCESS_MODE_READ, FETCH_ALL } from '../src/internal/constants' import ManagedTransaction from '../src/transaction-managed' +import { AuthToken } from '../src/types' import FakeConnection from './utils/connection.fake' import { validNotificationFilters } from './utils/notification-filters.fixtures' @@ -432,6 +433,47 @@ describe('session', () => { expect(updateBookmarksSpy).not.toBeCalled() }) + + it('should acquire connection with auth', async () => { + const auth = { + scheme: 'bearer', + credentials: 'bearer some-nice-token' + } + const connection = mockBeginWithSuccess(newFakeConnection()) + + const { session, connectionProvider } = setupSession({ + connection, + auth, + beginTx: false, + database: 'neo4j' + }) + + await session.beginTransaction() + + expect(connectionProvider.acquireConnection).toBeCalledWith( + expect.objectContaining({ auth }) + ) + }) + + it('should acquire connection without auth', async () => { + const auth = { + scheme: 'bearer', + credentials: 'bearer some-nice-token' + } + const connection = mockBeginWithSuccess(newFakeConnection()) + + const { session, connectionProvider } = setupSession({ + connection, + beginTx: false, + database: 'neo4j' + }) + + await session.beginTransaction() + + expect(connectionProvider.acquireConnection).not.toBeCalledWith( + expect.objectContaining({ auth }) + ) + }) }) describe('.commit()', () => { @@ -835,6 +877,47 @@ describe('session', () => { }) ) }) + + it('should acquire with auth', async () => { + const auth = { + scheme: 'bearer', + credentials: 'bearer some-nice-token' + } + const connection = newFakeConnection() + + const { session, connectionProvider } = setupSession({ + connection, + auth, + beginTx: false, + database: 'neo4j' + }) + + await session.run('query') + + expect(connectionProvider.acquireConnection).toBeCalledWith( + expect.objectContaining({ auth }) + ) + }) + + it('should acquire without auth', async () => { + const auth = { + scheme: 'bearer', + credentials: 'bearer some-nice-token' + } + const connection = newFakeConnection() + + const { session, connectionProvider } = setupSession({ + connection, + beginTx: false, + database: 'neo4j' + }) + + await session.run('query') + + expect(connectionProvider.acquireConnection).not.toBeCalledWith( + expect.objectContaining({ auth }) + ) + }) }) }) @@ -887,7 +970,8 @@ function setupSession ({ database = '', lastBookmarks = bookmarks.Bookmarks.empty(), bookmarkManager, - notificationFilter + notificationFilter, + auth }: { connection: Connection beginTx?: boolean @@ -896,6 +980,7 @@ function setupSession ({ database?: string bookmarkManager?: BookmarkManager notificationFilter?: NotificationFilter + auth?: AuthToken }): { session: Session, connectionProvider: ConnectionProvider } { const connectionProvider = new ConnectionProvider() connectionProvider.acquireConnection = jest.fn(async () => await Promise.resolve(connection)) @@ -910,7 +995,8 @@ function setupSession ({ reactive: false, bookmarks: lastBookmarks, bookmarkManager, - notificationFilter + notificationFilter, + auth }) if (beginTx) { diff --git a/packages/neo4j-driver-deno/lib/bolt-connection/bolt/bolt-protocol-v1.js b/packages/neo4j-driver-deno/lib/bolt-connection/bolt/bolt-protocol-v1.js index bafc8d74e..1212bb019 100644 --- a/packages/neo4j-driver-deno/lib/bolt-connection/bolt/bolt-protocol-v1.js +++ b/packages/neo4j-driver-deno/lib/bolt-connection/bolt/bolt-protocol-v1.js @@ -205,6 +205,7 @@ export default class BoltProtocol { onError: onError }) + // TODO: Verify the Neo4j version in the message const error = newError( 'Driver is connected to a database that does not support logoff. ' + 'Please upgrade to Neo4j 5.5.0 or later in order to use this functionality.' @@ -233,6 +234,7 @@ export default class BoltProtocol { onError: (error) => this._onLoginError(error, onError) }) + // TODO: Verify the Neo4j version in the message const error = newError( 'Driver is connected to a database that does not support logon. ' + 'Please upgrade to Neo4j 5.5.0 or later in order to use this functionality.' diff --git a/packages/neo4j-driver-deno/lib/bolt-connection/connection-provider/authentication-provider.js b/packages/neo4j-driver-deno/lib/bolt-connection/connection-provider/authentication-provider.js new file mode 100644 index 000000000..9f8cc81bf --- /dev/null +++ b/packages/neo4j-driver-deno/lib/bolt-connection/connection-provider/authentication-provider.js @@ -0,0 +1,66 @@ +/** + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [http://neo4j.com] + * + * This file is part of Neo4j. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { expirationBasedAuthTokenManager } from '../../core/index.ts' +import { object } from '../lang/index.js' + +/** + * Class which provides Authorization for {@link Connection} + */ +export default class AuthenticationProvider { + constructor ({ authTokenManager, userAgent }) { + this._authTokenManager = authTokenManager || expirationBasedAuthTokenManager({ + tokenProvider: () => {} + }) + this._userAgent = userAgent + } + + async authenticate ({ connection, auth, skipReAuth, waitReAuth, forceReAuth }) { + if (auth != null) { + const shouldReAuth = connection.supportsReAuth === true && ( + (!object.equals(connection.authToken, auth) && skipReAuth !== true) || + forceReAuth === true + ) + if (connection.authToken == null || shouldReAuth) { + return await connection.connect(this._userAgent, auth, waitReAuth || false) + } + return connection + } + + const authToken = await this._authTokenManager.getToken() + + if (!object.equals(authToken, connection.authToken)) { + return await connection.connect(this._userAgent, authToken) + } + + return connection + } + + handleError ({ connection, code }) { + if ( + connection && + [ + 'Neo.ClientError.Security.Unauthorized', + 'Neo.ClientError.Security.TokenExpired' + ].includes(code) + ) { + this._authTokenManager.onTokenExpired(connection.authToken) + } + } +} diff --git a/packages/neo4j-driver-deno/lib/bolt-connection/connection-provider/connection-provider-direct.js b/packages/neo4j-driver-deno/lib/bolt-connection/connection-provider/connection-provider-direct.js index 9895701c7..bc9ddb666 100644 --- a/packages/neo4j-driver-deno/lib/bolt-connection/connection-provider/connection-provider-direct.js +++ b/packages/neo4j-driver-deno/lib/bolt-connection/connection-provider/connection-provider-direct.js @@ -26,14 +26,19 @@ import { import { internal, error } from '../../core/index.ts' const { - constants: { BOLT_PROTOCOL_V3, BOLT_PROTOCOL_V4_0, BOLT_PROTOCOL_V4_4 } + constants: { + BOLT_PROTOCOL_V3, + BOLT_PROTOCOL_V4_0, + BOLT_PROTOCOL_V4_4, + BOLT_PROTOCOL_V5_1 + } } = internal const { SERVICE_UNAVAILABLE } = error export default class DirectConnectionProvider extends PooledConnectionProvider { - constructor ({ id, config, log, address, userAgent, authToken }) { - super({ id, config, log, userAgent, authToken }) + constructor ({ id, config, log, address, userAgent, authTokenManager, newPool }) { + super({ id, config, log, userAgent, authTokenManager, newPool }) this._address = address } @@ -42,27 +47,33 @@ export default class DirectConnectionProvider extends PooledConnectionProvider { * See {@link ConnectionProvider} for more information about this method and * its arguments. */ - acquireConnection ({ accessMode, database, bookmarks } = {}) { + async acquireConnection ({ accessMode, database, bookmarks, auth, forceReAuth } = {}) { const databaseSpecificErrorHandler = ConnectionErrorHandler.create({ errorCode: SERVICE_UNAVAILABLE, - handleAuthorizationExpired: (error, address) => - this._handleAuthorizationExpired(error, address, database) + handleAuthorizationExpired: (error, address, conn) => + this._handleAuthorizationExpired(error, address, conn, database) }) - return this._connectionPool - .acquire(this._address) - .then( - connection => - new DelegateConnection(connection, databaseSpecificErrorHandler) - ) + const connection = await this._connectionPool.acquire({ auth, forceReAuth }, this._address) + + if (auth) { + await this._verifyStickyConnection({ + auth, + connection, + address: this._address + }) + return connection + } + + return new DelegateConnection(connection, databaseSpecificErrorHandler) } - _handleAuthorizationExpired (error, address, database) { + _handleAuthorizationExpired (error, address, connection, database) { this._log.warn( `Direct driver ${this._id} will close connection to ${address} for database '${database}' because of an error ${error.code} '${error.message}'` ) - this._connectionPool.purge(address).catch(() => {}) - return error + + return super._handleAuthorizationExpired(error, address, connection) } async _hasProtocolVersion (versionPredicate) { @@ -111,6 +122,19 @@ export default class DirectConnectionProvider extends PooledConnectionProvider { ) } + async supportsSessionAuth () { + return await this._hasProtocolVersion( + version => version >= BOLT_PROTOCOL_V5_1 + ) + } + + async verifyAuthentication ({ auth }) { + return this._verifyAuthentication({ + auth, + getAddress: () => this._address + }) + } + async verifyConnectivityAndGetServerInfo () { return await this._verifyConnectivityAndGetServerVersion({ address: this._address }) } diff --git a/packages/neo4j-driver-deno/lib/bolt-connection/connection-provider/connection-provider-pooled.js b/packages/neo4j-driver-deno/lib/bolt-connection/connection-provider/connection-provider-pooled.js index 03651b15b..0b0fc5bfa 100644 --- a/packages/neo4j-driver-deno/lib/bolt-connection/connection-provider/connection-provider-pooled.js +++ b/packages/neo4j-driver-deno/lib/bolt-connection/connection-provider/connection-provider-pooled.js @@ -19,12 +19,21 @@ import { createChannelConnection, ConnectionErrorHandler } from '../connection/index.js' import Pool, { PoolConfig } from '../pool/index.js' -import { error, ConnectionProvider, ServerInfo } from '../../core/index.ts' +import { error, ConnectionProvider, ServerInfo, newError, isStaticAuthTokenManger } from '../../core/index.ts' +import AuthenticationProvider from './authentication-provider.js' +import { object } from '../lang/index.js' const { SERVICE_UNAVAILABLE } = error +const AUTHENTICATION_ERRORS = [ + 'Neo.ClientError.Security.CredentialsExpired', + 'Neo.ClientError.Security.Forbidden', + 'Neo.ClientError.Security.TokenExpired', + 'Neo.ClientError.Security.Unauthorized' +] + export default class PooledConnectionProvider extends ConnectionProvider { constructor ( - { id, config, log, userAgent, authToken }, + { id, config, log, userAgent, authTokenManager, newPool = (...args) => new Pool(...args) }, createChannelConnectionHook = null ) { super() @@ -32,8 +41,8 @@ export default class PooledConnectionProvider extends ConnectionProvider { this._id = id this._config = config this._log = log - this._userAgent = userAgent - this._authToken = authToken + this._authTokenManager = authTokenManager + this._authenticationProvider = new AuthenticationProvider({ authTokenManager, userAgent }) this._createChannelConnection = createChannelConnectionHook || (address => { @@ -44,10 +53,11 @@ export default class PooledConnectionProvider extends ConnectionProvider { this._log ) }) - this._connectionPool = new Pool({ + this._connectionPool = newPool({ create: this._createConnection.bind(this), destroy: this._destroyConnection.bind(this), - validate: this._validateConnection.bind(this), + validateOnAcquire: this._validateConnectionOnAcquire.bind(this), + validateOnRelease: this._validateConnectionOnRelease.bind(this), installIdleObserver: PooledConnectionProvider._installIdleObserverOnConnection.bind( this ), @@ -57,6 +67,7 @@ export default class PooledConnectionProvider extends ConnectionProvider { config: PoolConfig.fromDriverConfig(config), log: this._log }) + this._userAgent = userAgent this._openConnections = {} } @@ -69,14 +80,13 @@ export default class PooledConnectionProvider extends ConnectionProvider { * @return {Promise} promise resolved with a new connection or rejected when failed to connect. * @access private */ - _createConnection (address, release) { + _createConnection ({ auth }, address, release) { return this._createChannelConnection(address).then(connection => { connection._release = () => { return release(address, connection) } this._openConnections[connection.id] = connection - return connection - .connect(this._userAgent, this._authToken) + return this._authenticationProvider.authenticate({ connection, auth }) .catch(error => { // let's destroy this connection this._destroyConnection(connection) @@ -86,6 +96,26 @@ export default class PooledConnectionProvider extends ConnectionProvider { }) } + async _validateConnectionOnAcquire ({ auth, skipReAuth }, conn) { + if (!this._validateConnection(conn)) { + return false + } + + try { + await this._authenticationProvider.authenticate({ connection: conn, auth, skipReAuth }) + return true + } catch (error) { + this._log.debug( + `The connection ${conn.id} is not valid because of an error ${error.code} '${error.message}'` + ) + return false + } + } + + _validateConnectionOnRelease (conn) { + return conn._sticky !== true && this._validateConnection(conn) + } + /** * Check that a connection is usable * @return {boolean} true if the connection is open @@ -98,7 +128,11 @@ export default class PooledConnectionProvider extends ConnectionProvider { const maxConnectionLifetime = this._config.maxConnectionLifetime const lifetime = Date.now() - conn.creationTimestamp - return lifetime <= maxConnectionLifetime + if (lifetime > maxConnectionLifetime) { + return false + } + + return true } /** @@ -118,7 +152,7 @@ export default class PooledConnectionProvider extends ConnectionProvider { * @return {Promise} the server info */ async _verifyConnectivityAndGetServerVersion ({ address }) { - const connection = await this._connectionPool.acquire(address) + const connection = await this._connectionPool.acquire({}, address) const serverInfo = new ServerInfo(connection.server, connection.protocol().version) try { if (!connection.protocol().isLastMessageLogon()) { @@ -130,6 +164,47 @@ export default class PooledConnectionProvider extends ConnectionProvider { return serverInfo } + async _verifyAuthentication ({ getAddress, auth }) { + const connectionsToRelease = [] + try { + const address = await getAddress() + const connection = await this._connectionPool.acquire({ auth, skipReAuth: true }, address) + connectionsToRelease.push(connection) + + const lastMessageIsNotLogin = !connection.protocol().isLastMessageLogon() + + if (!connection.supportsReAuth) { + throw newError('Driver is connected to a database that does not support user switch.') + } + if (lastMessageIsNotLogin && connection.supportsReAuth) { + await this._authenticationProvider.authenticate({ connection, auth, waitReAuth: true, forceReAuth: true }) + } else if (lastMessageIsNotLogin && !connection.supportsReAuth) { + const stickyConnection = await this._connectionPool.acquire({ auth }, address, { requireNew: true }) + stickyConnection._sticky = true + connectionsToRelease.push(stickyConnection) + } + return true + } catch (error) { + if (AUTHENTICATION_ERRORS.includes(error.code)) { + return false + } + throw error + } finally { + await Promise.all(connectionsToRelease.map(conn => conn._release())) + } + } + + async _verifyStickyConnection ({ auth, connection, address }) { + const connectionWithSameCredentials = object.equals(auth, connection.authToken) + const shouldCreateStickyConnection = !connectionWithSameCredentials + connection._sticky = connectionWithSameCredentials && !connection.supportsReAuth + + if (shouldCreateStickyConnection || connection._sticky) { + await connection._release() + throw newError('Driver is connected to a database that does not support user switch.') + } + } + async close () { // purge all idle connections in the connection pool await this._connectionPool.close() @@ -146,4 +221,22 @@ export default class PooledConnectionProvider extends ConnectionProvider { static _removeIdleObserverOnConnection (conn) { conn._updateCurrentObserver() } + + _handleAuthorizationExpired (error, address, connection) { + this._authenticationProvider.handleError({ connection, code: error.code }) + + if (error.code === 'Neo.ClientError.Security.AuthorizationExpired') { + this._connectionPool.apply(address, (conn) => { conn.authToken = null }) + } + + if (connection) { + connection.close().catch(() => undefined) + } + + if (error.code === 'Neo.ClientError.Security.TokenExpired' && !isStaticAuthTokenManger(this._authTokenManager)) { + error.retriable = true + } + + return error + } } diff --git a/packages/neo4j-driver-deno/lib/bolt-connection/connection-provider/connection-provider-routing.js b/packages/neo4j-driver-deno/lib/bolt-connection/connection-provider/connection-provider-routing.js index 95480286c..c66d09b97 100644 --- a/packages/neo4j-driver-deno/lib/bolt-connection/connection-provider/connection-provider-routing.js +++ b/packages/neo4j-driver-deno/lib/bolt-connection/connection-provider/connection-provider-routing.js @@ -37,7 +37,8 @@ const { ACCESS_MODE_WRITE: WRITE, BOLT_PROTOCOL_V3, BOLT_PROTOCOL_V4_0, - BOLT_PROTOCOL_V4_4 + BOLT_PROTOCOL_V4_4, + BOLT_PROTOCOL_V5_1 } } = internal @@ -51,6 +52,7 @@ const AUTHORIZATION_EXPIRED_CODE = const INVALID_ARGUMENT_ERROR = 'Neo.ClientError.Statement.ArgumentError' const INVALID_REQUEST_ERROR = 'Neo.ClientError.Request.Invalid' const STATEMENT_TYPE_ERROR = 'Neo.ClientError.Statement.TypeError' +const NOT_AVAILABLE = 'N/A' const SYSTEM_DB_NAME = 'system' const DEFAULT_DB_NAME = null @@ -65,10 +67,11 @@ export default class RoutingConnectionProvider extends PooledConnectionProvider config, log, userAgent, - authToken, - routingTablePurgeDelay + authTokenManager, + routingTablePurgeDelay, + newPool }) { - super({ id, config, log, userAgent, authToken }, address => { + super({ id, config, log, userAgent, authTokenManager, newPool }, address => { return createChannelConnection( address, this._config, @@ -109,12 +112,12 @@ export default class RoutingConnectionProvider extends PooledConnectionProvider return error } - _handleAuthorizationExpired (error, address, database) { + _handleAuthorizationExpired (error, address, connection, database) { this._log.warn( `Routing driver ${this._id} will close connections to ${address} for database '${database}' because of an error ${error.code} '${error.message}'` ) - this._connectionPool.purge(address).catch(() => {}) - return error + + return super._handleAuthorizationExpired(error, address, connection, database) } _handleWriteFailure (error, address, database) { @@ -133,7 +136,7 @@ export default class RoutingConnectionProvider extends PooledConnectionProvider * See {@link ConnectionProvider} for more information about this method and * its arguments. */ - async acquireConnection ({ accessMode, database, bookmarks, impersonatedUser, onDatabaseNameResolved } = {}) { + async acquireConnection ({ accessMode, database, bookmarks, impersonatedUser, onDatabaseNameResolved, auth } = {}) { let name let address const context = { database: database || DEFAULT_DB_NAME } @@ -142,8 +145,8 @@ export default class RoutingConnectionProvider extends PooledConnectionProvider SESSION_EXPIRED, (error, address) => this._handleUnavailability(error, address, context.database), (error, address) => this._handleWriteFailure(error, address, context.database), - (error, address) => - this._handleAuthorizationExpired(error, address, context.database) + (error, address, conn) => + this._handleAuthorizationExpired(error, address, conn, context.database) ) const routingTable = await this._freshRoutingTable({ @@ -151,6 +154,7 @@ export default class RoutingConnectionProvider extends PooledConnectionProvider database: context.database, bookmarks, impersonatedUser, + auth, onDatabaseNameResolved: (databaseName) => { context.database = context.database || databaseName if (onDatabaseNameResolved) { @@ -179,11 +183,16 @@ export default class RoutingConnectionProvider extends PooledConnectionProvider } try { - const connection = await this._acquireConnectionToServer( - address, - name, - routingTable - ) + const connection = await this._connectionPool.acquire({ auth }, address) + + if (auth) { + await this._verifyStickyConnection({ + auth, + connection, + address + }) + return connection + } return new DelegateConnection(connection, databaseSpecificErrorHandler) } catch (error) { @@ -248,6 +257,12 @@ export default class RoutingConnectionProvider extends PooledConnectionProvider ) } + async supportsSessionAuth () { + return await this._hasProtocolVersion( + version => version >= BOLT_PROTOCOL_V5_1 + ) + } + getNegotiatedProtocolVersion () { return new Promise((resolve, reject) => { this._hasProtocolVersion(resolve) @@ -255,6 +270,35 @@ export default class RoutingConnectionProvider extends PooledConnectionProvider }) } + async verifyAuthentication ({ database, accessMode, auth }) { + return this._verifyAuthentication({ + auth, + getAddress: async () => { + const context = { database: database || DEFAULT_DB_NAME } + + const routingTable = await this._freshRoutingTable({ + accessMode, + database: context.database, + auth, + onDatabaseNameResolved: (databaseName) => { + context.database = context.database || databaseName + } + }) + + const servers = accessMode === WRITE ? routingTable.writers : routingTable.readers + + if (servers.length === 0) { + throw newError( + `No servers available for database '${context.database}' with access mode '${accessMode}'`, + SERVICE_UNAVAILABLE + ) + } + + return servers[0] + } + }) + } + async verifyConnectivityAndGetServerInfo ({ database, accessMode }) { const context = { database: database || DEFAULT_DB_NAME } @@ -300,11 +344,7 @@ export default class RoutingConnectionProvider extends PooledConnectionProvider }) } - _acquireConnectionToServer (address, serverName, routingTable) { - return this._connectionPool.acquire(address) - } - - _freshRoutingTable ({ accessMode, database, bookmarks, impersonatedUser, onDatabaseNameResolved } = {}) { + _freshRoutingTable ({ accessMode, database, bookmarks, impersonatedUser, onDatabaseNameResolved, auth } = {}) { const currentRoutingTable = this._routingTableRegistry.get( database, () => new RoutingTable({ database }) @@ -316,10 +356,10 @@ export default class RoutingConnectionProvider extends PooledConnectionProvider this._log.info( `Routing table is stale for database: "${database}" and access mode: "${accessMode}": ${currentRoutingTable}` ) - return this._refreshRoutingTable(currentRoutingTable, bookmarks, impersonatedUser, onDatabaseNameResolved) + return this._refreshRoutingTable(currentRoutingTable, bookmarks, impersonatedUser, onDatabaseNameResolved, auth) } - _refreshRoutingTable (currentRoutingTable, bookmarks, impersonatedUser, onDatabaseNameResolved) { + _refreshRoutingTable (currentRoutingTable, bookmarks, impersonatedUser, onDatabaseNameResolved, auth) { const knownRouters = currentRoutingTable.routers if (this._useSeedRouter) { @@ -328,7 +368,8 @@ export default class RoutingConnectionProvider extends PooledConnectionProvider currentRoutingTable, bookmarks, impersonatedUser, - onDatabaseNameResolved + onDatabaseNameResolved, + auth ) } return this._fetchRoutingTableFromKnownRoutersFallbackToSeedRouter( @@ -336,7 +377,8 @@ export default class RoutingConnectionProvider extends PooledConnectionProvider currentRoutingTable, bookmarks, impersonatedUser, - onDatabaseNameResolved + onDatabaseNameResolved, + auth ) } @@ -345,7 +387,8 @@ export default class RoutingConnectionProvider extends PooledConnectionProvider currentRoutingTable, bookmarks, impersonatedUser, - onDatabaseNameResolved + onDatabaseNameResolved, + auth ) { // we start with seed router, no routers were probed before const seenRouters = [] @@ -354,7 +397,8 @@ export default class RoutingConnectionProvider extends PooledConnectionProvider this._seedRouter, currentRoutingTable, bookmarks, - impersonatedUser + impersonatedUser, + auth ) if (newRoutingTable) { @@ -365,7 +409,8 @@ export default class RoutingConnectionProvider extends PooledConnectionProvider knownRouters, currentRoutingTable, bookmarks, - impersonatedUser + impersonatedUser, + auth ) newRoutingTable = newRoutingTable2 error = error2 || error @@ -384,13 +429,15 @@ export default class RoutingConnectionProvider extends PooledConnectionProvider currentRoutingTable, bookmarks, impersonatedUser, - onDatabaseNameResolved + onDatabaseNameResolved, + auth ) { let [newRoutingTable, error] = await this._fetchRoutingTableUsingKnownRouters( knownRouters, currentRoutingTable, bookmarks, - impersonatedUser + impersonatedUser, + auth ) if (!newRoutingTable) { @@ -400,7 +447,8 @@ export default class RoutingConnectionProvider extends PooledConnectionProvider this._seedRouter, currentRoutingTable, bookmarks, - impersonatedUser + impersonatedUser, + auth ) } @@ -416,13 +464,15 @@ export default class RoutingConnectionProvider extends PooledConnectionProvider knownRouters, currentRoutingTable, bookmarks, - impersonatedUser + impersonatedUser, + auth ) { const [newRoutingTable, error] = await this._fetchRoutingTable( knownRouters, currentRoutingTable, bookmarks, - impersonatedUser + impersonatedUser, + auth ) if (newRoutingTable) { @@ -447,7 +497,8 @@ export default class RoutingConnectionProvider extends PooledConnectionProvider seedRouter, routingTable, bookmarks, - impersonatedUser + impersonatedUser, + auth ) { const resolvedAddresses = await this._resolveSeedRouter(seedRouter) @@ -456,7 +507,7 @@ export default class RoutingConnectionProvider extends PooledConnectionProvider address => seenRouters.indexOf(address) < 0 ) - return await this._fetchRoutingTable(newAddresses, routingTable, bookmarks, impersonatedUser) + return await this._fetchRoutingTable(newAddresses, routingTable, bookmarks, impersonatedUser, auth) } async _resolveSeedRouter (seedRouter) { @@ -468,7 +519,7 @@ export default class RoutingConnectionProvider extends PooledConnectionProvider return [].concat.apply([], dnsResolvedAddresses) } - async _fetchRoutingTable (routerAddresses, routingTable, bookmarks, impersonatedUser) { + async _fetchRoutingTable (routerAddresses, routingTable, bookmarks, impersonatedUser, auth) { return routerAddresses.reduce( async (refreshedTablePromise, currentRouter, currentIndex) => { const [newRoutingTable] = await refreshedTablePromise @@ -491,7 +542,8 @@ export default class RoutingConnectionProvider extends PooledConnectionProvider const [session, error] = await this._createSessionForRediscovery( currentRouter, bookmarks, - impersonatedUser + impersonatedUser, + auth ) if (session) { try { @@ -516,17 +568,28 @@ export default class RoutingConnectionProvider extends PooledConnectionProvider ) } - async _createSessionForRediscovery (routerAddress, bookmarks, impersonatedUser) { + async _createSessionForRediscovery (routerAddress, bookmarks, impersonatedUser, auth) { try { - const connection = await this._connectionPool.acquire(routerAddress) + const connection = await this._connectionPool.acquire({ auth }, routerAddress) + + if (auth) { + await this._verifyStickyConnection({ + auth, + connection, + address: routerAddress + }) + } const databaseSpecificErrorHandler = ConnectionErrorHandler.create({ errorCode: SESSION_EXPIRED, - handleAuthorizationExpired: (error, address) => this._handleAuthorizationExpired(error, address) + handleAuthorizationExpired: (error, address, conn) => this._handleAuthorizationExpired(error, address, conn) }) - const connectionProvider = new SingleConnectionProvider( - new DelegateConnection(connection, databaseSpecificErrorHandler)) + const delegateConnection = !connection._sticky + ? new DelegateConnection(connection, databaseSpecificErrorHandler) + : new DelegateConnection(connection) + + const connectionProvider = new SingleConnectionProvider(delegateConnection) const protocolVersion = connection.protocol().version if (protocolVersion < 4.0) { @@ -709,7 +772,8 @@ function _isFailFastError (error) { INVALID_BOOKMARK_MIXTURE_CODE, INVALID_ARGUMENT_ERROR, INVALID_REQUEST_ERROR, - STATEMENT_TYPE_ERROR + STATEMENT_TYPE_ERROR, + NOT_AVAILABLE ].includes(error.code) } diff --git a/packages/neo4j-driver-deno/lib/bolt-connection/connection/connection-channel.js b/packages/neo4j-driver-deno/lib/bolt-connection/connection/connection-channel.js index 79dd0e4b9..a974d17c7 100644 --- a/packages/neo4j-driver-deno/lib/bolt-connection/connection/connection-channel.js +++ b/packages/neo4j-driver-deno/lib/bolt-connection/connection/connection-channel.js @@ -124,7 +124,7 @@ export default class ChannelConnection extends Connection { protocolSupplier ) { super(errorHandler) - + this._authToken = null this._reseting = false this._resetObservers = [] this._id = idGenerator++ @@ -156,6 +156,18 @@ export default class ChannelConnection extends Connection { } } + get authToken () { + return this._authToken + } + + set authToken (value) { + this._authToken = value + } + + get supportsReAuth () { + return this._protocol.supportsReAuth + } + get id () { return this._id } @@ -174,8 +186,36 @@ export default class ChannelConnection extends Connection { * @param {Object} authToken the object containing auth information. * @return {Promise} promise resolved with the current connection if connection is successful. Rejected promise otherwise. */ - connect (userAgent, authToken) { - return this._initialize(userAgent, authToken) + async connect (userAgent, authToken, waitReAuth) { + if (this._protocol.initialized && !this._protocol.supportsReAuth) { + throw newError('Connection does not support re-auth') + } + + this._authToken = authToken + + if (!this._protocol.initialized) { + return await this._initialize(userAgent, authToken) + } + + if (waitReAuth) { + return await new Promise((resolve, reject) => { + this._protocol.logoff({ + onError: reject + }) + + this._protocol.logon({ + authToken, + onError: reject, + onComplete: () => resolve(this), + flush: true + }) + }) + } + + this._protocol.logoff() + this._protocol.logon({ authToken, flush: true }) + + return this } /** diff --git a/packages/neo4j-driver-deno/lib/bolt-connection/connection/connection-delegate.js b/packages/neo4j-driver-deno/lib/bolt-connection/connection/connection-delegate.js index 6d195d1d9..70b43b2fa 100644 --- a/packages/neo4j-driver-deno/lib/bolt-connection/connection/connection-delegate.js +++ b/packages/neo4j-driver-deno/lib/bolt-connection/connection/connection-delegate.js @@ -51,6 +51,18 @@ export default class DelegateConnection extends Connection { return this._delegate.server } + get authToken () { + return this._delegate.authToken + } + + get supportsReAuth () { + return this._delegate.supportsReAuth + } + + set authToken (value) { + this._delegate.authToken = value + } + get address () { return this._delegate.address } @@ -71,8 +83,8 @@ export default class DelegateConnection extends Connection { return this._delegate.protocol() } - connect (userAgent, authToken) { - return this._delegate.connect(userAgent, authToken) + connect (userAgent, authToken, waitReAuth) { + return this._delegate.connect(userAgent, authToken, waitReAuth) } write (message, observer, flush) { diff --git a/packages/neo4j-driver-deno/lib/bolt-connection/connection/connection-error-handler.js b/packages/neo4j-driver-deno/lib/bolt-connection/connection/connection-error-handler.js index 8544c4c10..ebe305e26 100644 --- a/packages/neo4j-driver-deno/lib/bolt-connection/connection/connection-error-handler.js +++ b/packages/neo4j-driver-deno/lib/bolt-connection/connection/connection-error-handler.js @@ -62,15 +62,15 @@ export default class ConnectionErrorHandler { * @param {ServerAddress} address the address of the connection where the error happened. * @return {Neo4jError} new error that should be propagated to the user. */ - handleAndTransformError (error, address) { + handleAndTransformError (error, address, connection) { if (isAutorizationExpiredError(error)) { - return this._handleAuthorizationExpired(error, address) + return this._handleAuthorizationExpired(error, address, connection) } if (isAvailabilityError(error)) { - return this._handleUnavailability(error, address) + return this._handleUnavailability(error, address, connection) } if (isFailureToWrite(error)) { - return this._handleWriteFailure(error, address) + return this._handleWriteFailure(error, address, connection) } return error } diff --git a/packages/neo4j-driver-deno/lib/bolt-connection/connection/connection.js b/packages/neo4j-driver-deno/lib/bolt-connection/connection/connection.js index dc996522c..d9856921d 100644 --- a/packages/neo4j-driver-deno/lib/bolt-connection/connection/connection.js +++ b/packages/neo4j-driver-deno/lib/bolt-connection/connection/connection.js @@ -39,6 +39,18 @@ export default class Connection { throw new Error('not implemented') } + get authToken () { + throw new Error('not implemented') + } + + set authToken (value) { + throw new Error('not implemented') + } + + get supportsReAuth () { + throw new Error('not implemented') + } + /** * @returns {boolean} whether this connection is in a working condition */ @@ -124,7 +136,7 @@ export default class Connection { */ handleAndTransformError (error, address) { if (this._errorHandler) { - return this._errorHandler.handleAndTransformError(error, address) + return this._errorHandler.handleAndTransformError(error, address, this) } return error diff --git a/packages/neo4j-driver-deno/lib/bolt-connection/lang/index.js b/packages/neo4j-driver-deno/lib/bolt-connection/lang/index.js index 2c7efd846..15b8e8340 100644 --- a/packages/neo4j-driver-deno/lib/bolt-connection/lang/index.js +++ b/packages/neo4j-driver-deno/lib/bolt-connection/lang/index.js @@ -18,3 +18,4 @@ */ export * as functional from './functional.js' +export * as object from './object.js' diff --git a/packages/neo4j-driver-deno/lib/bolt-connection/lang/object.js b/packages/neo4j-driver-deno/lib/bolt-connection/lang/object.js new file mode 100644 index 000000000..e2a862c04 --- /dev/null +++ b/packages/neo4j-driver-deno/lib/bolt-connection/lang/object.js @@ -0,0 +1,47 @@ +/** + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [http://neo4j.com] + * + * This file is part of Neo4j. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +export function equals (a, b) { + if (a === b) { + return true + } + + if (a === null || b === null) { + return false + } + + if (typeof a === 'object' && typeof b === 'object') { + const keysA = Object.keys(a) + const keysB = Object.keys(b) + + if (keysA.length !== keysB.length) { + return false + } + + for (const key of keysA) { + if (a[key] !== b[key]) { + return false + } + } + + return true + } + + return false +} diff --git a/packages/neo4j-driver-deno/lib/bolt-connection/pool/pool.js b/packages/neo4j-driver-deno/lib/bolt-connection/pool/pool.js index 148c5f1b9..c1789564c 100644 --- a/packages/neo4j-driver-deno/lib/bolt-connection/pool/pool.js +++ b/packages/neo4j-driver-deno/lib/bolt-connection/pool/pool.js @@ -26,15 +26,18 @@ const { class Pool { /** - * @param {function(address: ServerAddress, function(address: ServerAddress, resource: object): Promise): Promise} create + * @param {function(acquisitionContext: object, address: ServerAddress, function(address: ServerAddress, resource: object): Promise): Promise} create * an allocation function that creates a promise with a new resource. It's given an address for which to * allocate the connection and a function that will return the resource to the pool if invoked, which is * meant to be called on .dispose or .close or whatever mechanism the resource uses to finalize. + * @param {function(acquisitionContext: object, resource: object): boolean} validateOnAcquire + * called at various times when an instance is acquired + * If this returns false, the resource will be evicted + * @param {function(resource: object): boolean} validateOnRelease + * called at various times when an instance is released + * If this returns false, the resource will be evicted * @param {function(resource: object): Promise} destroy * called with the resource when it is evicted from this pool - * @param {function(resource: object): boolean} validate - * called at various times (like when an instance is acquired and when it is returned. - * If this returns false, the resource will be evicted * @param {function(resource: object, observer: { onError }): void} installIdleObserver * called when the resource is released back to pool * @param {function(resource: object): void} removeIdleObserver @@ -43,9 +46,10 @@ class Pool { * @param {Logger} log the driver logger. */ constructor ({ - create = (address, release) => Promise.resolve(), + create = (acquisitionContext, address, release) => Promise.resolve(), destroy = conn => Promise.resolve(), - validate = conn => true, + validateOnAcquire = (acquisitionContext, conn) => true, + validateOnRelease = (conn) => true, installIdleObserver = (conn, observer) => {}, removeIdleObserver = conn => {}, config = PoolConfig.defaultConfig(), @@ -53,7 +57,8 @@ class Pool { } = {}) { this._create = create this._destroy = destroy - this._validate = validate + this._validateOnAcquire = validateOnAcquire + this._validateOnRelease = validateOnRelease this._installIdleObserver = installIdleObserver this._removeIdleObserver = removeIdleObserver this._maxSize = config.maxSize @@ -69,10 +74,13 @@ class Pool { /** * Acquire and idle resource fom the pool or create a new one. + * @param {object} acquisitionContext the acquisition context used for create and validateOnAcquire connection * @param {ServerAddress} address the address for which we're acquiring. + * @param {object} config the config + * @param {boolean} config.requireNew Indicate it requires a new resource * @return {Promise} resource that is ready to use. */ - acquire (address) { + acquire (acquisitionContext, address, config) { const key = address.asKey() // We're out of resources and will try to acquire later on when an existing resource is released. @@ -108,7 +116,7 @@ class Pool { } }, this._acquisitionTimeout) - request = new PendingRequest(key, resolve, reject, timeoutId, this._log) + request = new PendingRequest(key, acquisitionContext, config, resolve, reject, timeoutId, this._log) allRequests[key].push(request) this._processPendingAcquireRequests(address) }) @@ -123,6 +131,14 @@ class Pool { return this._purgeKey(address.asKey()) } + apply (address, resourceConsumer) { + const key = address.asKey() + + if (key in this._pools) { + this._pools[key].apply(resourceConsumer) + } + } + /** * Destroy all idle resources in this pool. * @returns {Promise} A promise that is resolved when the resources are purged @@ -185,29 +201,32 @@ class Pool { return pool } - async _acquire (address) { + async _acquire (acquisitionContext, address, requireNew) { if (this._closed) { throw newError('Pool is closed, it is no more able to serve requests.') } const key = address.asKey() const pool = this._getOrInitializePoolFor(key) - while (pool.length) { - const resource = pool.pop() + if (!requireNew) { + while (pool.length) { + const resource = pool.pop() - if (this._validate(resource)) { if (this._removeIdleObserver) { this._removeIdleObserver(resource) } - // idle resource is valid and can be acquired - resourceAcquired(key, this._activeResourceCounts) - if (this._log.isDebugEnabled()) { - this._log.debug(`${resource} acquired from the pool ${key}`) + if (await this._validateOnAcquire(acquisitionContext, resource)) { + // idle resource is valid and can be acquired + resourceAcquired(key, this._activeResourceCounts) + if (this._log.isDebugEnabled()) { + this._log.debug(`${resource} acquired from the pool ${key}`) + } + return { resource, pool } + } else { + pool.removeInUse(resource) + await this._destroy(resource) } - return { resource, pool } - } else { - await this._destroy(resource) } } @@ -228,9 +247,19 @@ class Pool { this._pendingCreates[key] = this._pendingCreates[key] + 1 let resource try { - // Invoke callback that creates actual connection - resource = await this._create(address, (address, resource) => this._release(address, resource, pool)) + const numConnections = this.activeResourceCount(address) + pool.length + if (numConnections >= this._maxSize && requireNew) { + const resource = pool.pop() + if (this._removeIdleObserver) { + this._removeIdleObserver(resource) + } + pool.removeInUse(resource) + await this._destroy(resource) + } + // Invoke callback that creates actual connection + resource = await this._create(acquisitionContext, address, (address, resource) => this._release(address, resource, pool)) + pool.pushInUse(resource) resourceAcquired(key, this._activeResourceCounts) if (this._log.isDebugEnabled()) { this._log.debug(`${resource} created for the pool ${key}`) @@ -246,12 +275,13 @@ class Pool { if (pool.isActive()) { // there exist idle connections for the given key - if (!this._validate(resource)) { + if (!await this._validateOnRelease(resource)) { if (this._log.isDebugEnabled()) { this._log.debug( `${resource} destroyed and can't be released to the pool ${key} because it is not functional` ) } + pool.removeInUse(resource) await this._destroy(resource) } else { if (this._installIdleObserver) { @@ -263,6 +293,7 @@ class Pool { const pool = this._pools[key] if (pool) { this._pools[key] = pool.filter(r => r !== resource) + pool.removeInUse(resource) } // let's not care about background clean-ups due to errors but just trigger the destroy // process for the resource, we especially catch any errors and ignore them to avoid @@ -283,6 +314,7 @@ class Pool { `${resource} destroyed and can't be released to the pool ${key} because pool has been purged` ) } + pool.removeInUse(resource) await this._destroy(resource) } resourceReleased(key, this._activeResourceCounts) @@ -314,7 +346,7 @@ class Pool { const pendingRequest = requests.shift() // pop a pending acquire request if (pendingRequest) { - this._acquire(address) + this._acquire(pendingRequest.context, address, pendingRequest.requireNew) .catch(error => { // failed to acquire/create a new connection to resolve the pending acquire request // propagate the error by failing the pending request @@ -378,13 +410,23 @@ function resourceReleased (key, activeResourceCounts) { } class PendingRequest { - constructor (key, resolve, reject, timeoutId, log) { + constructor (key, context, config, resolve, reject, timeoutId, log) { this._key = key + this._context = context this._resolve = resolve this._reject = reject this._timeoutId = timeoutId this._log = log this._completed = false + this._config = config || {} + } + + get context () { + return this._context + } + + get requireNew () { + return this._config.requireNew || false } isCompleted () { @@ -419,6 +461,7 @@ class SingleAddressPool { constructor () { this._active = true this._elements = [] + this._elementsInUse = new Set() } isActive () { @@ -427,6 +470,8 @@ class SingleAddressPool { close () { this._active = false + this._elements = [] + this._elementsInUse = new Set() } filter (predicate) { @@ -434,17 +479,33 @@ class SingleAddressPool { return this } + apply (resourceConsumer) { + this._elements.forEach(resourceConsumer) + this._elementsInUse.forEach(resourceConsumer) + } + get length () { return this._elements.length } pop () { - return this._elements.pop() + const element = this._elements.pop() + this._elementsInUse.add(element) + return element } push (element) { + this._elementsInUse.delete(element) return this._elements.push(element) } + + pushInUse (element) { + this._elementsInUse.add(element) + } + + removeInUse (element) { + this._elementsInUse.delete(element) + } } export default Pool diff --git a/packages/neo4j-driver-deno/lib/core/auth-token-manager.ts b/packages/neo4j-driver-deno/lib/core/auth-token-manager.ts new file mode 100644 index 000000000..c692be877 --- /dev/null +++ b/packages/neo4j-driver-deno/lib/core/auth-token-manager.ts @@ -0,0 +1,227 @@ +/** + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [http://neo4j.com] + * + * This file is part of Neo4j. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import auth from './auth.ts' +import { AuthToken } from './types.ts' +import { util } from './internal/index.ts' + +/** + * Interface for the piece of software responsible for keeping track of current active {@link AuthToken} across the driver. + * @interface + * @experimental Exposed as preview feature. + * @since 5.8 + */ +export default class AuthTokenManager { + /** + * Returns a valid token. + * + * **Warning**: This method must only ever return auth information belonging to the same identity. + * Switching identities using the `AuthTokenManager` is undefined behavior. + * + * @returns {Promise|AuthToken} The valid auth token or a promise for a valid auth token + */ + getToken (): Promise | AuthToken { + throw new Error('Not Implemented') + } + + /** + * Called to notify a token expiration. + * + * @param {AuthToken} token The expired token. + * @return {void} + */ + onTokenExpired (token: AuthToken): void { + throw new Error('Not implemented') + } +} + +/** + * Interface which defines an {@link AuthToken} with an expiration data time associated + * @interface + * @experimental Exposed as preview feature. + * @since 5.8 + */ +export class AuthTokenAndExpiration { + public readonly token: AuthToken + public readonly expiration?: Date + + private constructor () { + /** + * The {@link AuthToken} used for authenticate connections. + * + * @type {AuthToken} + * @see {auth} + */ + this.token = auth.none() as AuthToken + + /** + * The expected expiration date of the auth token. + * + * This information will be used for triggering the auth token refresh + * in managers created with {@link expirationBasedAuthTokenManager}. + * + * If this value is not defined, the {@link AuthToken} will be considered valid + * until a `Neo.ClientError.Security.TokenExpired` error happens. + * + * @type {Date|undefined} + */ + this.expiration = undefined + } +} + +/** + * Creates a {@link AuthTokenManager} for handle {@link AuthToken} which is expires. + * + * **Warning**: `tokenProvider` must only ever return auth information belonging to the same identity. + * Switching identities using the `AuthTokenManager` is undefined behavior. + * + * @param {object} param0 - The params + * @param {function(): Promise} param0.tokenProvider - Retrieves a new valid auth token. + * Must only ever return auth information belonging to the same identity. + * @returns {AuthTokenManager} The temporal auth data manager. + * @experimental Exposed as preview feature. + */ +export function expirationBasedAuthTokenManager ({ tokenProvider }: { tokenProvider: () => Promise }): AuthTokenManager { + if (typeof tokenProvider !== 'function') { + throw new TypeError(`tokenProvider should be function, but got: ${typeof tokenProvider}`) + } + return new ExpirationBasedAuthTokenManager(tokenProvider) +} + +/** + * Create a {@link AuthTokenManager} for handle static {@link AuthToken} + * + * @private + * @param {param} args - The args + * @param {AuthToken} args.authToken - The static auth token which will always used in the driver. + * @returns {AuthTokenManager} The temporal auth data manager. + */ +export function staticAuthTokenManager ({ authToken }: { authToken: AuthToken }): AuthTokenManager { + return new StaticAuthTokenManager(authToken) +} + +/** + * Checks if the manager is a StaticAuthTokenManager + * + * @private + * @experimental + * @param {AuthTokenManager} manager The auth token manager to be checked. + * @returns {boolean} Manager is StaticAuthTokenManager + */ +export function isStaticAuthTokenManger (manager: AuthTokenManager): manager is StaticAuthTokenManager { + return manager instanceof StaticAuthTokenManager +} + +interface TokenRefreshObserver { + onCompleted: (data: AuthTokenAndExpiration) => void + onError: (error: Error) => void +} + +class TokenRefreshObservable implements TokenRefreshObserver { + constructor (private readonly _subscribers: TokenRefreshObserver[] = []) { + + } + + subscribe (sub: TokenRefreshObserver): void { + this._subscribers.push(sub) + } + + onCompleted (data: AuthTokenAndExpiration): void { + this._subscribers.forEach(sub => sub.onCompleted(data)) + } + + onError (error: Error): void { + this._subscribers.forEach(sub => sub.onError(error)) + } +} + +class ExpirationBasedAuthTokenManager implements AuthTokenManager { + constructor ( + private readonly _tokenProvider: () => Promise, + private _currentAuthData?: AuthTokenAndExpiration, + private _refreshObservable?: TokenRefreshObservable) { + + } + + async getToken (): Promise { + if (this._currentAuthData === undefined || + ( + this._currentAuthData.expiration !== undefined && + this._currentAuthData.expiration < new Date() + )) { + await this._refreshAuthToken() + } + + return this._currentAuthData?.token as AuthToken + } + + onTokenExpired (token: AuthToken): void { + if (util.equals(token, this._currentAuthData?.token)) { + this._scheduleRefreshAuthToken() + } + } + + private _scheduleRefreshAuthToken (observer?: TokenRefreshObserver): void { + if (this._refreshObservable === undefined) { + this._currentAuthData = undefined + this._refreshObservable = new TokenRefreshObservable() + + Promise.resolve(this._tokenProvider()) + .then(data => { + this._currentAuthData = data + this._refreshObservable?.onCompleted(data) + }) + .catch(error => { + this._refreshObservable?.onError(error) + }) + .finally(() => { + this._refreshObservable = undefined + }) + } + + if (observer !== undefined) { + this._refreshObservable.subscribe(observer) + } + } + + private async _refreshAuthToken (): Promise { + return await new Promise((resolve, reject) => { + this._scheduleRefreshAuthToken({ + onCompleted: resolve, + onError: reject + }) + }) + } +} + +class StaticAuthTokenManager implements AuthTokenManager { + constructor ( + private readonly _authToken: AuthToken + ) { + + } + + getToken (): AuthToken { + return this._authToken + } + + onTokenExpired (_: AuthToken): void { + // nothing to do here + } +} diff --git a/packages/neo4j-driver-deno/lib/core/auth.ts b/packages/neo4j-driver-deno/lib/core/auth.ts index 502166a0e..f15db8dee 100644 --- a/packages/neo4j-driver-deno/lib/core/auth.ts +++ b/packages/neo4j-driver-deno/lib/core/auth.ts @@ -53,6 +53,11 @@ const auth = { credentials: base64EncodedToken } }, + none: () => { + return { + scheme: 'none' + } + }, custom: ( principal: string, credentials: string, diff --git a/packages/neo4j-driver-deno/lib/core/connection-provider.ts b/packages/neo4j-driver-deno/lib/core/connection-provider.ts index 0326d5931..5de1b320d 100644 --- a/packages/neo4j-driver-deno/lib/core/connection-provider.ts +++ b/packages/neo4j-driver-deno/lib/core/connection-provider.ts @@ -21,9 +21,10 @@ import Connection from './connection.ts' import { bookmarks } from './internal/index.ts' import { ServerInfo } from './result-summary.ts' +import { AuthToken } from './types.ts' /** - * Inteface define a common way to acquire a connection + * Interface define a common way to acquire a connection * * @private */ @@ -51,6 +52,7 @@ class ConnectionProvider { bookmarks: bookmarks.Bookmarks impersonatedUser?: string onDatabaseNameResolved?: (databaseName?: string) => void + auth?: AuthToken }): Promise { throw Error('Not implemented') } @@ -85,6 +87,16 @@ class ConnectionProvider { throw Error('Not implemented') } + /** + * This method checks whether the driver session re-auth functionality + * by checking protocol handshake result + * + * @returns {Promise} + */ + supportsSessionAuth (): Promise { + throw Error('Not implemented') + } + /** * This method verifies the connectivity of the database by trying to acquire a connection * for each server available in the cluster. @@ -99,6 +111,22 @@ class ConnectionProvider { throw Error('Not implemented') } + /** + * This method verifies the authorization credentials work by trying to acquire a connection + * to one of the servers with the given credentials. + * + * @param {object} param - object parameter + * @property {AuthToken} param.auth - the target auth for the to-be-acquired connection + * @property {string} param.database - the target database for the to-be-acquired connection + * @property {string} param.accessMode - the access mode for the to-be-acquired connection + * + * @returns {Promise} promise resolved with true if succeed, false if failed with + * authentication issue and rejected with error if non-authentication error happens. + */ + verifyAuthentication (param?: { auth?: AuthToken, database?: string, accessMode?: string }): Promise { + throw Error('Not implemented') + } + /** * Returns the protocol version negotiated via handshake. * diff --git a/packages/neo4j-driver-deno/lib/core/connection.ts b/packages/neo4j-driver-deno/lib/core/connection.ts index 30c896932..c97c873bb 100644 --- a/packages/neo4j-driver-deno/lib/core/connection.ts +++ b/packages/neo4j-driver-deno/lib/core/connection.ts @@ -37,6 +37,13 @@ class Connection { return {} } + /** + * @property {object} authToken The auth registered in the connection + */ + get authToken (): any { + return {} + } + /** * @property {ServerAddress} the server address this connection is opened against */ @@ -51,6 +58,13 @@ class Connection { return undefined } + /** + * @property {boolean} supportsReAuth Indicates the connection supports re-auth + */ + get supportsReAuth (): boolean { + return false + } + /** * @returns {boolean} whether this connection is in a working condition */ diff --git a/packages/neo4j-driver-deno/lib/core/driver.ts b/packages/neo4j-driver-deno/lib/core/driver.ts index aa61ee18e..4824acc91 100644 --- a/packages/neo4j-driver-deno/lib/core/driver.ts +++ b/packages/neo4j-driver-deno/lib/core/driver.ts @@ -38,7 +38,8 @@ import { LoggingConfig, TrustStrategy, SessionMode, - Query + Query, + AuthToken } from './types.ts' import { ServerAddress } from './internal/server-address.ts' import BookmarkManager, { bookmarkManager } from './bookmark-manager.ts' @@ -96,6 +97,7 @@ type CreateSession = (args: { impersonatedUser?: string bookmarkManager?: BookmarkManager notificationFilter?: NotificationFilter + auth?: AuthToken }) => Session type CreateQueryExecutor = (createSession: (config: { database?: string, bookmarkManager?: BookmarkManager }) => Session) => QueryExecutor @@ -121,6 +123,7 @@ class SessionConfig { fetchSize?: number bookmarkManager?: BookmarkManager notificationFilter?: NotificationFilter + auth?: AuthToken /** * @constructor @@ -192,6 +195,22 @@ class SessionConfig { */ this.impersonatedUser = undefined + /** + * The {@link AuthToken} which will be used for the duration of the session. + * + * By default, the session will use connections authenticated with {@link AuthToken} configured in the + * driver creation. This configuration allows switch user and/or authorization information for the + * session lifetime. + * + * **Warning**: This option is only enable when the driver is connected with Neo4j Database servers + * which supports Bolt 5.1 and onwards. + * + * @type {AuthToken|undefined} + * @experimental Exposed as preview feature. + * @see {@link driver} + */ + this.auth = undefined + /** * The record fetch size of each batch of this session. * @@ -573,6 +592,26 @@ class Driver { return connectionProvider.verifyConnectivityAndGetServerInfo({ database, accessMode: READ }) } + /** + * This method verifies the authorization credentials work by trying to acquire a connection + * to one of the servers with the given credentials. + * + * @param {object} param - object parameter + * @property {AuthToken} param.auth - the target auth for the to-be-acquired connection + * @property {string} param.database - the target database for the to-be-acquired connection + * + * @returns {Promise} promise resolved with true if succeed, false if failed with + * authentication issue and rejected with error if non-authentication error happens. + */ + async verifyAuthentication ({ database, auth }: { auth?: AuthToken, database?: string } = {}): Promise { + const connectionProvider = this._getOrCreateConnectionProvider() + return await connectionProvider.verifyAuthentication({ + database: database ?? 'system', + auth, + accessMode: READ + }) + } + /** * Get ServerInfo for the giver database. * @@ -624,6 +663,19 @@ class Driver { return connectionProvider.supportsUserImpersonation() } + /** + * Returns whether the driver session re-auth functionality capabilities based on the protocol + * version negotiated via handshake. + * + * Note that this function call _always_ causes a round-trip to the server. + * + * @returns {Promise} promise resolved with a boolean or rejected with error. + */ + supportsSessionAuth (): Promise { + const connectionProvider = this._getOrCreateConnectionProvider() + return connectionProvider.supportsSessionAuth() + } + /** * Returns the protocol version negotiated via handshake. * @@ -696,7 +748,8 @@ class Driver { impersonatedUser, fetchSize, bookmarkManager, - notificationFilter + notificationFilter, + auth }: SessionConfig = {}): Session { return this._newSession({ defaultAccessMode, @@ -707,7 +760,8 @@ class Driver { // eslint-disable-next-line @typescript-eslint/no-non-null-assertion fetchSize: validateFetchSizeValue(fetchSize, this._config.fetchSize!), bookmarkManager, - notificationFilter + notificationFilter, + auth }) } @@ -746,7 +800,8 @@ class Driver { impersonatedUser, fetchSize, bookmarkManager, - notificationFilter + notificationFilter, + auth }: { defaultAccessMode: SessionMode bookmarkOrBookmarks?: string | string[] @@ -756,6 +811,7 @@ class Driver { fetchSize: number bookmarkManager?: BookmarkManager notificationFilter?: NotificationFilter + auth?: AuthToken }): Session { const sessionMode = Session._validateSessionMode(defaultAccessMode) const connectionProvider = this._getOrCreateConnectionProvider() @@ -773,7 +829,8 @@ class Driver { impersonatedUser, fetchSize, bookmarkManager, - notificationFilter + notificationFilter, + auth }) } diff --git a/packages/neo4j-driver-deno/lib/core/index.ts b/packages/neo4j-driver-deno/lib/core/index.ts index 12352d6fd..d9a1ab2f6 100644 --- a/packages/neo4j-driver-deno/lib/core/index.ts +++ b/packages/neo4j-driver-deno/lib/core/index.ts @@ -87,6 +87,7 @@ import Session, { TransactionConfig } from './session.ts' import Driver, * as driver from './driver.ts' import auth from './auth.ts' import BookmarkManager, { BookmarkManagerConfig, bookmarkManager } from './bookmark-manager.ts' +import AuthTokenManager, { expirationBasedAuthTokenManager, staticAuthTokenManager, isStaticAuthTokenManger, AuthTokenAndExpiration } from './auth-token-manager.ts' import { SessionConfig, QueryConfig, RoutingControl, routing } from './driver.ts' import * as types from './types.ts' import * as json from './json.ts' @@ -163,6 +164,7 @@ const forExport = { json, auth, bookmarkManager, + expirationBasedAuthTokenManager, routing, resultTransformers, notificationCategory, @@ -230,6 +232,9 @@ export { json, auth, bookmarkManager, + expirationBasedAuthTokenManager, + staticAuthTokenManager, + isStaticAuthTokenManger, routing, resultTransformers, notificationCategory, @@ -247,6 +252,8 @@ export type { TransactionConfig, BookmarkManager, BookmarkManagerConfig, + AuthTokenManager, + AuthTokenAndExpiration, SessionConfig, QueryConfig, RoutingControl, diff --git a/packages/neo4j-driver-deno/lib/core/internal/connection-holder.ts b/packages/neo4j-driver-deno/lib/core/internal/connection-holder.ts index d67044100..7e49502ec 100644 --- a/packages/neo4j-driver-deno/lib/core/internal/connection-holder.ts +++ b/packages/neo4j-driver-deno/lib/core/internal/connection-holder.ts @@ -24,6 +24,7 @@ import Connection from '../connection.ts' import { ACCESS_MODE_WRITE } from './constants.ts' import { Bookmarks } from './bookmarks.ts' import ConnectionProvider from '../connection-provider.ts' +import { AuthToken } from '../types.ts' /** * @private @@ -85,6 +86,8 @@ class ConnectionHolder implements ConnectionHolderInterface { private readonly _impersonatedUser?: string private readonly _getConnectionAcquistionBookmarks: () => Promise private readonly _onDatabaseNameResolved?: (databaseName?: string) => void + private readonly _auth?: AuthToken + private _closed: boolean /** * @constructor @@ -96,6 +99,7 @@ class ConnectionHolder implements ConnectionHolderInterface { * @property {string?} params.impersonatedUser - the user which will be impersonated * @property {function(databaseName:string)} params.onDatabaseNameResolved - callback called when the database name is resolved * @property {function():Promise} params.getConnectionAcquistionBookmarks - called for getting Bookmarks for acquiring connections + * @property {AuthToken} params.auth - the target auth for the to-be-acquired connection */ constructor ({ mode = ACCESS_MODE_WRITE, @@ -104,7 +108,8 @@ class ConnectionHolder implements ConnectionHolderInterface { connectionProvider, impersonatedUser, onDatabaseNameResolved, - getConnectionAcquistionBookmarks + getConnectionAcquistionBookmarks, + auth }: { mode?: string database?: string @@ -113,8 +118,10 @@ class ConnectionHolder implements ConnectionHolderInterface { impersonatedUser?: string onDatabaseNameResolved?: (databaseName?: string) => void getConnectionAcquistionBookmarks?: () => Promise + auth?: AuthToken } = {}) { this._mode = mode + this._closed = false this._database = database != null ? assertString(database, 'database') : '' this._bookmarks = bookmarks ?? Bookmarks.empty() this._connectionProvider = connectionProvider @@ -122,6 +129,7 @@ class ConnectionHolder implements ConnectionHolderInterface { this._referenceCount = 0 this._connectionPromise = Promise.resolve(null) this._onDatabaseNameResolved = onDatabaseNameResolved + this._auth = auth this._getConnectionAcquistionBookmarks = getConnectionAcquistionBookmarks ?? (() => Promise.resolve(Bookmarks.empty())) } @@ -166,7 +174,8 @@ class ConnectionHolder implements ConnectionHolderInterface { database: this._database, bookmarks: await this._getBookmarks(), impersonatedUser: this._impersonatedUser, - onDatabaseNameResolved: this._onDatabaseNameResolved + onDatabaseNameResolved: this._onDatabaseNameResolved, + auth: this._auth }) } @@ -192,6 +201,7 @@ class ConnectionHolder implements ConnectionHolderInterface { } close (hasTx?: boolean): Promise { + this._closed = true if (this._referenceCount === 0) { return this._connectionPromise } diff --git a/packages/neo4j-driver-deno/lib/core/internal/util.ts b/packages/neo4j-driver-deno/lib/core/internal/util.ts index 3a181fdae..d0b691b8e 100644 --- a/packages/neo4j-driver-deno/lib/core/internal/util.ts +++ b/packages/neo4j-driver-deno/lib/core/internal/util.ts @@ -223,6 +223,44 @@ function isString (str: any): str is string { return Object.prototype.toString.call(str) === '[object String]' } +/** + * Verifies if object are the equals + * @param {unknown} a + * @param {unknown} b + * @returns {boolean} + */ +function equals (a: unknown, b: unknown): boolean { + if (a === b) { + return true + } + + if (a === null || b === null) { + return false + } + + if (typeof a === 'object' && typeof b === 'object') { + const keysA = Object.keys(a) + const keysB = Object.keys(b) + + if (keysA.length !== keysB.length) { + return false + } + + type AObjectKey = keyof typeof a + type BObjectKey = keyof typeof b + + for (const key of keysA) { + if (!equals(a[key as AObjectKey], b[key as BObjectKey])) { + return false + } + } + + return true + } + + return false +} + export { isEmptyObjectOrNull, isObject, @@ -233,6 +271,7 @@ export { assertNumberOrInteger, assertValidDate, validateQueryAndParameters, + equals, ENCRYPTION_ON, ENCRYPTION_OFF } diff --git a/packages/neo4j-driver-deno/lib/core/session.ts b/packages/neo4j-driver-deno/lib/core/session.ts index 26fea9162..051045979 100644 --- a/packages/neo4j-driver-deno/lib/core/session.ts +++ b/packages/neo4j-driver-deno/lib/core/session.ts @@ -30,7 +30,7 @@ import { TransactionExecutor } from './internal/transaction-executor.ts' import { Bookmarks } from './internal/bookmarks.ts' import { TxConfig } from './internal/tx-config.ts' import ConnectionProvider from './connection-provider.ts' -import { Query, SessionMode } from './types.ts' +import { AuthToken, Query, SessionMode } from './types.ts' import Connection from './connection.ts' import { NumberOrInteger } from './graph-types.ts' import TransactionPromise from './transaction-promise.ts' @@ -86,6 +86,7 @@ class Session { * @param {boolean} args.reactive - Whether this session should create reactive streams * @param {number} args.fetchSize - Defines how many records is pulled in each pulling batch * @param {string} args.impersonatedUser - The username which the user wants to impersonate for the duration of the session. + * @param {AuthToken} args.auth - the target auth for the to-be-acquired connection * @param {NotificationFilter} args.notificationFilter - The notification filter used for this session. */ constructor ({ @@ -98,7 +99,8 @@ class Session { fetchSize, impersonatedUser, bookmarkManager, - notificationFilter + notificationFilter, + auth }: { mode: SessionMode connectionProvider: ConnectionProvider @@ -110,6 +112,7 @@ class Session { impersonatedUser?: string bookmarkManager?: BookmarkManager notificationFilter?: NotificationFilter + auth?: AuthToken }) { this._mode = mode this._database = database @@ -119,6 +122,7 @@ class Session { this._getConnectionAcquistionBookmarks = this._getConnectionAcquistionBookmarks.bind(this) this._readConnectionHolder = new ConnectionHolder({ mode: ACCESS_MODE_READ, + auth, database, bookmarks, connectionProvider, @@ -128,6 +132,7 @@ class Session { }) this._writeConnectionHolder = new ConnectionHolder({ mode: ACCESS_MODE_WRITE, + auth, database, bookmarks, connectionProvider, diff --git a/packages/neo4j-driver-deno/lib/core/types.ts b/packages/neo4j-driver-deno/lib/core/types.ts index 63140cc16..463dd4713 100644 --- a/packages/neo4j-driver-deno/lib/core/types.ts +++ b/packages/neo4j-driver-deno/lib/core/types.ts @@ -43,11 +43,12 @@ export type TrustStrategy = export interface Parameters { [key: string]: any } export interface AuthToken { scheme: string - principal: string + principal?: string credentials: string realm?: string parameters?: Parameters } + export interface Config { encrypted?: boolean | EncryptionLevel trust?: TrustStrategy diff --git a/packages/neo4j-driver-deno/lib/mod.ts b/packages/neo4j-driver-deno/lib/mod.ts index 633060a56..2715877ce 100644 --- a/packages/neo4j-driver-deno/lib/mod.ts +++ b/packages/neo4j-driver-deno/lib/mod.ts @@ -93,7 +93,11 @@ import { NotificationFilterDisabledCategory, NotificationFilterMinimumSeverityLevel, notificationFilterDisabledCategory, - notificationFilterMinimumSeverityLevel + notificationFilterMinimumSeverityLevel, + AuthTokenManager, + expirationBasedAuthTokenManager, + AuthTokenAndExpiration, + staticAuthTokenManager } from './core/index.ts' // @deno-types=./bolt-connection/types/index.d.ts import { @@ -117,6 +121,32 @@ const { urlUtil } = internal +function isAuthTokenManager (value: unknown): value is AuthTokenManager { + if (typeof value === 'object' && + value != null && + 'getToken' in value && + 'onTokenExpired' in value) { + const manager = value as AuthTokenManager + + return typeof manager.getToken === 'function' && + typeof manager.onTokenExpired === 'function' + } + + return false +} + +function createAuthManager (authTokenOrProvider: AuthToken | AuthTokenManager): AuthTokenManager { + if (isAuthTokenManager(authTokenOrProvider)) { + return authTokenOrProvider + } + + let authToken: AuthToken = authTokenOrProvider + // Sanitize authority token. Nicer error from server when a scheme is set. + authToken = authToken ?? {} + authToken.scheme = authToken.scheme ?? 'none' + return staticAuthTokenManager({ authToken }) +} + /** * Construct a new Neo4j Driver. This is your main entry point for this * library. @@ -247,13 +277,13 @@ const { * } * * @param {string} url The URL for the Neo4j database, for instance "neo4j://localhost" and/or "bolt://localhost" - * @param {Map} authToken Authentication credentials. See {@link auth} for helpers. + * @param {Map| function()} authToken Authentication credentials. See {@link auth} for helpers. * @param {Object} config Configuration object. See the configuration section above for details. * @returns {Driver} */ function driver ( url: string, - authToken: AuthToken, + authToken: AuthToken | AuthTokenManager, config: Config = {} ): Driver { assertString(url, 'Bolt URL') @@ -303,9 +333,7 @@ function driver ( config.trust = trust } - // Sanitize authority token. Nicer error from server when a scheme is set. - authToken = authToken ?? {} - authToken.scheme = authToken.scheme ?? 'none' + const authTokenManager = createAuthManager(authToken) // Use default user agent or user agent specified by user. config.userAgent = config.userAgent ?? USER_AGENT @@ -332,7 +360,7 @@ function driver ( config, log, hostNameResolver, - authToken, + authTokenManager, address, userAgent: config.userAgent, routingContext: parsedUrl.query @@ -349,7 +377,7 @@ function driver ( id, config, log, - authToken, + authTokenManager, address, userAgent: config.userAgent }) @@ -515,7 +543,8 @@ const forExport = { notificationCategory, notificationSeverityLevel, notificationFilterDisabledCategory, - notificationFilterMinimumSeverityLevel + notificationFilterMinimumSeverityLevel, + expirationBasedAuthTokenManager } export { @@ -581,11 +610,14 @@ export { notificationCategory, notificationSeverityLevel, notificationFilterDisabledCategory, - notificationFilterMinimumSeverityLevel + notificationFilterMinimumSeverityLevel, + expirationBasedAuthTokenManager } export type { QueryResult, AuthToken, + AuthTokenManager, + AuthTokenAndExpiration, Config, EncryptionLevel, TrustStrategy, diff --git a/packages/neo4j-driver-lite/src/index.ts b/packages/neo4j-driver-lite/src/index.ts index e64639855..41fb4d73e 100644 --- a/packages/neo4j-driver-lite/src/index.ts +++ b/packages/neo4j-driver-lite/src/index.ts @@ -93,7 +93,11 @@ import { NotificationFilterDisabledCategory, NotificationFilterMinimumSeverityLevel, notificationFilterDisabledCategory, - notificationFilterMinimumSeverityLevel + notificationFilterMinimumSeverityLevel, + AuthTokenManager, + expirationBasedAuthTokenManager, + AuthTokenAndExpiration, + staticAuthTokenManager } from 'neo4j-driver-core' import { DirectConnectionProvider, @@ -116,6 +120,32 @@ const { urlUtil } = internal +function isAuthTokenManager (value: unknown): value is AuthTokenManager { + if (typeof value === 'object' && + value != null && + 'getToken' in value && + 'onTokenExpired' in value) { + const manager = value as AuthTokenManager + + return typeof manager.getToken === 'function' && + typeof manager.onTokenExpired === 'function' + } + + return false +} + +function createAuthManager (authTokenOrProvider: AuthToken | AuthTokenManager): AuthTokenManager { + if (isAuthTokenManager(authTokenOrProvider)) { + return authTokenOrProvider + } + + let authToken: AuthToken = authTokenOrProvider + // Sanitize authority token. Nicer error from server when a scheme is set. + authToken = authToken ?? {} + authToken.scheme = authToken.scheme ?? 'none' + return staticAuthTokenManager({ authToken }) +} + /** * Construct a new Neo4j Driver. This is your main entry point for this * library. @@ -246,13 +276,13 @@ const { * } * * @param {string} url The URL for the Neo4j database, for instance "neo4j://localhost" and/or "bolt://localhost" - * @param {Map} authToken Authentication credentials. See {@link auth} for helpers. + * @param {Map| function()} authToken Authentication credentials. See {@link auth} for helpers. * @param {Object} config Configuration object. See the configuration section above for details. * @returns {Driver} */ function driver ( url: string, - authToken: AuthToken, + authToken: AuthToken | AuthTokenManager, config: Config = {} ): Driver { assertString(url, 'Bolt URL') @@ -302,9 +332,7 @@ function driver ( config.trust = trust } - // Sanitize authority token. Nicer error from server when a scheme is set. - authToken = authToken ?? {} - authToken.scheme = authToken.scheme ?? 'none' + const authTokenManager = createAuthManager(authToken) // Use default user agent or user agent specified by user. config.userAgent = config.userAgent ?? USER_AGENT @@ -331,7 +359,7 @@ function driver ( config, log, hostNameResolver, - authToken, + authTokenManager, address, userAgent: config.userAgent, routingContext: parsedUrl.query @@ -348,7 +376,7 @@ function driver ( id, config, log, - authToken, + authTokenManager, address, userAgent: config.userAgent }) @@ -514,7 +542,8 @@ const forExport = { notificationCategory, notificationSeverityLevel, notificationFilterDisabledCategory, - notificationFilterMinimumSeverityLevel + notificationFilterMinimumSeverityLevel, + expirationBasedAuthTokenManager } export { @@ -580,11 +609,14 @@ export { notificationCategory, notificationSeverityLevel, notificationFilterDisabledCategory, - notificationFilterMinimumSeverityLevel + notificationFilterMinimumSeverityLevel, + expirationBasedAuthTokenManager } export type { QueryResult, AuthToken, + AuthTokenManager, + AuthTokenAndExpiration, Config, EncryptionLevel, TrustStrategy, diff --git a/packages/neo4j-driver-lite/test/unit/index.test.ts b/packages/neo4j-driver-lite/test/unit/index.test.ts index a5ebbc4aa..b9b6fd4a1 100644 --- a/packages/neo4j-driver-lite/test/unit/index.test.ts +++ b/packages/neo4j-driver-lite/test/unit/index.test.ts @@ -255,7 +255,9 @@ describe('index', () => { supportsTransactionConfig: async () => true, supportsUserImpersonation: async () => true, verifyConnectivityAndGetServerInfo: async () => new ServerInfo({}), - getNegotiatedProtocolVersion: async () => 5.0 + getNegotiatedProtocolVersion: async () => 5.0, + verifyAuthentication: async () => true, + supportsSessionAuth: async () => true } }) expect(session).toBeDefined() diff --git a/packages/neo4j-driver/src/driver.js b/packages/neo4j-driver/src/driver.js index f91638af4..c71621b9d 100644 --- a/packages/neo4j-driver/src/driver.js +++ b/packages/neo4j-driver/src/driver.js @@ -59,7 +59,8 @@ class Driver extends CoreDriver { fetchSize, impersonatedUser, bookmarkManager, - notificationFilter + notificationFilter, + auth } = {}) { return new RxSession({ session: this._newSession({ @@ -67,6 +68,7 @@ class Driver extends CoreDriver { bookmarkOrBookmarks: bookmarks, database, impersonatedUser, + auth, reactive: false, fetchSize: validateFetchSizeValue(fetchSize, this._config.fetchSize), bookmarkManager, diff --git a/packages/neo4j-driver/src/index.js b/packages/neo4j-driver/src/index.js index 954bff531..012faaf97 100644 --- a/packages/neo4j-driver/src/index.js +++ b/packages/neo4j-driver/src/index.js @@ -73,7 +73,9 @@ import { notificationCategory, notificationSeverityLevel, notificationFilterDisabledCategory, - notificationFilterMinimumSeverityLevel + notificationFilterMinimumSeverityLevel, + expirationBasedAuthTokenManager, + staticAuthTokenManager } from 'neo4j-driver-core' import { DirectConnectionProvider, @@ -91,6 +93,27 @@ const { urlUtil } = internal +function isAuthTokenManager (value) { + return typeof value === 'object' && + value != null && + 'getToken' in value && + 'onTokenExpired' in value && + typeof value.getToken === 'function' && + typeof value.onTokenExpired === 'function' +} + +function createAuthManager (authTokenOrManager) { + if (isAuthTokenManager(authTokenOrManager)) { + return authTokenOrManager + } + + let authToken = authTokenOrManager + // Sanitize authority token. Nicer error from server when a scheme is set. + authToken = authToken || {} + authToken.scheme = authToken.scheme || 'none' + return staticAuthTokenManager({ authToken }) +} + /** * Construct a new Neo4j Driver. This is your main entry point for this * library. @@ -273,9 +296,7 @@ function driver (url, authToken, config = {}) { config.trust = trust } - // Sanitize authority token. Nicer error from server when a scheme is set. - authToken = authToken || {} - authToken.scheme = authToken.scheme || 'none' + const authTokenManager = createAuthManager(authToken) // Use default user agent or user agent specified by user. config.userAgent = config.userAgent || USER_AGENT @@ -297,7 +318,7 @@ function driver (url, authToken, config = {}) { config, log, hostNameResolver, - authToken, + authTokenManager, address, userAgent: config.userAgent, routingContext: parsedUrl.query @@ -314,7 +335,7 @@ function driver (url, authToken, config = {}) { id, config, log, - authToken, + authTokenManager, address, userAgent: config.userAgent }) @@ -496,7 +517,8 @@ const forExport = { notificationCategory, notificationSeverityLevel, notificationFilterDisabledCategory, - notificationFilterMinimumSeverityLevel + notificationFilterMinimumSeverityLevel, + expirationBasedAuthTokenManager } export { @@ -563,6 +585,7 @@ export { notificationCategory, notificationSeverityLevel, notificationFilterDisabledCategory, - notificationFilterMinimumSeverityLevel + notificationFilterMinimumSeverityLevel, + expirationBasedAuthTokenManager } export default forExport diff --git a/packages/neo4j-driver/test/driver.test.js b/packages/neo4j-driver/test/driver.test.js index 36b16c712..fe19b49d8 100644 --- a/packages/neo4j-driver/test/driver.test.js +++ b/packages/neo4j-driver/test/driver.test.js @@ -145,6 +145,20 @@ describe('#unit driver', () => { }) }) + it('should create session using auth', () => { + driver = neo4j.driver( + `neo4j+ssc://${sharedNeo4j.hostname}`, + sharedNeo4j.authToken + ) + + const auth = { scheme: 'none' } + + const session = driver.rxSession({ auth }) + + expect(session._session._readConnectionHolder._auth).toEqual(auth) + expect(session._session._writeConnectionHolder._auth).toEqual(auth) + }) + ;[ [manager, manager], [undefined, undefined] diff --git a/packages/neo4j-driver/test/internal/connection-provider-pooled.test.js b/packages/neo4j-driver/test/internal/connection-provider-pooled.test.js index 159a5a455..546baa031 100644 --- a/packages/neo4j-driver/test/internal/connection-provider-pooled.test.js +++ b/packages/neo4j-driver/test/internal/connection-provider-pooled.test.js @@ -21,20 +21,20 @@ import FakeConnection from './fake-connection' import lolex from 'lolex' describe('#unit PooledConnectionProvider', () => { - it('should treat closed connections as invalid', () => { + it('should treat closed connections as invalid', async () => { const provider = new PooledConnectionProvider({ id: 0, config: {} }) - const connectionValid = provider._validateConnection( + const connectionValid = await provider._validateConnection( new FakeConnection().closed() ) expect(connectionValid).toBeFalsy() }) - it('should treat not old open connections as valid', () => { + it('should treat not old open connections as valid', async () => { const provider = new PooledConnectionProvider({ id: 0, config: { @@ -46,7 +46,7 @@ describe('#unit PooledConnectionProvider', () => { const clock = lolex.install() try { clock.setSystemTime(20) - const connectionValid = provider._validateConnection(connection) + const connectionValid = await provider._validateConnection(connection) expect(connectionValid).toBeTruthy() } finally { @@ -54,7 +54,7 @@ describe('#unit PooledConnectionProvider', () => { } }) - it('should treat old open connections as invalid', () => { + it('should treat old open connections as invalid', async () => { const provider = new PooledConnectionProvider({ id: 0, config: { @@ -66,7 +66,7 @@ describe('#unit PooledConnectionProvider', () => { const clock = lolex.install() try { clock.setSystemTime(20) - const connectionValid = provider._validateConnection(connection) + const connectionValid = await provider._validateConnection(connection) expect(connectionValid).toBeFalsy() } finally { diff --git a/packages/neo4j-driver/test/internal/fake-connection.js b/packages/neo4j-driver/test/internal/fake-connection.js index fff4d6d4c..75b9497a8 100644 --- a/packages/neo4j-driver/test/internal/fake-connection.js +++ b/packages/neo4j-driver/test/internal/fake-connection.js @@ -79,6 +79,14 @@ export default class FakeConnection extends Connection { this._server.version = value } + get authToken () { + return this._authToken + } + + set authToken (authToken) { + this._authToken = authToken + } + protocol () { // return fake protocol object that simply records seen queries and parameters return { @@ -175,4 +183,8 @@ export default class FakeConnection extends Connection { this._open = false return this } + + async close () { + this._open = false + } } diff --git a/packages/neo4j-driver/test/types/driver.test.ts b/packages/neo4j-driver/test/types/driver.test.ts index 577d13465..a39445e5d 100644 --- a/packages/neo4j-driver/test/types/driver.test.ts +++ b/packages/neo4j-driver/test/types/driver.test.ts @@ -38,7 +38,7 @@ const dummy: any = null const authToken: AuthToken = dummy const scheme: string = authToken.scheme -const principal: string = authToken.principal +const principal: string | undefined = authToken.principal const credentials: string = authToken.credentials const realm1: undefined = authToken.realm as undefined const realm2: string = authToken.realm as string diff --git a/packages/neo4j-driver/types/index.d.ts b/packages/neo4j-driver/types/index.d.ts index e0dfa1865..f6e3f7250 100644 --- a/packages/neo4j-driver/types/index.d.ts +++ b/packages/neo4j-driver/types/index.d.ts @@ -83,7 +83,10 @@ import { NotificationFilterDisabledCategory, NotificationFilterMinimumSeverityLevel, notificationFilterDisabledCategory, - notificationFilterMinimumSeverityLevel + notificationFilterMinimumSeverityLevel, + AuthTokenManager, + AuthTokenAndExpiration, + expirationBasedAuthTokenManager } from 'neo4j-driver-core' import { AuthToken, @@ -265,6 +268,7 @@ declare const forExport: { notificationSeverityLevel: typeof notificationSeverityLevel notificationFilterDisabledCategory: typeof notificationFilterDisabledCategory notificationFilterMinimumSeverityLevel: typeof notificationFilterMinimumSeverityLevel + expirationBasedAuthTokenManager: typeof expirationBasedAuthTokenManager } export { @@ -338,7 +342,8 @@ export { notificationCategory, notificationSeverityLevel, notificationFilterDisabledCategory, - notificationFilterMinimumSeverityLevel + notificationFilterMinimumSeverityLevel, + expirationBasedAuthTokenManager } export type { @@ -352,7 +357,9 @@ export type { NotificationSeverityLevel, NotificationFilter, NotificationFilterDisabledCategory, - NotificationFilterMinimumSeverityLevel + NotificationFilterMinimumSeverityLevel, + AuthTokenManager, + AuthTokenAndExpiration } export default forExport diff --git a/packages/testkit-backend/deno/controller.ts b/packages/testkit-backend/deno/controller.ts index eafd454c5..3fb58bc76 100644 --- a/packages/testkit-backend/deno/controller.ts +++ b/packages/testkit-backend/deno/controller.ts @@ -1,4 +1,5 @@ import Context from "../src/context.js"; +import { FakeTime } from "./deps.ts"; import { RequestHandlerMap, TestkitRequest, @@ -74,7 +75,7 @@ export function createHandler( const handleRequest = requestHandlers[name]; - handleRequest(neo4j, context, data, wire); + handleRequest({ neo4j, mock: { FakeTime } }, context, data, wire); } }; } diff --git a/packages/testkit-backend/deno/deps.ts b/packages/testkit-backend/deno/deps.ts index 898990acd..bee7f42fa 100644 --- a/packages/testkit-backend/deno/deps.ts +++ b/packages/testkit-backend/deno/deps.ts @@ -1,4 +1,5 @@ export { iterateReader } from "https://deno.land/std@0.119.0/streams/conversion.ts"; +export { FakeTime } from "https://deno.land/std@0.165.0/testing/time.ts"; export { default as Context } from "../src/context.js"; export { getShouldRunTest } from "../src/skipped-tests/index.js"; export { default as neo4j } from "../../neo4j-driver-deno/lib/mod.ts"; diff --git a/packages/testkit-backend/deno/domain.ts b/packages/testkit-backend/deno/domain.ts index 74d5e817b..41b02fc86 100644 --- a/packages/testkit-backend/deno/domain.ts +++ b/packages/testkit-backend/deno/domain.ts @@ -1,5 +1,6 @@ // deno-lint-ignore-file no-explicit-any import Context from "../src/context.js"; +import { FakeTime } from "./deps.ts"; export interface TestkitRequest { name: string; @@ -11,8 +12,12 @@ export interface TestkitResponse { data?: any; } +export interface Mock { + FakeTime: typeof FakeTime; +} + export interface RequestHandler { - (neo4j: any, c: Context, data: any, wire: any): void; + (service: { neo4j: any; mock: Mock }, c: Context, data: any, wire: any): void; } export interface RequestHandlerMap { diff --git a/packages/testkit-backend/package-lock.json b/packages/testkit-backend/package-lock.json index 3bff0c0dd..202b0a406 100644 --- a/packages/testkit-backend/package-lock.json +++ b/packages/testkit-backend/package-lock.json @@ -18,7 +18,8 @@ "esm": "^3.2.25", "rollup": "^2.77.4-1", "rollup-plugin-inject-process-env": "^1.3.1", - "rollup-plugin-polyfill-node": "^0.11.0" + "rollup-plugin-polyfill-node": "^0.11.0", + "sinon": "^15.0.1" } }, "node_modules/@jridgewell/sourcemap-codec": { @@ -91,6 +92,59 @@ "rollup": "^1.20.0||^2.0.0" } }, + "node_modules/@sinonjs/commons": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/@sinonjs/commons/-/commons-3.0.0.tgz", + "integrity": "sha512-jXBtWAF4vmdNmZgD5FoKsVLv3rPgDnLgPbU84LIJ3otV44vJlDRokVng5v8NFJdCf/da9legHcKaRuZs4L7faA==", + "dev": true, + "dependencies": { + "type-detect": "4.0.8" + } + }, + "node_modules/@sinonjs/fake-timers": { + "version": "10.0.2", + "resolved": "https://registry.npmjs.org/@sinonjs/fake-timers/-/fake-timers-10.0.2.tgz", + "integrity": "sha512-SwUDyjWnah1AaNl7kxsa7cfLhlTYoiyhDAIgyh+El30YvXs/o7OLXpYH88Zdhyx9JExKrmHDJ+10bwIcY80Jmw==", + "dev": true, + "dependencies": { + "@sinonjs/commons": "^2.0.0" + } + }, + "node_modules/@sinonjs/fake-timers/node_modules/@sinonjs/commons": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/@sinonjs/commons/-/commons-2.0.0.tgz", + "integrity": "sha512-uLa0j859mMrg2slwQYdO/AkrOfmH+X6LTVmNTS9CqexuE2IvVORIkSpJLqePAbEnKJ77aMmCwr1NUZ57120Xcg==", + "dev": true, + "dependencies": { + "type-detect": "4.0.8" + } + }, + "node_modules/@sinonjs/samsam": { + "version": "8.0.0", + "resolved": "https://registry.npmjs.org/@sinonjs/samsam/-/samsam-8.0.0.tgz", + "integrity": "sha512-Bp8KUVlLp8ibJZrnvq2foVhP0IVX2CIprMJPK0vqGqgrDa0OHVKeZyBykqskkrdxV6yKBPmGasO8LVjAKR3Gew==", + "dev": true, + "dependencies": { + "@sinonjs/commons": "^2.0.0", + "lodash.get": "^4.4.2", + "type-detect": "^4.0.8" + } + }, + "node_modules/@sinonjs/samsam/node_modules/@sinonjs/commons": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/@sinonjs/commons/-/commons-2.0.0.tgz", + "integrity": "sha512-uLa0j859mMrg2slwQYdO/AkrOfmH+X6LTVmNTS9CqexuE2IvVORIkSpJLqePAbEnKJ77aMmCwr1NUZ57120Xcg==", + "dev": true, + "dependencies": { + "type-detect": "4.0.8" + } + }, + "node_modules/@sinonjs/text-encoding": { + "version": "0.7.2", + "resolved": "https://registry.npmjs.org/@sinonjs/text-encoding/-/text-encoding-0.7.2.tgz", + "integrity": "sha512-sXXKG+uL9IrKqViTtao2Ws6dy0znu9sOaP1di/jKGW1M6VssO8vlpXCQcpZ+jisQ1tTFAC5Jo/EOzFbggBagFQ==", + "dev": true + }, "node_modules/@types/estree": { "version": "0.0.39", "resolved": "https://registry.npmjs.org/@types/estree/-/estree-0.0.39.tgz", @@ -169,6 +223,15 @@ "node": ">=0.10.0" } }, + "node_modules/diff": { + "version": "5.1.0", + "resolved": "https://registry.npmjs.org/diff/-/diff-5.1.0.tgz", + "integrity": "sha512-D+mk+qE8VC/PAUrlAU34N+VfXev0ghe5ywmpqrawphmVZc1bEfn56uo9qpyGp1p4xpzOHkSW4ztBd6L7Xx4ACw==", + "dev": true, + "engines": { + "node": ">=0.3.1" + } + }, "node_modules/esm": { "version": "3.2.25", "resolved": "https://registry.npmjs.org/esm/-/esm-3.2.25.tgz", @@ -242,6 +305,15 @@ "node": ">= 0.4.0" } }, + "node_modules/has-flag": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/has-flag/-/has-flag-4.0.0.tgz", + "integrity": "sha512-EykJT/Q1KjTWctppgIAgfSO0tKVuZUjhgMr17kqTumMl6Afv3EISleU7qZUzoXDFTAHTDC4NOoG/ZxU3EvlMPQ==", + "dev": true, + "engines": { + "node": ">=8" + } + }, "node_modules/inflight": { "version": "1.0.6", "resolved": "https://registry.npmjs.org/inflight/-/inflight-1.0.6.tgz", @@ -300,6 +372,24 @@ "@types/estree": "*" } }, + "node_modules/isarray": { + "version": "0.0.1", + "resolved": "https://registry.npmjs.org/isarray/-/isarray-0.0.1.tgz", + "integrity": "sha512-D2S+3GLxWH+uhrNEcoh/fnmYeP8E8/zHl644d/jdA0g2uyXvy3sb0qxotE+ne0LtccHknQzWwZEzhak7oJ0COQ==", + "dev": true + }, + "node_modules/just-extend": { + "version": "4.2.1", + "resolved": "https://registry.npmjs.org/just-extend/-/just-extend-4.2.1.tgz", + "integrity": "sha512-g3UB796vUFIY90VIv/WX3L2c8CS2MdWUww3CNrYmqza1Fg0DURc2K/O4YrnklBdQarSJ/y8JnJYDGc+1iumQjg==", + "dev": true + }, + "node_modules/lodash.get": { + "version": "4.4.2", + "resolved": "https://registry.npmjs.org/lodash.get/-/lodash.get-4.4.2.tgz", + "integrity": "sha512-z+Uw/vLuy6gQe8cfaFWD7p0wVv8fJl3mbzXh33RS+0oW2wvUqiRXiQ69gLWSLpgB5/6sU+r6BlQR0MBILadqTQ==", + "dev": true + }, "node_modules/magic-string": { "version": "0.25.7", "resolved": "https://registry.npmjs.org/magic-string/-/magic-string-0.25.7.tgz", @@ -337,6 +427,28 @@ "resolved": "https://registry.npmjs.org/minimist/-/minimist-0.0.10.tgz", "integrity": "sha1-3j+YVD2/lggr5IrRoMfNqDYwHc8=" }, + "node_modules/nise": { + "version": "5.1.4", + "resolved": "https://registry.npmjs.org/nise/-/nise-5.1.4.tgz", + "integrity": "sha512-8+Ib8rRJ4L0o3kfmyVCL7gzrohyDe0cMFTBa2d364yIrEGMEoetznKJx899YxjybU6bL9SQkYPSBBs1gyYs8Xg==", + "dev": true, + "dependencies": { + "@sinonjs/commons": "^2.0.0", + "@sinonjs/fake-timers": "^10.0.2", + "@sinonjs/text-encoding": "^0.7.1", + "just-extend": "^4.0.2", + "path-to-regexp": "^1.7.0" + } + }, + "node_modules/nise/node_modules/@sinonjs/commons": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/@sinonjs/commons/-/commons-2.0.0.tgz", + "integrity": "sha512-uLa0j859mMrg2slwQYdO/AkrOfmH+X6LTVmNTS9CqexuE2IvVORIkSpJLqePAbEnKJ77aMmCwr1NUZ57120Xcg==", + "dev": true, + "dependencies": { + "type-detect": "4.0.8" + } + }, "node_modules/node-static": { "version": "0.7.11", "resolved": "https://registry.npmjs.org/node-static/-/node-static-0.7.11.tgz", @@ -386,6 +498,15 @@ "integrity": "sha512-LDJzPVEEEPR+y48z93A0Ed0yXb8pAByGWo/k5YYdYgpY2/2EsOsksJrq7lOHxryrVOn1ejG6oAp8ahvOIQD8sw==", "dev": true }, + "node_modules/path-to-regexp": { + "version": "1.8.0", + "resolved": "https://registry.npmjs.org/path-to-regexp/-/path-to-regexp-1.8.0.tgz", + "integrity": "sha512-n43JRhlUKUAlibEJhPeir1ncUID16QnEjNpwzNdO3Lm4ywrBpBZ5oLD0I6br9evr1Y9JTqwRtAh7JLoOzAQdVA==", + "dev": true, + "dependencies": { + "isarray": "0.0.1" + } + }, "node_modules/picomatch": { "version": "2.3.1", "resolved": "https://registry.npmjs.org/picomatch/-/picomatch-2.3.1.tgz", @@ -515,12 +636,51 @@ "node": ">=12" } }, + "node_modules/sinon": { + "version": "15.0.3", + "resolved": "https://registry.npmjs.org/sinon/-/sinon-15.0.3.tgz", + "integrity": "sha512-si3geiRkeovP7Iel2O+qGL4NrO9vbMf3KsrJEi0ghP1l5aBkB5UxARea5j0FUsSqH3HLBh0dQPAyQ8fObRUqHw==", + "dev": true, + "dependencies": { + "@sinonjs/commons": "^3.0.0", + "@sinonjs/fake-timers": "^10.0.2", + "@sinonjs/samsam": "^8.0.0", + "diff": "^5.1.0", + "nise": "^5.1.4", + "supports-color": "^7.2.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/sinon" + } + }, "node_modules/sourcemap-codec": { "version": "1.4.8", "resolved": "https://registry.npmjs.org/sourcemap-codec/-/sourcemap-codec-1.4.8.tgz", "integrity": "sha512-9NykojV5Uih4lgo5So5dtw+f0JgJX30KCNI8gwhz2J9A15wD0Ml6tjHKwf6fTSa6fAdVBdZeNOs9eJ71qCk8vA==", "dev": true }, + "node_modules/supports-color": { + "version": "7.2.0", + "resolved": "https://registry.npmjs.org/supports-color/-/supports-color-7.2.0.tgz", + "integrity": "sha512-qpCAvRl9stuOHveKsn7HncJRvv501qIacKzQlO/+Lwxc9+0q2wLyv4Dfvt80/DPn2pqOBsJdDiogXGR9+OvwRw==", + "dev": true, + "dependencies": { + "has-flag": "^4.0.0" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/type-detect": { + "version": "4.0.8", + "resolved": "https://registry.npmjs.org/type-detect/-/type-detect-4.0.8.tgz", + "integrity": "sha512-0fr/mIH1dlO+x7TlcMy+bIDqKPsw/70tVyeHW787goQjhmqaZe10uwLujubK9q9Lg6Fiho1KUKDYz0Z7k7g5/g==", + "dev": true, + "engines": { + "node": ">=4" + } + }, "node_modules/wordwrap": { "version": "0.0.3", "resolved": "https://registry.npmjs.org/wordwrap/-/wordwrap-0.0.3.tgz", @@ -611,6 +771,63 @@ "picomatch": "^2.2.2" } }, + "@sinonjs/commons": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/@sinonjs/commons/-/commons-3.0.0.tgz", + "integrity": "sha512-jXBtWAF4vmdNmZgD5FoKsVLv3rPgDnLgPbU84LIJ3otV44vJlDRokVng5v8NFJdCf/da9legHcKaRuZs4L7faA==", + "dev": true, + "requires": { + "type-detect": "4.0.8" + } + }, + "@sinonjs/fake-timers": { + "version": "10.0.2", + "resolved": "https://registry.npmjs.org/@sinonjs/fake-timers/-/fake-timers-10.0.2.tgz", + "integrity": "sha512-SwUDyjWnah1AaNl7kxsa7cfLhlTYoiyhDAIgyh+El30YvXs/o7OLXpYH88Zdhyx9JExKrmHDJ+10bwIcY80Jmw==", + "dev": true, + "requires": { + "@sinonjs/commons": "^2.0.0" + }, + "dependencies": { + "@sinonjs/commons": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/@sinonjs/commons/-/commons-2.0.0.tgz", + "integrity": "sha512-uLa0j859mMrg2slwQYdO/AkrOfmH+X6LTVmNTS9CqexuE2IvVORIkSpJLqePAbEnKJ77aMmCwr1NUZ57120Xcg==", + "dev": true, + "requires": { + "type-detect": "4.0.8" + } + } + } + }, + "@sinonjs/samsam": { + "version": "8.0.0", + "resolved": "https://registry.npmjs.org/@sinonjs/samsam/-/samsam-8.0.0.tgz", + "integrity": "sha512-Bp8KUVlLp8ibJZrnvq2foVhP0IVX2CIprMJPK0vqGqgrDa0OHVKeZyBykqskkrdxV6yKBPmGasO8LVjAKR3Gew==", + "dev": true, + "requires": { + "@sinonjs/commons": "^2.0.0", + "lodash.get": "^4.4.2", + "type-detect": "^4.0.8" + }, + "dependencies": { + "@sinonjs/commons": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/@sinonjs/commons/-/commons-2.0.0.tgz", + "integrity": "sha512-uLa0j859mMrg2slwQYdO/AkrOfmH+X6LTVmNTS9CqexuE2IvVORIkSpJLqePAbEnKJ77aMmCwr1NUZ57120Xcg==", + "dev": true, + "requires": { + "type-detect": "4.0.8" + } + } + } + }, + "@sinonjs/text-encoding": { + "version": "0.7.2", + "resolved": "https://registry.npmjs.org/@sinonjs/text-encoding/-/text-encoding-0.7.2.tgz", + "integrity": "sha512-sXXKG+uL9IrKqViTtao2Ws6dy0znu9sOaP1di/jKGW1M6VssO8vlpXCQcpZ+jisQ1tTFAC5Jo/EOzFbggBagFQ==", + "dev": true + }, "@types/estree": { "version": "0.0.39", "resolved": "https://registry.npmjs.org/@types/estree/-/estree-0.0.39.tgz", @@ -677,6 +894,12 @@ "integrity": "sha512-FJ3UgI4gIl+PHZm53knsuSFpE+nESMr7M4v9QcgB7S63Kj/6WqMiFQJpBBYz1Pt+66bZpP3Q7Lye0Oo9MPKEdg==", "dev": true }, + "diff": { + "version": "5.1.0", + "resolved": "https://registry.npmjs.org/diff/-/diff-5.1.0.tgz", + "integrity": "sha512-D+mk+qE8VC/PAUrlAU34N+VfXev0ghe5ywmpqrawphmVZc1bEfn56uo9qpyGp1p4xpzOHkSW4ztBd6L7Xx4ACw==", + "dev": true + }, "esm": { "version": "3.2.25", "resolved": "https://registry.npmjs.org/esm/-/esm-3.2.25.tgz", @@ -731,6 +954,12 @@ "function-bind": "^1.1.1" } }, + "has-flag": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/has-flag/-/has-flag-4.0.0.tgz", + "integrity": "sha512-EykJT/Q1KjTWctppgIAgfSO0tKVuZUjhgMr17kqTumMl6Afv3EISleU7qZUzoXDFTAHTDC4NOoG/ZxU3EvlMPQ==", + "dev": true + }, "inflight": { "version": "1.0.6", "resolved": "https://registry.npmjs.org/inflight/-/inflight-1.0.6.tgz", @@ -780,6 +1009,24 @@ "@types/estree": "*" } }, + "isarray": { + "version": "0.0.1", + "resolved": "https://registry.npmjs.org/isarray/-/isarray-0.0.1.tgz", + "integrity": "sha512-D2S+3GLxWH+uhrNEcoh/fnmYeP8E8/zHl644d/jdA0g2uyXvy3sb0qxotE+ne0LtccHknQzWwZEzhak7oJ0COQ==", + "dev": true + }, + "just-extend": { + "version": "4.2.1", + "resolved": "https://registry.npmjs.org/just-extend/-/just-extend-4.2.1.tgz", + "integrity": "sha512-g3UB796vUFIY90VIv/WX3L2c8CS2MdWUww3CNrYmqza1Fg0DURc2K/O4YrnklBdQarSJ/y8JnJYDGc+1iumQjg==", + "dev": true + }, + "lodash.get": { + "version": "4.4.2", + "resolved": "https://registry.npmjs.org/lodash.get/-/lodash.get-4.4.2.tgz", + "integrity": "sha512-z+Uw/vLuy6gQe8cfaFWD7p0wVv8fJl3mbzXh33RS+0oW2wvUqiRXiQ69gLWSLpgB5/6sU+r6BlQR0MBILadqTQ==", + "dev": true + }, "magic-string": { "version": "0.25.7", "resolved": "https://registry.npmjs.org/magic-string/-/magic-string-0.25.7.tgz", @@ -808,6 +1055,30 @@ "resolved": "https://registry.npmjs.org/minimist/-/minimist-0.0.10.tgz", "integrity": "sha1-3j+YVD2/lggr5IrRoMfNqDYwHc8=" }, + "nise": { + "version": "5.1.4", + "resolved": "https://registry.npmjs.org/nise/-/nise-5.1.4.tgz", + "integrity": "sha512-8+Ib8rRJ4L0o3kfmyVCL7gzrohyDe0cMFTBa2d364yIrEGMEoetznKJx899YxjybU6bL9SQkYPSBBs1gyYs8Xg==", + "dev": true, + "requires": { + "@sinonjs/commons": "^2.0.0", + "@sinonjs/fake-timers": "^10.0.2", + "@sinonjs/text-encoding": "^0.7.1", + "just-extend": "^4.0.2", + "path-to-regexp": "^1.7.0" + }, + "dependencies": { + "@sinonjs/commons": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/@sinonjs/commons/-/commons-2.0.0.tgz", + "integrity": "sha512-uLa0j859mMrg2slwQYdO/AkrOfmH+X6LTVmNTS9CqexuE2IvVORIkSpJLqePAbEnKJ77aMmCwr1NUZ57120Xcg==", + "dev": true, + "requires": { + "type-detect": "4.0.8" + } + } + } + }, "node-static": { "version": "0.7.11", "resolved": "https://registry.npmjs.org/node-static/-/node-static-0.7.11.tgz", @@ -848,6 +1119,15 @@ "integrity": "sha512-LDJzPVEEEPR+y48z93A0Ed0yXb8pAByGWo/k5YYdYgpY2/2EsOsksJrq7lOHxryrVOn1ejG6oAp8ahvOIQD8sw==", "dev": true }, + "path-to-regexp": { + "version": "1.8.0", + "resolved": "https://registry.npmjs.org/path-to-regexp/-/path-to-regexp-1.8.0.tgz", + "integrity": "sha512-n43JRhlUKUAlibEJhPeir1ncUID16QnEjNpwzNdO3Lm4ywrBpBZ5oLD0I6br9evr1Y9JTqwRtAh7JLoOzAQdVA==", + "dev": true, + "requires": { + "isarray": "0.0.1" + } + }, "picomatch": { "version": "2.3.1", "resolved": "https://registry.npmjs.org/picomatch/-/picomatch-2.3.1.tgz", @@ -938,12 +1218,41 @@ } } }, + "sinon": { + "version": "15.0.3", + "resolved": "https://registry.npmjs.org/sinon/-/sinon-15.0.3.tgz", + "integrity": "sha512-si3geiRkeovP7Iel2O+qGL4NrO9vbMf3KsrJEi0ghP1l5aBkB5UxARea5j0FUsSqH3HLBh0dQPAyQ8fObRUqHw==", + "dev": true, + "requires": { + "@sinonjs/commons": "^3.0.0", + "@sinonjs/fake-timers": "^10.0.2", + "@sinonjs/samsam": "^8.0.0", + "diff": "^5.1.0", + "nise": "^5.1.4", + "supports-color": "^7.2.0" + } + }, "sourcemap-codec": { "version": "1.4.8", "resolved": "https://registry.npmjs.org/sourcemap-codec/-/sourcemap-codec-1.4.8.tgz", "integrity": "sha512-9NykojV5Uih4lgo5So5dtw+f0JgJX30KCNI8gwhz2J9A15wD0Ml6tjHKwf6fTSa6fAdVBdZeNOs9eJ71qCk8vA==", "dev": true }, + "supports-color": { + "version": "7.2.0", + "resolved": "https://registry.npmjs.org/supports-color/-/supports-color-7.2.0.tgz", + "integrity": "sha512-qpCAvRl9stuOHveKsn7HncJRvv501qIacKzQlO/+Lwxc9+0q2wLyv4Dfvt80/DPn2pqOBsJdDiogXGR9+OvwRw==", + "dev": true, + "requires": { + "has-flag": "^4.0.0" + } + }, + "type-detect": { + "version": "4.0.8", + "resolved": "https://registry.npmjs.org/type-detect/-/type-detect-4.0.8.tgz", + "integrity": "sha512-0fr/mIH1dlO+x7TlcMy+bIDqKPsw/70tVyeHW787goQjhmqaZe10uwLujubK9q9Lg6Fiho1KUKDYz0Z7k7g5/g==", + "dev": true + }, "wordwrap": { "version": "0.0.3", "resolved": "https://registry.npmjs.org/wordwrap/-/wordwrap-0.0.3.tgz", diff --git a/packages/testkit-backend/package.json b/packages/testkit-backend/package.json index e10c21db7..923ce63e0 100644 --- a/packages/testkit-backend/package.json +++ b/packages/testkit-backend/package.json @@ -43,6 +43,7 @@ "esm": "^3.2.25", "rollup": "^2.77.4-1", "rollup-plugin-inject-process-env": "^1.3.1", - "rollup-plugin-polyfill-node": "^0.11.0" + "rollup-plugin-polyfill-node": "^0.11.0", + "sinon": "^15.0.1" } } diff --git a/packages/testkit-backend/src/context.js b/packages/testkit-backend/src/context.js index ab4733c0b..f38520c76 100644 --- a/packages/testkit-backend/src/context.js +++ b/packages/testkit-backend/src/context.js @@ -12,6 +12,10 @@ export default class Context { this._bookmarkSupplierRequests = {} this._notifyBookmarksRequests = {} this._bookmarksManagers = {} + this._authTokenManagers = {} + this._authTokenManagerGetAuthRequests = {} + this._authTokenManagerOnAuthExpiredRequests = {} + this._expirationBasedAuthTokenProviderRequests = {} this._binder = binder this._environmentLogLevel = environmentLogLevel } @@ -161,6 +165,54 @@ export default class Context { delete this._bookmarksManagers[id] } + addAuthTokenManager (authTokenManagersFactory) { + this._id++ + this._authTokenManagers[this._id] = authTokenManagersFactory(this._id) + return this._id + } + + getAuthTokenManager (id) { + return this._authTokenManagers[id] + } + + removeAuthTokenManager (id) { + delete this._authTokenManagers[id] + } + + addAuthTokenManagerGetAuthRequest (resolve, reject) { + return this._add(this._authTokenManagerGetAuthRequests, { + resolve, reject + }) + } + + getAuthTokenManagerGetAuthRequest (id) { + return this._authTokenManagerGetAuthRequests[id] + } + + removeAuthTokenManagerGetAuthRequest (id) { + delete this._authTokenManagerGetAuthRequests[id] + } + + addAuthTokenManagerOnAuthExpiredRequest (request) { + return this._add(this._authTokenManagerOnAuthExpiredRequests, request) + } + + removeAuthTokenManagerOnAuthExpiredRequest (id) { + delete this._authTokenManagerOnAuthExpiredRequests[id] + } + + addExpirationBasedAuthTokenProviderRequest (resolve, reject) { + return this._add(this._expirationBasedAuthTokenProviderRequests, { resolve, reject }) + } + + getExpirationBasedAuthTokenProviderRequest (id) { + return this._expirationBasedAuthTokenProviderRequests[id] + } + + removeExpirationBasedAuthTokenProviderRequest (id) { + delete this._expirationBasedAuthTokenProviderRequests[id] + } + _add (map, object) { this._id++ map[this._id] = object diff --git a/packages/testkit-backend/src/controller/local.js b/packages/testkit-backend/src/controller/local.js index 2d5c80af0..9513eab95 100644 --- a/packages/testkit-backend/src/controller/local.js +++ b/packages/testkit-backend/src/controller/local.js @@ -3,6 +3,7 @@ import Controller from './interface' import stringify from '../stringify' import { isFrontendError } from '../request-handlers' import CypherNativeBinders from '../cypher-native-binders' +import FakeTime from '../mock/fake-time' /** * Local controller handles the requests locally by redirecting them to the correct request handler/service. @@ -37,7 +38,12 @@ export default class LocalController extends Controller { throw new Error(`Unknown request: ${name}`) } - return await this._requestHandlers[name](this._neo4j, this._contexts.get(contextId), data, { + return await this._requestHandlers[name]({ + neo4j: this._neo4j, + mock: { + FakeTime + } + }, this._contexts.get(contextId), data, { writeResponse: (response) => this._writeResponse(contextId, response), writeError: (e) => this._writeError(contextId, e), writeBackendError: (msg) => this._writeBackendError(contextId, msg) @@ -53,6 +59,7 @@ export default class LocalController extends Controller { } _writeError (contextId, e) { + console.trace(e) if (e.name) { if (isFrontendError(e)) { this._writeResponse(contextId, newResponse('FrontendError', { diff --git a/packages/testkit-backend/src/cypher-native-binders.js b/packages/testkit-backend/src/cypher-native-binders.js index a0535732f..ff9498e7e 100644 --- a/packages/testkit-backend/src/cypher-native-binders.js +++ b/packages/testkit-backend/src/cypher-native-binders.js @@ -253,10 +253,34 @@ export default function CypherNativeBinders (neo4j) { throw Error(err) } + function parseAuthToken (authToken) { + switch (authToken.scheme) { + case 'basic': + return neo4j.auth.basic( + authToken.principal, + authToken.credentials, + authToken.realm + ) + case 'kerberos': + return neo4j.auth.kerberos(authToken.credentials) + case 'bearer': + return neo4j.auth.bearer(authToken.credentials) + default: + return neo4j.auth.custom( + authToken.principal, + authToken.credentials, + authToken.realm, + authToken.scheme, + authToken.parameters + ) + } + } + this.valueResponse = valueResponse this.objectToCypher = objectToCypher this.objectToNative = objectToNative this.objectMemberBitIntToNumber = objectMemberBitIntToNumber this.nativeToCypher = nativeToCypher this.cypherToNative = cypherToNative + this.parseAuthToken = parseAuthToken } diff --git a/packages/testkit-backend/src/feature/common.js b/packages/testkit-backend/src/feature/common.js index 99181a219..0a2235f98 100644 --- a/packages/testkit-backend/src/feature/common.js +++ b/packages/testkit-backend/src/feature/common.js @@ -1,9 +1,12 @@ const features = [ + 'Backend:MockTime', 'Feature:Auth:Custom', 'Feature:Auth:Kerberos', 'Feature:Auth:Bearer', + 'Feature:Auth:Managed', 'Feature:API:BookmarkManager', + 'Feature:API:Session:AuthConfig', 'Feature:API:SSLConfig', 'Feature:API:SSLSchemes', 'Feature:API:Type.Temporal', @@ -23,8 +26,10 @@ const features = [ 'Feature:API:Driver.ExecuteQuery', 'Feature:API:Driver:NotificationsConfig', 'Feature:API:Driver:GetServerInfo', + 'Feature:API:Driver.VerifyAuthentication', 'Feature:API:Driver.VerifyConnectivity', 'Feature:API:Session:NotificationsConfig', + 'Optimization:AuthPipelining', 'Optimization:EagerTransactionBegin', 'Optimization:ImplicitDefaultArguments', 'Optimization:MinimalBookmarksSet', diff --git a/packages/testkit-backend/src/mock/fake-time.js b/packages/testkit-backend/src/mock/fake-time.js new file mode 100644 index 000000000..53f38546a --- /dev/null +++ b/packages/testkit-backend/src/mock/fake-time.js @@ -0,0 +1,15 @@ +import sinon from 'sinon' + +export default class FakeTime { + constructor (time) { + this._clock = sinon.useFakeTimers(time || new Date().getTime()) + } + + tick (incrementMs) { + this._clock.tick(incrementMs) + } + + restore () { + this._clock.restore() + } +} diff --git a/packages/testkit-backend/src/request-handlers-rx.js b/packages/testkit-backend/src/request-handlers-rx.js index 1d0762856..65a62b6af 100644 --- a/packages/testkit-backend/src/request-handlers-rx.js +++ b/packages/testkit-backend/src/request-handlers-rx.js @@ -9,8 +9,10 @@ export { StartTest, GetFeatures, VerifyConnectivity, + VerifyAuthentication, GetServerInfo, CheckMultiDBSupport, + CheckSessionAuthSupport, ResolverResolutionCompleted, GetRoutingTable, ForcedRoutingTableUpdate, @@ -22,10 +24,19 @@ export { BookmarksSupplierCompleted, BookmarksConsumerCompleted, StartSubTest, - ExecuteQuery + ExecuteQuery, + NewAuthTokenManager, + AuthTokenManagerClose, + AuthTokenManagerGetAuthCompleted, + AuthTokenManagerOnAuthExpiredCompleted, + NewExpirationBasedAuthTokenManager, + ExpirationBasedAuthTokenProviderCompleted, + FakeTimeInstall, + FakeTimeTick, + FakeTimeUninstall } from './request-handlers.js' -export function NewSession (neo4j, context, data, wire) { +export function NewSession ({ neo4j }, context, data, wire) { let { driverId, accessMode, bookmarks, database, fetchSize, impersonatedUser, bookmarkManagerId } = data switch (accessMode) { case 'r': @@ -53,6 +64,10 @@ export function NewSession (neo4j, context, data, wire) { disabledCategories: data.notificationsDisabledCategories } } + const auth = data.authorizationToken != null + ? context.binder.parseAuthToken(data.authorizationToken.data) + : undefined + const driver = context.getDriver(driverId) const session = driver.rxSession({ defaultAccessMode: accessMode, @@ -61,7 +76,8 @@ export function NewSession (neo4j, context, data, wire) { fetchSize, impersonatedUser, bookmarkManager, - notificationFilter + notificationFilter, + auth }) const id = context.addSession(session) wire.writeResponse(responses.Session({ id })) diff --git a/packages/testkit-backend/src/request-handlers.js b/packages/testkit-backend/src/request-handlers.js index 295033053..468a3b8ef 100644 --- a/packages/testkit-backend/src/request-handlers.js +++ b/packages/testkit-backend/src/request-handlers.js @@ -8,36 +8,24 @@ export function isFrontendError (error) { return error.message === 'TestKit FrontendError' } -export function NewDriver (neo4j, context, data, wire) { +export function NewDriver ({ neo4j }, context, data, wire) { const { uri, - authorizationToken: { data: authToken }, + authorizationToken, + authTokenManagerId, userAgent, resolverRegistered } = data - let parsedAuthToken = authToken - switch (authToken.scheme) { - case 'basic': - parsedAuthToken = neo4j.auth.basic( - authToken.principal, - authToken.credentials, - authToken.realm - ) - break - case 'kerberos': - parsedAuthToken = neo4j.auth.kerberos(authToken.credentials) - break - case 'bearer': - parsedAuthToken = neo4j.auth.bearer(authToken.credentials) - break - default: - parsedAuthToken = neo4j.auth.custom( - authToken.principal, - authToken.credentials, - authToken.realm, - authToken.scheme, - authToken.parameters - ) + + let parsedAuthToken = null + + if (authorizationToken != null && authTokenManagerId != null) { + throw new Error('Can not set authorizationToken and authTokenManagerId') + } else if (authorizationToken) { + const { data: authToken } = authorizationToken + parsedAuthToken = context.binder.parseAuthToken(authToken) + } else { + parsedAuthToken = context.getAuthTokenManager(authTokenManagerId) } const resolver = resolverRegistered @@ -112,7 +100,7 @@ export function DriverClose (_, context, data, wire) { .catch(err => wire.writeError(err)) } -export function NewSession (neo4j, context, data, wire) { +export function NewSession ({ neo4j }, context, data, wire) { let { driverId, accessMode, bookmarks, database, fetchSize, impersonatedUser, bookmarkManagerId } = data switch (accessMode) { case 'r': @@ -140,6 +128,10 @@ export function NewSession (neo4j, context, data, wire) { disabledCategories: data.notificationsDisabledCategories } } + const auth = data.authorizationToken != null + ? context.binder.parseAuthToken(data.authorizationToken.data) + : undefined + const driver = context.getDriver(driverId) const session = driver.session({ defaultAccessMode: accessMode, @@ -148,7 +140,8 @@ export function NewSession (neo4j, context, data, wire) { fetchSize, impersonatedUser, bookmarkManager, - notificationFilter + notificationFilter, + auth }) const id = context.addSession(session) wire.writeResponse(responses.Session({ id })) @@ -403,6 +396,18 @@ export function VerifyConnectivity (_, context, { driverId }, wire) { .catch(error => wire.writeError(error)) } +export function VerifyAuthentication (_, context, { driverId, authorizationToken }, wire) { + const auth = authorizationToken != null && authorizationToken.data != null + ? context.binder.parseAuthToken(authorizationToken.data) + : undefined + + const driver = context.getDriver(driverId) + return driver + .verifyAuthentication({ auth }) + .then(authenticated => wire.writeResponse(responses.DriverIsAuthenticated({ id: driverId, authenticated }))) + .catch(error => wire.writeError(error)) +} + export function GetServerInfo (_, context, { driverId }, wire) { const driver = context.getDriver(driverId) return driver @@ -421,6 +426,16 @@ export function CheckMultiDBSupport (_, context, { driverId }, wire) { .catch(error => wire.writeError(error)) } +export function CheckSessionAuthSupport (_, context, { driverId }, wire) { + const driver = context.getDriver(driverId) + return driver + .supportsSessionAuth() + .then(available => + wire.writeResponse(responses.SessionAuthSupport({ id: driverId, available })) + ) + .catch(error => wire.writeError(error)) +} + export function ResolverResolutionCompleted ( _, context, @@ -432,7 +447,7 @@ export function ResolverResolutionCompleted ( } export function NewBookmarkManager ( - neo4j, + { neo4j }, context, { initialBookmarks, @@ -506,6 +521,62 @@ export function BookmarksConsumerCompleted ( notifyBookmarksRequest.resolve() } +export function NewAuthTokenManager (_, context, _data, wire) { + const id = context.addAuthTokenManager((authTokenManagerId) => { + return { + getToken: () => new Promise((resolve, reject) => { + const id = context.addAuthTokenManagerGetAuthRequest(resolve, reject) + wire.writeResponse(responses.AuthTokenManagerGetAuthRequest({ id, authTokenManagerId })) + }), + onTokenExpired: (auth) => { + const id = context.addAuthTokenManagerOnAuthExpiredRequest() + wire.writeResponse(responses.AuthTokenManagerOnAuthExpiredRequest({ id, authTokenManagerId, auth })) + } + } + }) + + wire.writeResponse(responses.AuthTokenManager({ id })) +} + +export function AuthTokenManagerClose (_, context, { id }, wire) { + context.removeAuthTokenManager(id) + wire.writeResponse(responses.AuthTokenManager({ id })) +} + +export function AuthTokenManagerGetAuthCompleted (_, context, { requestId, auth }) { + const request = context.getAuthTokenManagerGetAuthRequest(requestId) + request.resolve(auth.data) + context.removeAuthTokenManagerGetAuthRequest(requestId) +} + +export function AuthTokenManagerOnAuthExpiredCompleted (_, context, { requestId }) { + context.removeAuthTokenManagerOnAuthExpiredRequest(requestId) +} + +export function NewExpirationBasedAuthTokenManager ({ neo4j }, context, _, wire) { + const id = context.addAuthTokenManager((expirationBasedAuthTokenManagerId) => { + return neo4j.expirationBasedAuthTokenManager({ + tokenProvider: () => new Promise((resolve, reject) => { + const id = context.addExpirationBasedAuthTokenProviderRequest(resolve, reject) + wire.writeResponse(responses.ExpirationBasedAuthTokenProviderRequest({ id, expirationBasedAuthTokenManagerId })) + }) + }) + }) + + wire.writeResponse(responses.ExpirationBasedAuthTokenManager({ id })) +} + +export function ExpirationBasedAuthTokenProviderCompleted (_, context, { requestId, auth }) { + const request = context.getExpirationBasedAuthTokenProviderRequest(requestId) + request.resolve({ + expiration: auth.data.expiresInMs != null + ? new Date(new Date().getTime() + auth.data.expiresInMs) + : undefined, + token: context.binder.parseAuthToken(auth.data.auth.data) + }) + context.removeExpirationBasedAuthTokenProviderRequest(requestId) +} + export function GetRoutingTable (_, context, { driverId, database }, wire) { const driver = context.getDriver(driverId) const routingTable = @@ -549,7 +620,7 @@ export function ForcedRoutingTableUpdate (_, context, { driverId, database, book } } -export function ExecuteQuery (neo4j, context, { driverId, cypher, params, config }, wire) { +export function ExecuteQuery ({ neo4j }, context, { driverId, cypher, params, config }, wire) { const driver = context.getDriver(driverId) if (params) { for (const [key, value] of Object.entries(params)) { @@ -601,3 +672,19 @@ export function ExecuteQuery (neo4j, context, { driverId, cypher, params, config }) .catch(e => wire.writeError(e)) } + +export function FakeTimeInstall ({ mock }, context, _data, wire) { + context.clock = new mock.FakeTime() + wire.writeResponse(responses.FakeTimeAck()) +} + +export function FakeTimeTick (_, context, { incrementMs }, wire) { + context.clock.tick(incrementMs) + wire.writeResponse(responses.FakeTimeAck()) +} + +export function FakeTimeUninstall (_, context, _data, wire) { + context.clock.restore() + delete context.clock + wire.writeResponse(responses.FakeTimeAck()) +} diff --git a/packages/testkit-backend/src/responses.js b/packages/testkit-backend/src/responses.js index 7278e9f0c..9d223f0e2 100644 --- a/packages/testkit-backend/src/responses.js +++ b/packages/testkit-backend/src/responses.js @@ -77,6 +77,10 @@ export function MultiDBSupport ({ id, available }) { return response('MultiDBSupport', { id, available }) } +export function SessionAuthSupport ({ id, available }) { + return response('SessionAuthSupport', { id, available }) +} + export function RoutingTable ({ routingTable }) { const serverAddressToString = serverAddress => serverAddress.asHostPort() return response('RoutingTable', { @@ -99,6 +103,34 @@ export function EagerResult ({ keys, records, summary }, { binder }) { }) } +export function AuthTokenManager ({ id }) { + return response('AuthTokenManager', { id }) +} + +export function AuthTokenManagerGetAuthRequest ({ id, authTokenManagerId }) { + return response('AuthTokenManagerGetAuthRequest', { id, authTokenManagerId }) +} + +export function AuthorizationToken (data) { + return response('AuthorizationToken', data) +} + +export function AuthTokenManagerOnAuthExpiredRequest ({ id, authTokenManagerId, auth }) { + return response('AuthTokenManagerOnAuthExpiredRequest', { id, authTokenManagerId, auth: AuthorizationToken(auth) }) +} + +export function ExpirationBasedAuthTokenManager ({ id }) { + return response('ExpirationBasedAuthTokenManager', { id }) +} + +export function ExpirationBasedAuthTokenProviderRequest ({ id, expirationBasedAuthTokenManagerId }) { + return response('ExpirationBasedAuthTokenProviderRequest', { id, expirationBasedAuthTokenManagerId }) +} + +export function DriverIsAuthenticated ({ id, authenticated }) { + return response('DriverIsAuthenticated', { id, authenticated }) +} + // Testkit controller messages export function RunTest () { return response('RunTest', null) @@ -116,6 +148,10 @@ export function FeatureList ({ features }) { return response('FeatureList', { features }) } +export function FakeTimeAck () { + return response('FakeTimeAck', {}) +} + function response (name, data) { return { name, data } }