Skip to content

Improve AuthTokenManager interface and factory method #1123

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Aug 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,15 @@
* limitations under the License.
*/

import { expirationBasedAuthTokenManager } from 'neo4j-driver-core'
import { staticAuthTokenManager } from 'neo4j-driver-core'
import { object } from '../lang'

/**
* Class which provides Authorization for {@link Connection}
*/
export default class AuthenticationProvider {
constructor ({ authTokenManager, userAgent, boltAgent }) {
this._authTokenManager = authTokenManager || expirationBasedAuthTokenManager({
tokenProvider: () => {}
})
this._authTokenManager = authTokenManager || staticAuthTokenManager({})
this._userAgent = userAgent
this._boltAgent = boltAgent
}
Expand Down Expand Up @@ -56,12 +54,10 @@ export default class AuthenticationProvider {
handleError ({ connection, code }) {
if (
connection &&
[
'Neo.ClientError.Security.Unauthorized',
'Neo.ClientError.Security.TokenExpired'
].includes(code)
code.startsWith('Neo.ClientError.Security.')
) {
this._authTokenManager.onTokenExpired(connection.authToken)
return this._authTokenManager.handleSecurityException(connection.authToken, code)
}
return false
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ export default class DirectConnectionProvider extends PooledConnectionProvider {
async acquireConnection ({ accessMode, database, bookmarks, auth, forceReAuth } = {}) {
const databaseSpecificErrorHandler = ConnectionErrorHandler.create({
errorCode: SERVICE_UNAVAILABLE,
handleAuthorizationExpired: (error, address, conn) =>
this._handleAuthorizationExpired(error, address, conn, database)
handleSecurityError: (error, address, conn) =>
this._handleSecurityError(error, address, conn, database)
})

const connection = await this._connectionPool.acquire({ auth, forceReAuth }, this._address)
Expand All @@ -68,12 +68,12 @@ export default class DirectConnectionProvider extends PooledConnectionProvider {
return new DelegateConnection(connection, databaseSpecificErrorHandler)
}

_handleAuthorizationExpired (error, address, connection, database) {
_handleSecurityError (error, address, connection, database) {
this._log.warn(
`Direct driver ${this._id} will close connection to ${address} for database '${database}' because of an error ${error.code} '${error.message}'`
)

return super._handleAuthorizationExpired(error, address, connection)
return super._handleSecurityError(error, address, connection)
}

async _hasProtocolVersion (versionPredicate) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import { createChannelConnection, ConnectionErrorHandler } from '../connection'
import Pool, { PoolConfig } from '../pool'
import { error, ConnectionProvider, ServerInfo, newError, isStaticAuthTokenManger } from 'neo4j-driver-core'
import { error, ConnectionProvider, ServerInfo, newError } from 'neo4j-driver-core'
import AuthenticationProvider from './authentication-provider'
import { object } from '../lang'

Expand All @@ -41,7 +41,6 @@ export default class PooledConnectionProvider extends ConnectionProvider {
this._id = id
this._config = config
this._log = log
this._authTokenManager = authTokenManager
this._authenticationProvider = new AuthenticationProvider({ authTokenManager, userAgent, boltAgent })
this._userAgent = userAgent
this._boltAgent = boltAgent
Expand Down Expand Up @@ -224,8 +223,12 @@ export default class PooledConnectionProvider extends ConnectionProvider {
conn._updateCurrentObserver()
}

_handleAuthorizationExpired (error, address, connection) {
this._authenticationProvider.handleError({ connection, code: error.code })
_handleSecurityError (error, address, connection) {
const handled = this._authenticationProvider.handleError({ connection, code: error.code })

if (handled) {
error.retriable = true
}

if (error.code === 'Neo.ClientError.Security.AuthorizationExpired') {
this._connectionPool.apply(address, (conn) => { conn.authToken = null })
Expand All @@ -235,10 +238,6 @@ export default class PooledConnectionProvider extends ConnectionProvider {
connection.close().catch(() => undefined)
}

if (error.code === 'Neo.ClientError.Security.TokenExpired' && !isStaticAuthTokenManger(this._authTokenManager)) {
error.retriable = true
}

return error
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -116,12 +116,12 @@ export default class RoutingConnectionProvider extends PooledConnectionProvider
return error
}

_handleAuthorizationExpired (error, address, connection, database) {
_handleSecurityError (error, address, connection, database) {
this._log.warn(
`Routing driver ${this._id} will close connections to ${address} for database '${database}' because of an error ${error.code} '${error.message}'`
)

return super._handleAuthorizationExpired(error, address, connection, database)
return super._handleSecurityError(error, address, connection, database)
}

_handleWriteFailure (error, address, database) {
Expand Down Expand Up @@ -150,7 +150,7 @@ export default class RoutingConnectionProvider extends PooledConnectionProvider
(error, address) => this._handleUnavailability(error, address, context.database),
(error, address) => this._handleWriteFailure(error, address, context.database),
(error, address, conn) =>
this._handleAuthorizationExpired(error, address, conn, context.database)
this._handleSecurityError(error, address, conn, context.database)
)

const routingTable = await this._freshRoutingTable({
Expand Down Expand Up @@ -584,7 +584,7 @@ export default class RoutingConnectionProvider extends PooledConnectionProvider

const databaseSpecificErrorHandler = ConnectionErrorHandler.create({
errorCode: SESSION_EXPIRED,
handleAuthorizationExpired: (error, address, conn) => this._handleAuthorizationExpired(error, address, conn)
handleSecurityError: (error, address, conn) => this._handleSecurityError(error, address, conn)
})

const delegateConnection = !connection._sticky
Expand Down
21 changes: 10 additions & 11 deletions packages/bolt-connection/src/connection/connection-error-handler.js
Original file line number Diff line number Diff line change
Expand Up @@ -26,25 +26,25 @@ export default class ConnectionErrorHandler {
errorCode,
handleUnavailability,
handleWriteFailure,
handleAuthorizationExpired
handleSecurityError
) {
this._errorCode = errorCode
this._handleUnavailability = handleUnavailability || noOpHandler
this._handleWriteFailure = handleWriteFailure || noOpHandler
this._handleAuthorizationExpired = handleAuthorizationExpired || noOpHandler
this._handleSecurityError = handleSecurityError || noOpHandler
}

static create ({
errorCode,
handleUnavailability,
handleWriteFailure,
handleAuthorizationExpired
handleSecurityError
}) {
return new ConnectionErrorHandler(
errorCode,
handleUnavailability,
handleWriteFailure,
handleAuthorizationExpired
handleSecurityError
)
}

Expand All @@ -63,8 +63,8 @@ export default class ConnectionErrorHandler {
* @return {Neo4jError} new error that should be propagated to the user.
*/
handleAndTransformError (error, address, connection) {
if (isAutorizationExpiredError(error)) {
return this._handleAuthorizationExpired(error, address, connection)
if (isSecurityError(error)) {
return this._handleSecurityError(error, address, connection)
}
if (isAvailabilityError(error)) {
return this._handleUnavailability(error, address, connection)
Expand All @@ -76,11 +76,10 @@ export default class ConnectionErrorHandler {
}
}

function isAutorizationExpiredError (error) {
return error && (
error.code === 'Neo.ClientError.Security.AuthorizationExpired' ||
error.code === 'Neo.ClientError.Security.TokenExpired'
)
function isSecurityError (error) {
return error != null &&
error.code != null &&
error.code.startsWith('Neo.ClientError.Security.')
}

function isAvailabilityError (error) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import { expirationBasedAuthTokenManager } from 'neo4j-driver-core'
import { authTokenManagers } from 'neo4j-driver-core'
import AuthenticationProvider from '../../src/connection-provider/authentication-provider'

describe('AuthenticationProvider', () => {
Expand Down Expand Up @@ -642,8 +642,15 @@ describe('AuthenticationProvider', () => {
authenticationProvider
} = createScenario()

const handleSecurityExceptionSpy = jest.spyOn(authenticationProvider._authTokenManager, 'handleSecurityException')

authenticationProvider.handleError({ code, connection })

if (code.startsWith('Neo.ClientError.Security.')) {
expect(handleSecurityExceptionSpy).toBeCalledWith(connection.authToken, code)
} else {
expect(handleSecurityExceptionSpy).not.toBeCalled()
}
expect(authTokenProvider).not.toHaveBeenCalled()
})

Expand Down Expand Up @@ -785,7 +792,7 @@ describe('AuthenticationProvider', () => {
})

function createAuthenticationProvider (authTokenProvider, mocks) {
const authTokenManager = expirationBasedAuthTokenManager({ tokenProvider: authTokenProvider })
const authTokenManager = authTokenManagers.bearer({ tokenProvider: authTokenProvider })
const provider = new AuthenticationProvider({
authTokenManager,
userAgent: USER_AGENT,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import DirectConnectionProvider from '../../src/connection-provider/connection-provider-direct'
import { Pool } from '../../src/pool'
import { Connection, DelegateConnection } from '../../src/connection'
import { internal, newError, ServerInfo, staticAuthTokenManager, expirationBasedAuthTokenManager } from 'neo4j-driver-core'
import { authTokenManagers, internal, newError, ServerInfo, staticAuthTokenManager } from 'neo4j-driver-core'
import AuthenticationProvider from '../../src/connection-provider/authentication-provider'
import { functional } from '../../src/lang'

Expand Down Expand Up @@ -209,7 +209,7 @@ it('should call authenticationAuthProvider.handleError when TokenExpired happens
it('should change error to retriable when error when TokenExpired happens and staticAuthTokenManager is not being used', async () => {
const address = ServerAddress.fromUrl('localhost:123')
const pool = newPool()
const connectionProvider = newDirectConnectionProvider(address, pool, expirationBasedAuthTokenManager({ tokenProvider: () => null }))
const connectionProvider = newDirectConnectionProvider(address, pool, authTokenManagers.bearer({ tokenProvider: () => null }))

const conn = await connectionProvider.acquireConnection({
accessMode: 'READ',
Expand Down Expand Up @@ -246,6 +246,26 @@ it('should not change error to retriable when error when TokenExpired happens an
expect(error.retriable).toBe(false)
})

it('should not change error to retriable when error when TokenExpired happens and authTokenManagers.basic is being used', async () => {
const address = ServerAddress.fromUrl('localhost:123')
const pool = newPool()
const connectionProvider = newDirectConnectionProvider(address, pool, authTokenManagers.basic({ tokenProvider: () => null }))

const conn = await connectionProvider.acquireConnection({
accessMode: 'READ',
database: ''
})

const expectedError = newError(
'Message',
'Neo.ClientError.Security.TokenExpired'
)

const error = conn.handleAndTransformError(expectedError, address)

expect(error.retriable).toBe(false)
})

describe('constructor', () => {
describe('newPool', () => {
const server0 = ServerAddress.fromUrl('localhost:123')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import {
internal,
ServerInfo,
staticAuthTokenManager,
expirationBasedAuthTokenManager
authTokenManagers
} from 'neo4j-driver-core'
import { RoutingTable } from '../../src/rediscovery/'
import { Pool } from '../../src/pool'
Expand Down Expand Up @@ -1718,7 +1718,8 @@ describe.each([
],
pool
)
connectionProvider._authTokenManager = expirationBasedAuthTokenManager({ tokenProvider: () => null })

setupAuthTokenManager(connectionProvider, authTokenManagers.bearer({ tokenProvider: () => null }))

const error = newError(
'Message',
Expand Down Expand Up @@ -1756,8 +1757,51 @@ describe.each([
)
],
pool

)

setupAuthTokenManager(connectionProvider, staticAuthTokenManager({ authToken: null }))

const error = newError(
'Message',
'Neo.ClientError.Security.TokenExpired'
)

const server2Connection = await connectionProvider.acquireConnection({
accessMode: 'WRITE',
database: null,
impersonatedUser: user
})

const server3Connection = await connectionProvider.acquireConnection({
accessMode: 'READ',
database: null,
impersonatedUser: user
})

const error1 = server3Connection.handleAndTransformError(error, server3)
const error2 = server2Connection.handleAndTransformError(error, server2)

expect(error1.retriable).toBe(false)
expect(error2.retriable).toBe(false)
})

it.each(usersDataSet)('should not change error to retriable when error when TokenExpired happens and authTokenManagers.basic is being used [user=%s]', async (user) => {
const pool = newPool()
const connectionProvider = newRoutingConnectionProvider(
[
newRoutingTable(
null,
[server1, server2],
[server3, server2],
[server2, server4]
)
],
pool

)
connectionProvider._authTokenManager = staticAuthTokenManager({ authToken: null })

setupAuthTokenManager(connectionProvider, authTokenManagers.basic({ tokenProvider: () => {} }))

const error = newError(
'Message',
Expand Down Expand Up @@ -3944,3 +3988,7 @@ class FakeDnsResolver {
return Promise.resolve(this._addresses ? this._addresses : [seedRouter])
}
}

function setupAuthTokenManager (provider, authTokenManager) {
provider._authenticationProvider._authTokenManager = authTokenManager
}
Loading