From 94e4cc6b37f78fa874e4a34f8522a30a4b08e7e8 Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Fri, 9 Dec 2022 11:11:17 +0100 Subject: [PATCH 01/23] WIP todo: unit tests --- src/neo4j/_async/driver.py | 38 +- src/neo4j/_async/io/_bolt.py | 82 +++- src/neo4j/_async/io/_bolt3.py | 15 +- src/neo4j/_async/io/_bolt4.py | 10 + src/neo4j/_async/io/_bolt5.py | 123 +++++- src/neo4j/_async/io/_common.py | 7 +- src/neo4j/_async/io/_pool.py | 142 ++++-- src/neo4j/_async/work/session.py | 4 +- src/neo4j/_async/work/workspace.py | 4 +- src/neo4j/_conf.py | 6 + src/neo4j/_sync/driver.py | 36 +- src/neo4j/_sync/io/_bolt.py | 82 +++- src/neo4j/_sync/io/_bolt3.py | 15 +- src/neo4j/_sync/io/_bolt4.py | 10 + src/neo4j/_sync/io/_bolt5.py | 123 +++++- src/neo4j/_sync/io/_common.py | 7 +- src/neo4j/_sync/io/_pool.py | 142 ++++-- src/neo4j/_sync/work/session.py | 4 +- src/neo4j/_sync/work/workspace.py | 4 +- src/neo4j/api.py | 54 +++ src/neo4j/scratch_11.py | 28 ++ testkitbackend/_async/requests.py | 13 +- testkitbackend/_sync/requests.py | 13 +- testkitbackend/test_config.json | 4 + tests/unit/async_/fixtures/fake_connection.py | 1 + tests/unit/async_/io/test_class_bolt.py | 12 +- tests/unit/async_/io/test_class_bolt3.py | 90 ++++ tests/unit/async_/io/test_class_bolt4x0.py | 91 ++++ tests/unit/async_/io/test_class_bolt4x1.py | 91 ++++ tests/unit/async_/io/test_class_bolt4x2.py | 91 ++++ tests/unit/async_/io/test_class_bolt4x3.py | 90 ++++ tests/unit/async_/io/test_class_bolt4x4.py | 89 ++++ tests/unit/async_/io/test_class_bolt5x0.py | 358 ++++++++++++++++ tests/unit/async_/io/test_class_bolt5x1.py | 405 ++++++++++++++++++ tests/unit/async_/io/test_direct.py | 32 +- tests/unit/async_/io/test_neo4j_pool.py | 114 +++-- tests/unit/async_/test_driver.py | 1 + tests/unit/async_/work/test_session.py | 2 +- tests/unit/common/test_conf.py | 2 + tests/unit/mixed/io/test_direct.py | 8 +- tests/unit/sync/fixtures/fake_connection.py | 1 + tests/unit/sync/io/test_class_bolt.py | 12 +- tests/unit/sync/io/test_class_bolt3.py | 90 ++++ tests/unit/sync/io/test_class_bolt4x0.py | 91 ++++ tests/unit/sync/io/test_class_bolt4x1.py | 91 ++++ tests/unit/sync/io/test_class_bolt4x2.py | 91 ++++ tests/unit/sync/io/test_class_bolt4x3.py | 90 ++++ tests/unit/sync/io/test_class_bolt4x4.py | 89 ++++ tests/unit/sync/io/test_class_bolt5x0.py | 358 ++++++++++++++++ tests/unit/sync/io/test_class_bolt5x1.py | 405 ++++++++++++++++++ tests/unit/sync/io/test_direct.py | 32 +- tests/unit/sync/io/test_neo4j_pool.py | 114 +++-- tests/unit/sync/test_driver.py | 1 + tests/unit/sync/work/test_session.py | 2 +- 54 files changed, 3656 insertions(+), 254 deletions(-) create mode 100644 src/neo4j/scratch_11.py create mode 100644 tests/unit/async_/io/test_class_bolt5x0.py create mode 100644 tests/unit/async_/io/test_class_bolt5x1.py create mode 100644 tests/unit/sync/io/test_class_bolt5x0.py create mode 100644 tests/unit/sync/io/test_class_bolt5x1.py diff --git a/src/neo4j/_async/driver.py b/src/neo4j/_async/driver.py index 114e1e6d..faf01dc9 100644 --- a/src/neo4j/_async/driver.py +++ b/src/neo4j/_async/driver.py @@ -27,8 +27,6 @@ import ssl - - from .._async_compat.util import AsyncUtil from .._conf import ( Config, @@ -46,6 +44,7 @@ ) from ..addressing import Address from ..api import ( + _TAuthTokenProvider, AsyncBookmarkManager, Auth, BookmarkManager, @@ -86,7 +85,9 @@ def driver( cls, uri: str, *, - auth: t.Union[t.Tuple[t.Any, t.Any], Auth, None] = ..., + auth: t.Union[ + t.Tuple[t.Any, t.Any], Auth, _TAuthTokenProvider, None + ] = ..., max_connection_lifetime: float = ..., max_connection_pool_size: int = ..., connection_timeout: float = ..., @@ -124,7 +125,9 @@ def driver( @classmethod def driver( cls, uri: str, *, - auth: t.Union[t.Tuple[t.Any, t.Any], Auth, None] = None, + auth: t.Union[ + t.Tuple[t.Any, t.Any], Auth, _TAuthTokenProvider, None + ] = None, **config ) -> AsyncDriver: """Create a driver. @@ -140,6 +143,10 @@ def driver( driver_type, security_type, parsed = parse_neo4j_uri(uri) + if not callable(config.get("auth")): + auth = config.get("auth") + config["auth"] = lambda: auth + # TODO: 6.0 - remove "trust" config option if "trust" in config.keys(): if config["trust"] not in ( @@ -216,10 +223,10 @@ def driver( # 'Routing parameters are not supported with scheme ' # '"bolt". Given URI "{}".'.format(uri) # ) - return cls.bolt_driver(parsed.netloc, auth=auth, **config) + return cls.bolt_driver(parsed.netloc, **config) # else driver_type == DRIVER_NEO4J routing_context = parse_routing_context(parsed.query) - return cls.neo4j_driver(parsed.netloc, auth=auth, + return cls.neo4j_driver(parsed.netloc, routing_context=routing_context, **config) @classmethod @@ -311,7 +318,7 @@ def bookmark_manager( ) @classmethod - def bolt_driver(cls, target, *, auth=None, **config): + def bolt_driver(cls, target, **config): """ Create a driver for direct Bolt server access that uses socket I/O and thread-based concurrency. """ @@ -321,13 +328,13 @@ def bolt_driver(cls, target, *, auth=None, **config): ) try: - return AsyncBoltDriver.open(target, auth=auth, **config) + return AsyncBoltDriver.open(target, **config) except (BoltHandshakeError, BoltSecurityError) as error: from ..exceptions import ServiceUnavailable raise ServiceUnavailable(str(error)) from error @classmethod - def neo4j_driver(cls, *targets, auth=None, routing_context=None, **config): + def neo4j_driver(cls, *targets, routing_context=None, **config): """ Create a driver for routing-capable Neo4j service access that uses socket I/O and thread-based concurrency. """ @@ -337,7 +344,7 @@ def neo4j_driver(cls, *targets, auth=None, routing_context=None, **config): ) try: - return AsyncNeo4jDriver.open(*targets, auth=auth, routing_context=routing_context, **config) + return AsyncNeo4jDriver.open(*targets, routing_context=routing_context, **config) except (BoltHandshakeError, BoltSecurityError) as error: from ..exceptions import ServiceUnavailable raise ServiceUnavailable(str(error)) from error @@ -450,6 +457,7 @@ def session( default_access_mode: str = ..., bookmark_manager: t.Union[AsyncBookmarkManager, BookmarkManager, None] = ..., + auth: t.Union[Auth, t.Tuple[t.Any, t.Any]] = ..., # undocumented/unsupported options # they may be change or removed any time without prior notice @@ -498,6 +506,7 @@ async def verify_connectivity( default_access_mode: str = ..., bookmark_manager: t.Union[AsyncBookmarkManager, BookmarkManager, None] = ..., + auth: t.Union[Auth, t.Tuple[t.Any, t.Any]] = ..., # undocumented/unsupported options initial_retry_delay: float = ..., @@ -562,6 +571,7 @@ async def get_server_info( default_access_mode: str = ..., bookmark_manager: t.Union[AsyncBookmarkManager, BookmarkManager, None] = ..., + auth: t.Union[Auth, t.Tuple[t.Any, t.Any]] = ..., # undocumented/unsupported options initial_retry_delay: float = ..., @@ -640,7 +650,7 @@ class AsyncBoltDriver(_Direct, AsyncDriver): """ @classmethod - def open(cls, target, *, auth=None, **config): + def open(cls, target, **config): """ :param target: :param auth: @@ -652,7 +662,7 @@ def open(cls, target, *, auth=None, **config): from .io import AsyncBoltPool address = cls.parse_target(target) pool_config, default_workspace_config = Config.consume_chain(config, PoolConfig, WorkspaceConfig) - pool = AsyncBoltPool.open(address, auth=auth, pool_config=pool_config, workspace_config=default_workspace_config) + pool = AsyncBoltPool.open(address, pool_config=pool_config, workspace_config=default_workspace_config) return cls(pool, default_workspace_config) def __init__(self, pool, default_workspace_config): @@ -687,11 +697,11 @@ class AsyncNeo4jDriver(_Routing, AsyncDriver): """ @classmethod - def open(cls, *targets, auth=None, routing_context=None, **config): + def open(cls, *targets, routing_context=None, **config): from .io import AsyncNeo4jPool addresses = cls.parse_targets(*targets) pool_config, default_workspace_config = Config.consume_chain(config, PoolConfig, WorkspaceConfig) - pool = AsyncNeo4jPool.open(*addresses, auth=auth, routing_context=routing_context, pool_config=pool_config, workspace_config=default_workspace_config) + pool = AsyncNeo4jPool.open(*addresses, routing_context=routing_context, pool_config=pool_config, workspace_config=default_workspace_config) return cls(pool, default_workspace_config) def __init__(self, pool, default_workspace_config): diff --git a/src/neo4j/_async/io/_bolt.py b/src/neo4j/_async/io/_bolt.py index 8865e09c..50c5e7fe 100644 --- a/src/neo4j/_async/io/_bolt.py +++ b/src/neo4j/_async/io/_bolt.py @@ -40,6 +40,7 @@ ) from ...exceptions import ( AuthError, + ConfigurationError, DriverError, IncompleteCommit, ServiceUnavailable, @@ -134,17 +135,7 @@ def __init__(self, unresolved_address, sock, max_connection_lifetime, *, else: self.user_agent = get_user_agent() - # Determine auth details - if not auth: - self.auth_dict = {} - elif isinstance(auth, tuple) and 2 <= len(auth) <= 3: - from ...api import Auth - self.auth_dict = vars(Auth("basic", *auth)) - else: - try: - self.auth_dict = vars(auth) - except (KeyError, TypeError): - raise AuthError("Cannot determine auth details from %r" % auth) + self.auth_dict = self._to_auth_dict(auth) # Check for missing password try: @@ -159,6 +150,20 @@ def __del__(self): if not asyncio.iscoroutinefunction(self.close): self.close() + @classmethod + def _to_auth_dict(cls, auth): + # Determine auth details + if not auth: + return {} + elif isinstance(auth, tuple) and 2 <= len(auth) <= 3: + from ...api import Auth + return vars(Auth("basic", *auth)) + else: + try: + return vars(auth) + except (KeyError, TypeError): + raise AuthError("Cannot determine auth details from %r" % auth) + @property def connection_id(self): return self.server_info._metadata.get("connection_id", "") @@ -180,6 +185,20 @@ def supports_multiple_databases(self): """ pass + @property + @abc.abstractmethod + def supports_re_auth(self): + """TODO""" + pass + + def assert_re_auth_support(self): + if not self.supports_re_auth: + raise ConfigurationError( + "Session level authentication is not supported for Bolt " + f"Protocol {self.PROTOCOL_VERSION!r}. Server Agent " + f"{self.server_info.agent!r}" + ) + @classmethod def protocol_handlers(cls, protocol_version=None): """ Return a dictionary of available Bolt protocol handlers, @@ -203,7 +222,10 @@ def protocol_handlers(cls, protocol_version=None): AsyncBolt4x3, AsyncBolt4x4, ) - from ._bolt5 import AsyncBolt5x0 + from ._bolt5 import ( + AsyncBolt5x0, + AsyncBolt5x1, + ) handlers = { AsyncBolt3.PROTOCOL_VERSION: AsyncBolt3, @@ -213,6 +235,7 @@ def protocol_handlers(cls, protocol_version=None): AsyncBolt4x3.PROTOCOL_VERSION: AsyncBolt4x3, AsyncBolt4x4.PROTOCOL_VERSION: AsyncBolt4x4, AsyncBolt5x0.PROTOCOL_VERSION: AsyncBolt5x0, + AsyncBolt5x1.PROTOCOL_VERSION: AsyncBolt5x1, } if protocol_version is None: @@ -353,6 +376,9 @@ def time_remaining(): elif pool_config.protocol_version == (5, 0): from ._bolt5 import AsyncBolt5x0 bolt_cls = AsyncBolt5x0 + elif pool_config.protocol_version == (5, 1): + from ._bolt5 import AsyncBolt5x1 + bolt_cls = AsyncBolt5x1 else: log.debug("[#%04X] S: ", s.getsockname()[1]) await AsyncBoltSocket.close_socket(s) @@ -404,6 +430,38 @@ async def hello(self, dehydration_hooks=None, hydration_hooks=None): """ pass + @abc.abstractmethod + def logon(self, dehydration_hooks=None, hydration_hooks=None): + """Append a LOGON message to the outgoing queue.""" + pass + + @abc.abstractmethod + def logoff(self, dehydration_hooks=None, hydration_hooks=None): + """Append a LOGOFF message to the outgoing queue.""" + pass + + async def re_auth( + self, auth, dehydration_hooks=None, hydration_hooks=None + ): + """Append LOGON, LOGOFF to the outgoing queue, flush, then receive. + + If auth is the same as the current auth, this method does nothing. + + :returns: whether the auth was changed + """ + new_auth_dict = self._to_auth_dict(auth) + if new_auth_dict == self.auth_dict: + return False + self.logoff(dehydration_hooks=dehydration_hooks, + hydration_hooks=hydration_hooks) + self.auth_dict = new_auth_dict + self.logon(dehydration_hooks=dehydration_hooks, + hydration_hooks=hydration_hooks) + await self.send_all() + await self.fetch_all() + return True + + @abc.abstractmethod async def route( self, database=None, imp_user=None, bookmarks=None, diff --git a/src/neo4j/_async/io/_bolt3.py b/src/neo4j/_async/io/_bolt3.py index 2f81f526..76f856e3 100644 --- a/src/neo4j/_async/io/_bolt3.py +++ b/src/neo4j/_async/io/_bolt3.py @@ -16,6 +16,9 @@ # limitations under the License. +from __future__ import annotations + +import typing as t from enum import Enum from logging import getLogger from ssl import SSLSocket @@ -54,7 +57,7 @@ class ServerStates(Enum): class ServerStateManager: - _STATE_TRANSITIONS = { + _STATE_TRANSITIONS: t.Dict[Enum, t.Dict[str, Enum]] = { ServerStates.CONNECTED: { "hello": ServerStates.READY, }, @@ -104,6 +107,8 @@ class AsyncBolt3(AsyncBolt): supports_multiple_databases = False + supports_re_auth = False + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._server_state_manager = ServerStateManager( @@ -152,6 +157,14 @@ async def hello(self, dehydration_hooks=None, hydration_hooks=None): await self.fetch_all() check_supported_server_product(self.server_info.agent) + def logon(self, dehydration_hooks=None, hydration_hooks=None): + """Append a LOGON message to the outgoing queue.""" + self.assert_re_auth_support() + + def logoff(self, dehydration_hooks=None, hydration_hooks=None): + """Append a LOGOFF message to the outgoing queue.""" + self.assert_re_auth_support() + async def route( self, database=None, imp_user=None, bookmarks=None, dehydration_hooks=None, hydration_hooks=None diff --git a/src/neo4j/_async/io/_bolt4.py b/src/neo4j/_async/io/_bolt4.py index 70c3d6ad..2503db71 100644 --- a/src/neo4j/_async/io/_bolt4.py +++ b/src/neo4j/_async/io/_bolt4.py @@ -62,6 +62,8 @@ class AsyncBolt4x0(AsyncBolt): supports_multiple_databases = True + supports_re_auth = False + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._server_state_manager = ServerStateManager( @@ -112,6 +114,14 @@ async def hello(self, dehydration_hooks=None, hydration_hooks=None): await self.fetch_all() check_supported_server_product(self.server_info.agent) + def logon(self, dehydration_hooks=None, hydration_hooks=None): + """Append a LOGON message to the outgoing queue.""" + self.assert_re_auth_support() + + def logoff(self, dehydration_hooks=None, hydration_hooks=None): + """Append a LOGOFF message to the outgoing queue.""" + self.assert_re_auth_support() + async def route( self, database=None, imp_user=None, bookmarks=None, dehydration_hooks=None, hydration_hooks=None diff --git a/src/neo4j/_async/io/_bolt5.py b/src/neo4j/_async/io/_bolt5.py index 24953bb5..a18ee45c 100644 --- a/src/neo4j/_async/io/_bolt5.py +++ b/src/neo4j/_async/io/_bolt5.py @@ -16,6 +16,8 @@ # limitations under the License. +import typing as t +from enum import Enum from logging import getLogger from ssl import SSLSocket @@ -41,6 +43,7 @@ check_supported_server_product, CommitResponse, InitResponse, + LogonResponse, Response, ) @@ -49,7 +52,7 @@ class AsyncBolt5x0(AsyncBolt): - """Protocol handler for Bolt 5.0. """ + """Protocol handler for Bolt 5.0.""" PROTOCOL_VERSION = Version(5, 0) @@ -59,10 +62,15 @@ class AsyncBolt5x0(AsyncBolt): supports_multiple_databases = True + supports_re_auth = False + + server_states: t.Any = ServerStates + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._server_state_manager = ServerStateManager( - ServerStates.CONNECTED, on_change=self._on_server_state_change + self.server_states.CONNECTED, + on_change=self._on_server_state_change ) def _on_server_state_change(self, old_state, new_state): @@ -77,7 +85,7 @@ def is_reset(self): if (self.responses and self.responses[-1] and self.responses[-1].message == "reset"): return True - return self._server_state_manager.state == ServerStates.READY + return self._server_state_manager.state == self.server_states.READY @property def encrypted(self): @@ -124,6 +132,14 @@ def on_success(metadata): await self.fetch_all() check_supported_server_product(self.server_info.agent) + def logon(self, dehydration_hooks=None, hydration_hooks=None): + """Append a LOGON message to the outgoing queue.""" + self.assert_re_auth_support() + + def logoff(self, dehydration_hooks=None, hydration_hooks=None): + """Append a LOGOFF message to the outgoing queue.""" + self.assert_re_auth_support() + async def route(self, database=None, imp_user=None, bookmarks=None, dehydration_hooks=None, hydration_hooks=None): routing_context = self.routing_context or {} @@ -312,7 +328,7 @@ async def _process_message(self, tag, fields): elif summary_signature == b"\x7F": log.debug("[#%04X] S: FAILURE %r", self.local_port, summary_metadata) - self._server_state_manager.state = ServerStates.FAILED + self._server_state_manager.state = self.server_states.FAILED try: await response.on_failure(summary_metadata or {}) except (ServiceUnavailable, DatabaseUnavailable): @@ -335,3 +351,102 @@ async def _process_message(self, tag, fields): ) return len(details), 1 + + +class ServerStates5x1(Enum): + CONNECTED = "CONNECTED" + READY = "READY" + STREAMING = "STREAMING" + TX_READY_OR_TX_STREAMING = "TX_READY||TX_STREAMING" + FAILED = "FAILED" + AUTHENTICATION = "AUTHENTICATION" + + +class ServerStateManager5x1(ServerStateManager): + _STATE_TRANSITIONS = { # type: ignore + ServerStates5x1.CONNECTED: { + "hello": ServerStates5x1.AUTHENTICATION, + }, + ServerStates5x1.AUTHENTICATION: { + "logon": ServerStates5x1.READY, + }, + ServerStates5x1.READY: { + "run": ServerStates5x1.STREAMING, + "begin": ServerStates5x1.TX_READY_OR_TX_STREAMING, + "logoff": ServerStates5x1.AUTHENTICATION, + }, + ServerStates5x1.STREAMING: { + "pull": ServerStates5x1.READY, + "discard": ServerStates5x1.READY, + "reset": ServerStates5x1.READY, + }, + ServerStates5x1.TX_READY_OR_TX_STREAMING: { + "commit": ServerStates5x1.READY, + "rollback": ServerStates5x1.READY, + "reset": ServerStates5x1.READY, + }, + ServerStates5x1.FAILED: { + "reset": ServerStates5x1.READY, + } + } + + +class AsyncBolt5x1(AsyncBolt5x0): + """Protocol handler for Bolt 5.1.""" + + PROTOCOL_VERSION = Version(5, 1) + + supports_re_auth = True + + server_states = ServerStates5x1 + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._server_state_manager = ServerStateManager5x1( + ServerStates5x1.CONNECTED, on_change=self._on_server_state_change + ) + + async def hello(self, dehydration_hooks=None, hydration_hooks=None): + def on_success(metadata): + self.configuration_hints.update(metadata.pop("hints", {})) + self.server_info.update(metadata) + if "connection.recv_timeout_seconds" in self.configuration_hints: + recv_timeout = self.configuration_hints[ + "connection.recv_timeout_seconds" + ] + if isinstance(recv_timeout, int) and recv_timeout > 0: + self.socket.settimeout(recv_timeout) + else: + log.info("[#%04X] _: Server supplied an " + "invalid value for " + "connection.recv_timeout_seconds (%r). Make sure " + "the server and network is set up correctly.", + self.local_port, recv_timeout) + + headers = self.get_base_headers() + logged_headers = dict(headers) + log.debug("[#%04X] C: HELLO %r", self.local_port, logged_headers) + self._append(b"\x01", (headers,), + response=InitResponse(self, "hello", hydration_hooks, + on_success=on_success), + dehydration_hooks=dehydration_hooks) + self.logon(dehydration_hooks=dehydration_hooks, + hydration_hooks=hydration_hooks) + await self.send_all() + await self.fetch_all() + check_supported_server_product(self.server_info.agent) + + def logon(self, dehydration_hooks=None, hydration_hooks=None): + logged_auth_dict = dict(self.auth_dict) + if "credentials" in logged_auth_dict: + logged_auth_dict["credentials"] = "*******" + log.debug("[#%04X] C: LOGON %r", self.local_port, logged_auth_dict) + self._append(b"\x6A", (self.auth_dict,), + response=LogonResponse(self, "logon", hydration_hooks), + dehydration_hooks=dehydration_hooks) + + def logoff(self, dehydration_hooks=None, hydration_hooks=None): + log.debug("[#%04X] C: LOGOFF", self.local_port) + self._append(b"\x6B", + response=LogonResponse(self, "logoff", hydration_hooks), + dehydration_hooks=dehydration_hooks) diff --git a/src/neo4j/_async/io/_common.py b/src/neo4j/_async/io/_common.py index fa21d3b6..1888f0f5 100644 --- a/src/neo4j/_async/io/_common.py +++ b/src/neo4j/_async/io/_common.py @@ -256,10 +256,10 @@ async def on_ignored(self, metadata=None): class InitResponse(Response): - async def on_failure(self, metadata): code = metadata.get("code") if code == "Neo.ClientError.Security.Unauthorized": + # this branch is only needed as long as we support Bolt 5.0 raise Neo4jError.hydrate(**metadata) else: raise ServiceUnavailable( @@ -267,6 +267,11 @@ async def on_failure(self, metadata): ) +class LogonResponse(Response): + async def on_failure(self, metadata): + raise Neo4jError.hydrate(**metadata) + + class CommitResponse(Response): pass diff --git a/src/neo4j/_async/io/_pool.py b/src/neo4j/_async/io/_pool.py index 8376d561..55927496 100644 --- a/src/neo4j/_async/io/_pool.py +++ b/src/neo4j/_async/io/_pool.py @@ -16,9 +16,12 @@ # limitations under the License. +from __future__ import annotations + import abc import asyncio import logging +import typing as t from collections import ( defaultdict, deque, @@ -28,10 +31,13 @@ from ..._async_compat.concurrency import ( AsyncCondition, + AsyncCooperativeLock, AsyncCooperativeRLock, + AsyncLock, AsyncRLock, ) from ..._async_compat.network import AsyncNetworkUtil +from ..._async_compat.util import AsyncUtil from ..._conf import ( PoolConfig, WorkspaceConfig, @@ -44,6 +50,7 @@ from ..._routing import RoutingTable from ...api import ( READ_ACCESS, + RenewableAuth, WRITE_ACCESS, ) from ...exceptions import ( @@ -79,6 +86,10 @@ def __init__(self, opener, pool_config, workspace_config): self.connections_reservations = defaultdict(lambda: 0) self.lock = AsyncCooperativeRLock() self.cond = AsyncCondition(self.lock) + self.refreshing_auth = False + self.refreshing_auth_lock = AsyncCooperativeLock() + self.initializing_auth_lock = AsyncLock() + self.last_auth: t.Optional[RenewableAuth] = None async def __aenter__(self): return self @@ -86,6 +97,44 @@ async def __aenter__(self): async def __aexit__(self, exc_type, exc_value, traceback): await self.close() + async def get_auth(self): + if await self._initialize_auth(): + return self.last_auth.auth + + with self.refreshing_auth_lock: + if self.refreshing_auth: + # someone else is already getting a new auth + return self.last_auth + if self.last_auth is None or self.last_auth.expired: + with self.refreshing_auth_lock: + self.refreshing_auth = True + try: + new_auth = await self._get_new_auth() + with self.refreshing_auth_lock: + self.last_auth = new_auth + self.refreshing_auth = False + except: + with self.refreshing_auth_lock: + self.refreshing_auth = False + raise + return self.last_auth.auth + + async def _initialize_auth(self): + if self.last_auth is not None: + return False + async with self.initializing_auth_lock: + if self.last_auth is not None: + # someone else initialized the auth + return True + self.last_auth = await self._get_new_auth() + return True + + async def _get_new_auth(self): + new_auth = await AsyncUtil.callback(self.pool_config.auth) + if not isinstance(self.last_auth, RenewableAuth): + return RenewableAuth(new_auth) + return new_auth + async def _acquire_from_pool(self, address): with self.lock: for connection in list(self.connections.get(address, [])): @@ -130,17 +179,27 @@ async def _acquire_from_pool_checked( else: return connection - def _acquire_new_later(self, address, deadline): + def _acquire_new_later(self, address, auth, deadline): async def connection_creator(): released_reservation = False try: try: connection = await self.opener( - address, deadline.to_timeout() + address, auth or await self.get_auth(), + deadline.to_timeout() ) except ServiceUnavailable: await self.deactivate(address) raise + if auth: + # It's unfortunate that we have to create a connection + # first to determine if the session-level auth is supported + # by the protocol or not. + try: + connection.assert_re_auth_support() + except ConfigurationError: + await connection.close() + raise connection.pool = self connection.in_use = True with self.lock: @@ -166,7 +225,7 @@ async def connection_creator(): return connection_creator return None - async def _acquire(self, address, deadline, liveness_check_timeout): + async def _acquire(self, address, auth, deadline, liveness_check_timeout): """ Acquire a connection to a given address from the pool. The address supplied should always be an IP address, not a host name. @@ -178,6 +237,19 @@ async def health_check(connection_, deadline_): or connection_.defunct() or connection_.stale()): return False + try: + if await connection_.re_auth(auth or await self.get_auth()): + # no need for an extra liveness check if the connection was + # successfully re-authenticated + return True + except ConfigurationError: + # protocol does not support re-authentication + if auth: + # session-level not supported + raise + # expiring tokens supported by flushing the pool + # => give up this connection + return False if liveness_check_timeout is not None: if connection_.is_idle_for(liveness_check_timeout): with connection_deadline(connection_, deadline_): @@ -201,7 +273,9 @@ async def health_check(connection_, deadline_): return connection # all connections in pool are in-use with self.lock: - connection_creator = self._acquire_new_later(address, deadline) + connection_creator = self._acquire_new_later( + address, auth, deadline + ) if connection_creator: break @@ -222,7 +296,8 @@ async def health_check(connection_, deadline_): @abc.abstractmethod async def acquire( - self, access_mode, timeout, database, bookmarks, liveness_check_timeout + self, access_mode, timeout, database, bookmarks, auth, + liveness_check_timeout ): """ Acquire a connection to a server that can satisfy a set of parameters. @@ -231,6 +306,7 @@ async def acquire( (excluding potential preparation like fetching routing tables). :param database: :param bookmarks: + :param auth: :param liveness_check_timeout: """ ... @@ -372,17 +448,16 @@ async def close(self): class AsyncBoltPool(AsyncIOPool): @classmethod - def open(cls, address, *, auth, pool_config, workspace_config): + def open(cls, address, *, pool_config, workspace_config): """Create a new BoltPool :param address: - :param auth: :param pool_config: :param workspace_config: :returns: BoltPool """ - async def opener(addr, timeout): + async def opener(addr, auth, timeout): return await AsyncBolt.open( addr, auth=auth, timeout=timeout, routing_context=None, pool_config=pool_config @@ -401,7 +476,8 @@ def __repr__(self): self.address) async def acquire( - self, access_mode, timeout, database, bookmarks, liveness_check_timeout + self, access_mode, timeout, database, bookmarks, auth, + liveness_check_timeout ): # The access_mode and database is not needed for a direct connection, # it's just there for consistency. @@ -409,7 +485,7 @@ async def acquire( "access_mode=%r, database=%r", access_mode, database) deadline = Deadline.from_timeout_or_deadline(timeout) return await self._acquire( - self.address, deadline, liveness_check_timeout + self.address, auth, deadline, liveness_check_timeout ) @@ -418,12 +494,11 @@ class AsyncNeo4jPool(AsyncIOPool): """ @classmethod - def open(cls, *addresses, auth, pool_config, workspace_config, + def open(cls, *addresses, pool_config, workspace_config, routing_context=None): """Create a new Neo4jPool :param addresses: one or more address as positional argument - :param auth: :param pool_config: :param workspace_config: :param routing_context: @@ -437,7 +512,7 @@ def open(cls, *addresses, auth, pool_config, workspace_config, raise ConfigurationError("The key 'address' is reserved for routing context.") routing_context["address"] = str(address) - async def opener(addr, timeout): + async def opener(addr, auth, timeout): return await AsyncBolt.open( addr, auth=auth, timeout=timeout, routing_context=routing_context, pool_config=pool_config @@ -480,7 +555,7 @@ async def get_or_create_routing_table(self, database): return self.routing_tables[database] async def fetch_routing_info( - self, address, database, imp_user, bookmarks, acquisition_timeout + self, address, database, imp_user, bookmarks, auth, acquisition_timeout ): """ Fetch raw routing info from a given router address. @@ -491,6 +566,7 @@ async def fetch_routing_info( :type imp_user: str or None :param bookmarks: iterable of bookmark values after which the routing info should be fetched + :param auth: auth :param acquisition_timeout: connection acquisition timeout :returns: list of routing records, or None if no connection @@ -501,7 +577,7 @@ async def fetch_routing_info( deadline = Deadline.from_timeout_or_deadline(acquisition_timeout) log.debug("[#0000] _: _acquire router connection, " "database=%r, address=%r", database, address) - cx = await self._acquire(address, deadline, None) + cx = await self._acquire(address, auth, deadline, None) try: routing_table = await cx.route( database=database or self.workspace_config.database, @@ -513,7 +589,8 @@ async def fetch_routing_info( return routing_table async def fetch_routing_table( - self, *, address, acquisition_timeout, database, imp_user, bookmarks + self, *, address, acquisition_timeout, database, imp_user, + bookmarks, auth ): """ Fetch a routing table from a given router address. @@ -525,6 +602,7 @@ async def fetch_routing_table( table :type imp_user: str or None :param bookmarks: bookmarks used when fetching routing table + :param auth: auth :returns: a new RoutingTable instance or None if the given router is currently unable to provide routing information @@ -532,7 +610,8 @@ async def fetch_routing_table( new_routing_info = None try: new_routing_info = await self.fetch_routing_info( - address, database, imp_user, bookmarks, acquisition_timeout + address, database, imp_user, bookmarks, auth, + acquisition_timeout ) except Neo4jError as e: # checks if the code is an error that is caused by the client. In @@ -578,8 +657,8 @@ async def fetch_routing_table( return new_routing_table async def _update_routing_table_from( - self, *routers, database, imp_user, bookmarks, acquisition_timeout, - database_callback + self, *routers, database, imp_user, bookmarks, auth, + acquisition_timeout, database_callback ): """ Try to update routing tables with the given routers. @@ -595,7 +674,8 @@ async def _update_routing_table_from( ): new_routing_table = await self.fetch_routing_table( address=address, acquisition_timeout=acquisition_timeout, - database=database, imp_user=imp_user, bookmarks=bookmarks + database=database, imp_user=imp_user, bookmarks=bookmarks, + auth=auth ) if new_routing_table is not None: new_database = new_routing_table.database @@ -615,8 +695,8 @@ async def _update_routing_table_from( return False async def update_routing_table( - self, *, database, imp_user, bookmarks, acquisition_timeout=None, - database_callback=None + self, *, database, imp_user, bookmarks, auth=None, + acquisition_timeout=None, database_callback=None ): """ Update the routing table from the first router able to provide valid routing information. @@ -626,6 +706,7 @@ async def update_routing_table( table :type imp_user: str or None :param bookmarks: bookmarks used when fetching routing table + :param auth: auth :param acquisition_timeout: connection acquisition timeout :param database_callback: A callback function that will be called with the database name as only argument when a new routing table has been @@ -647,15 +728,15 @@ async def update_routing_table( # TODO: Test this state if await self._update_routing_table_from( self.address, database=database, - imp_user=imp_user, bookmarks=bookmarks, + imp_user=imp_user, bookmarks=bookmarks, auth=auth, acquisition_timeout=acquisition_timeout, database_callback=database_callback ): # Why is only the first initial routing address used? return if await self._update_routing_table_from( - *(existing_routers - {self.address}), - database=database, imp_user=imp_user, bookmarks=bookmarks, + *(existing_routers - {self.address}), database=database, + imp_user=imp_user, bookmarks=bookmarks, auth=auth, acquisition_timeout=acquisition_timeout, database_callback=database_callback ): @@ -664,7 +745,7 @@ async def update_routing_table( if not prefer_initial_routing_address: if await self._update_routing_table_from( self.address, database=database, - imp_user=imp_user, bookmarks=bookmarks, + imp_user=imp_user, bookmarks=bookmarks, auth=auth, acquisition_timeout=acquisition_timeout, database_callback=database_callback ): @@ -683,7 +764,7 @@ async def update_connection_pool(self, *, database): await super(AsyncNeo4jPool, self).deactivate(address) async def ensure_routing_table_is_fresh( - self, *, access_mode, database, imp_user, bookmarks, + self, *, access_mode, database, imp_user, bookmarks, auth=None, acquisition_timeout=None, database_callback=None ): """ Update the routing table if stale. @@ -720,7 +801,7 @@ async def ensure_routing_table_is_fresh( await self.update_routing_table( database=database, imp_user=imp_user, bookmarks=bookmarks, - acquisition_timeout=acquisition_timeout, + auth=auth, acquisition_timeout=acquisition_timeout, database_callback=database_callback ) await self.update_connection_pool(database=database) @@ -757,7 +838,8 @@ async def _select_address(self, *, access_mode, database): return choice(addresses_by_usage[min(addresses_by_usage)]) async def acquire( - self, access_mode, timeout, database, bookmarks, liveness_check_timeout + self, access_mode, timeout, database, bookmarks, auth, + liveness_check_timeout ): if access_mode not in (WRITE_ACCESS, READ_ACCESS): raise ClientError("Non valid 'access_mode'; {}".format(access_mode)) @@ -796,7 +878,7 @@ async def acquire( deadline = Deadline.from_timeout_or_deadline(timeout) # should always be a resolved address connection = await self._acquire( - address, deadline, liveness_check_timeout + address, auth, deadline, liveness_check_timeout ) except (ServiceUnavailable, SessionExpired): await self.deactivate(address=address) diff --git a/src/neo4j/_async/work/session.py b/src/neo4j/_async/work/session.py index e7e8820b..a0d3ceb9 100644 --- a/src/neo4j/_async/work/session.py +++ b/src/neo4j/_async/work/session.py @@ -98,6 +98,7 @@ class AsyncSession(AsyncWorkspace): def __init__(self, pool, session_config): assert isinstance(session_config, SessionConfig) super().__init__(pool, session_config) + self._config = session_config self._initialize_bookmarks(session_config.bookmarks) self._bookmark_manager = session_config.bookmark_manager @@ -117,7 +118,8 @@ async def _connect(self, access_mode, **access_kwargs): if access_mode is None: access_mode = self._config.default_access_mode try: - await super()._connect(access_mode, **access_kwargs) + await super()._connect(access_mode, auth=self._config.auth, + **access_kwargs) except asyncio.CancelledError: self._handle_cancellation(message="_connect") raise diff --git a/src/neo4j/_async/work/workspace.py b/src/neo4j/_async/work/workspace.py index 21a18d5b..75dc1d76 100644 --- a/src/neo4j/_async/work/workspace.py +++ b/src/neo4j/_async/work/workspace.py @@ -130,7 +130,7 @@ async def _update_bookmark(self, bookmark): return await self._update_bookmarks((bookmark,)) - async def _connect(self, access_mode, **acquire_kwargs): + async def _connect(self, access_mode, auth=None, **acquire_kwargs): acquisition_timeout = self._config.connection_acquisition_timeout if self._connection: # TODO: Investigate this @@ -154,6 +154,7 @@ async def _connect(self, access_mode, **acquire_kwargs): database=self._config.database, imp_user=self._config.impersonated_user, bookmarks=await self._get_bookmarks(), + auth=auth, acquisition_timeout=acquisition_timeout, database_callback=self._set_cached_database ) @@ -162,6 +163,7 @@ async def _connect(self, access_mode, **acquire_kwargs): "timeout": acquisition_timeout, "database": self._config.database, "bookmarks": await self._get_bookmarks(), + "auth": auth, "liveness_check_timeout": None, } acquire_kwargs_.update(acquire_kwargs) diff --git a/src/neo4j/_conf.py b/src/neo4j/_conf.py index 25e45944..9707334c 100644 --- a/src/neo4j/_conf.py +++ b/src/neo4j/_conf.py @@ -405,6 +405,9 @@ class PoolConfig(Config): keep_alive = True # Specify whether TCP keep-alive should be enabled. + #: Authentication provider + auth = None + def get_ssl_context(self): if self.ssl_context is not None: return self.ssl_context @@ -503,6 +506,9 @@ class SessionConfig(WorkspaceConfig): #: Default AccessMode default_access_mode = WRITE_ACCESS + #: Auth token to temporarily switch the user + auth = None + class TransactionConfig(Config): """ Transaction configuration. This is internal for now. diff --git a/src/neo4j/_sync/driver.py b/src/neo4j/_sync/driver.py index 9b64dbef..31cf7ff9 100644 --- a/src/neo4j/_sync/driver.py +++ b/src/neo4j/_sync/driver.py @@ -44,6 +44,7 @@ ) from ..addressing import Address from ..api import ( + _TAuthTokenProvider, Auth, BookmarkManager, Bookmarks, @@ -83,7 +84,9 @@ def driver( cls, uri: str, *, - auth: t.Union[t.Tuple[t.Any, t.Any], Auth, None] = ..., + auth: t.Union[ + t.Tuple[t.Any, t.Any], Auth, _TAuthTokenProvider, None + ] = ..., max_connection_lifetime: float = ..., max_connection_pool_size: int = ..., connection_timeout: float = ..., @@ -121,7 +124,9 @@ def driver( @classmethod def driver( cls, uri: str, *, - auth: t.Union[t.Tuple[t.Any, t.Any], Auth, None] = None, + auth: t.Union[ + t.Tuple[t.Any, t.Any], Auth, _TAuthTokenProvider, None + ] = None, **config ) -> Driver: """Create a driver. @@ -137,6 +142,10 @@ def driver( driver_type, security_type, parsed = parse_neo4j_uri(uri) + if not callable(config.get("auth")): + auth = config.get("auth") + config["auth"] = lambda: auth + # TODO: 6.0 - remove "trust" config option if "trust" in config.keys(): if config["trust"] not in ( @@ -213,10 +222,10 @@ def driver( # 'Routing parameters are not supported with scheme ' # '"bolt". Given URI "{}".'.format(uri) # ) - return cls.bolt_driver(parsed.netloc, auth=auth, **config) + return cls.bolt_driver(parsed.netloc, **config) # else driver_type == DRIVER_NEO4J routing_context = parse_routing_context(parsed.query) - return cls.neo4j_driver(parsed.netloc, auth=auth, + return cls.neo4j_driver(parsed.netloc, routing_context=routing_context, **config) @classmethod @@ -308,7 +317,7 @@ def bookmark_manager( ) @classmethod - def bolt_driver(cls, target, *, auth=None, **config): + def bolt_driver(cls, target, **config): """ Create a driver for direct Bolt server access that uses socket I/O and thread-based concurrency. """ @@ -318,13 +327,13 @@ def bolt_driver(cls, target, *, auth=None, **config): ) try: - return BoltDriver.open(target, auth=auth, **config) + return BoltDriver.open(target, **config) except (BoltHandshakeError, BoltSecurityError) as error: from ..exceptions import ServiceUnavailable raise ServiceUnavailable(str(error)) from error @classmethod - def neo4j_driver(cls, *targets, auth=None, routing_context=None, **config): + def neo4j_driver(cls, *targets, routing_context=None, **config): """ Create a driver for routing-capable Neo4j service access that uses socket I/O and thread-based concurrency. """ @@ -334,7 +343,7 @@ def neo4j_driver(cls, *targets, auth=None, routing_context=None, **config): ) try: - return Neo4jDriver.open(*targets, auth=auth, routing_context=routing_context, **config) + return Neo4jDriver.open(*targets, routing_context=routing_context, **config) except (BoltHandshakeError, BoltSecurityError) as error: from ..exceptions import ServiceUnavailable raise ServiceUnavailable(str(error)) from error @@ -447,6 +456,7 @@ def session( default_access_mode: str = ..., bookmark_manager: t.Union[BookmarkManager, BookmarkManager, None] = ..., + auth: t.Union[Auth, t.Tuple[t.Any, t.Any]] = ..., # undocumented/unsupported options # they may be change or removed any time without prior notice @@ -495,6 +505,7 @@ def verify_connectivity( default_access_mode: str = ..., bookmark_manager: t.Union[BookmarkManager, BookmarkManager, None] = ..., + auth: t.Union[Auth, t.Tuple[t.Any, t.Any]] = ..., # undocumented/unsupported options initial_retry_delay: float = ..., @@ -559,6 +570,7 @@ def get_server_info( default_access_mode: str = ..., bookmark_manager: t.Union[BookmarkManager, BookmarkManager, None] = ..., + auth: t.Union[Auth, t.Tuple[t.Any, t.Any]] = ..., # undocumented/unsupported options initial_retry_delay: float = ..., @@ -637,7 +649,7 @@ class BoltDriver(_Direct, Driver): """ @classmethod - def open(cls, target, *, auth=None, **config): + def open(cls, target, **config): """ :param target: :param auth: @@ -649,7 +661,7 @@ def open(cls, target, *, auth=None, **config): from .io import BoltPool address = cls.parse_target(target) pool_config, default_workspace_config = Config.consume_chain(config, PoolConfig, WorkspaceConfig) - pool = BoltPool.open(address, auth=auth, pool_config=pool_config, workspace_config=default_workspace_config) + pool = BoltPool.open(address, pool_config=pool_config, workspace_config=default_workspace_config) return cls(pool, default_workspace_config) def __init__(self, pool, default_workspace_config): @@ -684,11 +696,11 @@ class Neo4jDriver(_Routing, Driver): """ @classmethod - def open(cls, *targets, auth=None, routing_context=None, **config): + def open(cls, *targets, routing_context=None, **config): from .io import Neo4jPool addresses = cls.parse_targets(*targets) pool_config, default_workspace_config = Config.consume_chain(config, PoolConfig, WorkspaceConfig) - pool = Neo4jPool.open(*addresses, auth=auth, routing_context=routing_context, pool_config=pool_config, workspace_config=default_workspace_config) + pool = Neo4jPool.open(*addresses, routing_context=routing_context, pool_config=pool_config, workspace_config=default_workspace_config) return cls(pool, default_workspace_config) def __init__(self, pool, default_workspace_config): diff --git a/src/neo4j/_sync/io/_bolt.py b/src/neo4j/_sync/io/_bolt.py index 9980b8c7..2cc6faa0 100644 --- a/src/neo4j/_sync/io/_bolt.py +++ b/src/neo4j/_sync/io/_bolt.py @@ -40,6 +40,7 @@ ) from ...exceptions import ( AuthError, + ConfigurationError, DriverError, IncompleteCommit, ServiceUnavailable, @@ -134,17 +135,7 @@ def __init__(self, unresolved_address, sock, max_connection_lifetime, *, else: self.user_agent = get_user_agent() - # Determine auth details - if not auth: - self.auth_dict = {} - elif isinstance(auth, tuple) and 2 <= len(auth) <= 3: - from ...api import Auth - self.auth_dict = vars(Auth("basic", *auth)) - else: - try: - self.auth_dict = vars(auth) - except (KeyError, TypeError): - raise AuthError("Cannot determine auth details from %r" % auth) + self.auth_dict = self._to_auth_dict(auth) # Check for missing password try: @@ -159,6 +150,20 @@ def __del__(self): if not asyncio.iscoroutinefunction(self.close): self.close() + @classmethod + def _to_auth_dict(cls, auth): + # Determine auth details + if not auth: + return {} + elif isinstance(auth, tuple) and 2 <= len(auth) <= 3: + from ...api import Auth + return vars(Auth("basic", *auth)) + else: + try: + return vars(auth) + except (KeyError, TypeError): + raise AuthError("Cannot determine auth details from %r" % auth) + @property def connection_id(self): return self.server_info._metadata.get("connection_id", "") @@ -180,6 +185,20 @@ def supports_multiple_databases(self): """ pass + @property + @abc.abstractmethod + def supports_re_auth(self): + """TODO""" + pass + + def assert_re_auth_support(self): + if not self.supports_re_auth: + raise ConfigurationError( + "Session level authentication is not supported for Bolt " + f"Protocol {self.PROTOCOL_VERSION!r}. Server Agent " + f"{self.server_info.agent!r}" + ) + @classmethod def protocol_handlers(cls, protocol_version=None): """ Return a dictionary of available Bolt protocol handlers, @@ -203,7 +222,10 @@ def protocol_handlers(cls, protocol_version=None): Bolt4x3, Bolt4x4, ) - from ._bolt5 import Bolt5x0 + from ._bolt5 import ( + Bolt5x0, + Bolt5x1, + ) handlers = { Bolt3.PROTOCOL_VERSION: Bolt3, @@ -213,6 +235,7 @@ def protocol_handlers(cls, protocol_version=None): Bolt4x3.PROTOCOL_VERSION: Bolt4x3, Bolt4x4.PROTOCOL_VERSION: Bolt4x4, Bolt5x0.PROTOCOL_VERSION: Bolt5x0, + Bolt5x1.PROTOCOL_VERSION: Bolt5x1, } if protocol_version is None: @@ -353,6 +376,9 @@ def time_remaining(): elif pool_config.protocol_version == (5, 0): from ._bolt5 import Bolt5x0 bolt_cls = Bolt5x0 + elif pool_config.protocol_version == (5, 1): + from ._bolt5 import Bolt5x1 + bolt_cls = Bolt5x1 else: log.debug("[#%04X] S: ", s.getsockname()[1]) BoltSocket.close_socket(s) @@ -404,6 +430,38 @@ def hello(self, dehydration_hooks=None, hydration_hooks=None): """ pass + @abc.abstractmethod + def logon(self, dehydration_hooks=None, hydration_hooks=None): + """Append a LOGON message to the outgoing queue.""" + pass + + @abc.abstractmethod + def logoff(self, dehydration_hooks=None, hydration_hooks=None): + """Append a LOGOFF message to the outgoing queue.""" + pass + + def re_auth( + self, auth, dehydration_hooks=None, hydration_hooks=None + ): + """Append LOGON, LOGOFF to the outgoing queue, flush, then receive. + + If auth is the same as the current auth, this method does nothing. + + :returns: whether the auth was changed + """ + new_auth_dict = self._to_auth_dict(auth) + if new_auth_dict == self.auth_dict: + return False + self.logoff(dehydration_hooks=dehydration_hooks, + hydration_hooks=hydration_hooks) + self.auth_dict = new_auth_dict + self.logon(dehydration_hooks=dehydration_hooks, + hydration_hooks=hydration_hooks) + self.send_all() + self.fetch_all() + return True + + @abc.abstractmethod def route( self, database=None, imp_user=None, bookmarks=None, diff --git a/src/neo4j/_sync/io/_bolt3.py b/src/neo4j/_sync/io/_bolt3.py index 249f3ecf..1bfacaf7 100644 --- a/src/neo4j/_sync/io/_bolt3.py +++ b/src/neo4j/_sync/io/_bolt3.py @@ -16,6 +16,9 @@ # limitations under the License. +from __future__ import annotations + +import typing as t from enum import Enum from logging import getLogger from ssl import SSLSocket @@ -54,7 +57,7 @@ class ServerStates(Enum): class ServerStateManager: - _STATE_TRANSITIONS = { + _STATE_TRANSITIONS: t.Dict[Enum, t.Dict[str, Enum]] = { ServerStates.CONNECTED: { "hello": ServerStates.READY, }, @@ -104,6 +107,8 @@ class Bolt3(Bolt): supports_multiple_databases = False + supports_re_auth = False + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._server_state_manager = ServerStateManager( @@ -152,6 +157,14 @@ def hello(self, dehydration_hooks=None, hydration_hooks=None): self.fetch_all() check_supported_server_product(self.server_info.agent) + def logon(self, dehydration_hooks=None, hydration_hooks=None): + """Append a LOGON message to the outgoing queue.""" + self.assert_re_auth_support() + + def logoff(self, dehydration_hooks=None, hydration_hooks=None): + """Append a LOGOFF message to the outgoing queue.""" + self.assert_re_auth_support() + def route( self, database=None, imp_user=None, bookmarks=None, dehydration_hooks=None, hydration_hooks=None diff --git a/src/neo4j/_sync/io/_bolt4.py b/src/neo4j/_sync/io/_bolt4.py index 3ae41138..741fb195 100644 --- a/src/neo4j/_sync/io/_bolt4.py +++ b/src/neo4j/_sync/io/_bolt4.py @@ -62,6 +62,8 @@ class Bolt4x0(Bolt): supports_multiple_databases = True + supports_re_auth = False + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._server_state_manager = ServerStateManager( @@ -112,6 +114,14 @@ def hello(self, dehydration_hooks=None, hydration_hooks=None): self.fetch_all() check_supported_server_product(self.server_info.agent) + def logon(self, dehydration_hooks=None, hydration_hooks=None): + """Append a LOGON message to the outgoing queue.""" + self.assert_re_auth_support() + + def logoff(self, dehydration_hooks=None, hydration_hooks=None): + """Append a LOGOFF message to the outgoing queue.""" + self.assert_re_auth_support() + def route( self, database=None, imp_user=None, bookmarks=None, dehydration_hooks=None, hydration_hooks=None diff --git a/src/neo4j/_sync/io/_bolt5.py b/src/neo4j/_sync/io/_bolt5.py index 1be6de95..e9ba82ad 100644 --- a/src/neo4j/_sync/io/_bolt5.py +++ b/src/neo4j/_sync/io/_bolt5.py @@ -16,6 +16,8 @@ # limitations under the License. +import typing as t +from enum import Enum from logging import getLogger from ssl import SSLSocket @@ -41,6 +43,7 @@ check_supported_server_product, CommitResponse, InitResponse, + LogonResponse, Response, ) @@ -49,7 +52,7 @@ class Bolt5x0(Bolt): - """Protocol handler for Bolt 5.0. """ + """Protocol handler for Bolt 5.0.""" PROTOCOL_VERSION = Version(5, 0) @@ -59,10 +62,15 @@ class Bolt5x0(Bolt): supports_multiple_databases = True + supports_re_auth = False + + server_states: t.Any = ServerStates + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._server_state_manager = ServerStateManager( - ServerStates.CONNECTED, on_change=self._on_server_state_change + self.server_states.CONNECTED, + on_change=self._on_server_state_change ) def _on_server_state_change(self, old_state, new_state): @@ -77,7 +85,7 @@ def is_reset(self): if (self.responses and self.responses[-1] and self.responses[-1].message == "reset"): return True - return self._server_state_manager.state == ServerStates.READY + return self._server_state_manager.state == self.server_states.READY @property def encrypted(self): @@ -124,6 +132,14 @@ def on_success(metadata): self.fetch_all() check_supported_server_product(self.server_info.agent) + def logon(self, dehydration_hooks=None, hydration_hooks=None): + """Append a LOGON message to the outgoing queue.""" + self.assert_re_auth_support() + + def logoff(self, dehydration_hooks=None, hydration_hooks=None): + """Append a LOGOFF message to the outgoing queue.""" + self.assert_re_auth_support() + def route(self, database=None, imp_user=None, bookmarks=None, dehydration_hooks=None, hydration_hooks=None): routing_context = self.routing_context or {} @@ -312,7 +328,7 @@ def _process_message(self, tag, fields): elif summary_signature == b"\x7F": log.debug("[#%04X] S: FAILURE %r", self.local_port, summary_metadata) - self._server_state_manager.state = ServerStates.FAILED + self._server_state_manager.state = self.server_states.FAILED try: response.on_failure(summary_metadata or {}) except (ServiceUnavailable, DatabaseUnavailable): @@ -335,3 +351,102 @@ def _process_message(self, tag, fields): ) return len(details), 1 + + +class ServerStates5x1(Enum): + CONNECTED = "CONNECTED" + READY = "READY" + STREAMING = "STREAMING" + TX_READY_OR_TX_STREAMING = "TX_READY||TX_STREAMING" + FAILED = "FAILED" + AUTHENTICATION = "AUTHENTICATION" + + +class ServerStateManager5x1(ServerStateManager): + _STATE_TRANSITIONS = { # type: ignore + ServerStates5x1.CONNECTED: { + "hello": ServerStates5x1.AUTHENTICATION, + }, + ServerStates5x1.AUTHENTICATION: { + "logon": ServerStates5x1.READY, + }, + ServerStates5x1.READY: { + "run": ServerStates5x1.STREAMING, + "begin": ServerStates5x1.TX_READY_OR_TX_STREAMING, + "logoff": ServerStates5x1.AUTHENTICATION, + }, + ServerStates5x1.STREAMING: { + "pull": ServerStates5x1.READY, + "discard": ServerStates5x1.READY, + "reset": ServerStates5x1.READY, + }, + ServerStates5x1.TX_READY_OR_TX_STREAMING: { + "commit": ServerStates5x1.READY, + "rollback": ServerStates5x1.READY, + "reset": ServerStates5x1.READY, + }, + ServerStates5x1.FAILED: { + "reset": ServerStates5x1.READY, + } + } + + +class Bolt5x1(Bolt5x0): + """Protocol handler for Bolt 5.1.""" + + PROTOCOL_VERSION = Version(5, 1) + + supports_re_auth = True + + server_states = ServerStates5x1 + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._server_state_manager = ServerStateManager5x1( + ServerStates5x1.CONNECTED, on_change=self._on_server_state_change + ) + + def hello(self, dehydration_hooks=None, hydration_hooks=None): + def on_success(metadata): + self.configuration_hints.update(metadata.pop("hints", {})) + self.server_info.update(metadata) + if "connection.recv_timeout_seconds" in self.configuration_hints: + recv_timeout = self.configuration_hints[ + "connection.recv_timeout_seconds" + ] + if isinstance(recv_timeout, int) and recv_timeout > 0: + self.socket.settimeout(recv_timeout) + else: + log.info("[#%04X] _: Server supplied an " + "invalid value for " + "connection.recv_timeout_seconds (%r). Make sure " + "the server and network is set up correctly.", + self.local_port, recv_timeout) + + headers = self.get_base_headers() + logged_headers = dict(headers) + log.debug("[#%04X] C: HELLO %r", self.local_port, logged_headers) + self._append(b"\x01", (headers,), + response=InitResponse(self, "hello", hydration_hooks, + on_success=on_success), + dehydration_hooks=dehydration_hooks) + self.logon(dehydration_hooks=dehydration_hooks, + hydration_hooks=hydration_hooks) + self.send_all() + self.fetch_all() + check_supported_server_product(self.server_info.agent) + + def logon(self, dehydration_hooks=None, hydration_hooks=None): + logged_auth_dict = dict(self.auth_dict) + if "credentials" in logged_auth_dict: + logged_auth_dict["credentials"] = "*******" + log.debug("[#%04X] C: LOGON %r", self.local_port, logged_auth_dict) + self._append(b"\x6A", (self.auth_dict,), + response=LogonResponse(self, "logon", hydration_hooks), + dehydration_hooks=dehydration_hooks) + + def logoff(self, dehydration_hooks=None, hydration_hooks=None): + log.debug("[#%04X] C: LOGOFF", self.local_port) + self._append(b"\x6B", + response=LogonResponse(self, "logoff", hydration_hooks), + dehydration_hooks=dehydration_hooks) diff --git a/src/neo4j/_sync/io/_common.py b/src/neo4j/_sync/io/_common.py index 89681a38..85cff520 100644 --- a/src/neo4j/_sync/io/_common.py +++ b/src/neo4j/_sync/io/_common.py @@ -256,10 +256,10 @@ def on_ignored(self, metadata=None): class InitResponse(Response): - def on_failure(self, metadata): code = metadata.get("code") if code == "Neo.ClientError.Security.Unauthorized": + # this branch is only needed as long as we support Bolt 5.0 raise Neo4jError.hydrate(**metadata) else: raise ServiceUnavailable( @@ -267,6 +267,11 @@ def on_failure(self, metadata): ) +class LogonResponse(Response): + def on_failure(self, metadata): + raise Neo4jError.hydrate(**metadata) + + class CommitResponse(Response): pass diff --git a/src/neo4j/_sync/io/_pool.py b/src/neo4j/_sync/io/_pool.py index 9fe58428..ef58970a 100644 --- a/src/neo4j/_sync/io/_pool.py +++ b/src/neo4j/_sync/io/_pool.py @@ -16,9 +16,12 @@ # limitations under the License. +from __future__ import annotations + import abc import asyncio import logging +import typing as t from collections import ( defaultdict, deque, @@ -28,10 +31,13 @@ from ..._async_compat.concurrency import ( Condition, + CooperativeLock, CooperativeRLock, + Lock, RLock, ) from ..._async_compat.network import NetworkUtil +from ..._async_compat.util import Util from ..._conf import ( PoolConfig, WorkspaceConfig, @@ -44,6 +50,7 @@ from ..._routing import RoutingTable from ...api import ( READ_ACCESS, + RenewableAuth, WRITE_ACCESS, ) from ...exceptions import ( @@ -79,6 +86,10 @@ def __init__(self, opener, pool_config, workspace_config): self.connections_reservations = defaultdict(lambda: 0) self.lock = CooperativeRLock() self.cond = Condition(self.lock) + self.refreshing_auth = False + self.refreshing_auth_lock = CooperativeLock() + self.initializing_auth_lock = Lock() + self.last_auth: t.Optional[RenewableAuth] = None def __enter__(self): return self @@ -86,6 +97,44 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, traceback): self.close() + def get_auth(self): + if self._initialize_auth(): + return self.last_auth.auth + + with self.refreshing_auth_lock: + if self.refreshing_auth: + # someone else is already getting a new auth + return self.last_auth + if self.last_auth is None or self.last_auth.expired: + with self.refreshing_auth_lock: + self.refreshing_auth = True + try: + new_auth = self._get_new_auth() + with self.refreshing_auth_lock: + self.last_auth = new_auth + self.refreshing_auth = False + except: + with self.refreshing_auth_lock: + self.refreshing_auth = False + raise + return self.last_auth.auth + + def _initialize_auth(self): + if self.last_auth is not None: + return False + with self.initializing_auth_lock: + if self.last_auth is not None: + # someone else initialized the auth + return True + self.last_auth = self._get_new_auth() + return True + + def _get_new_auth(self): + new_auth = Util.callback(self.pool_config.auth) + if not isinstance(self.last_auth, RenewableAuth): + return RenewableAuth(new_auth) + return new_auth + def _acquire_from_pool(self, address): with self.lock: for connection in list(self.connections.get(address, [])): @@ -130,17 +179,27 @@ def _acquire_from_pool_checked( else: return connection - def _acquire_new_later(self, address, deadline): + def _acquire_new_later(self, address, auth, deadline): def connection_creator(): released_reservation = False try: try: connection = self.opener( - address, deadline.to_timeout() + address, auth or self.get_auth(), + deadline.to_timeout() ) except ServiceUnavailable: self.deactivate(address) raise + if auth: + # It's unfortunate that we have to create a connection + # first to determine if the session-level auth is supported + # by the protocol or not. + try: + connection.assert_re_auth_support() + except ConfigurationError: + connection.close() + raise connection.pool = self connection.in_use = True with self.lock: @@ -166,7 +225,7 @@ def connection_creator(): return connection_creator return None - def _acquire(self, address, deadline, liveness_check_timeout): + def _acquire(self, address, auth, deadline, liveness_check_timeout): """ Acquire a connection to a given address from the pool. The address supplied should always be an IP address, not a host name. @@ -178,6 +237,19 @@ def health_check(connection_, deadline_): or connection_.defunct() or connection_.stale()): return False + try: + if connection_.re_auth(auth or self.get_auth()): + # no need for an extra liveness check if the connection was + # successfully re-authenticated + return True + except ConfigurationError: + # protocol does not support re-authentication + if auth: + # session-level not supported + raise + # expiring tokens supported by flushing the pool + # => give up this connection + return False if liveness_check_timeout is not None: if connection_.is_idle_for(liveness_check_timeout): with connection_deadline(connection_, deadline_): @@ -201,7 +273,9 @@ def health_check(connection_, deadline_): return connection # all connections in pool are in-use with self.lock: - connection_creator = self._acquire_new_later(address, deadline) + connection_creator = self._acquire_new_later( + address, auth, deadline + ) if connection_creator: break @@ -222,7 +296,8 @@ def health_check(connection_, deadline_): @abc.abstractmethod def acquire( - self, access_mode, timeout, database, bookmarks, liveness_check_timeout + self, access_mode, timeout, database, bookmarks, auth, + liveness_check_timeout ): """ Acquire a connection to a server that can satisfy a set of parameters. @@ -231,6 +306,7 @@ def acquire( (excluding potential preparation like fetching routing tables). :param database: :param bookmarks: + :param auth: :param liveness_check_timeout: """ ... @@ -372,17 +448,16 @@ def close(self): class BoltPool(IOPool): @classmethod - def open(cls, address, *, auth, pool_config, workspace_config): + def open(cls, address, *, pool_config, workspace_config): """Create a new BoltPool :param address: - :param auth: :param pool_config: :param workspace_config: :returns: BoltPool """ - def opener(addr, timeout): + def opener(addr, auth, timeout): return Bolt.open( addr, auth=auth, timeout=timeout, routing_context=None, pool_config=pool_config @@ -401,7 +476,8 @@ def __repr__(self): self.address) def acquire( - self, access_mode, timeout, database, bookmarks, liveness_check_timeout + self, access_mode, timeout, database, bookmarks, auth, + liveness_check_timeout ): # The access_mode and database is not needed for a direct connection, # it's just there for consistency. @@ -409,7 +485,7 @@ def acquire( "access_mode=%r, database=%r", access_mode, database) deadline = Deadline.from_timeout_or_deadline(timeout) return self._acquire( - self.address, deadline, liveness_check_timeout + self.address, auth, deadline, liveness_check_timeout ) @@ -418,12 +494,11 @@ class Neo4jPool(IOPool): """ @classmethod - def open(cls, *addresses, auth, pool_config, workspace_config, + def open(cls, *addresses, pool_config, workspace_config, routing_context=None): """Create a new Neo4jPool :param addresses: one or more address as positional argument - :param auth: :param pool_config: :param workspace_config: :param routing_context: @@ -437,7 +512,7 @@ def open(cls, *addresses, auth, pool_config, workspace_config, raise ConfigurationError("The key 'address' is reserved for routing context.") routing_context["address"] = str(address) - def opener(addr, timeout): + def opener(addr, auth, timeout): return Bolt.open( addr, auth=auth, timeout=timeout, routing_context=routing_context, pool_config=pool_config @@ -480,7 +555,7 @@ def get_or_create_routing_table(self, database): return self.routing_tables[database] def fetch_routing_info( - self, address, database, imp_user, bookmarks, acquisition_timeout + self, address, database, imp_user, bookmarks, auth, acquisition_timeout ): """ Fetch raw routing info from a given router address. @@ -491,6 +566,7 @@ def fetch_routing_info( :type imp_user: str or None :param bookmarks: iterable of bookmark values after which the routing info should be fetched + :param auth: auth :param acquisition_timeout: connection acquisition timeout :returns: list of routing records, or None if no connection @@ -501,7 +577,7 @@ def fetch_routing_info( deadline = Deadline.from_timeout_or_deadline(acquisition_timeout) log.debug("[#0000] _: _acquire router connection, " "database=%r, address=%r", database, address) - cx = self._acquire(address, deadline, None) + cx = self._acquire(address, auth, deadline, None) try: routing_table = cx.route( database=database or self.workspace_config.database, @@ -513,7 +589,8 @@ def fetch_routing_info( return routing_table def fetch_routing_table( - self, *, address, acquisition_timeout, database, imp_user, bookmarks + self, *, address, acquisition_timeout, database, imp_user, + bookmarks, auth ): """ Fetch a routing table from a given router address. @@ -525,6 +602,7 @@ def fetch_routing_table( table :type imp_user: str or None :param bookmarks: bookmarks used when fetching routing table + :param auth: auth :returns: a new RoutingTable instance or None if the given router is currently unable to provide routing information @@ -532,7 +610,8 @@ def fetch_routing_table( new_routing_info = None try: new_routing_info = self.fetch_routing_info( - address, database, imp_user, bookmarks, acquisition_timeout + address, database, imp_user, bookmarks, auth, + acquisition_timeout ) except Neo4jError as e: # checks if the code is an error that is caused by the client. In @@ -578,8 +657,8 @@ def fetch_routing_table( return new_routing_table def _update_routing_table_from( - self, *routers, database, imp_user, bookmarks, acquisition_timeout, - database_callback + self, *routers, database, imp_user, bookmarks, auth, + acquisition_timeout, database_callback ): """ Try to update routing tables with the given routers. @@ -595,7 +674,8 @@ def _update_routing_table_from( ): new_routing_table = self.fetch_routing_table( address=address, acquisition_timeout=acquisition_timeout, - database=database, imp_user=imp_user, bookmarks=bookmarks + database=database, imp_user=imp_user, bookmarks=bookmarks, + auth=auth ) if new_routing_table is not None: new_database = new_routing_table.database @@ -615,8 +695,8 @@ def _update_routing_table_from( return False def update_routing_table( - self, *, database, imp_user, bookmarks, acquisition_timeout=None, - database_callback=None + self, *, database, imp_user, bookmarks, auth=None, + acquisition_timeout=None, database_callback=None ): """ Update the routing table from the first router able to provide valid routing information. @@ -626,6 +706,7 @@ def update_routing_table( table :type imp_user: str or None :param bookmarks: bookmarks used when fetching routing table + :param auth: auth :param acquisition_timeout: connection acquisition timeout :param database_callback: A callback function that will be called with the database name as only argument when a new routing table has been @@ -647,15 +728,15 @@ def update_routing_table( # TODO: Test this state if self._update_routing_table_from( self.address, database=database, - imp_user=imp_user, bookmarks=bookmarks, + imp_user=imp_user, bookmarks=bookmarks, auth=auth, acquisition_timeout=acquisition_timeout, database_callback=database_callback ): # Why is only the first initial routing address used? return if self._update_routing_table_from( - *(existing_routers - {self.address}), - database=database, imp_user=imp_user, bookmarks=bookmarks, + *(existing_routers - {self.address}), database=database, + imp_user=imp_user, bookmarks=bookmarks, auth=auth, acquisition_timeout=acquisition_timeout, database_callback=database_callback ): @@ -664,7 +745,7 @@ def update_routing_table( if not prefer_initial_routing_address: if self._update_routing_table_from( self.address, database=database, - imp_user=imp_user, bookmarks=bookmarks, + imp_user=imp_user, bookmarks=bookmarks, auth=auth, acquisition_timeout=acquisition_timeout, database_callback=database_callback ): @@ -683,7 +764,7 @@ def update_connection_pool(self, *, database): super(Neo4jPool, self).deactivate(address) def ensure_routing_table_is_fresh( - self, *, access_mode, database, imp_user, bookmarks, + self, *, access_mode, database, imp_user, bookmarks, auth=None, acquisition_timeout=None, database_callback=None ): """ Update the routing table if stale. @@ -720,7 +801,7 @@ def ensure_routing_table_is_fresh( self.update_routing_table( database=database, imp_user=imp_user, bookmarks=bookmarks, - acquisition_timeout=acquisition_timeout, + auth=auth, acquisition_timeout=acquisition_timeout, database_callback=database_callback ) self.update_connection_pool(database=database) @@ -757,7 +838,8 @@ def _select_address(self, *, access_mode, database): return choice(addresses_by_usage[min(addresses_by_usage)]) def acquire( - self, access_mode, timeout, database, bookmarks, liveness_check_timeout + self, access_mode, timeout, database, bookmarks, auth, + liveness_check_timeout ): if access_mode not in (WRITE_ACCESS, READ_ACCESS): raise ClientError("Non valid 'access_mode'; {}".format(access_mode)) @@ -796,7 +878,7 @@ def acquire( deadline = Deadline.from_timeout_or_deadline(timeout) # should always be a resolved address connection = self._acquire( - address, deadline, liveness_check_timeout + address, auth, deadline, liveness_check_timeout ) except (ServiceUnavailable, SessionExpired): self.deactivate(address=address) diff --git a/src/neo4j/_sync/work/session.py b/src/neo4j/_sync/work/session.py index 41da7973..2ee8c852 100644 --- a/src/neo4j/_sync/work/session.py +++ b/src/neo4j/_sync/work/session.py @@ -98,6 +98,7 @@ class Session(Workspace): def __init__(self, pool, session_config): assert isinstance(session_config, SessionConfig) super().__init__(pool, session_config) + self._config = session_config self._initialize_bookmarks(session_config.bookmarks) self._bookmark_manager = session_config.bookmark_manager @@ -117,7 +118,8 @@ def _connect(self, access_mode, **access_kwargs): if access_mode is None: access_mode = self._config.default_access_mode try: - super()._connect(access_mode, **access_kwargs) + super()._connect(access_mode, auth=self._config.auth, + **access_kwargs) except asyncio.CancelledError: self._handle_cancellation(message="_connect") raise diff --git a/src/neo4j/_sync/work/workspace.py b/src/neo4j/_sync/work/workspace.py index 2cb91648..6fd37028 100644 --- a/src/neo4j/_sync/work/workspace.py +++ b/src/neo4j/_sync/work/workspace.py @@ -130,7 +130,7 @@ def _update_bookmark(self, bookmark): return self._update_bookmarks((bookmark,)) - def _connect(self, access_mode, **acquire_kwargs): + def _connect(self, access_mode, auth=None, **acquire_kwargs): acquisition_timeout = self._config.connection_acquisition_timeout if self._connection: # TODO: Investigate this @@ -154,6 +154,7 @@ def _connect(self, access_mode, **acquire_kwargs): database=self._config.database, imp_user=self._config.impersonated_user, bookmarks=self._get_bookmarks(), + auth=auth, acquisition_timeout=acquisition_timeout, database_callback=self._set_cached_database ) @@ -162,6 +163,7 @@ def _connect(self, access_mode, **acquire_kwargs): "timeout": acquisition_timeout, "database": self._config.database, "bookmarks": self._get_bookmarks(), + "auth": auth, "liveness_check_timeout": None, } acquire_kwargs_.update(acquire_kwargs) diff --git a/src/neo4j/api.py b/src/neo4j/api.py index e91aed6e..06d75a99 100644 --- a/src/neo4j/api.py +++ b/src/neo4j/api.py @@ -21,6 +21,7 @@ from __future__ import annotations import abc +import time import typing as t from urllib.parse import ( parse_qs, @@ -103,6 +104,11 @@ def __init__( if parameters: self.parameters = parameters + def __eq__(self, other): + if not isinstance(other, Auth): + return NotImplemented + return vars(self) == vars(other) + # For backwards compatibility AuthToken = Auth @@ -175,6 +181,54 @@ def custom_auth( return Auth(scheme, principal, credentials, realm, **parameters) +class RenewableAuth: + """Container for details which potentially expire. + + This is meant to be used with auth token provider which is a callable + that returns... + + .. warning:: + + This function **must not** interact with the driver in any way as this + can cause a deadlock or undefined behaviour. + + :param auth: The auth token. :param expires_in: The expected expiry time + of the auth token in seconds from now. It is recommended to set this + a little before the actual expiry time to give the driver time to + renew the auth token before connections start to fail. If set to + :data:`None`, the token is assumed to never expire. + """ + + def __init__( + self, + auth: t.Union[Auth, t.Tuple[t.Any, t.Any], None], + expires_in: t.Optional[float] = None, + ) -> None: + self.auth = auth + self.created_at: t.Optional[float] + self.expires_in: t.Optional[float] + self.expires_at: t.Optional[float] + if expires_in is not None: + self.expires_in = expires_in + self.created_at = time.monotonic() + self.expires_at = self.created_at + expires_in + else: + self.expires_in = None + self.created_at = None + self.expires_at = None + + @property + def expired(self): + return (self.expires_at is not None + and self.expires_at < time.monotonic()) + + +_TAuthTokenProvider = t.Callable[[], t.Union[ + RenewableAuth, Auth, t.Tuple[t.Any, t.Any], + t.Awaitable[t.Union[RenewableAuth, Auth, t.Tuple[t.Any, t.Any]]] +]] + + # TODO: 6.0 - remove this class class Bookmark: """A Bookmark object contains an immutable list of bookmark string values. diff --git a/src/neo4j/scratch_11.py b/src/neo4j/scratch_11.py new file mode 100644 index 00000000..6a80b32c --- /dev/null +++ b/src/neo4j/scratch_11.py @@ -0,0 +1,28 @@ +from time import time + +import neo4j +from neo4j.debug import watch + + +# watch("neo4j") + + +URI = "neo4j://localhost:7687" +USER = "neo4j" +PASSWORD = "pass" + + + +def main(): + start = time() + with neo4j.GraphDatabase.driver(URI, auth=(USER, PASSWORD)) as driver: + for _ in range(10000): + with driver.session() as session: + value = list(range(100)) + list(session.run("RETURN $value", value=value)) + end = time() + print(f"Time taken: {end - start}s") + + +if __name__ == '__main__': + main() diff --git a/testkitbackend/_async/requests.py b/testkitbackend/_async/requests.py index c3b28216..2fb4b94d 100644 --- a/testkitbackend/_async/requests.py +++ b/testkitbackend/_async/requests.py @@ -93,9 +93,9 @@ async def GetFeatures(backend, data): await backend.send_response("FeatureList", {"features": FEATURES}) -async def NewDriver(backend, data): - auth_token = data["authorizationToken"]["data"] - data["authorizationToken"].mark_item_as_read_if_equals( +def _convert_auth_token(data, key): + auth_token = data[key]["data"] + data[key].mark_item_as_read_if_equals( "name", "AuthorizationToken" ) scheme = auth_token["scheme"] @@ -115,6 +115,11 @@ async def NewDriver(backend, data): **auth_token.get("parameters", {}) ) auth_token.mark_item_as_read("parameters", recursive=True) + return auth + + +async def NewDriver(backend, data): + auth = _convert_auth_token(data, "authorizationToken") kwargs = {} if data["resolverRegistered"] or data["domainNameResolverRegistered"]: kwargs["resolver"] = resolution_func( @@ -385,6 +390,8 @@ async def NewSession(backend, data): ): if data_name in data: config[conf_name] = data[data_name] + if data.get("authorizationToken"): + config["auth"] = _convert_auth_token(data, "authorizationToken") if "bookmark_manager" in config: with warning_check( neo4j.ExperimentalWarning, diff --git a/testkitbackend/_sync/requests.py b/testkitbackend/_sync/requests.py index ab6ba792..65342a6f 100644 --- a/testkitbackend/_sync/requests.py +++ b/testkitbackend/_sync/requests.py @@ -93,9 +93,9 @@ def GetFeatures(backend, data): backend.send_response("FeatureList", {"features": FEATURES}) -def NewDriver(backend, data): - auth_token = data["authorizationToken"]["data"] - data["authorizationToken"].mark_item_as_read_if_equals( +def _convert_auth_token(data, key): + auth_token = data[key]["data"] + data[key].mark_item_as_read_if_equals( "name", "AuthorizationToken" ) scheme = auth_token["scheme"] @@ -115,6 +115,11 @@ def NewDriver(backend, data): **auth_token.get("parameters", {}) ) auth_token.mark_item_as_read("parameters", recursive=True) + return auth + + +def NewDriver(backend, data): + auth = _convert_auth_token(data, "authorizationToken") kwargs = {} if data["resolverRegistered"] or data["domainNameResolverRegistered"]: kwargs["resolver"] = resolution_func( @@ -385,6 +390,8 @@ def NewSession(backend, data): ): if data_name in data: config[conf_name] = data[data_name] + if data.get("authorizationToken"): + config["auth"] = _convert_auth_token(data, "authorizationToken") if "bookmark_manager" in config: with warning_check( neo4j.ExperimentalWarning, diff --git a/testkitbackend/test_config.json b/testkitbackend/test_config.json index 18710693..9ae7e7d7 100644 --- a/testkitbackend/test_config.json +++ b/testkitbackend/test_config.json @@ -21,11 +21,13 @@ "Feature:API:Driver:GetServerInfo": true, "Feature:API:Driver.IsEncrypted": true, "Feature:API:Driver.VerifyConnectivity": true, + "Feature:API:Driver.SupportsSessionAuth": true, "Feature:API:Liveness.Check": false, "Feature:API:Result.List": true, "Feature:API:Result.Peek": true, "Feature:API:Result.Single": true, "Feature:API:Result.SingleOptional": true, + "Feature:API:Session:AuthConfig": true, "Feature:API:SSLConfig": true, "Feature:API:SSLSchemes": true, "Feature:API:Type.Spatial": true, @@ -39,6 +41,7 @@ "Feature:Bolt:4.3": true, "Feature:Bolt:4.4": true, "Feature:Bolt:5.0": true, + "Feature:Bolt:5.1": true, "Feature:Bolt:Patch:UTC": true, "Feature:Impersonation": true, "Feature:TLS:1.1": "Driver blocks TLS 1.1 for security reasons.", @@ -51,6 +54,7 @@ "Optimization:ImplicitDefaultArguments": true, "Optimization:MinimalBookmarksSet": true, "Optimization:MinimalResets": true, + "Optimization:AuthPipelining": true, "Optimization:PullPipelining": true, "Optimization:ResultListFetchAll": "The idiomatic way to cast to list is indistinguishable from iterating over the result.", diff --git a/tests/unit/async_/fixtures/fake_connection.py b/tests/unit/async_/fixtures/fake_connection.py index f0d0070c..ae4b2dc9 100644 --- a/tests/unit/async_/fixtures/fake_connection.py +++ b/tests/unit/async_/fixtures/fake_connection.py @@ -50,6 +50,7 @@ def __init__(self, *args, **kwargs): self.attach_mock(mock.Mock(return_value=False), "stale") self.attach_mock(mock.Mock(return_value=False), "closed") self.attach_mock(mock.Mock(return_value=False), "socket") + self.attach_mock(mock.AsyncMock(return_value=False), "re_auth") self.attach_mock(mock.Mock(), "unresolved_address") def close_side_effect(): diff --git a/tests/unit/async_/io/test_class_bolt.py b/tests/unit/async_/io/test_class_bolt.py index 771c4d73..60adfd28 100644 --- a/tests/unit/async_/io/test_class_bolt.py +++ b/tests/unit/async_/io/test_class_bolt.py @@ -31,12 +31,13 @@ def test_class_method_protocol_handlers(): protocol_handlers = AsyncBolt.protocol_handlers() - assert len(protocol_handlers) == 6 - assert protocol_handlers.keys() == { + expected_versions = { (3, 0), (4, 1), (4, 2), (4, 3), (4, 4), - (5, 0), + (5, 0), (5, 1), } + assert len(protocol_handlers) == len(expected_versions) + assert protocol_handlers.keys() == expected_versions @pytest.mark.parametrize( @@ -52,7 +53,8 @@ def test_class_method_protocol_handlers(): ((4, 3), 1), ((4, 4), 1), ((5, 0), 1), - ((5, 1), 0), + ((5, 1), 1), + ((5, 2), 0), ((6, 0), 0), ] ) @@ -71,7 +73,7 @@ def test_class_method_protocol_handlers_with_invalid_protocol_version(): def test_class_method_get_handshake(): handshake = AsyncBolt.get_handshake() - assert (b"\x00\x00\x00\x05\x00\x02\x04\x04\x00\x00\x01\x04\x00\x00\x00\x03" + assert (b"\x00\x01\x01\x05\x00\x02\x04\x04\x00\x00\x01\x04\x00\x00\x00\x03" == handshake) diff --git a/tests/unit/async_/io/test_class_bolt3.py b/tests/unit/async_/io/test_class_bolt3.py index aa6aac10..8f3bda27 100644 --- a/tests/unit/async_/io/test_class_bolt3.py +++ b/tests/unit/async_/io/test_class_bolt3.py @@ -16,8 +16,12 @@ # limitations under the License. +import logging +from itertools import permutations + import pytest +import neo4j from neo4j._async.io._bolt3 import AsyncBolt3 from neo4j._conf import PoolConfig from neo4j.exceptions import ConfigurationError @@ -112,3 +116,89 @@ async def test_hint_recv_timeout_seconds_gets_ignored( ) await connection.hello() sockets.client.settimeout.assert_not_called() + + +CREDENTIALS = "+++super-secret-sauce+++" + + +@pytest.mark.parametrize("auth", ( + ("user", CREDENTIALS), + neo4j.basic_auth("user", CREDENTIALS), + neo4j.kerberos_auth(CREDENTIALS), + neo4j.bearer_auth(CREDENTIALS), + neo4j.custom_auth("user", CREDENTIALS, "realm", "scheme"), + neo4j.Auth("scheme", "principal", CREDENTIALS, "realm", foo="bar"), +)) +@mark_async_test +async def test_credentials_are_not_logged( + auth, fake_socket_pair, mocker, caplog +): + address = ("127.0.0.1", 7687) + sockets = fake_socket_pair(address, + packer_cls=AsyncBolt3.PACKER_CLS, + unpacker_cls=AsyncBolt3.UNPACKER_CLS) + sockets.client.settimeout = mocker.Mock() + await sockets.server.send_message(b"\x70", {"server": "Neo4j/4.3.4"}) + connection = AsyncBolt3( + address, sockets.client, PoolConfig.max_connection_lifetime, auth=auth + ) + with caplog.at_level(logging.DEBUG): + await connection.hello() + + if isinstance(auth, tuple): + auth = neo4j.basic_auth(*auth) + for field in ("scheme", "principal", "realm", "parameters"): + value = getattr(auth, field, None) + if value: + assert repr(value) in caplog.text + assert CREDENTIALS not in caplog.text + + +@pytest.mark.parametrize("message", ("logon", "logoff")) +def test_auth_message_raises_configuration_error(message, fake_socket): + address = ("127.0.0.1", 7687) + connection = AsyncBolt3(address, fake_socket(address), + PoolConfig.max_connection_lifetime) + with pytest.raises(ConfigurationError, + match="Session level authentication is not supported"): + getattr(connection, message)() + + +@pytest.mark.parametrize("auth", ( + None, + neo4j.Auth("scheme", "principal", "credentials", "realm"), + ("user", "password"), +)) +@mark_async_test +async def test_re_auth_noop(auth, fake_socket, mocker): + address = ("127.0.0.1", 7687) + connection = AsyncBolt3(address, fake_socket(address), + PoolConfig.max_connection_lifetime, auth=auth) + logon_spy = mocker.spy(connection, "logon") + logoff_spy = mocker.spy(connection, "logoff") + res = await connection.re_auth(auth) + + assert res is False + logon_spy.assert_not_called() + logoff_spy.assert_not_called() + + +@pytest.mark.parametrize( + ("auth1", "auth2"), + permutations( + ( + None, + neo4j.Auth("scheme", "principal", "credentials", "realm"), + ("user", "password"), + ), + 2 + ) +) +@mark_async_test +async def test_re_auth(auth1, auth2, fake_socket): + address = ("127.0.0.1", 7687) + connection = AsyncBolt3(address, fake_socket(address), + PoolConfig.max_connection_lifetime, auth=auth1) + with pytest.raises(ConfigurationError, + match="Session level authentication is not supported"): + await connection.re_auth(auth2) diff --git a/tests/unit/async_/io/test_class_bolt4x0.py b/tests/unit/async_/io/test_class_bolt4x0.py index 56b15f4d..659db673 100644 --- a/tests/unit/async_/io/test_class_bolt4x0.py +++ b/tests/unit/async_/io/test_class_bolt4x0.py @@ -16,10 +16,15 @@ # limitations under the License. +import logging +from itertools import permutations + import pytest +import neo4j from neo4j._async.io._bolt4 import AsyncBolt4x0 from neo4j._conf import PoolConfig +from neo4j.exceptions import ConfigurationError from ...._async_compat import mark_async_test @@ -207,3 +212,89 @@ async def test_hint_recv_timeout_seconds_gets_ignored( ) await connection.hello() sockets.client.settimeout.assert_not_called() + + +CREDENTIALS = "+++super-secret-sauce+++" + + +@pytest.mark.parametrize("auth", ( + ("user", CREDENTIALS), + neo4j.basic_auth("user", CREDENTIALS), + neo4j.kerberos_auth(CREDENTIALS), + neo4j.bearer_auth(CREDENTIALS), + neo4j.custom_auth("user", CREDENTIALS, "realm", "scheme"), + neo4j.Auth("scheme", "principal", CREDENTIALS, "realm", foo="bar"), +)) +@mark_async_test +async def test_credentials_are_not_logged( + auth, fake_socket_pair, mocker, caplog +): + address = ("127.0.0.1", 7687) + sockets = fake_socket_pair(address, + packer_cls=AsyncBolt4x0.PACKER_CLS, + unpacker_cls=AsyncBolt4x0.UNPACKER_CLS) + sockets.client.settimeout = mocker.Mock() + await sockets.server.send_message(b"\x70", {"server": "Neo4j/4.3.4"}) + connection = AsyncBolt4x0( + address, sockets.client, PoolConfig.max_connection_lifetime, auth=auth + ) + with caplog.at_level(logging.DEBUG): + await connection.hello() + + if isinstance(auth, tuple): + auth = neo4j.basic_auth(*auth) + for field in ("scheme", "principal", "realm", "parameters"): + value = getattr(auth, field, None) + if value: + assert repr(value) in caplog.text + assert CREDENTIALS not in caplog.text + + +@pytest.mark.parametrize("message", ("logon", "logoff")) +def test_auth_message_raises_configuration_error(message, fake_socket): + address = ("127.0.0.1", 7687) + connection = AsyncBolt4x0(address, fake_socket(address), + PoolConfig.max_connection_lifetime) + with pytest.raises(ConfigurationError, + match="Session level authentication is not supported"): + getattr(connection, message)() + + +@pytest.mark.parametrize("auth", ( + None, + neo4j.Auth("scheme", "principal", "credentials", "realm"), + ("user", "password"), +)) +@mark_async_test +async def test_re_auth_noop(auth, fake_socket, mocker): + address = ("127.0.0.1", 7687) + connection = AsyncBolt4x0(address, fake_socket(address), + PoolConfig.max_connection_lifetime, auth=auth) + logon_spy = mocker.spy(connection, "logon") + logoff_spy = mocker.spy(connection, "logoff") + res = await connection.re_auth(auth) + + assert res is False + logon_spy.assert_not_called() + logoff_spy.assert_not_called() + + +@pytest.mark.parametrize( + ("auth1", "auth2"), + permutations( + ( + None, + neo4j.Auth("scheme", "principal", "credentials", "realm"), + ("user", "password"), + ), + 2 + ) +) +@mark_async_test +async def test_re_auth(auth1, auth2, fake_socket): + address = ("127.0.0.1", 7687) + connection = AsyncBolt4x0(address, fake_socket(address), + PoolConfig.max_connection_lifetime, auth=auth1) + with pytest.raises(ConfigurationError, + match="Session level authentication is not supported"): + await connection.re_auth(auth2) diff --git a/tests/unit/async_/io/test_class_bolt4x1.py b/tests/unit/async_/io/test_class_bolt4x1.py index 4371f005..d236cac4 100644 --- a/tests/unit/async_/io/test_class_bolt4x1.py +++ b/tests/unit/async_/io/test_class_bolt4x1.py @@ -16,10 +16,15 @@ # limitations under the License. +import logging +from itertools import permutations + import pytest +import neo4j from neo4j._async.io._bolt4 import AsyncBolt4x1 from neo4j._conf import PoolConfig +from neo4j.exceptions import ConfigurationError from ...._async_compat import mark_async_test @@ -226,3 +231,89 @@ async def test_hint_recv_timeout_seconds_gets_ignored( PoolConfig.max_connection_lifetime) await connection.hello() sockets.client.settimeout.assert_not_called() + + +CREDENTIALS = "+++super-secret-sauce+++" + + +@pytest.mark.parametrize("auth", ( + ("user", CREDENTIALS), + neo4j.basic_auth("user", CREDENTIALS), + neo4j.kerberos_auth(CREDENTIALS), + neo4j.bearer_auth(CREDENTIALS), + neo4j.custom_auth("user", CREDENTIALS, "realm", "scheme"), + neo4j.Auth("scheme", "principal", CREDENTIALS, "realm", foo="bar"), +)) +@mark_async_test +async def test_credentials_are_not_logged( + auth, fake_socket_pair, mocker, caplog +): + address = ("127.0.0.1", 7687) + sockets = fake_socket_pair(address, + packer_cls=AsyncBolt4x1.PACKER_CLS, + unpacker_cls=AsyncBolt4x1.UNPACKER_CLS) + sockets.client.settimeout = mocker.Mock() + await sockets.server.send_message(b"\x70", {"server": "Neo4j/4.3.4"}) + connection = AsyncBolt4x1( + address, sockets.client, PoolConfig.max_connection_lifetime, auth=auth + ) + with caplog.at_level(logging.DEBUG): + await connection.hello() + + if isinstance(auth, tuple): + auth = neo4j.basic_auth(*auth) + for field in ("scheme", "principal", "realm", "parameters"): + value = getattr(auth, field, None) + if value: + assert repr(value) in caplog.text + assert CREDENTIALS not in caplog.text + + +@pytest.mark.parametrize("message", ("logon", "logoff")) +def test_auth_message_raises_configuration_error(message, fake_socket): + address = ("127.0.0.1", 7687) + connection = AsyncBolt4x1(address, fake_socket(address), + PoolConfig.max_connection_lifetime) + with pytest.raises(ConfigurationError, + match="Session level authentication is not supported"): + getattr(connection, message)() + + +@pytest.mark.parametrize("auth", ( + None, + neo4j.Auth("scheme", "principal", "credentials", "realm"), + ("user", "password"), +)) +@mark_async_test +async def test_re_auth_noop(auth, fake_socket, mocker): + address = ("127.0.0.1", 7687) + connection = AsyncBolt4x1(address, fake_socket(address), + PoolConfig.max_connection_lifetime, auth=auth) + logon_spy = mocker.spy(connection, "logon") + logoff_spy = mocker.spy(connection, "logoff") + res = await connection.re_auth(auth) + + assert res is False + logon_spy.assert_not_called() + logoff_spy.assert_not_called() + + +@pytest.mark.parametrize( + ("auth1", "auth2"), + permutations( + ( + None, + neo4j.Auth("scheme", "principal", "credentials", "realm"), + ("user", "password"), + ), + 2 + ) +) +@mark_async_test +async def test_re_auth(auth1, auth2, fake_socket): + address = ("127.0.0.1", 7687) + connection = AsyncBolt4x1(address, fake_socket(address), + PoolConfig.max_connection_lifetime, auth=auth1) + with pytest.raises(ConfigurationError, + match="Session level authentication is not supported"): + await connection.re_auth(auth2) diff --git a/tests/unit/async_/io/test_class_bolt4x2.py b/tests/unit/async_/io/test_class_bolt4x2.py index 804038cb..92f4617e 100644 --- a/tests/unit/async_/io/test_class_bolt4x2.py +++ b/tests/unit/async_/io/test_class_bolt4x2.py @@ -16,10 +16,15 @@ # limitations under the License. +import logging +from itertools import permutations + import pytest +import neo4j from neo4j._async.io._bolt4 import AsyncBolt4x2 from neo4j._conf import PoolConfig +from neo4j.exceptions import ConfigurationError from ...._async_compat import mark_async_test @@ -227,3 +232,89 @@ async def test_hint_recv_timeout_seconds_gets_ignored( ) await connection.hello() sockets.client.settimeout.assert_not_called() + + +CREDENTIALS = "+++super-secret-sauce+++" + + +@pytest.mark.parametrize("auth", ( + ("user", CREDENTIALS), + neo4j.basic_auth("user", CREDENTIALS), + neo4j.kerberos_auth(CREDENTIALS), + neo4j.bearer_auth(CREDENTIALS), + neo4j.custom_auth("user", CREDENTIALS, "realm", "scheme"), + neo4j.Auth("scheme", "principal", CREDENTIALS, "realm", foo="bar"), +)) +@mark_async_test +async def test_credentials_are_not_logged( + auth, fake_socket_pair, mocker, caplog +): + address = ("127.0.0.1", 7687) + sockets = fake_socket_pair(address, + packer_cls=AsyncBolt4x2.PACKER_CLS, + unpacker_cls=AsyncBolt4x2.UNPACKER_CLS) + sockets.client.settimeout = mocker.Mock() + await sockets.server.send_message(b"\x70", {"server": "Neo4j/4.3.4"}) + connection = AsyncBolt4x2( + address, sockets.client, PoolConfig.max_connection_lifetime, auth=auth + ) + with caplog.at_level(logging.DEBUG): + await connection.hello() + + if isinstance(auth, tuple): + auth = neo4j.basic_auth(*auth) + for field in ("scheme", "principal", "realm", "parameters"): + value = getattr(auth, field, None) + if value: + assert repr(value) in caplog.text + assert CREDENTIALS not in caplog.text + + +@pytest.mark.parametrize("message", ("logon", "logoff")) +def test_auth_message_raises_configuration_error(message, fake_socket): + address = ("127.0.0.1", 7687) + connection = AsyncBolt4x2(address, fake_socket(address), + PoolConfig.max_connection_lifetime) + with pytest.raises(ConfigurationError, + match="Session level authentication is not supported"): + getattr(connection, message)() + + +@pytest.mark.parametrize("auth", ( + None, + neo4j.Auth("scheme", "principal", "credentials", "realm"), + ("user", "password"), +)) +@mark_async_test +async def test_re_auth_noop(auth, fake_socket, mocker): + address = ("127.0.0.1", 7687) + connection = AsyncBolt4x2(address, fake_socket(address), + PoolConfig.max_connection_lifetime, auth=auth) + logon_spy = mocker.spy(connection, "logon") + logoff_spy = mocker.spy(connection, "logoff") + res = await connection.re_auth(auth) + + assert res is False + logon_spy.assert_not_called() + logoff_spy.assert_not_called() + + +@pytest.mark.parametrize( + ("auth1", "auth2"), + permutations( + ( + None, + neo4j.Auth("scheme", "principal", "credentials", "realm"), + ("user", "password"), + ), + 2 + ) +) +@mark_async_test +async def test_re_auth(auth1, auth2, fake_socket): + address = ("127.0.0.1", 7687) + connection = AsyncBolt4x2(address, fake_socket(address), + PoolConfig.max_connection_lifetime, auth=auth1) + with pytest.raises(ConfigurationError, + match="Session level authentication is not supported"): + await connection.re_auth(auth2) diff --git a/tests/unit/async_/io/test_class_bolt4x3.py b/tests/unit/async_/io/test_class_bolt4x3.py index eb74241c..47d10ee9 100644 --- a/tests/unit/async_/io/test_class_bolt4x3.py +++ b/tests/unit/async_/io/test_class_bolt4x3.py @@ -17,11 +17,14 @@ import logging +from itertools import permutations import pytest +import neo4j from neo4j._async.io._bolt4 import AsyncBolt4x3 from neo4j._conf import PoolConfig +from neo4j.exceptions import ConfigurationError from ...._async_compat import mark_async_test @@ -255,3 +258,90 @@ async def test_hint_recv_timeout_seconds( and "recv_timeout_seconds" in msg and "invalid" in msg for msg in caplog.messages) + + +CREDENTIALS = "+++super-secret-sauce+++" + + +@pytest.mark.parametrize("auth", ( + ("user", CREDENTIALS), + neo4j.basic_auth("user", CREDENTIALS), + neo4j.kerberos_auth(CREDENTIALS), + neo4j.bearer_auth(CREDENTIALS), + neo4j.custom_auth("user", CREDENTIALS, "realm", "scheme"), + neo4j.Auth("scheme", "principal", CREDENTIALS, "realm", foo="bar"), +)) +@mark_async_test +async def test_credentials_are_not_logged( + auth, fake_socket_pair, mocker, caplog +): + address = ("127.0.0.1", 7687) + sockets = fake_socket_pair(address, + packer_cls=AsyncBolt4x3.PACKER_CLS, + unpacker_cls=AsyncBolt4x3.UNPACKER_CLS) + sockets.client.settimeout = mocker.Mock() + await sockets.server.send_message(b"\x70", {"server": "Neo4j/4.3.4"}) + connection = AsyncBolt4x3( + address, sockets.client, PoolConfig.max_connection_lifetime, auth=auth + ) + with caplog.at_level(logging.DEBUG): + await connection.hello() + + if isinstance(auth, tuple): + auth = neo4j.basic_auth(*auth) + for field in ("scheme", "principal", "realm", "parameters"): + value = getattr(auth, field, None) + if value: + assert repr(value) in caplog.text + assert CREDENTIALS not in caplog.text + + + +@pytest.mark.parametrize("message", ("logon", "logoff")) +def test_auth_message_raises_configuration_error(message, fake_socket): + address = ("127.0.0.1", 7687) + connection = AsyncBolt4x3(address, fake_socket(address), + PoolConfig.max_connection_lifetime) + with pytest.raises(ConfigurationError, + match="Session level authentication is not supported"): + getattr(connection, message)() + + +@pytest.mark.parametrize("auth", ( + None, + neo4j.Auth("scheme", "principal", "credentials", "realm"), + ("user", "password"), +)) +@mark_async_test +async def test_re_auth_noop(auth, fake_socket, mocker): + address = ("127.0.0.1", 7687) + connection = AsyncBolt4x3(address, fake_socket(address), + PoolConfig.max_connection_lifetime, auth=auth) + logon_spy = mocker.spy(connection, "logon") + logoff_spy = mocker.spy(connection, "logoff") + res = await connection.re_auth(auth) + + assert res is False + logon_spy.assert_not_called() + logoff_spy.assert_not_called() + + +@pytest.mark.parametrize( + ("auth1", "auth2"), + permutations( + ( + None, + neo4j.Auth("scheme", "principal", "credentials", "realm"), + ("user", "password"), + ), + 2 + ) +) +@mark_async_test +async def test_re_auth(auth1, auth2, fake_socket): + address = ("127.0.0.1", 7687) + connection = AsyncBolt4x3(address, fake_socket(address), + PoolConfig.max_connection_lifetime, auth=auth1) + with pytest.raises(ConfigurationError, + match="Session level authentication is not supported"): + await connection.re_auth(auth2) diff --git a/tests/unit/async_/io/test_class_bolt4x4.py b/tests/unit/async_/io/test_class_bolt4x4.py index c88b4af1..918b1498 100644 --- a/tests/unit/async_/io/test_class_bolt4x4.py +++ b/tests/unit/async_/io/test_class_bolt4x4.py @@ -17,11 +17,14 @@ import logging +from itertools import permutations import pytest +import neo4j from neo4j._async.io._bolt4 import AsyncBolt4x4 from neo4j._conf import PoolConfig +from neo4j.exceptions import ConfigurationError from ...._async_compat import mark_async_test @@ -269,3 +272,89 @@ async def test_hint_recv_timeout_seconds( and "recv_timeout_seconds" in msg and "invalid" in msg for msg in caplog.messages) + + +CREDENTIALS = "+++super-secret-sauce+++" + + +@pytest.mark.parametrize("auth", ( + ("user", CREDENTIALS), + neo4j.basic_auth("user", CREDENTIALS), + neo4j.kerberos_auth(CREDENTIALS), + neo4j.bearer_auth(CREDENTIALS), + neo4j.custom_auth("user", CREDENTIALS, "realm", "scheme"), + neo4j.Auth("scheme", "principal", CREDENTIALS, "realm", foo="bar"), +)) +@mark_async_test +async def test_credentials_are_not_logged( + auth, fake_socket_pair, mocker, caplog +): + address = ("127.0.0.1", 7687) + sockets = fake_socket_pair(address, + packer_cls=AsyncBolt4x4.PACKER_CLS, + unpacker_cls=AsyncBolt4x4.UNPACKER_CLS) + sockets.client.settimeout = mocker.Mock() + await sockets.server.send_message(b"\x70", {"server": "Neo4j/4.3.4"}) + connection = AsyncBolt4x4( + address, sockets.client, PoolConfig.max_connection_lifetime, auth=auth + ) + with caplog.at_level(logging.DEBUG): + await connection.hello() + + if isinstance(auth, tuple): + auth = neo4j.basic_auth(*auth) + for field in ("scheme", "principal", "realm", "parameters"): + value = getattr(auth, field, None) + if value: + assert repr(value) in caplog.text + assert CREDENTIALS not in caplog.text + + +@pytest.mark.parametrize("message", ("logon", "logoff")) +def test_auth_message_raises_configuration_error(message, fake_socket): + address = ("127.0.0.1", 7687) + connection = AsyncBolt4x4(address, fake_socket(address), + PoolConfig.max_connection_lifetime) + with pytest.raises(ConfigurationError, + match="Session level authentication is not supported"): + getattr(connection, message)() + + +@pytest.mark.parametrize("auth", ( + None, + neo4j.Auth("scheme", "principal", "credentials", "realm"), + ("user", "password"), +)) +@mark_async_test +async def test_re_auth_noop(auth, fake_socket, mocker): + address = ("127.0.0.1", 7687) + connection = AsyncBolt4x4(address, fake_socket(address), + PoolConfig.max_connection_lifetime, auth=auth) + logon_spy = mocker.spy(connection, "logon") + logoff_spy = mocker.spy(connection, "logoff") + res = await connection.re_auth(auth) + + assert res is False + logon_spy.assert_not_called() + logoff_spy.assert_not_called() + + +@pytest.mark.parametrize( + ("auth1", "auth2"), + permutations( + ( + None, + neo4j.Auth("scheme", "principal", "credentials", "realm"), + ("user", "password"), + ), + 2 + ) +) +@mark_async_test +async def test_re_auth(auth1, auth2, fake_socket): + address = ("127.0.0.1", 7687) + connection = AsyncBolt4x4(address, fake_socket(address), + PoolConfig.max_connection_lifetime, auth=auth1) + with pytest.raises(ConfigurationError, + match="Session level authentication is not supported"): + await connection.re_auth(auth2) diff --git a/tests/unit/async_/io/test_class_bolt5x0.py b/tests/unit/async_/io/test_class_bolt5x0.py new file mode 100644 index 00000000..f8c0c61e --- /dev/null +++ b/tests/unit/async_/io/test_class_bolt5x0.py @@ -0,0 +1,358 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging +from itertools import permutations + +import pytest + +import neo4j +from neo4j._async.io._bolt5 import AsyncBolt5x0 +from neo4j._conf import PoolConfig +from neo4j.exceptions import ConfigurationError + +from ...._async_compat import mark_async_test + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_stale(fake_socket, set_stale): + address = ("127.0.0.1", 7687) + max_connection_lifetime = 0 + connection = AsyncBolt5x0(address, fake_socket(address), max_connection_lifetime) + if set_stale: + connection.set_stale() + assert connection.stale() is True + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): + address = ("127.0.0.1", 7687) + max_connection_lifetime = -1 + connection = AsyncBolt5x0(address, fake_socket(address), max_connection_lifetime) + if set_stale: + connection.set_stale() + assert connection.stale() is set_stale + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_not_stale(fake_socket, set_stale): + address = ("127.0.0.1", 7687) + max_connection_lifetime = 999999999 + connection = AsyncBolt5x0(address, fake_socket(address), max_connection_lifetime) + if set_stale: + connection.set_stale() + assert connection.stale() is set_stale + + +@pytest.mark.parametrize(("args", "kwargs", "expected_fields"), ( + (("", {}), {"db": "something"}, ({"db": "something"},)), + (("", {}), {"imp_user": "imposter"}, ({"imp_user": "imposter"},)), + ( + ("", {}), + {"db": "something", "imp_user": "imposter"}, + ({"db": "something", "imp_user": "imposter"},) + ), +)) +@mark_async_test +async def test_extra_in_begin(fake_socket, args, kwargs, expected_fields): + address = ("127.0.0.1", 7687) + socket = fake_socket(address, AsyncBolt5x0.UNPACKER_CLS) + connection = AsyncBolt5x0(address, socket, PoolConfig.max_connection_lifetime) + connection.begin(*args, **kwargs) + await connection.send_all() + tag, is_fields = await socket.pop_message() + assert tag == b"\x11" + assert tuple(is_fields) == expected_fields + + +@pytest.mark.parametrize(("args", "kwargs", "expected_fields"), ( + (("", {}), {"db": "something"}, ("", {}, {"db": "something"})), + (("", {}), {"imp_user": "imposter"}, ("", {}, {"imp_user": "imposter"})), + ( + ("", {}), + {"db": "something", "imp_user": "imposter"}, + ("", {}, {"db": "something", "imp_user": "imposter"}) + ), +)) +@mark_async_test +async def test_extra_in_run(fake_socket, args, kwargs, expected_fields): + address = ("127.0.0.1", 7687) + socket = fake_socket(address, AsyncBolt5x0.UNPACKER_CLS) + connection = AsyncBolt5x0(address, socket, PoolConfig.max_connection_lifetime) + connection.run(*args, **kwargs) + await connection.send_all() + tag, is_fields = await socket.pop_message() + assert tag == b"\x10" + assert tuple(is_fields) == expected_fields + + +@mark_async_test +async def test_n_extra_in_discard(fake_socket): + address = ("127.0.0.1", 7687) + socket = fake_socket(address, AsyncBolt5x0.UNPACKER_CLS) + connection = AsyncBolt5x0(address, socket, PoolConfig.max_connection_lifetime) + connection.discard(n=666) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x2F" + assert len(fields) == 1 + assert fields[0] == {"n": 666} + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (666, {"n": -1, "qid": 666}), + (-1, {"n": -1}), + ] +) +@mark_async_test +async def test_qid_extra_in_discard(fake_socket, test_input, expected): + address = ("127.0.0.1", 7687) + socket = fake_socket(address, AsyncBolt5x0.UNPACKER_CLS) + connection = AsyncBolt5x0(address, socket, PoolConfig.max_connection_lifetime) + connection.discard(qid=test_input) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x2F" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (777, {"n": 666, "qid": 777}), + (-1, {"n": 666}), + ] +) +@mark_async_test +async def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): + address = ("127.0.0.1", 7687) + socket = fake_socket(address, AsyncBolt5x0.UNPACKER_CLS) + connection = AsyncBolt5x0(address, socket, PoolConfig.max_connection_lifetime) + connection.discard(n=666, qid=test_input) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x2F" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (666, {"n": 666}), + (-1, {"n": -1}), + ] +) +@mark_async_test +async def test_n_extra_in_pull(fake_socket, test_input, expected): + address = ("127.0.0.1", 7687) + socket = fake_socket(address, AsyncBolt5x0.UNPACKER_CLS) + connection = AsyncBolt5x0(address, socket, PoolConfig.max_connection_lifetime) + connection.pull(n=test_input) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x3F" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (777, {"n": -1, "qid": 777}), + (-1, {"n": -1}), + ] +) +@mark_async_test +async def test_qid_extra_in_pull(fake_socket, test_input, expected): + address = ("127.0.0.1", 7687) + socket = fake_socket(address, AsyncBolt5x0.UNPACKER_CLS) + connection = AsyncBolt5x0(address, socket, PoolConfig.max_connection_lifetime) + connection.pull(qid=test_input) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x3F" + assert len(fields) == 1 + assert fields[0] == expected + + +@mark_async_test +async def test_n_and_qid_extras_in_pull(fake_socket): + address = ("127.0.0.1", 7687) + socket = fake_socket(address, AsyncBolt5x0.UNPACKER_CLS) + connection = AsyncBolt5x0(address, socket, PoolConfig.max_connection_lifetime) + connection.pull(n=666, qid=777) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x3F" + assert len(fields) == 1 + assert fields[0] == {"n": 666, "qid": 777} + + +@mark_async_test +async def test_hello_passes_routing_metadata(fake_socket_pair): + address = ("127.0.0.1", 7687) + sockets = fake_socket_pair(address, + packer_cls=AsyncBolt5x0.PACKER_CLS, + unpacker_cls=AsyncBolt5x0.UNPACKER_CLS) + await sockets.server.send_message(b"\x70", {"server": "Neo4j/4.4.0"}) + connection = AsyncBolt5x0( + address, sockets.client, PoolConfig.max_connection_lifetime, + routing_context={"foo": "bar"} + ) + await connection.hello() + tag, fields = await sockets.server.pop_message() + assert tag == b"\x01" + assert len(fields) == 1 + assert fields[0]["routing"] == {"foo": "bar"} + + +@pytest.mark.parametrize(("hints", "valid"), ( + ({"connection.recv_timeout_seconds": 1}, True), + ({"connection.recv_timeout_seconds": 42}, True), + ({}, True), + ({"whatever_this_is": "ignore me!"}, True), + ({"connection.recv_timeout_seconds": -1}, False), + ({"connection.recv_timeout_seconds": 0}, False), + ({"connection.recv_timeout_seconds": 2.5}, False), + ({"connection.recv_timeout_seconds": None}, False), + ({"connection.recv_timeout_seconds": False}, False), + ({"connection.recv_timeout_seconds": "1"}, False), +)) +@mark_async_test +async def test_hint_recv_timeout_seconds( + fake_socket_pair, hints, valid, caplog, mocker +): + address = ("127.0.0.1", 7687) + sockets = fake_socket_pair(address, + packer_cls=AsyncBolt5x0.PACKER_CLS, + unpacker_cls=AsyncBolt5x0.UNPACKER_CLS) + sockets.client.settimeout = mocker.Mock() + await sockets.server.send_message( + b"\x70", {"server": "Neo4j/4.3.4", "hints": hints} + ) + connection = AsyncBolt5x0( + address, sockets.client, PoolConfig.max_connection_lifetime + ) + with caplog.at_level(logging.INFO): + await connection.hello() + if valid: + if "connection.recv_timeout_seconds" in hints: + sockets.client.settimeout.assert_called_once_with( + hints["connection.recv_timeout_seconds"] + ) + else: + sockets.client.settimeout.assert_not_called() + assert not any("recv_timeout_seconds" in msg + and "invalid" in msg + for msg in caplog.messages) + else: + sockets.client.settimeout.assert_not_called() + assert any(repr(hints["connection.recv_timeout_seconds"]) in msg + and "recv_timeout_seconds" in msg + and "invalid" in msg + for msg in caplog.messages) + + +CREDENTIALS = "+++super-secret-sauce+++" + + +@pytest.mark.parametrize("auth", ( + ("user", CREDENTIALS), + neo4j.basic_auth("user", CREDENTIALS), + neo4j.kerberos_auth(CREDENTIALS), + neo4j.bearer_auth(CREDENTIALS), + neo4j.custom_auth("user", CREDENTIALS, "realm", "scheme"), + neo4j.Auth("scheme", "principal", CREDENTIALS, "realm", foo="bar"), +)) +@mark_async_test +async def test_credentials_are_not_logged( + auth, fake_socket_pair, mocker, caplog +): + address = ("127.0.0.1", 7687) + sockets = fake_socket_pair(address, + packer_cls=AsyncBolt5x0.PACKER_CLS, + unpacker_cls=AsyncBolt5x0.UNPACKER_CLS) + sockets.client.settimeout = mocker.Mock() + await sockets.server.send_message(b"\x70", {"server": "Neo4j/4.3.4"}) + connection = AsyncBolt5x0( + address, sockets.client, PoolConfig.max_connection_lifetime, auth=auth + ) + with caplog.at_level(logging.DEBUG): + await connection.hello() + + if isinstance(auth, tuple): + auth = neo4j.basic_auth(*auth) + for field in ("scheme", "principal", "realm", "parameters"): + value = getattr(auth, field, None) + if value: + assert repr(value) in caplog.text + assert CREDENTIALS not in caplog.text + + +@pytest.mark.parametrize("message", ("logon", "logoff")) +def test_auth_message_raises_configuration_error(message, fake_socket): + address = ("127.0.0.1", 7687) + connection = AsyncBolt5x0(address, fake_socket(address), + PoolConfig.max_connection_lifetime) + with pytest.raises(ConfigurationError, + match="Session level authentication is not supported"): + getattr(connection, message)() + + +@pytest.mark.parametrize("auth", ( + None, + neo4j.Auth("scheme", "principal", "credentials", "realm"), + ("user", "password"), +)) +@mark_async_test +async def test_re_auth_noop(auth, fake_socket, mocker): + address = ("127.0.0.1", 7687) + connection = AsyncBolt5x0(address, fake_socket(address), + PoolConfig.max_connection_lifetime, auth=auth) + logon_spy = mocker.spy(connection, "logon") + logoff_spy = mocker.spy(connection, "logoff") + res = await connection.re_auth(auth) + + assert res is False + logon_spy.assert_not_called() + logoff_spy.assert_not_called() + + +@pytest.mark.parametrize( + ("auth1", "auth2"), + permutations( + ( + None, + neo4j.Auth("scheme", "principal", "credentials", "realm"), + ("user", "password"), + ), + 2 + ) +) +@mark_async_test +async def test_re_auth(auth1, auth2, fake_socket): + address = ("127.0.0.1", 7687) + connection = AsyncBolt5x0(address, fake_socket(address), + PoolConfig.max_connection_lifetime, auth=auth1) + with pytest.raises(ConfigurationError, + match="Session level authentication is not supported"): + await connection.re_auth(auth2) diff --git a/tests/unit/async_/io/test_class_bolt5x1.py b/tests/unit/async_/io/test_class_bolt5x1.py new file mode 100644 index 00000000..a2fdc280 --- /dev/null +++ b/tests/unit/async_/io/test_class_bolt5x1.py @@ -0,0 +1,405 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging + +import pytest + +import neo4j +import neo4j.exceptions +from neo4j._async.io._bolt5 import AsyncBolt5x1 +from neo4j._conf import PoolConfig + +from ...._async_compat import mark_async_test + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_stale(fake_socket, set_stale): + address = ("127.0.0.1", 7687) + max_connection_lifetime = 0 + connection = AsyncBolt5x1(address, fake_socket(address), + max_connection_lifetime) + if set_stale: + connection.set_stale() + assert connection.stale() is True + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): + address = ("127.0.0.1", 7687) + max_connection_lifetime = -1 + connection = AsyncBolt5x1(address, fake_socket(address), + max_connection_lifetime) + if set_stale: + connection.set_stale() + assert connection.stale() is set_stale + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_not_stale(fake_socket, set_stale): + address = ("127.0.0.1", 7687) + max_connection_lifetime = 999999999 + connection = AsyncBolt5x1(address, fake_socket(address), + max_connection_lifetime) + if set_stale: + connection.set_stale() + assert connection.stale() is set_stale + + +@pytest.mark.parametrize(("args", "kwargs", "expected_fields"), ( + (("", {}), {"db": "something"}, ({"db": "something"},)), + (("", {}), {"imp_user": "imposter"}, ({"imp_user": "imposter"},)), + ( + ("", {}), + {"db": "something", "imp_user": "imposter"}, + ({"db": "something", "imp_user": "imposter"},) + ), +)) +@mark_async_test +async def test_extra_in_begin(fake_socket, args, kwargs, expected_fields): + address = ("127.0.0.1", 7687) + socket = fake_socket(address, AsyncBolt5x1.UNPACKER_CLS) + connection = AsyncBolt5x1(address, socket, + PoolConfig.max_connection_lifetime) + connection.begin(*args, **kwargs) + await connection.send_all() + tag, is_fields = await socket.pop_message() + assert tag == b"\x11" + assert tuple(is_fields) == expected_fields + + +@pytest.mark.parametrize(("args", "kwargs", "expected_fields"), ( + (("", {}), {"db": "something"}, ("", {}, {"db": "something"})), + (("", {}), {"imp_user": "imposter"}, ("", {}, {"imp_user": "imposter"})), + ( + ("", {}), + {"db": "something", "imp_user": "imposter"}, + ("", {}, {"db": "something", "imp_user": "imposter"}) + ), +)) +@mark_async_test +async def test_extra_in_run(fake_socket, args, kwargs, expected_fields): + address = ("127.0.0.1", 7687) + socket = fake_socket(address, AsyncBolt5x1.UNPACKER_CLS) + connection = AsyncBolt5x1(address, socket, + PoolConfig.max_connection_lifetime) + connection.run(*args, **kwargs) + await connection.send_all() + tag, is_fields = await socket.pop_message() + assert tag == b"\x10" + assert tuple(is_fields) == expected_fields + + +@mark_async_test +async def test_n_extra_in_discard(fake_socket): + address = ("127.0.0.1", 7687) + socket = fake_socket(address, AsyncBolt5x1.UNPACKER_CLS) + connection = AsyncBolt5x1(address, socket, + PoolConfig.max_connection_lifetime) + connection.discard(n=666) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x2F" + assert len(fields) == 1 + assert fields[0] == {"n": 666} + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (666, {"n": -1, "qid": 666}), + (-1, {"n": -1}), + ] +) +@mark_async_test +async def test_qid_extra_in_discard(fake_socket, test_input, expected): + address = ("127.0.0.1", 7687) + socket = fake_socket(address, AsyncBolt5x1.UNPACKER_CLS) + connection = AsyncBolt5x1(address, socket, + PoolConfig.max_connection_lifetime) + connection.discard(qid=test_input) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x2F" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (777, {"n": 666, "qid": 777}), + (-1, {"n": 666}), + ] +) +@mark_async_test +async def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): + address = ("127.0.0.1", 7687) + socket = fake_socket(address, AsyncBolt5x1.UNPACKER_CLS) + connection = AsyncBolt5x1(address, socket, + PoolConfig.max_connection_lifetime) + connection.discard(n=666, qid=test_input) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x2F" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (666, {"n": 666}), + (-1, {"n": -1}), + ] +) +@mark_async_test +async def test_n_extra_in_pull(fake_socket, test_input, expected): + address = ("127.0.0.1", 7687) + socket = fake_socket(address, AsyncBolt5x1.UNPACKER_CLS) + connection = AsyncBolt5x1(address, socket, + PoolConfig.max_connection_lifetime) + connection.pull(n=test_input) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x3F" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (777, {"n": -1, "qid": 777}), + (-1, {"n": -1}), + ] +) +@mark_async_test +async def test_qid_extra_in_pull(fake_socket, test_input, expected): + address = ("127.0.0.1", 7687) + socket = fake_socket(address, AsyncBolt5x1.UNPACKER_CLS) + connection = AsyncBolt5x1(address, socket, + PoolConfig.max_connection_lifetime) + connection.pull(qid=test_input) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x3F" + assert len(fields) == 1 + assert fields[0] == expected + + +@mark_async_test +async def test_n_and_qid_extras_in_pull(fake_socket): + address = ("127.0.0.1", 7687) + socket = fake_socket(address, AsyncBolt5x1.UNPACKER_CLS) + connection = AsyncBolt5x1(address, socket, + PoolConfig.max_connection_lifetime) + connection.pull(n=666, qid=777) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x3F" + assert len(fields) == 1 + assert fields[0] == {"n": 666, "qid": 777} + + +@mark_async_test +async def test_hello_passes_routing_metadata(fake_socket_pair): + address = ("127.0.0.1", 7687) + sockets = fake_socket_pair(address, + packer_cls=AsyncBolt5x1.PACKER_CLS, + unpacker_cls=AsyncBolt5x1.UNPACKER_CLS) + await sockets.server.send_message(b"\x70", {"server": "Neo4j/4.4.0"}) + await sockets.server.send_message(b"\x70", {}) + connection = AsyncBolt5x1( + address, sockets.client, PoolConfig.max_connection_lifetime, + routing_context={"foo": "bar"} + ) + await connection.hello() + tag, fields = await sockets.server.pop_message() + assert tag == b"\x01" + assert len(fields) == 1 + assert fields[0]["routing"] == {"foo": "bar"} + + +async def _assert_logon_message(sockets, auth): + tag, fields = await sockets.server.pop_message() + assert tag == b"\x6A" # LOGON + assert len(fields) == 1 + keys = ["scheme", "principal", "credentials"] + assert list(fields[0].keys()) == keys + for key in keys: + assert fields[0][key] == getattr(auth, key) + + +@mark_async_test +async def test_hello_pipelines_logon(fake_socket_pair): + auth = neo4j.Auth("basic", "alice123", "supersecret123") + address = ("127.0.0.1", 7687) + sockets = fake_socket_pair(address, + packer_cls=AsyncBolt5x1.PACKER_CLS, + unpacker_cls=AsyncBolt5x1.UNPACKER_CLS) + await sockets.server.send_message( + b"\x7F", {"code": "Neo.DatabaseError.General.MadeUpError", + "message": "kthxbye"} + ) + connection = AsyncBolt5x1( + address, sockets.client, PoolConfig.max_connection_lifetime, auth=auth + ) + with pytest.raises(neo4j.exceptions.ServiceUnavailable): + await connection.hello() + tag, fields = await sockets.server.pop_message() + assert tag == b"\x01" # HELLO + assert len(fields) == 1 + assert list(fields[0].keys()) == ["user_agent"] + assert auth.credentials not in repr(fields) + await _assert_logon_message(sockets, auth) + + +@mark_async_test +async def test_logon(fake_socket_pair): + auth = neo4j.Auth("basic", "alice123", "supersecret123") + address = ("127.0.0.1", 7687) + sockets = fake_socket_pair(address, + packer_cls=AsyncBolt5x1.PACKER_CLS, + unpacker_cls=AsyncBolt5x1.UNPACKER_CLS) + connection = AsyncBolt5x1(address, sockets.client, + PoolConfig.max_connection_lifetime, auth=auth) + connection.logon() + await connection.send_all() + await _assert_logon_message(sockets, auth) + + +@mark_async_test +async def test_re_auth(fake_socket_pair): + auth = neo4j.Auth("basic", "alice123", "supersecret123") + address = ("127.0.0.1", 7687) + sockets = fake_socket_pair(address, + packer_cls=AsyncBolt5x1.PACKER_CLS, + unpacker_cls=AsyncBolt5x1.UNPACKER_CLS) + await sockets.server.send_message( + b"\x7F", {"code": "Neo.DatabaseError.General.MadeUpError", + "message": "kthxbye"} + ) + connection = AsyncBolt5x1(address, sockets.client, + PoolConfig.max_connection_lifetime) + with pytest.raises(neo4j.exceptions.Neo4jError): + await connection.re_auth(auth) + tag, fields = await sockets.server.pop_message() + assert tag == b"\x6B" # LOGOFF + assert len(fields) == 0 + await _assert_logon_message(sockets, auth) + + +@mark_async_test +async def test_logoff(fake_socket_pair): + address = ("127.0.0.1", 7687) + sockets = fake_socket_pair(address, + packer_cls=AsyncBolt5x1.PACKER_CLS, + unpacker_cls=AsyncBolt5x1.UNPACKER_CLS) + await sockets.server.send_message(b"\x70", {}) + connection = AsyncBolt5x1(address, sockets.client, + PoolConfig.max_connection_lifetime) + connection.logoff() + assert not sockets.server.recv_buffer # pipelined, so no response yet + await connection.send_all() + assert sockets.server.recv_buffer # now! + tag, fields = await sockets.server.pop_message() + assert tag == b"\x6B" # LOGOFF + assert len(fields) == 0 + + +@pytest.mark.parametrize(("hints", "valid"), ( + ({"connection.recv_timeout_seconds": 1}, True), + ({"connection.recv_timeout_seconds": 42}, True), + ({}, True), + ({"whatever_this_is": "ignore me!"}, True), + ({"connection.recv_timeout_seconds": -1}, False), + ({"connection.recv_timeout_seconds": 0}, False), + ({"connection.recv_timeout_seconds": 2.5}, False), + ({"connection.recv_timeout_seconds": None}, False), + ({"connection.recv_timeout_seconds": False}, False), + ({"connection.recv_timeout_seconds": "1"}, False), +)) +@mark_async_test +async def test_hint_recv_timeout_seconds( + fake_socket_pair, hints, valid, caplog, mocker +): + address = ("127.0.0.1", 7687) + sockets = fake_socket_pair(address, + packer_cls=AsyncBolt5x1.PACKER_CLS, + unpacker_cls=AsyncBolt5x1.UNPACKER_CLS) + sockets.client.settimeout = mocker.Mock() + await sockets.server.send_message( + b"\x70", {"server": "Neo4j/4.3.4", "hints": hints} + ) + await sockets.server.send_message(b"\x70", {}) + connection = AsyncBolt5x1( + address, sockets.client, PoolConfig.max_connection_lifetime + ) + with caplog.at_level(logging.INFO): + await connection.hello() + if valid: + if "connection.recv_timeout_seconds" in hints: + sockets.client.settimeout.assert_called_once_with( + hints["connection.recv_timeout_seconds"] + ) + else: + sockets.client.settimeout.assert_not_called() + assert not any("recv_timeout_seconds" in msg + and "invalid" in msg + for msg in caplog.messages) + else: + sockets.client.settimeout.assert_not_called() + assert any(repr(hints["connection.recv_timeout_seconds"]) in msg + and "recv_timeout_seconds" in msg + and "invalid" in msg + for msg in caplog.messages) + + +CREDENTIALS = "+++super-secret-sauce+++" + + +@pytest.mark.parametrize("auth", ( + ("user", CREDENTIALS), + neo4j.basic_auth("user", CREDENTIALS), + neo4j.kerberos_auth(CREDENTIALS), + neo4j.bearer_auth(CREDENTIALS), + neo4j.custom_auth("user", CREDENTIALS, "realm", "scheme"), + neo4j.Auth("scheme", "principal", CREDENTIALS, "realm", foo="bar"), +)) +@mark_async_test +async def test_credentials_are_not_logged(auth, fake_socket_pair, caplog): + address = ("127.0.0.1", 7687) + sockets = fake_socket_pair(address, + packer_cls=AsyncBolt5x1.PACKER_CLS, + unpacker_cls=AsyncBolt5x1.UNPACKER_CLS) + await sockets.server.send_message(b"\x70", {"server": "Neo4j/4.3.4"}) + await sockets.server.send_message(b"\x70", {}) + connection = AsyncBolt5x1( + address, sockets.client, PoolConfig.max_connection_lifetime, auth=auth + ) + with caplog.at_level(logging.DEBUG): + await connection.hello() + + if isinstance(auth, tuple): + auth = neo4j.basic_auth(*auth) + for field in ("scheme", "principal", "realm", "parameters"): + value = getattr(auth, field, None) + if value: + assert repr(value) in caplog.text + assert CREDENTIALS not in caplog.text diff --git a/tests/unit/async_/io/test_direct.py b/tests/unit/async_/io/test_direct.py index 7729036e..15e3c709 100644 --- a/tests/unit/async_/io/test_direct.py +++ b/tests/unit/async_/io/test_direct.py @@ -65,6 +65,9 @@ def stale(self): async def reset(self): pass + async def re_auth(self, auth): + return False + def close(self): self.socket.close() @@ -85,17 +88,18 @@ def __init__(self, address, *, auth=None, **config): if config: raise ValueError("Unexpected config keys: %s" % ", ".join(config.keys())) - async def opener(addr, timeout): + async def opener(addr, auth, timeout): return AsyncQuickConnection(AsyncFakeSocket(addr)) super().__init__(opener, self.pool_config, self.workspace_config) self.address = address async def acquire( - self, access_mode, timeout, database, bookmarks, liveness_check_timeout + self, access_mode, timeout, database, bookmarks, auth, + liveness_check_timeout ): return await self._acquire( - self.address, timeout, liveness_check_timeout + self.address, auth, timeout, liveness_check_timeout ) @@ -145,7 +149,7 @@ def assert_pool_size( address, expected_active, expected_inactive, pool): @mark_async_test async def test_pool_can_acquire(pool): address = ("127.0.0.1", 7687) - connection = await pool._acquire(address, Deadline(3), None) + connection = await pool._acquire(address, None, Deadline(3), None) assert connection.address == address assert_pool_size(address, 1, 0, pool) @@ -153,8 +157,8 @@ async def test_pool_can_acquire(pool): @mark_async_test async def test_pool_can_acquire_twice(pool): address = ("127.0.0.1", 7687) - connection_1 = await pool._acquire(address, Deadline(3), None) - connection_2 = await pool._acquire(address, Deadline(3), None) + connection_1 = await pool._acquire(address, None, Deadline(3), None) + connection_2 = await pool._acquire(address, None, Deadline(3), None) assert connection_1.address == address assert connection_2.address == address assert connection_1 is not connection_2 @@ -165,8 +169,8 @@ async def test_pool_can_acquire_twice(pool): async def test_pool_can_acquire_two_addresses(pool): address_1 = ("127.0.0.1", 7687) address_2 = ("127.0.0.1", 7474) - connection_1 = await pool._acquire(address_1, Deadline(3), None) - connection_2 = await pool._acquire(address_2, Deadline(3), None) + connection_1 = await pool._acquire(address_1, None, Deadline(3), None) + connection_2 = await pool._acquire(address_2, None, Deadline(3), None) assert connection_1.address == address_1 assert connection_2.address == address_2 assert_pool_size(address_1, 1, 0, pool) @@ -176,7 +180,7 @@ async def test_pool_can_acquire_two_addresses(pool): @mark_async_test async def test_pool_can_acquire_and_release(pool): address = ("127.0.0.1", 7687) - connection = await pool._acquire(address, Deadline(3), None) + connection = await pool._acquire(address, None, Deadline(3), None) assert_pool_size(address, 1, 0, pool) await pool.release(connection) assert_pool_size(address, 0, 1, pool) @@ -185,7 +189,7 @@ async def test_pool_can_acquire_and_release(pool): @mark_async_test async def test_pool_releasing_twice(pool): address = ("127.0.0.1", 7687) - connection = await pool._acquire(address, Deadline(3), None) + connection = await pool._acquire(address, None, Deadline(3), None) await pool.release(connection) assert_pool_size(address, 0, 1, pool) await pool.release(connection) @@ -196,7 +200,7 @@ async def test_pool_releasing_twice(pool): async def test_pool_in_use_count(pool): address = ("127.0.0.1", 7687) assert pool.in_use_connection_count(address) == 0 - connection = await pool._acquire(address, Deadline(3), None) + connection = await pool._acquire(address, None, Deadline(3), None) assert pool.in_use_connection_count(address) == 1 await pool.release(connection) assert pool.in_use_connection_count(address) == 0 @@ -206,10 +210,10 @@ async def test_pool_in_use_count(pool): async def test_pool_max_conn_pool_size(pool): async with AsyncFakeBoltPool((), max_connection_pool_size=1) as pool: address = ("127.0.0.1", 7687) - await pool._acquire(address, Deadline(0), None) + await pool._acquire(address, None, Deadline(0), None) assert pool.in_use_connection_count(address) == 1 with pytest.raises(ClientError): - await pool._acquire(address, Deadline(0), None) + await pool._acquire(address, None, Deadline(0), None) assert pool.in_use_connection_count(address) == 1 @@ -227,7 +231,7 @@ async def test_pool_reset_when_released(is_reset, pool, mocker): new_callable=mocker.AsyncMock ) is_reset_mock.return_value = is_reset - connection = await pool._acquire(address, Deadline(3), None) + connection = await pool._acquire(address, None, Deadline(3), None) assert isinstance(connection, AsyncQuickConnection) assert is_reset_mock.call_count == 0 assert reset_mock.call_count == 0 diff --git a/tests/unit/async_/io/test_neo4j_pool.py b/tests/unit/async_/io/test_neo4j_pool.py index 3bf510df..cb47c159 100644 --- a/tests/unit/async_/io/test_neo4j_pool.py +++ b/tests/unit/async_/io/test_neo4j_pool.py @@ -69,10 +69,11 @@ def routing_side_effect(*args, **kwargs): }] raise res - async def open_(addr, timeout): + async def open_(addr, auth, timeout): connection = async_fake_connection_generator() connection.addr = addr connection.timeout = timeout + connection.auth = auth route_mock = mocker.AsyncMock() route_mock.side_effect = routing_side_effect @@ -99,13 +100,13 @@ async def test_acquires_new_routing_table_if_deleted(opener): pool = AsyncNeo4jPool( opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS ) - cx = await pool.acquire(READ_ACCESS, 30, "test_db", None, None) + cx = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) await pool.release(cx) assert pool.routing_tables.get("test_db") del pool.routing_tables["test_db"] - cx = await pool.acquire(READ_ACCESS, 30, "test_db", None, None) + cx = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) await pool.release(cx) assert pool.routing_tables.get("test_db") @@ -115,14 +116,14 @@ async def test_acquires_new_routing_table_if_stale(opener): pool = AsyncNeo4jPool( opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS ) - cx = await pool.acquire(READ_ACCESS, 30, "test_db", None, None) + cx = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) await pool.release(cx) assert pool.routing_tables.get("test_db") old_value = pool.routing_tables["test_db"].last_updated_time pool.routing_tables["test_db"].ttl = 0 - cx = await pool.acquire(READ_ACCESS, 30, "test_db", None, None) + cx = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) await pool.release(cx) assert pool.routing_tables["test_db"].last_updated_time > old_value @@ -132,10 +133,10 @@ async def test_removes_old_routing_table(opener): pool = AsyncNeo4jPool( opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS ) - cx = await pool.acquire(READ_ACCESS, 30, "test_db1", None, None) + cx = await pool.acquire(READ_ACCESS, 30, "test_db1", None, None, None) await pool.release(cx) assert pool.routing_tables.get("test_db1") - cx = await pool.acquire(READ_ACCESS, 30, "test_db2", None, None) + cx = await pool.acquire(READ_ACCESS, 30, "test_db2", None, None, None) await pool.release(cx) assert pool.routing_tables.get("test_db2") @@ -144,7 +145,7 @@ async def test_removes_old_routing_table(opener): pool.routing_tables["test_db2"].ttl = \ -RoutingConfig.routing_table_purge_delay - cx = await pool.acquire(READ_ACCESS, 30, "test_db1", None, None) + cx = await pool.acquire(READ_ACCESS, 30, "test_db1", None, None, None) await pool.release(cx) assert pool.routing_tables["test_db1"].last_updated_time > old_value assert "test_db2" not in pool.routing_tables @@ -158,7 +159,7 @@ async def test_chooses_right_connection_type(opener, type_): ) cx1 = await pool.acquire( READ_ACCESS if type_ == "r" else WRITE_ACCESS, - 30, "test_db", None, None + 30, "test_db", None, None, None ) await pool.release(cx1) if type_ == "r": @@ -172,9 +173,9 @@ async def test_reuses_connection(opener): pool = AsyncNeo4jPool( opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS ) - cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None) + cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) await pool.release(cx1) - cx2 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None) + cx2 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) assert cx1 is cx2 @@ -192,7 +193,7 @@ async def break_connection(): pool = AsyncNeo4jPool( opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS ) - cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None) + cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) await pool.release(cx1) assert cx1 in pool.connections[cx1.addr] # simulate connection going stale (e.g. exceeding) and then breaking when @@ -202,7 +203,7 @@ async def break_connection(): if break_on_close: cx_close_mock_side_effect = cx_close_mock.side_effect cx_close_mock.side_effect = break_connection - cx2 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None) + cx2 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) await pool.release(cx2) if break_on_close: cx1.close.assert_called() @@ -219,11 +220,11 @@ async def test_does_not_close_stale_connections_in_use(opener): pool = AsyncNeo4jPool( opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS ) - cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None) + cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) assert cx1 in pool.connections[cx1.addr] # simulate connection going stale (e.g. exceeding) while being in use cx1.stale.return_value = True - cx2 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None) + cx2 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) await pool.release(cx2) cx1.close.assert_not_called() assert cx2 is not cx1 @@ -236,7 +237,7 @@ async def test_does_not_close_stale_connections_in_use(opener): # it should be closed when trying to acquire the next connection cx1.close.assert_not_called() - cx3 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None) + cx3 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) await pool.release(cx3) cx1.close.assert_called_once() assert cx2 is cx3 @@ -250,7 +251,7 @@ async def test_release_resets_connections(opener): pool = AsyncNeo4jPool( opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS ) - cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None) + cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) cx1.is_reset_mock.return_value = False cx1.is_reset_mock.reset_mock() await pool.release(cx1) @@ -263,7 +264,7 @@ async def test_release_does_not_resets_closed_connections(opener): pool = AsyncNeo4jPool( opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS ) - cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None) + cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) cx1.closed.return_value = True cx1.closed.reset_mock() cx1.is_reset_mock.reset_mock() @@ -278,7 +279,7 @@ async def test_release_does_not_resets_defunct_connections(opener): pool = AsyncNeo4jPool( opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS ) - cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None) + cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) cx1.defunct.return_value = True cx1.defunct.reset_mock() cx1.is_reset_mock.reset_mock() @@ -296,7 +297,8 @@ async def test_acquire_performs_no_liveness_check_on_fresh_connection( pool = AsyncNeo4jPool( opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS ) - cx1 = await pool._acquire(READER_ADDRESS, Deadline(30), liveness_timeout) + cx1 = await pool._acquire(READER_ADDRESS, None, Deadline(30), + liveness_timeout) assert cx1.addr == READER_ADDRESS cx1.reset.assert_not_called() @@ -310,7 +312,8 @@ async def test_acquire_performs_liveness_check_on_existing_connection( opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS ) # populate the pool with a connection - cx1 = await pool._acquire(READER_ADDRESS, Deadline(30), liveness_timeout) + cx1 = await pool._acquire(READER_ADDRESS, None, Deadline(30), + liveness_timeout) # make sure we assume the right state assert cx1.addr == READER_ADDRESS @@ -324,7 +327,8 @@ async def test_acquire_performs_liveness_check_on_existing_connection( cx1.reset.assert_not_called() # then acquire it again and assert the liveness check was performed - cx2 = await pool._acquire(READER_ADDRESS, Deadline(30), liveness_timeout) + cx2 = await pool._acquire(READER_ADDRESS, None, Deadline(30), + liveness_timeout) assert cx1 is cx2 cx1.is_idle_for.assert_called_once_with(liveness_timeout) cx2.reset.assert_awaited_once() @@ -344,7 +348,8 @@ def liveness_side_effect(*args, **kwargs): opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS ) # populate the pool with a connection - cx1 = await pool._acquire(READER_ADDRESS, Deadline(30), liveness_timeout) + cx1 = await pool._acquire(READER_ADDRESS, None, Deadline(30), + liveness_timeout) # make sure we assume the right state assert cx1.addr == READER_ADDRESS @@ -360,7 +365,8 @@ def liveness_side_effect(*args, **kwargs): cx1.reset.assert_not_called() # then acquire it again and assert the liveness check was performed - cx2 = await pool._acquire(READER_ADDRESS, Deadline(30), liveness_timeout) + cx2 = await pool._acquire(READER_ADDRESS, None, Deadline(30), + liveness_timeout) assert cx1 is not cx2 assert cx1.addr == cx2.addr cx1.is_idle_for.assert_called_once_with(liveness_timeout) @@ -383,8 +389,10 @@ def liveness_side_effect(*args, **kwargs): opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS ) # populate the pool with a connection - cx1 = await pool._acquire(READER_ADDRESS, Deadline(30), liveness_timeout) - cx2 = await pool._acquire(READER_ADDRESS, Deadline(30), liveness_timeout) + cx1 = await pool._acquire(READER_ADDRESS, None, Deadline(30), + liveness_timeout) + cx2 = await pool._acquire(READER_ADDRESS, None, Deadline(30), + liveness_timeout) # make sure we assume the right state assert cx1.addr == READER_ADDRESS @@ -406,7 +414,8 @@ def liveness_side_effect(*args, **kwargs): cx2.reset.assert_not_called() # then acquire it again and assert the liveness check was performed - cx3 = await pool._acquire(READER_ADDRESS, Deadline(30), liveness_timeout) + cx3 = await pool._acquire(READER_ADDRESS, None, Deadline(30), + liveness_timeout) assert cx3 is cx2 cx1.is_idle_for.assert_called_once_with(liveness_timeout) cx1.reset.assert_awaited_once() @@ -416,6 +425,33 @@ def liveness_side_effect(*args, **kwargs): assert cx3 in pool.connections[cx1.addr] +@mark_async_test +async def test_acquire_accepts_re_auth_as_liveness_check(opener): + pool = AsyncNeo4jPool( + opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS + ) + # populate the pool with a connection + cx1 = await pool._acquire(READER_ADDRESS, None, Deadline(30), 1) + # make sure we assume the right state + assert cx1.addr == READER_ADDRESS + cx1.is_idle_for.assert_not_called() + cx1.reset.assert_not_called() + + # simulate connections successfully re-authenticating + cx1.re_auth.return_value = True + + # release the connection + await pool.release(cx1) + cx1.reset.assert_not_called() + + # then acquire it again and assert the liveness check was performed + cx2 = await pool._acquire(READER_ADDRESS, None, Deadline(30), 1) + assert cx2 is cx1 + cx1.is_idle_for.assert_not_called() + cx1.reset.assert_not_called() + assert cx1 in pool.connections[cx1.addr] + + @mark_async_test async def test_multiple_broken_connections_on_close(opener, mocker): def mock_connection_breaks_on_close(cx): @@ -431,8 +467,8 @@ async def close_side_effect(): pool = AsyncNeo4jPool( opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS ) - cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None) - cx2 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None) + cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + cx2 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) await pool.release(cx1) await pool.release(cx2) @@ -444,7 +480,7 @@ async def close_side_effect(): # unreachable cx1.stale.return_value = True - cx3 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None) + cx3 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) assert cx3 is not cx1 assert cx3 is not cx2 @@ -455,11 +491,11 @@ async def test_failing_opener_leaves_connections_in_use_alone(opener): pool = AsyncNeo4jPool( opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS ) - cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None) + cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) opener.side_effect = ServiceUnavailable("Server overloaded") with pytest.raises((ServiceUnavailable, SessionExpired)): - await pool.acquire(READ_ACCESS, 30, "test_db", None, None) + await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) assert not cx1.closed() @@ -471,7 +507,7 @@ async def test__acquire_new_later_with_room(opener): opener, config, WorkspaceConfig(), ROUTER1_ADDRESS ) assert pool.connections_reservations[READER_ADDRESS] == 0 - creator = pool._acquire_new_later(READER_ADDRESS, Deadline(1)) + creator = pool._acquire_new_later(READER_ADDRESS, None, Deadline(1)) assert pool.connections_reservations[READER_ADDRESS] == 1 assert callable(creator) if AsyncUtil.is_async_code: @@ -485,10 +521,10 @@ async def test__acquire_new_later_without_room(opener): pool = AsyncNeo4jPool( opener, config, WorkspaceConfig(), ROUTER1_ADDRESS ) - _ = await pool.acquire(READ_ACCESS, 30, "test_db", None, None) + _ = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) # pool is full now assert pool.connections_reservations[READER_ADDRESS] == 0 - creator = pool._acquire_new_later(READER_ADDRESS, Deadline(1)) + creator = pool._acquire_new_later(READER_ADDRESS, None, Deadline(1)) assert pool.connections_reservations[READER_ADDRESS] == 0 assert creator is None @@ -509,11 +545,11 @@ async def test_discovery_is_retried(routing_failure_opener, error): opener, PoolConfig(), WorkspaceConfig(), ResolvedAddress(("1.2.3.1", 9999), host_name="host") ) - cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None) + cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) await pool.release(cx1) pool.routing_tables.get("test_db").ttl = 0 - cx2 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None) + cx2 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) await pool.release(cx2) assert pool.routing_tables.get("test_db") @@ -553,12 +589,12 @@ async def test_fast_failing_discovery(routing_failure_opener, error): opener, PoolConfig(), WorkspaceConfig(), ResolvedAddress(("1.2.3.1", 9999), host_name="host") ) - cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None) + cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) await pool.release(cx1) pool.routing_tables.get("test_db").ttl = 0 with pytest.raises(error.__class__) as exc: - await pool.acquire(READ_ACCESS, 30, "test_db", None, None) + await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) assert exc.value is error diff --git a/tests/unit/async_/test_driver.py b/tests/unit/async_/test_driver.py index f9007a40..31b4f21b 100644 --- a/tests/unit/async_/test_driver.py +++ b/tests/unit/async_/test_driver.py @@ -198,6 +198,7 @@ async def test_driver_opens_write_session_by_default(uri, fake_pool, mocker): timeout=mocker.ANY, database=mocker.ANY, bookmarks=mocker.ANY, + auth=mocker.ANY, liveness_check_timeout=mocker.ANY ) tx._begin.assert_awaited_once_with( diff --git a/tests/unit/async_/work/test_session.py b/tests/unit/async_/work/test_session.py index 5a0d0ad2..75439229 100644 --- a/tests/unit/async_/work/test_session.py +++ b/tests/unit/async_/work/test_session.py @@ -409,7 +409,7 @@ async def test_with_bookmark_manager( additional_session_bookmarks, mocker ): async def update_routing_table_side_effect( - database, imp_user, bookmarks, acquisition_timeout=None, + database, imp_user, bookmarks, auth=None, acquisition_timeout=None, database_callback=None ): if home_db_gets_resolved: diff --git a/tests/unit/common/test_conf.py b/tests/unit/common/test_conf.py index 9e6ef232..88e67197 100644 --- a/tests/unit/common/test_conf.py +++ b/tests/unit/common/test_conf.py @@ -56,6 +56,7 @@ "user_agent": "test", "trusted_certificates": TrustSystemCAs(), "ssl_context": None, + "auth": None, } test_session_config = { @@ -70,6 +71,7 @@ "impersonated_user": None, "fetch_size": 100, "bookmark_manager": object(), + "auth": None, } diff --git a/tests/unit/mixed/io/test_direct.py b/tests/unit/mixed/io/test_direct.py index be28f810..2eac6aa6 100644 --- a/tests/unit/mixed/io/test_direct.py +++ b/tests/unit/mixed/io/test_direct.py @@ -156,7 +156,7 @@ def test_multithread(self, pre_populated): def acquire_release_conn(pool_, address_, acquired_counter_, release_event_): nonlocal connections, connections_lock - conn_ = pool_._acquire(address_, Deadline(3), None) + conn_ = pool_._acquire(address_, None, Deadline(3), None) with connections_lock: if connections is not None: connections.append(conn_) @@ -171,7 +171,7 @@ def acquire_release_conn(pool_, address_, acquired_counter_, # pre-populate the pool with connections for _ in range(pre_populated): - conn = pool._acquire(address, Deadline(3), None) + conn = pool._acquire(address, None, Deadline(3), None) pre_populated_connections.append(conn) for conn in pre_populated_connections: pool.release(conn) @@ -217,7 +217,7 @@ async def test_multi_coroutine(self, pre_populated): async def acquire_release_conn(pool_, address_, acquired_counter_, release_event_): nonlocal connections - conn_ = await pool_._acquire(address_, Deadline(3), None) + conn_ = await pool_._acquire(address_, None, Deadline(3), None) if connections is not None: connections.append(conn_) await acquired_counter_.increment() @@ -251,7 +251,7 @@ async def waiter(pool_, acquired_counter_, release_event_): # pre-populate the pool with connections for _ in range(pre_populated): - conn = await pool._acquire(address, Deadline(3), None) + conn = await pool._acquire(address, None, Deadline(3), None) pre_populated_connections.append(conn) for conn in pre_populated_connections: await pool.release(conn) diff --git a/tests/unit/sync/fixtures/fake_connection.py b/tests/unit/sync/fixtures/fake_connection.py index f3a8e695..fdf5ad90 100644 --- a/tests/unit/sync/fixtures/fake_connection.py +++ b/tests/unit/sync/fixtures/fake_connection.py @@ -50,6 +50,7 @@ def __init__(self, *args, **kwargs): self.attach_mock(mock.Mock(return_value=False), "stale") self.attach_mock(mock.Mock(return_value=False), "closed") self.attach_mock(mock.Mock(return_value=False), "socket") + self.attach_mock(mock.Mock(return_value=False), "re_auth") self.attach_mock(mock.Mock(), "unresolved_address") def close_side_effect(): diff --git a/tests/unit/sync/io/test_class_bolt.py b/tests/unit/sync/io/test_class_bolt.py index dc323021..ed96bef5 100644 --- a/tests/unit/sync/io/test_class_bolt.py +++ b/tests/unit/sync/io/test_class_bolt.py @@ -31,12 +31,13 @@ def test_class_method_protocol_handlers(): protocol_handlers = Bolt.protocol_handlers() - assert len(protocol_handlers) == 6 - assert protocol_handlers.keys() == { + expected_versions = { (3, 0), (4, 1), (4, 2), (4, 3), (4, 4), - (5, 0), + (5, 0), (5, 1), } + assert len(protocol_handlers) == len(expected_versions) + assert protocol_handlers.keys() == expected_versions @pytest.mark.parametrize( @@ -52,7 +53,8 @@ def test_class_method_protocol_handlers(): ((4, 3), 1), ((4, 4), 1), ((5, 0), 1), - ((5, 1), 0), + ((5, 1), 1), + ((5, 2), 0), ((6, 0), 0), ] ) @@ -71,7 +73,7 @@ def test_class_method_protocol_handlers_with_invalid_protocol_version(): def test_class_method_get_handshake(): handshake = Bolt.get_handshake() - assert (b"\x00\x00\x00\x05\x00\x02\x04\x04\x00\x00\x01\x04\x00\x00\x00\x03" + assert (b"\x00\x01\x01\x05\x00\x02\x04\x04\x00\x00\x01\x04\x00\x00\x00\x03" == handshake) diff --git a/tests/unit/sync/io/test_class_bolt3.py b/tests/unit/sync/io/test_class_bolt3.py index 87f477d8..0c39167d 100644 --- a/tests/unit/sync/io/test_class_bolt3.py +++ b/tests/unit/sync/io/test_class_bolt3.py @@ -16,8 +16,12 @@ # limitations under the License. +import logging +from itertools import permutations + import pytest +import neo4j from neo4j._conf import PoolConfig from neo4j._sync.io._bolt3 import Bolt3 from neo4j.exceptions import ConfigurationError @@ -112,3 +116,89 @@ def test_hint_recv_timeout_seconds_gets_ignored( ) connection.hello() sockets.client.settimeout.assert_not_called() + + +CREDENTIALS = "+++super-secret-sauce+++" + + +@pytest.mark.parametrize("auth", ( + ("user", CREDENTIALS), + neo4j.basic_auth("user", CREDENTIALS), + neo4j.kerberos_auth(CREDENTIALS), + neo4j.bearer_auth(CREDENTIALS), + neo4j.custom_auth("user", CREDENTIALS, "realm", "scheme"), + neo4j.Auth("scheme", "principal", CREDENTIALS, "realm", foo="bar"), +)) +@mark_sync_test +def test_credentials_are_not_logged( + auth, fake_socket_pair, mocker, caplog +): + address = ("127.0.0.1", 7687) + sockets = fake_socket_pair(address, + packer_cls=Bolt3.PACKER_CLS, + unpacker_cls=Bolt3.UNPACKER_CLS) + sockets.client.settimeout = mocker.Mock() + sockets.server.send_message(b"\x70", {"server": "Neo4j/4.3.4"}) + connection = Bolt3( + address, sockets.client, PoolConfig.max_connection_lifetime, auth=auth + ) + with caplog.at_level(logging.DEBUG): + connection.hello() + + if isinstance(auth, tuple): + auth = neo4j.basic_auth(*auth) + for field in ("scheme", "principal", "realm", "parameters"): + value = getattr(auth, field, None) + if value: + assert repr(value) in caplog.text + assert CREDENTIALS not in caplog.text + + +@pytest.mark.parametrize("message", ("logon", "logoff")) +def test_auth_message_raises_configuration_error(message, fake_socket): + address = ("127.0.0.1", 7687) + connection = Bolt3(address, fake_socket(address), + PoolConfig.max_connection_lifetime) + with pytest.raises(ConfigurationError, + match="Session level authentication is not supported"): + getattr(connection, message)() + + +@pytest.mark.parametrize("auth", ( + None, + neo4j.Auth("scheme", "principal", "credentials", "realm"), + ("user", "password"), +)) +@mark_sync_test +def test_re_auth_noop(auth, fake_socket, mocker): + address = ("127.0.0.1", 7687) + connection = Bolt3(address, fake_socket(address), + PoolConfig.max_connection_lifetime, auth=auth) + logon_spy = mocker.spy(connection, "logon") + logoff_spy = mocker.spy(connection, "logoff") + res = connection.re_auth(auth) + + assert res is False + logon_spy.assert_not_called() + logoff_spy.assert_not_called() + + +@pytest.mark.parametrize( + ("auth1", "auth2"), + permutations( + ( + None, + neo4j.Auth("scheme", "principal", "credentials", "realm"), + ("user", "password"), + ), + 2 + ) +) +@mark_sync_test +def test_re_auth(auth1, auth2, fake_socket): + address = ("127.0.0.1", 7687) + connection = Bolt3(address, fake_socket(address), + PoolConfig.max_connection_lifetime, auth=auth1) + with pytest.raises(ConfigurationError, + match="Session level authentication is not supported"): + connection.re_auth(auth2) diff --git a/tests/unit/sync/io/test_class_bolt4x0.py b/tests/unit/sync/io/test_class_bolt4x0.py index 88f54993..ff3a595c 100644 --- a/tests/unit/sync/io/test_class_bolt4x0.py +++ b/tests/unit/sync/io/test_class_bolt4x0.py @@ -16,10 +16,15 @@ # limitations under the License. +import logging +from itertools import permutations + import pytest +import neo4j from neo4j._conf import PoolConfig from neo4j._sync.io._bolt4 import Bolt4x0 +from neo4j.exceptions import ConfigurationError from ...._async_compat import mark_sync_test @@ -207,3 +212,89 @@ def test_hint_recv_timeout_seconds_gets_ignored( ) connection.hello() sockets.client.settimeout.assert_not_called() + + +CREDENTIALS = "+++super-secret-sauce+++" + + +@pytest.mark.parametrize("auth", ( + ("user", CREDENTIALS), + neo4j.basic_auth("user", CREDENTIALS), + neo4j.kerberos_auth(CREDENTIALS), + neo4j.bearer_auth(CREDENTIALS), + neo4j.custom_auth("user", CREDENTIALS, "realm", "scheme"), + neo4j.Auth("scheme", "principal", CREDENTIALS, "realm", foo="bar"), +)) +@mark_sync_test +def test_credentials_are_not_logged( + auth, fake_socket_pair, mocker, caplog +): + address = ("127.0.0.1", 7687) + sockets = fake_socket_pair(address, + packer_cls=Bolt4x0.PACKER_CLS, + unpacker_cls=Bolt4x0.UNPACKER_CLS) + sockets.client.settimeout = mocker.Mock() + sockets.server.send_message(b"\x70", {"server": "Neo4j/4.3.4"}) + connection = Bolt4x0( + address, sockets.client, PoolConfig.max_connection_lifetime, auth=auth + ) + with caplog.at_level(logging.DEBUG): + connection.hello() + + if isinstance(auth, tuple): + auth = neo4j.basic_auth(*auth) + for field in ("scheme", "principal", "realm", "parameters"): + value = getattr(auth, field, None) + if value: + assert repr(value) in caplog.text + assert CREDENTIALS not in caplog.text + + +@pytest.mark.parametrize("message", ("logon", "logoff")) +def test_auth_message_raises_configuration_error(message, fake_socket): + address = ("127.0.0.1", 7687) + connection = Bolt4x0(address, fake_socket(address), + PoolConfig.max_connection_lifetime) + with pytest.raises(ConfigurationError, + match="Session level authentication is not supported"): + getattr(connection, message)() + + +@pytest.mark.parametrize("auth", ( + None, + neo4j.Auth("scheme", "principal", "credentials", "realm"), + ("user", "password"), +)) +@mark_sync_test +def test_re_auth_noop(auth, fake_socket, mocker): + address = ("127.0.0.1", 7687) + connection = Bolt4x0(address, fake_socket(address), + PoolConfig.max_connection_lifetime, auth=auth) + logon_spy = mocker.spy(connection, "logon") + logoff_spy = mocker.spy(connection, "logoff") + res = connection.re_auth(auth) + + assert res is False + logon_spy.assert_not_called() + logoff_spy.assert_not_called() + + +@pytest.mark.parametrize( + ("auth1", "auth2"), + permutations( + ( + None, + neo4j.Auth("scheme", "principal", "credentials", "realm"), + ("user", "password"), + ), + 2 + ) +) +@mark_sync_test +def test_re_auth(auth1, auth2, fake_socket): + address = ("127.0.0.1", 7687) + connection = Bolt4x0(address, fake_socket(address), + PoolConfig.max_connection_lifetime, auth=auth1) + with pytest.raises(ConfigurationError, + match="Session level authentication is not supported"): + connection.re_auth(auth2) diff --git a/tests/unit/sync/io/test_class_bolt4x1.py b/tests/unit/sync/io/test_class_bolt4x1.py index e656cc34..81b25336 100644 --- a/tests/unit/sync/io/test_class_bolt4x1.py +++ b/tests/unit/sync/io/test_class_bolt4x1.py @@ -16,10 +16,15 @@ # limitations under the License. +import logging +from itertools import permutations + import pytest +import neo4j from neo4j._conf import PoolConfig from neo4j._sync.io._bolt4 import Bolt4x1 +from neo4j.exceptions import ConfigurationError from ...._async_compat import mark_sync_test @@ -226,3 +231,89 @@ def test_hint_recv_timeout_seconds_gets_ignored( PoolConfig.max_connection_lifetime) connection.hello() sockets.client.settimeout.assert_not_called() + + +CREDENTIALS = "+++super-secret-sauce+++" + + +@pytest.mark.parametrize("auth", ( + ("user", CREDENTIALS), + neo4j.basic_auth("user", CREDENTIALS), + neo4j.kerberos_auth(CREDENTIALS), + neo4j.bearer_auth(CREDENTIALS), + neo4j.custom_auth("user", CREDENTIALS, "realm", "scheme"), + neo4j.Auth("scheme", "principal", CREDENTIALS, "realm", foo="bar"), +)) +@mark_sync_test +def test_credentials_are_not_logged( + auth, fake_socket_pair, mocker, caplog +): + address = ("127.0.0.1", 7687) + sockets = fake_socket_pair(address, + packer_cls=Bolt4x1.PACKER_CLS, + unpacker_cls=Bolt4x1.UNPACKER_CLS) + sockets.client.settimeout = mocker.Mock() + sockets.server.send_message(b"\x70", {"server": "Neo4j/4.3.4"}) + connection = Bolt4x1( + address, sockets.client, PoolConfig.max_connection_lifetime, auth=auth + ) + with caplog.at_level(logging.DEBUG): + connection.hello() + + if isinstance(auth, tuple): + auth = neo4j.basic_auth(*auth) + for field in ("scheme", "principal", "realm", "parameters"): + value = getattr(auth, field, None) + if value: + assert repr(value) in caplog.text + assert CREDENTIALS not in caplog.text + + +@pytest.mark.parametrize("message", ("logon", "logoff")) +def test_auth_message_raises_configuration_error(message, fake_socket): + address = ("127.0.0.1", 7687) + connection = Bolt4x1(address, fake_socket(address), + PoolConfig.max_connection_lifetime) + with pytest.raises(ConfigurationError, + match="Session level authentication is not supported"): + getattr(connection, message)() + + +@pytest.mark.parametrize("auth", ( + None, + neo4j.Auth("scheme", "principal", "credentials", "realm"), + ("user", "password"), +)) +@mark_sync_test +def test_re_auth_noop(auth, fake_socket, mocker): + address = ("127.0.0.1", 7687) + connection = Bolt4x1(address, fake_socket(address), + PoolConfig.max_connection_lifetime, auth=auth) + logon_spy = mocker.spy(connection, "logon") + logoff_spy = mocker.spy(connection, "logoff") + res = connection.re_auth(auth) + + assert res is False + logon_spy.assert_not_called() + logoff_spy.assert_not_called() + + +@pytest.mark.parametrize( + ("auth1", "auth2"), + permutations( + ( + None, + neo4j.Auth("scheme", "principal", "credentials", "realm"), + ("user", "password"), + ), + 2 + ) +) +@mark_sync_test +def test_re_auth(auth1, auth2, fake_socket): + address = ("127.0.0.1", 7687) + connection = Bolt4x1(address, fake_socket(address), + PoolConfig.max_connection_lifetime, auth=auth1) + with pytest.raises(ConfigurationError, + match="Session level authentication is not supported"): + connection.re_auth(auth2) diff --git a/tests/unit/sync/io/test_class_bolt4x2.py b/tests/unit/sync/io/test_class_bolt4x2.py index d6bff9c2..d5217c6d 100644 --- a/tests/unit/sync/io/test_class_bolt4x2.py +++ b/tests/unit/sync/io/test_class_bolt4x2.py @@ -16,10 +16,15 @@ # limitations under the License. +import logging +from itertools import permutations + import pytest +import neo4j from neo4j._conf import PoolConfig from neo4j._sync.io._bolt4 import Bolt4x2 +from neo4j.exceptions import ConfigurationError from ...._async_compat import mark_sync_test @@ -227,3 +232,89 @@ def test_hint_recv_timeout_seconds_gets_ignored( ) connection.hello() sockets.client.settimeout.assert_not_called() + + +CREDENTIALS = "+++super-secret-sauce+++" + + +@pytest.mark.parametrize("auth", ( + ("user", CREDENTIALS), + neo4j.basic_auth("user", CREDENTIALS), + neo4j.kerberos_auth(CREDENTIALS), + neo4j.bearer_auth(CREDENTIALS), + neo4j.custom_auth("user", CREDENTIALS, "realm", "scheme"), + neo4j.Auth("scheme", "principal", CREDENTIALS, "realm", foo="bar"), +)) +@mark_sync_test +def test_credentials_are_not_logged( + auth, fake_socket_pair, mocker, caplog +): + address = ("127.0.0.1", 7687) + sockets = fake_socket_pair(address, + packer_cls=Bolt4x2.PACKER_CLS, + unpacker_cls=Bolt4x2.UNPACKER_CLS) + sockets.client.settimeout = mocker.Mock() + sockets.server.send_message(b"\x70", {"server": "Neo4j/4.3.4"}) + connection = Bolt4x2( + address, sockets.client, PoolConfig.max_connection_lifetime, auth=auth + ) + with caplog.at_level(logging.DEBUG): + connection.hello() + + if isinstance(auth, tuple): + auth = neo4j.basic_auth(*auth) + for field in ("scheme", "principal", "realm", "parameters"): + value = getattr(auth, field, None) + if value: + assert repr(value) in caplog.text + assert CREDENTIALS not in caplog.text + + +@pytest.mark.parametrize("message", ("logon", "logoff")) +def test_auth_message_raises_configuration_error(message, fake_socket): + address = ("127.0.0.1", 7687) + connection = Bolt4x2(address, fake_socket(address), + PoolConfig.max_connection_lifetime) + with pytest.raises(ConfigurationError, + match="Session level authentication is not supported"): + getattr(connection, message)() + + +@pytest.mark.parametrize("auth", ( + None, + neo4j.Auth("scheme", "principal", "credentials", "realm"), + ("user", "password"), +)) +@mark_sync_test +def test_re_auth_noop(auth, fake_socket, mocker): + address = ("127.0.0.1", 7687) + connection = Bolt4x2(address, fake_socket(address), + PoolConfig.max_connection_lifetime, auth=auth) + logon_spy = mocker.spy(connection, "logon") + logoff_spy = mocker.spy(connection, "logoff") + res = connection.re_auth(auth) + + assert res is False + logon_spy.assert_not_called() + logoff_spy.assert_not_called() + + +@pytest.mark.parametrize( + ("auth1", "auth2"), + permutations( + ( + None, + neo4j.Auth("scheme", "principal", "credentials", "realm"), + ("user", "password"), + ), + 2 + ) +) +@mark_sync_test +def test_re_auth(auth1, auth2, fake_socket): + address = ("127.0.0.1", 7687) + connection = Bolt4x2(address, fake_socket(address), + PoolConfig.max_connection_lifetime, auth=auth1) + with pytest.raises(ConfigurationError, + match="Session level authentication is not supported"): + connection.re_auth(auth2) diff --git a/tests/unit/sync/io/test_class_bolt4x3.py b/tests/unit/sync/io/test_class_bolt4x3.py index 474b1585..4de581ea 100644 --- a/tests/unit/sync/io/test_class_bolt4x3.py +++ b/tests/unit/sync/io/test_class_bolt4x3.py @@ -17,11 +17,14 @@ import logging +from itertools import permutations import pytest +import neo4j from neo4j._conf import PoolConfig from neo4j._sync.io._bolt4 import Bolt4x3 +from neo4j.exceptions import ConfigurationError from ...._async_compat import mark_sync_test @@ -255,3 +258,90 @@ def test_hint_recv_timeout_seconds( and "recv_timeout_seconds" in msg and "invalid" in msg for msg in caplog.messages) + + +CREDENTIALS = "+++super-secret-sauce+++" + + +@pytest.mark.parametrize("auth", ( + ("user", CREDENTIALS), + neo4j.basic_auth("user", CREDENTIALS), + neo4j.kerberos_auth(CREDENTIALS), + neo4j.bearer_auth(CREDENTIALS), + neo4j.custom_auth("user", CREDENTIALS, "realm", "scheme"), + neo4j.Auth("scheme", "principal", CREDENTIALS, "realm", foo="bar"), +)) +@mark_sync_test +def test_credentials_are_not_logged( + auth, fake_socket_pair, mocker, caplog +): + address = ("127.0.0.1", 7687) + sockets = fake_socket_pair(address, + packer_cls=Bolt4x3.PACKER_CLS, + unpacker_cls=Bolt4x3.UNPACKER_CLS) + sockets.client.settimeout = mocker.Mock() + sockets.server.send_message(b"\x70", {"server": "Neo4j/4.3.4"}) + connection = Bolt4x3( + address, sockets.client, PoolConfig.max_connection_lifetime, auth=auth + ) + with caplog.at_level(logging.DEBUG): + connection.hello() + + if isinstance(auth, tuple): + auth = neo4j.basic_auth(*auth) + for field in ("scheme", "principal", "realm", "parameters"): + value = getattr(auth, field, None) + if value: + assert repr(value) in caplog.text + assert CREDENTIALS not in caplog.text + + + +@pytest.mark.parametrize("message", ("logon", "logoff")) +def test_auth_message_raises_configuration_error(message, fake_socket): + address = ("127.0.0.1", 7687) + connection = Bolt4x3(address, fake_socket(address), + PoolConfig.max_connection_lifetime) + with pytest.raises(ConfigurationError, + match="Session level authentication is not supported"): + getattr(connection, message)() + + +@pytest.mark.parametrize("auth", ( + None, + neo4j.Auth("scheme", "principal", "credentials", "realm"), + ("user", "password"), +)) +@mark_sync_test +def test_re_auth_noop(auth, fake_socket, mocker): + address = ("127.0.0.1", 7687) + connection = Bolt4x3(address, fake_socket(address), + PoolConfig.max_connection_lifetime, auth=auth) + logon_spy = mocker.spy(connection, "logon") + logoff_spy = mocker.spy(connection, "logoff") + res = connection.re_auth(auth) + + assert res is False + logon_spy.assert_not_called() + logoff_spy.assert_not_called() + + +@pytest.mark.parametrize( + ("auth1", "auth2"), + permutations( + ( + None, + neo4j.Auth("scheme", "principal", "credentials", "realm"), + ("user", "password"), + ), + 2 + ) +) +@mark_sync_test +def test_re_auth(auth1, auth2, fake_socket): + address = ("127.0.0.1", 7687) + connection = Bolt4x3(address, fake_socket(address), + PoolConfig.max_connection_lifetime, auth=auth1) + with pytest.raises(ConfigurationError, + match="Session level authentication is not supported"): + connection.re_auth(auth2) diff --git a/tests/unit/sync/io/test_class_bolt4x4.py b/tests/unit/sync/io/test_class_bolt4x4.py index d15bed04..6305b95c 100644 --- a/tests/unit/sync/io/test_class_bolt4x4.py +++ b/tests/unit/sync/io/test_class_bolt4x4.py @@ -17,11 +17,14 @@ import logging +from itertools import permutations import pytest +import neo4j from neo4j._conf import PoolConfig from neo4j._sync.io._bolt4 import Bolt4x4 +from neo4j.exceptions import ConfigurationError from ...._async_compat import mark_sync_test @@ -269,3 +272,89 @@ def test_hint_recv_timeout_seconds( and "recv_timeout_seconds" in msg and "invalid" in msg for msg in caplog.messages) + + +CREDENTIALS = "+++super-secret-sauce+++" + + +@pytest.mark.parametrize("auth", ( + ("user", CREDENTIALS), + neo4j.basic_auth("user", CREDENTIALS), + neo4j.kerberos_auth(CREDENTIALS), + neo4j.bearer_auth(CREDENTIALS), + neo4j.custom_auth("user", CREDENTIALS, "realm", "scheme"), + neo4j.Auth("scheme", "principal", CREDENTIALS, "realm", foo="bar"), +)) +@mark_sync_test +def test_credentials_are_not_logged( + auth, fake_socket_pair, mocker, caplog +): + address = ("127.0.0.1", 7687) + sockets = fake_socket_pair(address, + packer_cls=Bolt4x4.PACKER_CLS, + unpacker_cls=Bolt4x4.UNPACKER_CLS) + sockets.client.settimeout = mocker.Mock() + sockets.server.send_message(b"\x70", {"server": "Neo4j/4.3.4"}) + connection = Bolt4x4( + address, sockets.client, PoolConfig.max_connection_lifetime, auth=auth + ) + with caplog.at_level(logging.DEBUG): + connection.hello() + + if isinstance(auth, tuple): + auth = neo4j.basic_auth(*auth) + for field in ("scheme", "principal", "realm", "parameters"): + value = getattr(auth, field, None) + if value: + assert repr(value) in caplog.text + assert CREDENTIALS not in caplog.text + + +@pytest.mark.parametrize("message", ("logon", "logoff")) +def test_auth_message_raises_configuration_error(message, fake_socket): + address = ("127.0.0.1", 7687) + connection = Bolt4x4(address, fake_socket(address), + PoolConfig.max_connection_lifetime) + with pytest.raises(ConfigurationError, + match="Session level authentication is not supported"): + getattr(connection, message)() + + +@pytest.mark.parametrize("auth", ( + None, + neo4j.Auth("scheme", "principal", "credentials", "realm"), + ("user", "password"), +)) +@mark_sync_test +def test_re_auth_noop(auth, fake_socket, mocker): + address = ("127.0.0.1", 7687) + connection = Bolt4x4(address, fake_socket(address), + PoolConfig.max_connection_lifetime, auth=auth) + logon_spy = mocker.spy(connection, "logon") + logoff_spy = mocker.spy(connection, "logoff") + res = connection.re_auth(auth) + + assert res is False + logon_spy.assert_not_called() + logoff_spy.assert_not_called() + + +@pytest.mark.parametrize( + ("auth1", "auth2"), + permutations( + ( + None, + neo4j.Auth("scheme", "principal", "credentials", "realm"), + ("user", "password"), + ), + 2 + ) +) +@mark_sync_test +def test_re_auth(auth1, auth2, fake_socket): + address = ("127.0.0.1", 7687) + connection = Bolt4x4(address, fake_socket(address), + PoolConfig.max_connection_lifetime, auth=auth1) + with pytest.raises(ConfigurationError, + match="Session level authentication is not supported"): + connection.re_auth(auth2) diff --git a/tests/unit/sync/io/test_class_bolt5x0.py b/tests/unit/sync/io/test_class_bolt5x0.py new file mode 100644 index 00000000..9c7e4654 --- /dev/null +++ b/tests/unit/sync/io/test_class_bolt5x0.py @@ -0,0 +1,358 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging +from itertools import permutations + +import pytest + +import neo4j +from neo4j._conf import PoolConfig +from neo4j._sync.io._bolt5 import Bolt5x0 +from neo4j.exceptions import ConfigurationError + +from ...._async_compat import mark_sync_test + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_stale(fake_socket, set_stale): + address = ("127.0.0.1", 7687) + max_connection_lifetime = 0 + connection = Bolt5x0(address, fake_socket(address), max_connection_lifetime) + if set_stale: + connection.set_stale() + assert connection.stale() is True + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): + address = ("127.0.0.1", 7687) + max_connection_lifetime = -1 + connection = Bolt5x0(address, fake_socket(address), max_connection_lifetime) + if set_stale: + connection.set_stale() + assert connection.stale() is set_stale + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_not_stale(fake_socket, set_stale): + address = ("127.0.0.1", 7687) + max_connection_lifetime = 999999999 + connection = Bolt5x0(address, fake_socket(address), max_connection_lifetime) + if set_stale: + connection.set_stale() + assert connection.stale() is set_stale + + +@pytest.mark.parametrize(("args", "kwargs", "expected_fields"), ( + (("", {}), {"db": "something"}, ({"db": "something"},)), + (("", {}), {"imp_user": "imposter"}, ({"imp_user": "imposter"},)), + ( + ("", {}), + {"db": "something", "imp_user": "imposter"}, + ({"db": "something", "imp_user": "imposter"},) + ), +)) +@mark_sync_test +def test_extra_in_begin(fake_socket, args, kwargs, expected_fields): + address = ("127.0.0.1", 7687) + socket = fake_socket(address, Bolt5x0.UNPACKER_CLS) + connection = Bolt5x0(address, socket, PoolConfig.max_connection_lifetime) + connection.begin(*args, **kwargs) + connection.send_all() + tag, is_fields = socket.pop_message() + assert tag == b"\x11" + assert tuple(is_fields) == expected_fields + + +@pytest.mark.parametrize(("args", "kwargs", "expected_fields"), ( + (("", {}), {"db": "something"}, ("", {}, {"db": "something"})), + (("", {}), {"imp_user": "imposter"}, ("", {}, {"imp_user": "imposter"})), + ( + ("", {}), + {"db": "something", "imp_user": "imposter"}, + ("", {}, {"db": "something", "imp_user": "imposter"}) + ), +)) +@mark_sync_test +def test_extra_in_run(fake_socket, args, kwargs, expected_fields): + address = ("127.0.0.1", 7687) + socket = fake_socket(address, Bolt5x0.UNPACKER_CLS) + connection = Bolt5x0(address, socket, PoolConfig.max_connection_lifetime) + connection.run(*args, **kwargs) + connection.send_all() + tag, is_fields = socket.pop_message() + assert tag == b"\x10" + assert tuple(is_fields) == expected_fields + + +@mark_sync_test +def test_n_extra_in_discard(fake_socket): + address = ("127.0.0.1", 7687) + socket = fake_socket(address, Bolt5x0.UNPACKER_CLS) + connection = Bolt5x0(address, socket, PoolConfig.max_connection_lifetime) + connection.discard(n=666) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x2F" + assert len(fields) == 1 + assert fields[0] == {"n": 666} + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (666, {"n": -1, "qid": 666}), + (-1, {"n": -1}), + ] +) +@mark_sync_test +def test_qid_extra_in_discard(fake_socket, test_input, expected): + address = ("127.0.0.1", 7687) + socket = fake_socket(address, Bolt5x0.UNPACKER_CLS) + connection = Bolt5x0(address, socket, PoolConfig.max_connection_lifetime) + connection.discard(qid=test_input) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x2F" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (777, {"n": 666, "qid": 777}), + (-1, {"n": 666}), + ] +) +@mark_sync_test +def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): + address = ("127.0.0.1", 7687) + socket = fake_socket(address, Bolt5x0.UNPACKER_CLS) + connection = Bolt5x0(address, socket, PoolConfig.max_connection_lifetime) + connection.discard(n=666, qid=test_input) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x2F" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (666, {"n": 666}), + (-1, {"n": -1}), + ] +) +@mark_sync_test +def test_n_extra_in_pull(fake_socket, test_input, expected): + address = ("127.0.0.1", 7687) + socket = fake_socket(address, Bolt5x0.UNPACKER_CLS) + connection = Bolt5x0(address, socket, PoolConfig.max_connection_lifetime) + connection.pull(n=test_input) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x3F" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (777, {"n": -1, "qid": 777}), + (-1, {"n": -1}), + ] +) +@mark_sync_test +def test_qid_extra_in_pull(fake_socket, test_input, expected): + address = ("127.0.0.1", 7687) + socket = fake_socket(address, Bolt5x0.UNPACKER_CLS) + connection = Bolt5x0(address, socket, PoolConfig.max_connection_lifetime) + connection.pull(qid=test_input) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x3F" + assert len(fields) == 1 + assert fields[0] == expected + + +@mark_sync_test +def test_n_and_qid_extras_in_pull(fake_socket): + address = ("127.0.0.1", 7687) + socket = fake_socket(address, Bolt5x0.UNPACKER_CLS) + connection = Bolt5x0(address, socket, PoolConfig.max_connection_lifetime) + connection.pull(n=666, qid=777) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x3F" + assert len(fields) == 1 + assert fields[0] == {"n": 666, "qid": 777} + + +@mark_sync_test +def test_hello_passes_routing_metadata(fake_socket_pair): + address = ("127.0.0.1", 7687) + sockets = fake_socket_pair(address, + packer_cls=Bolt5x0.PACKER_CLS, + unpacker_cls=Bolt5x0.UNPACKER_CLS) + sockets.server.send_message(b"\x70", {"server": "Neo4j/4.4.0"}) + connection = Bolt5x0( + address, sockets.client, PoolConfig.max_connection_lifetime, + routing_context={"foo": "bar"} + ) + connection.hello() + tag, fields = sockets.server.pop_message() + assert tag == b"\x01" + assert len(fields) == 1 + assert fields[0]["routing"] == {"foo": "bar"} + + +@pytest.mark.parametrize(("hints", "valid"), ( + ({"connection.recv_timeout_seconds": 1}, True), + ({"connection.recv_timeout_seconds": 42}, True), + ({}, True), + ({"whatever_this_is": "ignore me!"}, True), + ({"connection.recv_timeout_seconds": -1}, False), + ({"connection.recv_timeout_seconds": 0}, False), + ({"connection.recv_timeout_seconds": 2.5}, False), + ({"connection.recv_timeout_seconds": None}, False), + ({"connection.recv_timeout_seconds": False}, False), + ({"connection.recv_timeout_seconds": "1"}, False), +)) +@mark_sync_test +def test_hint_recv_timeout_seconds( + fake_socket_pair, hints, valid, caplog, mocker +): + address = ("127.0.0.1", 7687) + sockets = fake_socket_pair(address, + packer_cls=Bolt5x0.PACKER_CLS, + unpacker_cls=Bolt5x0.UNPACKER_CLS) + sockets.client.settimeout = mocker.Mock() + sockets.server.send_message( + b"\x70", {"server": "Neo4j/4.3.4", "hints": hints} + ) + connection = Bolt5x0( + address, sockets.client, PoolConfig.max_connection_lifetime + ) + with caplog.at_level(logging.INFO): + connection.hello() + if valid: + if "connection.recv_timeout_seconds" in hints: + sockets.client.settimeout.assert_called_once_with( + hints["connection.recv_timeout_seconds"] + ) + else: + sockets.client.settimeout.assert_not_called() + assert not any("recv_timeout_seconds" in msg + and "invalid" in msg + for msg in caplog.messages) + else: + sockets.client.settimeout.assert_not_called() + assert any(repr(hints["connection.recv_timeout_seconds"]) in msg + and "recv_timeout_seconds" in msg + and "invalid" in msg + for msg in caplog.messages) + + +CREDENTIALS = "+++super-secret-sauce+++" + + +@pytest.mark.parametrize("auth", ( + ("user", CREDENTIALS), + neo4j.basic_auth("user", CREDENTIALS), + neo4j.kerberos_auth(CREDENTIALS), + neo4j.bearer_auth(CREDENTIALS), + neo4j.custom_auth("user", CREDENTIALS, "realm", "scheme"), + neo4j.Auth("scheme", "principal", CREDENTIALS, "realm", foo="bar"), +)) +@mark_sync_test +def test_credentials_are_not_logged( + auth, fake_socket_pair, mocker, caplog +): + address = ("127.0.0.1", 7687) + sockets = fake_socket_pair(address, + packer_cls=Bolt5x0.PACKER_CLS, + unpacker_cls=Bolt5x0.UNPACKER_CLS) + sockets.client.settimeout = mocker.Mock() + sockets.server.send_message(b"\x70", {"server": "Neo4j/4.3.4"}) + connection = Bolt5x0( + address, sockets.client, PoolConfig.max_connection_lifetime, auth=auth + ) + with caplog.at_level(logging.DEBUG): + connection.hello() + + if isinstance(auth, tuple): + auth = neo4j.basic_auth(*auth) + for field in ("scheme", "principal", "realm", "parameters"): + value = getattr(auth, field, None) + if value: + assert repr(value) in caplog.text + assert CREDENTIALS not in caplog.text + + +@pytest.mark.parametrize("message", ("logon", "logoff")) +def test_auth_message_raises_configuration_error(message, fake_socket): + address = ("127.0.0.1", 7687) + connection = Bolt5x0(address, fake_socket(address), + PoolConfig.max_connection_lifetime) + with pytest.raises(ConfigurationError, + match="Session level authentication is not supported"): + getattr(connection, message)() + + +@pytest.mark.parametrize("auth", ( + None, + neo4j.Auth("scheme", "principal", "credentials", "realm"), + ("user", "password"), +)) +@mark_sync_test +def test_re_auth_noop(auth, fake_socket, mocker): + address = ("127.0.0.1", 7687) + connection = Bolt5x0(address, fake_socket(address), + PoolConfig.max_connection_lifetime, auth=auth) + logon_spy = mocker.spy(connection, "logon") + logoff_spy = mocker.spy(connection, "logoff") + res = connection.re_auth(auth) + + assert res is False + logon_spy.assert_not_called() + logoff_spy.assert_not_called() + + +@pytest.mark.parametrize( + ("auth1", "auth2"), + permutations( + ( + None, + neo4j.Auth("scheme", "principal", "credentials", "realm"), + ("user", "password"), + ), + 2 + ) +) +@mark_sync_test +def test_re_auth(auth1, auth2, fake_socket): + address = ("127.0.0.1", 7687) + connection = Bolt5x0(address, fake_socket(address), + PoolConfig.max_connection_lifetime, auth=auth1) + with pytest.raises(ConfigurationError, + match="Session level authentication is not supported"): + connection.re_auth(auth2) diff --git a/tests/unit/sync/io/test_class_bolt5x1.py b/tests/unit/sync/io/test_class_bolt5x1.py new file mode 100644 index 00000000..011a4a35 --- /dev/null +++ b/tests/unit/sync/io/test_class_bolt5x1.py @@ -0,0 +1,405 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging + +import pytest + +import neo4j +import neo4j.exceptions +from neo4j._conf import PoolConfig +from neo4j._sync.io._bolt5 import Bolt5x1 + +from ...._async_compat import mark_sync_test + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_stale(fake_socket, set_stale): + address = ("127.0.0.1", 7687) + max_connection_lifetime = 0 + connection = Bolt5x1(address, fake_socket(address), + max_connection_lifetime) + if set_stale: + connection.set_stale() + assert connection.stale() is True + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): + address = ("127.0.0.1", 7687) + max_connection_lifetime = -1 + connection = Bolt5x1(address, fake_socket(address), + max_connection_lifetime) + if set_stale: + connection.set_stale() + assert connection.stale() is set_stale + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_not_stale(fake_socket, set_stale): + address = ("127.0.0.1", 7687) + max_connection_lifetime = 999999999 + connection = Bolt5x1(address, fake_socket(address), + max_connection_lifetime) + if set_stale: + connection.set_stale() + assert connection.stale() is set_stale + + +@pytest.mark.parametrize(("args", "kwargs", "expected_fields"), ( + (("", {}), {"db": "something"}, ({"db": "something"},)), + (("", {}), {"imp_user": "imposter"}, ({"imp_user": "imposter"},)), + ( + ("", {}), + {"db": "something", "imp_user": "imposter"}, + ({"db": "something", "imp_user": "imposter"},) + ), +)) +@mark_sync_test +def test_extra_in_begin(fake_socket, args, kwargs, expected_fields): + address = ("127.0.0.1", 7687) + socket = fake_socket(address, Bolt5x1.UNPACKER_CLS) + connection = Bolt5x1(address, socket, + PoolConfig.max_connection_lifetime) + connection.begin(*args, **kwargs) + connection.send_all() + tag, is_fields = socket.pop_message() + assert tag == b"\x11" + assert tuple(is_fields) == expected_fields + + +@pytest.mark.parametrize(("args", "kwargs", "expected_fields"), ( + (("", {}), {"db": "something"}, ("", {}, {"db": "something"})), + (("", {}), {"imp_user": "imposter"}, ("", {}, {"imp_user": "imposter"})), + ( + ("", {}), + {"db": "something", "imp_user": "imposter"}, + ("", {}, {"db": "something", "imp_user": "imposter"}) + ), +)) +@mark_sync_test +def test_extra_in_run(fake_socket, args, kwargs, expected_fields): + address = ("127.0.0.1", 7687) + socket = fake_socket(address, Bolt5x1.UNPACKER_CLS) + connection = Bolt5x1(address, socket, + PoolConfig.max_connection_lifetime) + connection.run(*args, **kwargs) + connection.send_all() + tag, is_fields = socket.pop_message() + assert tag == b"\x10" + assert tuple(is_fields) == expected_fields + + +@mark_sync_test +def test_n_extra_in_discard(fake_socket): + address = ("127.0.0.1", 7687) + socket = fake_socket(address, Bolt5x1.UNPACKER_CLS) + connection = Bolt5x1(address, socket, + PoolConfig.max_connection_lifetime) + connection.discard(n=666) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x2F" + assert len(fields) == 1 + assert fields[0] == {"n": 666} + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (666, {"n": -1, "qid": 666}), + (-1, {"n": -1}), + ] +) +@mark_sync_test +def test_qid_extra_in_discard(fake_socket, test_input, expected): + address = ("127.0.0.1", 7687) + socket = fake_socket(address, Bolt5x1.UNPACKER_CLS) + connection = Bolt5x1(address, socket, + PoolConfig.max_connection_lifetime) + connection.discard(qid=test_input) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x2F" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (777, {"n": 666, "qid": 777}), + (-1, {"n": 666}), + ] +) +@mark_sync_test +def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): + address = ("127.0.0.1", 7687) + socket = fake_socket(address, Bolt5x1.UNPACKER_CLS) + connection = Bolt5x1(address, socket, + PoolConfig.max_connection_lifetime) + connection.discard(n=666, qid=test_input) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x2F" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (666, {"n": 666}), + (-1, {"n": -1}), + ] +) +@mark_sync_test +def test_n_extra_in_pull(fake_socket, test_input, expected): + address = ("127.0.0.1", 7687) + socket = fake_socket(address, Bolt5x1.UNPACKER_CLS) + connection = Bolt5x1(address, socket, + PoolConfig.max_connection_lifetime) + connection.pull(n=test_input) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x3F" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (777, {"n": -1, "qid": 777}), + (-1, {"n": -1}), + ] +) +@mark_sync_test +def test_qid_extra_in_pull(fake_socket, test_input, expected): + address = ("127.0.0.1", 7687) + socket = fake_socket(address, Bolt5x1.UNPACKER_CLS) + connection = Bolt5x1(address, socket, + PoolConfig.max_connection_lifetime) + connection.pull(qid=test_input) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x3F" + assert len(fields) == 1 + assert fields[0] == expected + + +@mark_sync_test +def test_n_and_qid_extras_in_pull(fake_socket): + address = ("127.0.0.1", 7687) + socket = fake_socket(address, Bolt5x1.UNPACKER_CLS) + connection = Bolt5x1(address, socket, + PoolConfig.max_connection_lifetime) + connection.pull(n=666, qid=777) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x3F" + assert len(fields) == 1 + assert fields[0] == {"n": 666, "qid": 777} + + +@mark_sync_test +def test_hello_passes_routing_metadata(fake_socket_pair): + address = ("127.0.0.1", 7687) + sockets = fake_socket_pair(address, + packer_cls=Bolt5x1.PACKER_CLS, + unpacker_cls=Bolt5x1.UNPACKER_CLS) + sockets.server.send_message(b"\x70", {"server": "Neo4j/4.4.0"}) + sockets.server.send_message(b"\x70", {}) + connection = Bolt5x1( + address, sockets.client, PoolConfig.max_connection_lifetime, + routing_context={"foo": "bar"} + ) + connection.hello() + tag, fields = sockets.server.pop_message() + assert tag == b"\x01" + assert len(fields) == 1 + assert fields[0]["routing"] == {"foo": "bar"} + + +def _assert_logon_message(sockets, auth): + tag, fields = sockets.server.pop_message() + assert tag == b"\x6A" # LOGON + assert len(fields) == 1 + keys = ["scheme", "principal", "credentials"] + assert list(fields[0].keys()) == keys + for key in keys: + assert fields[0][key] == getattr(auth, key) + + +@mark_sync_test +def test_hello_pipelines_logon(fake_socket_pair): + auth = neo4j.Auth("basic", "alice123", "supersecret123") + address = ("127.0.0.1", 7687) + sockets = fake_socket_pair(address, + packer_cls=Bolt5x1.PACKER_CLS, + unpacker_cls=Bolt5x1.UNPACKER_CLS) + sockets.server.send_message( + b"\x7F", {"code": "Neo.DatabaseError.General.MadeUpError", + "message": "kthxbye"} + ) + connection = Bolt5x1( + address, sockets.client, PoolConfig.max_connection_lifetime, auth=auth + ) + with pytest.raises(neo4j.exceptions.ServiceUnavailable): + connection.hello() + tag, fields = sockets.server.pop_message() + assert tag == b"\x01" # HELLO + assert len(fields) == 1 + assert list(fields[0].keys()) == ["user_agent"] + assert auth.credentials not in repr(fields) + _assert_logon_message(sockets, auth) + + +@mark_sync_test +def test_logon(fake_socket_pair): + auth = neo4j.Auth("basic", "alice123", "supersecret123") + address = ("127.0.0.1", 7687) + sockets = fake_socket_pair(address, + packer_cls=Bolt5x1.PACKER_CLS, + unpacker_cls=Bolt5x1.UNPACKER_CLS) + connection = Bolt5x1(address, sockets.client, + PoolConfig.max_connection_lifetime, auth=auth) + connection.logon() + connection.send_all() + _assert_logon_message(sockets, auth) + + +@mark_sync_test +def test_re_auth(fake_socket_pair): + auth = neo4j.Auth("basic", "alice123", "supersecret123") + address = ("127.0.0.1", 7687) + sockets = fake_socket_pair(address, + packer_cls=Bolt5x1.PACKER_CLS, + unpacker_cls=Bolt5x1.UNPACKER_CLS) + sockets.server.send_message( + b"\x7F", {"code": "Neo.DatabaseError.General.MadeUpError", + "message": "kthxbye"} + ) + connection = Bolt5x1(address, sockets.client, + PoolConfig.max_connection_lifetime) + with pytest.raises(neo4j.exceptions.Neo4jError): + connection.re_auth(auth) + tag, fields = sockets.server.pop_message() + assert tag == b"\x6B" # LOGOFF + assert len(fields) == 0 + _assert_logon_message(sockets, auth) + + +@mark_sync_test +def test_logoff(fake_socket_pair): + address = ("127.0.0.1", 7687) + sockets = fake_socket_pair(address, + packer_cls=Bolt5x1.PACKER_CLS, + unpacker_cls=Bolt5x1.UNPACKER_CLS) + sockets.server.send_message(b"\x70", {}) + connection = Bolt5x1(address, sockets.client, + PoolConfig.max_connection_lifetime) + connection.logoff() + assert not sockets.server.recv_buffer # pipelined, so no response yet + connection.send_all() + assert sockets.server.recv_buffer # now! + tag, fields = sockets.server.pop_message() + assert tag == b"\x6B" # LOGOFF + assert len(fields) == 0 + + +@pytest.mark.parametrize(("hints", "valid"), ( + ({"connection.recv_timeout_seconds": 1}, True), + ({"connection.recv_timeout_seconds": 42}, True), + ({}, True), + ({"whatever_this_is": "ignore me!"}, True), + ({"connection.recv_timeout_seconds": -1}, False), + ({"connection.recv_timeout_seconds": 0}, False), + ({"connection.recv_timeout_seconds": 2.5}, False), + ({"connection.recv_timeout_seconds": None}, False), + ({"connection.recv_timeout_seconds": False}, False), + ({"connection.recv_timeout_seconds": "1"}, False), +)) +@mark_sync_test +def test_hint_recv_timeout_seconds( + fake_socket_pair, hints, valid, caplog, mocker +): + address = ("127.0.0.1", 7687) + sockets = fake_socket_pair(address, + packer_cls=Bolt5x1.PACKER_CLS, + unpacker_cls=Bolt5x1.UNPACKER_CLS) + sockets.client.settimeout = mocker.Mock() + sockets.server.send_message( + b"\x70", {"server": "Neo4j/4.3.4", "hints": hints} + ) + sockets.server.send_message(b"\x70", {}) + connection = Bolt5x1( + address, sockets.client, PoolConfig.max_connection_lifetime + ) + with caplog.at_level(logging.INFO): + connection.hello() + if valid: + if "connection.recv_timeout_seconds" in hints: + sockets.client.settimeout.assert_called_once_with( + hints["connection.recv_timeout_seconds"] + ) + else: + sockets.client.settimeout.assert_not_called() + assert not any("recv_timeout_seconds" in msg + and "invalid" in msg + for msg in caplog.messages) + else: + sockets.client.settimeout.assert_not_called() + assert any(repr(hints["connection.recv_timeout_seconds"]) in msg + and "recv_timeout_seconds" in msg + and "invalid" in msg + for msg in caplog.messages) + + +CREDENTIALS = "+++super-secret-sauce+++" + + +@pytest.mark.parametrize("auth", ( + ("user", CREDENTIALS), + neo4j.basic_auth("user", CREDENTIALS), + neo4j.kerberos_auth(CREDENTIALS), + neo4j.bearer_auth(CREDENTIALS), + neo4j.custom_auth("user", CREDENTIALS, "realm", "scheme"), + neo4j.Auth("scheme", "principal", CREDENTIALS, "realm", foo="bar"), +)) +@mark_sync_test +def test_credentials_are_not_logged(auth, fake_socket_pair, caplog): + address = ("127.0.0.1", 7687) + sockets = fake_socket_pair(address, + packer_cls=Bolt5x1.PACKER_CLS, + unpacker_cls=Bolt5x1.UNPACKER_CLS) + sockets.server.send_message(b"\x70", {"server": "Neo4j/4.3.4"}) + sockets.server.send_message(b"\x70", {}) + connection = Bolt5x1( + address, sockets.client, PoolConfig.max_connection_lifetime, auth=auth + ) + with caplog.at_level(logging.DEBUG): + connection.hello() + + if isinstance(auth, tuple): + auth = neo4j.basic_auth(*auth) + for field in ("scheme", "principal", "realm", "parameters"): + value = getattr(auth, field, None) + if value: + assert repr(value) in caplog.text + assert CREDENTIALS not in caplog.text diff --git a/tests/unit/sync/io/test_direct.py b/tests/unit/sync/io/test_direct.py index d54dfd73..e61d1e49 100644 --- a/tests/unit/sync/io/test_direct.py +++ b/tests/unit/sync/io/test_direct.py @@ -65,6 +65,9 @@ def stale(self): def reset(self): pass + def re_auth(self, auth): + return False + def close(self): self.socket.close() @@ -85,17 +88,18 @@ def __init__(self, address, *, auth=None, **config): if config: raise ValueError("Unexpected config keys: %s" % ", ".join(config.keys())) - def opener(addr, timeout): + def opener(addr, auth, timeout): return QuickConnection(FakeSocket(addr)) super().__init__(opener, self.pool_config, self.workspace_config) self.address = address def acquire( - self, access_mode, timeout, database, bookmarks, liveness_check_timeout + self, access_mode, timeout, database, bookmarks, auth, + liveness_check_timeout ): return self._acquire( - self.address, timeout, liveness_check_timeout + self.address, auth, timeout, liveness_check_timeout ) @@ -145,7 +149,7 @@ def assert_pool_size( address, expected_active, expected_inactive, pool): @mark_sync_test def test_pool_can_acquire(pool): address = ("127.0.0.1", 7687) - connection = pool._acquire(address, Deadline(3), None) + connection = pool._acquire(address, None, Deadline(3), None) assert connection.address == address assert_pool_size(address, 1, 0, pool) @@ -153,8 +157,8 @@ def test_pool_can_acquire(pool): @mark_sync_test def test_pool_can_acquire_twice(pool): address = ("127.0.0.1", 7687) - connection_1 = pool._acquire(address, Deadline(3), None) - connection_2 = pool._acquire(address, Deadline(3), None) + connection_1 = pool._acquire(address, None, Deadline(3), None) + connection_2 = pool._acquire(address, None, Deadline(3), None) assert connection_1.address == address assert connection_2.address == address assert connection_1 is not connection_2 @@ -165,8 +169,8 @@ def test_pool_can_acquire_twice(pool): def test_pool_can_acquire_two_addresses(pool): address_1 = ("127.0.0.1", 7687) address_2 = ("127.0.0.1", 7474) - connection_1 = pool._acquire(address_1, Deadline(3), None) - connection_2 = pool._acquire(address_2, Deadline(3), None) + connection_1 = pool._acquire(address_1, None, Deadline(3), None) + connection_2 = pool._acquire(address_2, None, Deadline(3), None) assert connection_1.address == address_1 assert connection_2.address == address_2 assert_pool_size(address_1, 1, 0, pool) @@ -176,7 +180,7 @@ def test_pool_can_acquire_two_addresses(pool): @mark_sync_test def test_pool_can_acquire_and_release(pool): address = ("127.0.0.1", 7687) - connection = pool._acquire(address, Deadline(3), None) + connection = pool._acquire(address, None, Deadline(3), None) assert_pool_size(address, 1, 0, pool) pool.release(connection) assert_pool_size(address, 0, 1, pool) @@ -185,7 +189,7 @@ def test_pool_can_acquire_and_release(pool): @mark_sync_test def test_pool_releasing_twice(pool): address = ("127.0.0.1", 7687) - connection = pool._acquire(address, Deadline(3), None) + connection = pool._acquire(address, None, Deadline(3), None) pool.release(connection) assert_pool_size(address, 0, 1, pool) pool.release(connection) @@ -196,7 +200,7 @@ def test_pool_releasing_twice(pool): def test_pool_in_use_count(pool): address = ("127.0.0.1", 7687) assert pool.in_use_connection_count(address) == 0 - connection = pool._acquire(address, Deadline(3), None) + connection = pool._acquire(address, None, Deadline(3), None) assert pool.in_use_connection_count(address) == 1 pool.release(connection) assert pool.in_use_connection_count(address) == 0 @@ -206,10 +210,10 @@ def test_pool_in_use_count(pool): def test_pool_max_conn_pool_size(pool): with FakeBoltPool((), max_connection_pool_size=1) as pool: address = ("127.0.0.1", 7687) - pool._acquire(address, Deadline(0), None) + pool._acquire(address, None, Deadline(0), None) assert pool.in_use_connection_count(address) == 1 with pytest.raises(ClientError): - pool._acquire(address, Deadline(0), None) + pool._acquire(address, None, Deadline(0), None) assert pool.in_use_connection_count(address) == 1 @@ -227,7 +231,7 @@ def test_pool_reset_when_released(is_reset, pool, mocker): new_callable=mocker.Mock ) is_reset_mock.return_value = is_reset - connection = pool._acquire(address, Deadline(3), None) + connection = pool._acquire(address, None, Deadline(3), None) assert isinstance(connection, QuickConnection) assert is_reset_mock.call_count == 0 assert reset_mock.call_count == 0 diff --git a/tests/unit/sync/io/test_neo4j_pool.py b/tests/unit/sync/io/test_neo4j_pool.py index 0c13a78d..424228fe 100644 --- a/tests/unit/sync/io/test_neo4j_pool.py +++ b/tests/unit/sync/io/test_neo4j_pool.py @@ -69,10 +69,11 @@ def routing_side_effect(*args, **kwargs): }] raise res - def open_(addr, timeout): + def open_(addr, auth, timeout): connection = fake_connection_generator() connection.addr = addr connection.timeout = timeout + connection.auth = auth route_mock = mocker.Mock() route_mock.side_effect = routing_side_effect @@ -99,13 +100,13 @@ def test_acquires_new_routing_table_if_deleted(opener): pool = Neo4jPool( opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS ) - cx = pool.acquire(READ_ACCESS, 30, "test_db", None, None) + cx = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) pool.release(cx) assert pool.routing_tables.get("test_db") del pool.routing_tables["test_db"] - cx = pool.acquire(READ_ACCESS, 30, "test_db", None, None) + cx = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) pool.release(cx) assert pool.routing_tables.get("test_db") @@ -115,14 +116,14 @@ def test_acquires_new_routing_table_if_stale(opener): pool = Neo4jPool( opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS ) - cx = pool.acquire(READ_ACCESS, 30, "test_db", None, None) + cx = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) pool.release(cx) assert pool.routing_tables.get("test_db") old_value = pool.routing_tables["test_db"].last_updated_time pool.routing_tables["test_db"].ttl = 0 - cx = pool.acquire(READ_ACCESS, 30, "test_db", None, None) + cx = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) pool.release(cx) assert pool.routing_tables["test_db"].last_updated_time > old_value @@ -132,10 +133,10 @@ def test_removes_old_routing_table(opener): pool = Neo4jPool( opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS ) - cx = pool.acquire(READ_ACCESS, 30, "test_db1", None, None) + cx = pool.acquire(READ_ACCESS, 30, "test_db1", None, None, None) pool.release(cx) assert pool.routing_tables.get("test_db1") - cx = pool.acquire(READ_ACCESS, 30, "test_db2", None, None) + cx = pool.acquire(READ_ACCESS, 30, "test_db2", None, None, None) pool.release(cx) assert pool.routing_tables.get("test_db2") @@ -144,7 +145,7 @@ def test_removes_old_routing_table(opener): pool.routing_tables["test_db2"].ttl = \ -RoutingConfig.routing_table_purge_delay - cx = pool.acquire(READ_ACCESS, 30, "test_db1", None, None) + cx = pool.acquire(READ_ACCESS, 30, "test_db1", None, None, None) pool.release(cx) assert pool.routing_tables["test_db1"].last_updated_time > old_value assert "test_db2" not in pool.routing_tables @@ -158,7 +159,7 @@ def test_chooses_right_connection_type(opener, type_): ) cx1 = pool.acquire( READ_ACCESS if type_ == "r" else WRITE_ACCESS, - 30, "test_db", None, None + 30, "test_db", None, None, None ) pool.release(cx1) if type_ == "r": @@ -172,9 +173,9 @@ def test_reuses_connection(opener): pool = Neo4jPool( opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS ) - cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None, None) + cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) pool.release(cx1) - cx2 = pool.acquire(READ_ACCESS, 30, "test_db", None, None) + cx2 = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) assert cx1 is cx2 @@ -192,7 +193,7 @@ def break_connection(): pool = Neo4jPool( opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS ) - cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None, None) + cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) pool.release(cx1) assert cx1 in pool.connections[cx1.addr] # simulate connection going stale (e.g. exceeding) and then breaking when @@ -202,7 +203,7 @@ def break_connection(): if break_on_close: cx_close_mock_side_effect = cx_close_mock.side_effect cx_close_mock.side_effect = break_connection - cx2 = pool.acquire(READ_ACCESS, 30, "test_db", None, None) + cx2 = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) pool.release(cx2) if break_on_close: cx1.close.assert_called() @@ -219,11 +220,11 @@ def test_does_not_close_stale_connections_in_use(opener): pool = Neo4jPool( opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS ) - cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None, None) + cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) assert cx1 in pool.connections[cx1.addr] # simulate connection going stale (e.g. exceeding) while being in use cx1.stale.return_value = True - cx2 = pool.acquire(READ_ACCESS, 30, "test_db", None, None) + cx2 = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) pool.release(cx2) cx1.close.assert_not_called() assert cx2 is not cx1 @@ -236,7 +237,7 @@ def test_does_not_close_stale_connections_in_use(opener): # it should be closed when trying to acquire the next connection cx1.close.assert_not_called() - cx3 = pool.acquire(READ_ACCESS, 30, "test_db", None, None) + cx3 = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) pool.release(cx3) cx1.close.assert_called_once() assert cx2 is cx3 @@ -250,7 +251,7 @@ def test_release_resets_connections(opener): pool = Neo4jPool( opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS ) - cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None, None) + cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) cx1.is_reset_mock.return_value = False cx1.is_reset_mock.reset_mock() pool.release(cx1) @@ -263,7 +264,7 @@ def test_release_does_not_resets_closed_connections(opener): pool = Neo4jPool( opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS ) - cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None, None) + cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) cx1.closed.return_value = True cx1.closed.reset_mock() cx1.is_reset_mock.reset_mock() @@ -278,7 +279,7 @@ def test_release_does_not_resets_defunct_connections(opener): pool = Neo4jPool( opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS ) - cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None, None) + cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) cx1.defunct.return_value = True cx1.defunct.reset_mock() cx1.is_reset_mock.reset_mock() @@ -296,7 +297,8 @@ def test_acquire_performs_no_liveness_check_on_fresh_connection( pool = Neo4jPool( opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS ) - cx1 = pool._acquire(READER_ADDRESS, Deadline(30), liveness_timeout) + cx1 = pool._acquire(READER_ADDRESS, None, Deadline(30), + liveness_timeout) assert cx1.addr == READER_ADDRESS cx1.reset.assert_not_called() @@ -310,7 +312,8 @@ def test_acquire_performs_liveness_check_on_existing_connection( opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS ) # populate the pool with a connection - cx1 = pool._acquire(READER_ADDRESS, Deadline(30), liveness_timeout) + cx1 = pool._acquire(READER_ADDRESS, None, Deadline(30), + liveness_timeout) # make sure we assume the right state assert cx1.addr == READER_ADDRESS @@ -324,7 +327,8 @@ def test_acquire_performs_liveness_check_on_existing_connection( cx1.reset.assert_not_called() # then acquire it again and assert the liveness check was performed - cx2 = pool._acquire(READER_ADDRESS, Deadline(30), liveness_timeout) + cx2 = pool._acquire(READER_ADDRESS, None, Deadline(30), + liveness_timeout) assert cx1 is cx2 cx1.is_idle_for.assert_called_once_with(liveness_timeout) cx2.reset.assert_called_once() @@ -344,7 +348,8 @@ def liveness_side_effect(*args, **kwargs): opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS ) # populate the pool with a connection - cx1 = pool._acquire(READER_ADDRESS, Deadline(30), liveness_timeout) + cx1 = pool._acquire(READER_ADDRESS, None, Deadline(30), + liveness_timeout) # make sure we assume the right state assert cx1.addr == READER_ADDRESS @@ -360,7 +365,8 @@ def liveness_side_effect(*args, **kwargs): cx1.reset.assert_not_called() # then acquire it again and assert the liveness check was performed - cx2 = pool._acquire(READER_ADDRESS, Deadline(30), liveness_timeout) + cx2 = pool._acquire(READER_ADDRESS, None, Deadline(30), + liveness_timeout) assert cx1 is not cx2 assert cx1.addr == cx2.addr cx1.is_idle_for.assert_called_once_with(liveness_timeout) @@ -383,8 +389,10 @@ def liveness_side_effect(*args, **kwargs): opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS ) # populate the pool with a connection - cx1 = pool._acquire(READER_ADDRESS, Deadline(30), liveness_timeout) - cx2 = pool._acquire(READER_ADDRESS, Deadline(30), liveness_timeout) + cx1 = pool._acquire(READER_ADDRESS, None, Deadline(30), + liveness_timeout) + cx2 = pool._acquire(READER_ADDRESS, None, Deadline(30), + liveness_timeout) # make sure we assume the right state assert cx1.addr == READER_ADDRESS @@ -406,7 +414,8 @@ def liveness_side_effect(*args, **kwargs): cx2.reset.assert_not_called() # then acquire it again and assert the liveness check was performed - cx3 = pool._acquire(READER_ADDRESS, Deadline(30), liveness_timeout) + cx3 = pool._acquire(READER_ADDRESS, None, Deadline(30), + liveness_timeout) assert cx3 is cx2 cx1.is_idle_for.assert_called_once_with(liveness_timeout) cx1.reset.assert_called_once() @@ -416,6 +425,33 @@ def liveness_side_effect(*args, **kwargs): assert cx3 in pool.connections[cx1.addr] +@mark_sync_test +def test_acquire_accepts_re_auth_as_liveness_check(opener): + pool = Neo4jPool( + opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS + ) + # populate the pool with a connection + cx1 = pool._acquire(READER_ADDRESS, None, Deadline(30), 1) + # make sure we assume the right state + assert cx1.addr == READER_ADDRESS + cx1.is_idle_for.assert_not_called() + cx1.reset.assert_not_called() + + # simulate connections successfully re-authenticating + cx1.re_auth.return_value = True + + # release the connection + pool.release(cx1) + cx1.reset.assert_not_called() + + # then acquire it again and assert the liveness check was performed + cx2 = pool._acquire(READER_ADDRESS, None, Deadline(30), 1) + assert cx2 is cx1 + cx1.is_idle_for.assert_not_called() + cx1.reset.assert_not_called() + assert cx1 in pool.connections[cx1.addr] + + @mark_sync_test def test_multiple_broken_connections_on_close(opener, mocker): def mock_connection_breaks_on_close(cx): @@ -431,8 +467,8 @@ def close_side_effect(): pool = Neo4jPool( opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS ) - cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None, None) - cx2 = pool.acquire(READ_ACCESS, 30, "test_db", None, None) + cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + cx2 = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) pool.release(cx1) pool.release(cx2) @@ -444,7 +480,7 @@ def close_side_effect(): # unreachable cx1.stale.return_value = True - cx3 = pool.acquire(READ_ACCESS, 30, "test_db", None, None) + cx3 = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) assert cx3 is not cx1 assert cx3 is not cx2 @@ -455,11 +491,11 @@ def test_failing_opener_leaves_connections_in_use_alone(opener): pool = Neo4jPool( opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS ) - cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None, None) + cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) opener.side_effect = ServiceUnavailable("Server overloaded") with pytest.raises((ServiceUnavailable, SessionExpired)): - pool.acquire(READ_ACCESS, 30, "test_db", None, None) + pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) assert not cx1.closed() @@ -471,7 +507,7 @@ def test__acquire_new_later_with_room(opener): opener, config, WorkspaceConfig(), ROUTER1_ADDRESS ) assert pool.connections_reservations[READER_ADDRESS] == 0 - creator = pool._acquire_new_later(READER_ADDRESS, Deadline(1)) + creator = pool._acquire_new_later(READER_ADDRESS, None, Deadline(1)) assert pool.connections_reservations[READER_ADDRESS] == 1 assert callable(creator) if Util.is_async_code: @@ -485,10 +521,10 @@ def test__acquire_new_later_without_room(opener): pool = Neo4jPool( opener, config, WorkspaceConfig(), ROUTER1_ADDRESS ) - _ = pool.acquire(READ_ACCESS, 30, "test_db", None, None) + _ = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) # pool is full now assert pool.connections_reservations[READER_ADDRESS] == 0 - creator = pool._acquire_new_later(READER_ADDRESS, Deadline(1)) + creator = pool._acquire_new_later(READER_ADDRESS, None, Deadline(1)) assert pool.connections_reservations[READER_ADDRESS] == 0 assert creator is None @@ -509,11 +545,11 @@ def test_discovery_is_retried(routing_failure_opener, error): opener, PoolConfig(), WorkspaceConfig(), ResolvedAddress(("1.2.3.1", 9999), host_name="host") ) - cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None, None) + cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) pool.release(cx1) pool.routing_tables.get("test_db").ttl = 0 - cx2 = pool.acquire(READ_ACCESS, 30, "test_db", None, None) + cx2 = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) pool.release(cx2) assert pool.routing_tables.get("test_db") @@ -553,12 +589,12 @@ def test_fast_failing_discovery(routing_failure_opener, error): opener, PoolConfig(), WorkspaceConfig(), ResolvedAddress(("1.2.3.1", 9999), host_name="host") ) - cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None, None) + cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) pool.release(cx1) pool.routing_tables.get("test_db").ttl = 0 with pytest.raises(error.__class__) as exc: - pool.acquire(READ_ACCESS, 30, "test_db", None, None) + pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) assert exc.value is error diff --git a/tests/unit/sync/test_driver.py b/tests/unit/sync/test_driver.py index 0f785c64..597c6581 100644 --- a/tests/unit/sync/test_driver.py +++ b/tests/unit/sync/test_driver.py @@ -197,6 +197,7 @@ def test_driver_opens_write_session_by_default(uri, fake_pool, mocker): timeout=mocker.ANY, database=mocker.ANY, bookmarks=mocker.ANY, + auth=mocker.ANY, liveness_check_timeout=mocker.ANY ) tx._begin.assert_called_once_with( diff --git a/tests/unit/sync/work/test_session.py b/tests/unit/sync/work/test_session.py index 9ee9def0..b6764104 100644 --- a/tests/unit/sync/work/test_session.py +++ b/tests/unit/sync/work/test_session.py @@ -409,7 +409,7 @@ def test_with_bookmark_manager( additional_session_bookmarks, mocker ): def update_routing_table_side_effect( - database, imp_user, bookmarks, acquisition_timeout=None, + database, imp_user, bookmarks, auth=None, acquisition_timeout=None, database_callback=None ): if home_db_gets_resolved: From b29eafbbaaa3d9802ad2a40c90f2b74c50227b02 Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Mon, 19 Dec 2022 15:32:08 +0100 Subject: [PATCH 02/23] WIP: TestKit backend: adding renewable auth token support --- src/neo4j/__init__.py | 2 ++ src/neo4j/_async/io/_pool.py | 4 +-- src/neo4j/_sync/io/_pool.py | 4 +-- src/neo4j/api.py | 4 +-- testkitbackend/_async/backend.py | 2 ++ testkitbackend/_async/requests.py | 56 +++++++++++++++++++++++++++++-- testkitbackend/_sync/backend.py | 2 ++ testkitbackend/_sync/requests.py | 56 +++++++++++++++++++++++++++++-- 8 files changed, 118 insertions(+), 12 deletions(-) diff --git a/src/neo4j/__init__.py b/src/neo4j/__init__.py index 3483e537..edb12229 100644 --- a/src/neo4j/__init__.py +++ b/src/neo4j/__init__.py @@ -77,6 +77,7 @@ DEFAULT_DATABASE, kerberos_auth, READ_ACCESS, + RenewableAuth, ServerInfo, SYSTEM_DATABASE, TRUST_ALL_CERTIFICATES, @@ -129,6 +130,7 @@ "Record", "Result", "ResultSummary", + "RenewableAuth", "ServerInfo", "Session", "SessionConfig", diff --git a/src/neo4j/_async/io/_pool.py b/src/neo4j/_async/io/_pool.py index 55927496..381c69ee 100644 --- a/src/neo4j/_async/io/_pool.py +++ b/src/neo4j/_async/io/_pool.py @@ -104,7 +104,7 @@ async def get_auth(self): with self.refreshing_auth_lock: if self.refreshing_auth: # someone else is already getting a new auth - return self.last_auth + return self.last_auth.auth if self.last_auth is None or self.last_auth.expired: with self.refreshing_auth_lock: self.refreshing_auth = True @@ -131,7 +131,7 @@ async def _initialize_auth(self): async def _get_new_auth(self): new_auth = await AsyncUtil.callback(self.pool_config.auth) - if not isinstance(self.last_auth, RenewableAuth): + if not isinstance(new_auth, RenewableAuth): return RenewableAuth(new_auth) return new_auth diff --git a/src/neo4j/_sync/io/_pool.py b/src/neo4j/_sync/io/_pool.py index ef58970a..8020f7b6 100644 --- a/src/neo4j/_sync/io/_pool.py +++ b/src/neo4j/_sync/io/_pool.py @@ -104,7 +104,7 @@ def get_auth(self): with self.refreshing_auth_lock: if self.refreshing_auth: # someone else is already getting a new auth - return self.last_auth + return self.last_auth.auth if self.last_auth is None or self.last_auth.expired: with self.refreshing_auth_lock: self.refreshing_auth = True @@ -131,7 +131,7 @@ def _initialize_auth(self): def _get_new_auth(self): new_auth = Util.callback(self.pool_config.auth) - if not isinstance(self.last_auth, RenewableAuth): + if not isinstance(new_auth, RenewableAuth): return RenewableAuth(new_auth) return new_auth diff --git a/src/neo4j/api.py b/src/neo4j/api.py index 06d75a99..a502e5bf 100644 --- a/src/neo4j/api.py +++ b/src/neo4j/api.py @@ -224,8 +224,8 @@ def expired(self): _TAuthTokenProvider = t.Callable[[], t.Union[ - RenewableAuth, Auth, t.Tuple[t.Any, t.Any], - t.Awaitable[t.Union[RenewableAuth, Auth, t.Tuple[t.Any, t.Any]]] + RenewableAuth, Auth, t.Tuple[t.Any, t.Any], None, + t.Awaitable[t.Union[RenewableAuth, Auth, t.Tuple[t.Any, t.Any], None]] ]] diff --git a/testkitbackend/_async/backend.py b/testkitbackend/_async/backend.py index 12891676..eab20a42 100644 --- a/testkitbackend/_async/backend.py +++ b/testkitbackend/_async/backend.py @@ -55,6 +55,8 @@ def __init__(self, rd, wr): self.drivers = {} self.custom_resolutions = {} self.dns_resolutions = {} + self.auth_token_providers = {} + self.renewable_auth_token_supplies = {} self.bookmark_managers = {} self.bookmarks_consumptions = {} self.bookmarks_supplies = {} diff --git a/testkitbackend/_async/requests.py b/testkitbackend/_async/requests.py index 2fb4b94d..7a8950d4 100644 --- a/testkitbackend/_async/requests.py +++ b/testkitbackend/_async/requests.py @@ -94,10 +94,10 @@ async def GetFeatures(backend, data): def _convert_auth_token(data, key): + if data[key] is None: + return None auth_token = data[key]["data"] - data[key].mark_item_as_read_if_equals( - "name", "AuthorizationToken" - ) + data[key].mark_item_as_read_if_equals("name", "AuthorizationToken") scheme = auth_token["scheme"] if scheme == "basic": auth = neo4j.basic_auth( @@ -120,6 +120,8 @@ def _convert_auth_token(data, key): async def NewDriver(backend, data): auth = _convert_auth_token(data, "authorizationToken") + if auth is None and data.get("authTokenProviderId") is not None: + auth = backend.auth_token_providers[data["authTokenProviderId"]] kwargs = {} if data["resolverRegistered"] or data["domainNameResolverRegistered"]: kwargs["resolver"] = resolution_func( @@ -161,6 +163,54 @@ async def NewDriver(backend, data): await backend.send_response("Driver", {"id": key}) +async def NewAuthTokenProvider(backend, data): + auth_token_provider_id = backend.next_key() + + async def auth_token_provider(): + key = backend.next_key() + await backend.send_response("AuthTokenProviderRequest", { + "id": key, + "authTokenProviderId": auth_token_provider_id, + }) + if not await backend.process_request(): + # connection was closed before end of next message + return None + if key not in backend.renewable_auth_token_supplies: + raise RuntimeError( + "Backend did not receive expected " + f"AuthTokenProviderCompleted message for id {key}" + ) + return backend.renewable_auth_token_supplies.pop(key) + + + backend.auth_token_providers[auth_token_provider_id] = auth_token_provider + await backend.send_response("AuthTokenProvider", + {"id": auth_token_provider_id}) + + +async def AuthTokenProviderClose(backend, data): + auth_token_provider_id = data["id"] + del backend.auth_token_providers[auth_token_provider_id] + await backend.send_response("AuthTokenProvider", + {"id": auth_token_provider_id}) + + +async def AuthTokenProviderCompleted(backend, data): + backend.renewable_auth_token_supplies[data["requestId"]] = \ + parse_renewable_auth(data["auth"]) + + +def parse_renewable_auth(data): + data.mark_item_as_read_if_equals("name", "RenewableAuthToken") + data = data["data"] + auth_token = _convert_auth_token(data, "auth") + if data["expiresInMs"] is not None: + expires_in = data["expiresInMs"] / 1000 + else: + expires_in = None + return neo4j.RenewableAuth(auth_token, expires_in) + + async def VerifyConnectivity(backend, data): driver_id = data["driverId"] driver = backend.drivers[driver_id] diff --git a/testkitbackend/_sync/backend.py b/testkitbackend/_sync/backend.py index d04d7c36..94147522 100644 --- a/testkitbackend/_sync/backend.py +++ b/testkitbackend/_sync/backend.py @@ -55,6 +55,8 @@ def __init__(self, rd, wr): self.drivers = {} self.custom_resolutions = {} self.dns_resolutions = {} + self.auth_token_providers = {} + self.renewable_auth_token_supplies = {} self.bookmark_managers = {} self.bookmarks_consumptions = {} self.bookmarks_supplies = {} diff --git a/testkitbackend/_sync/requests.py b/testkitbackend/_sync/requests.py index 65342a6f..a6026e6e 100644 --- a/testkitbackend/_sync/requests.py +++ b/testkitbackend/_sync/requests.py @@ -94,10 +94,10 @@ def GetFeatures(backend, data): def _convert_auth_token(data, key): + if data[key] is None: + return None auth_token = data[key]["data"] - data[key].mark_item_as_read_if_equals( - "name", "AuthorizationToken" - ) + data[key].mark_item_as_read_if_equals("name", "AuthorizationToken") scheme = auth_token["scheme"] if scheme == "basic": auth = neo4j.basic_auth( @@ -120,6 +120,8 @@ def _convert_auth_token(data, key): def NewDriver(backend, data): auth = _convert_auth_token(data, "authorizationToken") + if auth is None and data.get("authTokenProviderId") is not None: + auth = backend.auth_token_providers[data["authTokenProviderId"]] kwargs = {} if data["resolverRegistered"] or data["domainNameResolverRegistered"]: kwargs["resolver"] = resolution_func( @@ -161,6 +163,54 @@ def NewDriver(backend, data): backend.send_response("Driver", {"id": key}) +def NewAuthTokenProvider(backend, data): + auth_token_provider_id = backend.next_key() + + def auth_token_provider(): + key = backend.next_key() + backend.send_response("AuthTokenProviderRequest", { + "id": key, + "authTokenProviderId": auth_token_provider_id, + }) + if not backend.process_request(): + # connection was closed before end of next message + return None + if key not in backend.renewable_auth_token_supplies: + raise RuntimeError( + "Backend did not receive expected " + f"AuthTokenProviderCompleted message for id {key}" + ) + return backend.renewable_auth_token_supplies.pop(key) + + + backend.auth_token_providers[auth_token_provider_id] = auth_token_provider + backend.send_response("AuthTokenProvider", + {"id": auth_token_provider_id}) + + +def AuthTokenProviderClose(backend, data): + auth_token_provider_id = data["id"] + del backend.auth_token_providers[auth_token_provider_id] + backend.send_response("AuthTokenProvider", + {"id": auth_token_provider_id}) + + +def AuthTokenProviderCompleted(backend, data): + backend.renewable_auth_token_supplies[data["requestId"]] = \ + parse_renewable_auth(data["auth"]) + + +def parse_renewable_auth(data): + data.mark_item_as_read_if_equals("name", "RenewableAuthToken") + data = data["data"] + auth_token = _convert_auth_token(data, "auth") + if data["expiresInMs"] is not None: + expires_in = data["expiresInMs"] / 1000 + else: + expires_in = None + return neo4j.RenewableAuth(auth_token, expires_in) + + def VerifyConnectivity(backend, data): driver_id = data["driverId"] driver = backend.drivers[driver_id] From f0363fb8a3d1599f41c37af4881a1860d73d083b Mon Sep 17 00:00:00 2001 From: Antonio Barcelos Date: Mon, 19 Dec 2022 16:48:54 +0100 Subject: [PATCH 03/23] Fake Time Co-Authored-By: Rouven Bauer --- requirements-dev.txt | 1 + testkitbackend/_async/backend.py | 2 ++ testkitbackend/_async/requests.py | 31 ++++++++++++++++++++++++++++++- testkitbackend/_sync/backend.py | 2 ++ testkitbackend/_sync/requests.py | 31 ++++++++++++++++++++++++++++++- 5 files changed, 65 insertions(+), 2 deletions(-) diff --git a/requirements-dev.txt b/requirements-dev.txt index b070002c..15271e00 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -17,6 +17,7 @@ tomlkit~=0.11.6 # needed for running tests coverage[toml]>=5.5 +freezegun >= 1.2.2 mock>=4.0.3 numpy>=1.7.0 pandas>=1.0.0 diff --git a/testkitbackend/_async/backend.py b/testkitbackend/_async/backend.py index eab20a42..17724a54 100644 --- a/testkitbackend/_async/backend.py +++ b/testkitbackend/_async/backend.py @@ -66,6 +66,8 @@ def __init__(self, rd, wr): self.transactions = {} self.errors = {} self.key = 0 + self.fake_time = None + self.fake_time_ticker = None # Collect all request handlers self._requestHandlers = dict( [m for m in getmembers(requests, isfunction)]) diff --git a/testkitbackend/_async/requests.py b/testkitbackend/_async/requests.py index 7a8950d4..09f5c0d4 100644 --- a/testkitbackend/_async/requests.py +++ b/testkitbackend/_async/requests.py @@ -15,12 +15,14 @@ # See the License for the specific language governing permissions and # limitations under the License. - +import datetime import json import re import warnings from os import path +from freezegun import freeze_time + import neo4j import neo4j.api from neo4j._async_compat.util import AsyncUtil @@ -704,3 +706,30 @@ async def GetRoutingTable(backend, data): addresses = routing_table.__getattribute__(role) response_data[role] = list(map(str, addresses)) await backend.send_response("RoutingTable", response_data) + +async def FakeTimeInstall(backend, _data): + assert backend.fake_time is None + assert backend.fake_time_ticker is None + + backend.fake_time = freeze_time() + backend.fake_time_ticker = backend.fake_time.start() + await backend.send_response("FakeTimeAck", {}) + +async def FakeTimeTick(backend, data): + assert backend.fake_time is not None + assert backend.fake_time_ticker is not None + + increment_ms = data["incrementMs"] + delta = datetime.timedelta(milliseconds=increment_ms) + backend.fake_time_ticker.tick(delta=delta) + await backend.send_response("FakeTimeAck", {}) + + +async def FakeTimeUninstall(backend, _data): + assert backend.fake_time is not None + assert backend.fake_time_ticker is not None + + backend.fake_time.stop() + backend.fake_time_ticker = None + backend.fake_time = None + await backend.send_response("FakeTimeAck", {}) diff --git a/testkitbackend/_sync/backend.py b/testkitbackend/_sync/backend.py index 94147522..052bce25 100644 --- a/testkitbackend/_sync/backend.py +++ b/testkitbackend/_sync/backend.py @@ -66,6 +66,8 @@ def __init__(self, rd, wr): self.transactions = {} self.errors = {} self.key = 0 + self.fake_time = None + self.fake_time_ticker = None # Collect all request handlers self._requestHandlers = dict( [m for m in getmembers(requests, isfunction)]) diff --git a/testkitbackend/_sync/requests.py b/testkitbackend/_sync/requests.py index a6026e6e..43470bdd 100644 --- a/testkitbackend/_sync/requests.py +++ b/testkitbackend/_sync/requests.py @@ -15,12 +15,14 @@ # See the License for the specific language governing permissions and # limitations under the License. - +import datetime import json import re import warnings from os import path +from freezegun import freeze_time + import neo4j import neo4j.api from neo4j._async_compat.util import Util @@ -704,3 +706,30 @@ def GetRoutingTable(backend, data): addresses = routing_table.__getattribute__(role) response_data[role] = list(map(str, addresses)) backend.send_response("RoutingTable", response_data) + +def FakeTimeInstall(backend, _data): + assert backend.fake_time is None + assert backend.fake_time_ticker is None + + backend.fake_time = freeze_time() + backend.fake_time_ticker = backend.fake_time.start() + backend.send_response("FakeTimeAck", {}) + +def FakeTimeTick(backend, data): + assert backend.fake_time is not None + assert backend.fake_time_ticker is not None + + increment_ms = data["incrementMs"] + delta = datetime.timedelta(milliseconds=increment_ms) + backend.fake_time_ticker.tick(delta=delta) + backend.send_response("FakeTimeAck", {}) + + +def FakeTimeUninstall(backend, _data): + assert backend.fake_time is not None + assert backend.fake_time_ticker is not None + + backend.fake_time.stop() + backend.fake_time_ticker = None + backend.fake_time = None + backend.send_response("FakeTimeAck", {}) From c2314b328cc0c1050799bd647c8c23df73604410 Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Thu, 12 Jan 2023 09:52:41 +0100 Subject: [PATCH 04/23] Error handling + TestKit support --- src/neo4j/_async/driver.py | 4 +- src/neo4j/_async/io/_bolt.py | 21 +++++++- src/neo4j/_async/io/_bolt3.py | 16 ++++-- src/neo4j/_async/io/_bolt4.py | 12 +++-- src/neo4j/_async/io/_bolt5.py | 15 ++++-- src/neo4j/_async/io/_common.py | 1 - src/neo4j/_async/io/_pool.py | 81 ++++++++++++++++++++----------- src/neo4j/_sync/driver.py | 4 +- src/neo4j/_sync/io/_bolt.py | 21 +++++++- src/neo4j/_sync/io/_bolt3.py | 16 ++++-- src/neo4j/_sync/io/_bolt4.py | 12 +++-- src/neo4j/_sync/io/_bolt5.py | 15 ++++-- src/neo4j/_sync/io/_common.py | 1 - src/neo4j/_sync/io/_pool.py | 81 ++++++++++++++++++++----------- src/neo4j/exceptions.py | 7 ++- testkitbackend/_async/requests.py | 2 + testkitbackend/_sync/requests.py | 2 + testkitbackend/test_config.json | 1 + 18 files changed, 220 insertions(+), 92 deletions(-) diff --git a/src/neo4j/_async/driver.py b/src/neo4j/_async/driver.py index faf01dc9..fc1b8360 100644 --- a/src/neo4j/_async/driver.py +++ b/src/neo4j/_async/driver.py @@ -143,9 +143,7 @@ def driver( driver_type, security_type, parsed = parse_neo4j_uri(uri) - if not callable(config.get("auth")): - auth = config.get("auth") - config["auth"] = lambda: auth + config["auth"] = auth if callable(auth) else lambda: auth # TODO: 6.0 - remove "trust" config option if "trust" in config.keys(): diff --git a/src/neo4j/_async/io/_bolt.py b/src/neo4j/_async/io/_bolt.py index 50c5e7fe..e9b46185 100644 --- a/src/neo4j/_async/io/_bolt.py +++ b/src/neo4j/_async/io/_bolt.py @@ -20,6 +20,7 @@ import abc import asyncio +import typing as t from collections import deque from logging import getLogger from time import perf_counter @@ -57,6 +58,20 @@ log = getLogger("neo4j") +class ServerStateManagerBase(abc.ABC): + @abc.abstractmethod + def __init__(self, init_state, on_change=None): + ... + + @abc.abstractmethod + def transition(self, message, metadata): + ... + + @abc.abstractmethod + def failed(self): + ... + + class AsyncBolt: """ Server connection for Bolt protocol. @@ -150,6 +165,10 @@ def __del__(self): if not asyncio.iscoroutinefunction(self.close): self.close() + @abc.abstractmethod + def _get_server_state_manager(self) -> ServerStateManagerBase: + ... + @classmethod def _to_auth_dict(cls, auth): # Determine auth details @@ -767,7 +786,7 @@ async def _set_defunct(self, message, error=None, silent=False): # remove the connection from the pool, nor to try to close the # connection again. await self.close() - if self.pool: + if self.pool and not self._get_server_state_manager().failed(): await self.pool.deactivate(address=self.unresolved_address) # Iterate through the outstanding responses, and if any correspond diff --git a/src/neo4j/_async/io/_bolt3.py b/src/neo4j/_async/io/_bolt3.py index 76f856e3..e481de31 100644 --- a/src/neo4j/_async/io/_bolt3.py +++ b/src/neo4j/_async/io/_bolt3.py @@ -36,7 +36,10 @@ NotALeader, ServiceUnavailable, ) -from ._bolt import AsyncBolt +from ._bolt import ( + AsyncBolt, + ServerStateManagerBase, +) from ._common import ( check_supported_server_product, CommitResponse, @@ -56,7 +59,7 @@ class ServerStates(Enum): FAILED = "FAILED" -class ServerStateManager: +class ServerStateManager(ServerStateManagerBase): _STATE_TRANSITIONS: t.Dict[Enum, t.Dict[str, Enum]] = { ServerStates.CONNECTED: { "hello": ServerStates.READY, @@ -94,6 +97,9 @@ def transition(self, message, metadata): if state_before != self.state and callable(self._on_change): self._on_change(state_before, self.state) + def failed(self): + return self.state == ServerStates.FAILED + class AsyncBolt3(AsyncBolt): """ Protocol handler for Bolt 3. @@ -119,6 +125,9 @@ def _on_server_state_change(self, old_state, new_state): log.debug("[#%04X] _: state: %s > %s", self.local_port, old_state.name, new_state.name) + def _get_server_state_manager(self) -> ServerStateManagerBase: + return self._server_state_manager + @property def is_reset(self): # We can't be sure of the server's state if there are still pending @@ -390,8 +399,7 @@ async def _process_message(self, tag, fields): self.pool.on_write_failure(address=self.unresolved_address) raise except Neo4jError as e: - if self.pool and e._invalidates_all_connections(): - await self.pool.mark_all_stale() + await self.pool.on_neo4j_error(e, self.server_info.address) raise else: raise BoltProtocolError("Unexpected response message with signature %02X" % summary_signature, address=self.unresolved_address) diff --git a/src/neo4j/_async/io/_bolt4.py b/src/neo4j/_async/io/_bolt4.py index 2503db71..cd99bca4 100644 --- a/src/neo4j/_async/io/_bolt4.py +++ b/src/neo4j/_async/io/_bolt4.py @@ -19,7 +19,6 @@ from logging import getLogger from ssl import SSLSocket -from ..._async_compat.util import AsyncUtil from ..._exceptions import BoltProtocolError from ...api import ( READ_ACCESS, @@ -34,7 +33,10 @@ NotALeader, ServiceUnavailable, ) -from ._bolt import AsyncBolt +from ._bolt import ( + AsyncBolt, + ServerStateManagerBase, +) from ._bolt3 import ( ServerStateManager, ServerStates, @@ -74,6 +76,9 @@ def _on_server_state_change(self, old_state, new_state): log.debug("[#%04X] _: state: %s > %s", self.local_port, old_state.name, new_state.name) + def _get_server_state_manager(self) -> ServerStateManagerBase: + return self._server_state_manager + @property def is_reset(self): # We can't be sure of the server's state if there are still pending @@ -343,8 +348,7 @@ async def _process_message(self, tag, fields): self.pool.on_write_failure(address=self.unresolved_address) raise except Neo4jError as e: - if self.pool and e._invalidates_all_connections(): - await self.pool.mark_all_stale() + await self.pool.on_neo4j_error(e, self.server_info.address) raise else: raise BoltProtocolError("Unexpected response message with signature " diff --git a/src/neo4j/_async/io/_bolt5.py b/src/neo4j/_async/io/_bolt5.py index a18ee45c..53b2f41c 100644 --- a/src/neo4j/_async/io/_bolt5.py +++ b/src/neo4j/_async/io/_bolt5.py @@ -34,7 +34,10 @@ NotALeader, ServiceUnavailable, ) -from ._bolt import AsyncBolt +from ._bolt import ( + AsyncBolt, + ServerStateManagerBase, +) from ._bolt3 import ( ServerStateManager, ServerStates, @@ -77,6 +80,9 @@ def _on_server_state_change(self, old_state, new_state): log.debug("[#%04X] _: state: %s > %s", self.local_port, old_state.name, new_state.name) + def _get_server_state_manager(self) -> ServerStateManagerBase: + return self._server_state_manager + @property def is_reset(self): # We can't be sure of the server's state if there are still pending @@ -340,8 +346,7 @@ async def _process_message(self, tag, fields): self.pool.on_write_failure(address=self.unresolved_address) raise except Neo4jError as e: - if self.pool and e._invalidates_all_connections(): - await self.pool.mark_all_stale() + await self.pool.on_neo4j_error(e, self.server_info.address) raise else: raise BoltProtocolError( @@ -391,6 +396,10 @@ class ServerStateManager5x1(ServerStateManager): } + def failed(self): + return self.state == ServerStates5x1.FAILED + + class AsyncBolt5x1(AsyncBolt5x0): """Protocol handler for Bolt 5.1.""" diff --git a/src/neo4j/_async/io/_common.py b/src/neo4j/_async/io/_common.py index 1888f0f5..278a41b7 100644 --- a/src/neo4j/_async/io/_common.py +++ b/src/neo4j/_async/io/_common.py @@ -273,7 +273,6 @@ async def on_failure(self, metadata): class CommitResponse(Response): - pass diff --git a/src/neo4j/_async/io/_pool.py b/src/neo4j/_async/io/_pool.py index 381c69ee..dcd459f2 100644 --- a/src/neo4j/_async/io/_pool.py +++ b/src/neo4j/_async/io/_pool.py @@ -87,8 +87,7 @@ def __init__(self, opener, pool_config, workspace_config): self.lock = AsyncCooperativeRLock() self.cond = AsyncCondition(self.lock) self.refreshing_auth = False - self.refreshing_auth_lock = AsyncCooperativeLock() - self.initializing_auth_lock = AsyncLock() + self.auth_condition = AsyncCondition() self.last_auth: t.Optional[RenewableAuth] = None async def __aenter__(self): @@ -98,38 +97,49 @@ async def __aexit__(self, exc_type, exc_value, traceback): await self.close() async def get_auth(self): - if await self._initialize_auth(): - return self.last_auth.auth + async with self.auth_condition: + auth_missing = self.last_auth is None + needs_refresh = auth_missing or self.last_auth.expired + if not needs_refresh: + return self.last_auth.auth - with self.refreshing_auth_lock: if self.refreshing_auth: - # someone else is already getting a new auth - return self.last_auth.auth - if self.last_auth is None or self.last_auth.expired: - with self.refreshing_auth_lock: + # someone else is already getting new auth info + if not auth_missing: + # there is old auth info we can use in the meantime + return self.last_auth.auth + else: + while self.last_auth is None: + await self.auth_condition.wait() + return self.last_auth.auth + else: + # we need to get new auth info self.refreshing_auth = True - try: - new_auth = await self._get_new_auth() - with self.refreshing_auth_lock: - self.last_auth = new_auth - self.refreshing_auth = False - except: - with self.refreshing_auth_lock: - self.refreshing_auth = False - raise - return self.last_auth.auth - - async def _initialize_auth(self): - if self.last_auth is not None: - return False - async with self.initializing_auth_lock: - if self.last_auth is not None: - # someone else initialized the auth - return True - self.last_auth = await self._get_new_auth() - return True + + auth = await self._get_new_auth() + async with self.auth_condition: + self.last_auth = auth + self.refreshing_auth = False + self.auth_condition.notify_all() + return self.last_auth.auth + + async def force_new_auth(self): + async with self.auth_condition: + log.debug("[#0000] _: force new auth info") + if self.refreshing_auth: + return + self.last_auth = None + self.refreshing_auth = True + + auth = await self._get_new_auth() + async with self.auth_condition: + self.last_auth = auth + self.refreshing_auth = False + self.auth_condition.notify_all() async def _get_new_auth(self): + log.debug("[#0000] _: requesting new auth info from %r", + self.pool_config.auth) new_auth = await AsyncUtil.callback(self.pool_config.auth) if not isinstance(new_auth, RenewableAuth): return RenewableAuth(new_auth) @@ -429,6 +439,19 @@ def on_write_failure(self, address): "No write service available for pool {}".format(self) ) + async def on_neo4j_error(self, error, address): + assert isinstance(error, Neo4jError) + if error._unauthenticates_all_connections(): + log.debug( + "[#0000] _: mark all connections to %r as " + "unauthenticated", address + ) + with self.lock: + for connection in self.connections.get(address, ()): + connection.auth_dict = {} + if error._requires_new_credentials(): + await self.force_new_auth() + async def close(self): """ Close all connections and empty the pool. This method is thread safe. diff --git a/src/neo4j/_sync/driver.py b/src/neo4j/_sync/driver.py index 31cf7ff9..28aa1712 100644 --- a/src/neo4j/_sync/driver.py +++ b/src/neo4j/_sync/driver.py @@ -142,9 +142,7 @@ def driver( driver_type, security_type, parsed = parse_neo4j_uri(uri) - if not callable(config.get("auth")): - auth = config.get("auth") - config["auth"] = lambda: auth + config["auth"] = auth if callable(auth) else lambda: auth # TODO: 6.0 - remove "trust" config option if "trust" in config.keys(): diff --git a/src/neo4j/_sync/io/_bolt.py b/src/neo4j/_sync/io/_bolt.py index 2cc6faa0..cb860e06 100644 --- a/src/neo4j/_sync/io/_bolt.py +++ b/src/neo4j/_sync/io/_bolt.py @@ -20,6 +20,7 @@ import abc import asyncio +import typing as t from collections import deque from logging import getLogger from time import perf_counter @@ -57,6 +58,20 @@ log = getLogger("neo4j") +class ServerStateManagerBase(abc.ABC): + @abc.abstractmethod + def __init__(self, init_state, on_change=None): + ... + + @abc.abstractmethod + def transition(self, message, metadata): + ... + + @abc.abstractmethod + def failed(self): + ... + + class Bolt: """ Server connection for Bolt protocol. @@ -150,6 +165,10 @@ def __del__(self): if not asyncio.iscoroutinefunction(self.close): self.close() + @abc.abstractmethod + def _get_server_state_manager(self) -> ServerStateManagerBase: + ... + @classmethod def _to_auth_dict(cls, auth): # Determine auth details @@ -767,7 +786,7 @@ def _set_defunct(self, message, error=None, silent=False): # remove the connection from the pool, nor to try to close the # connection again. self.close() - if self.pool: + if self.pool and not self._get_server_state_manager().failed(): self.pool.deactivate(address=self.unresolved_address) # Iterate through the outstanding responses, and if any correspond diff --git a/src/neo4j/_sync/io/_bolt3.py b/src/neo4j/_sync/io/_bolt3.py index 1bfacaf7..7229fc99 100644 --- a/src/neo4j/_sync/io/_bolt3.py +++ b/src/neo4j/_sync/io/_bolt3.py @@ -36,7 +36,10 @@ NotALeader, ServiceUnavailable, ) -from ._bolt import Bolt +from ._bolt import ( + Bolt, + ServerStateManagerBase, +) from ._common import ( check_supported_server_product, CommitResponse, @@ -56,7 +59,7 @@ class ServerStates(Enum): FAILED = "FAILED" -class ServerStateManager: +class ServerStateManager(ServerStateManagerBase): _STATE_TRANSITIONS: t.Dict[Enum, t.Dict[str, Enum]] = { ServerStates.CONNECTED: { "hello": ServerStates.READY, @@ -94,6 +97,9 @@ def transition(self, message, metadata): if state_before != self.state and callable(self._on_change): self._on_change(state_before, self.state) + def failed(self): + return self.state == ServerStates.FAILED + class Bolt3(Bolt): """ Protocol handler for Bolt 3. @@ -119,6 +125,9 @@ def _on_server_state_change(self, old_state, new_state): log.debug("[#%04X] _: state: %s > %s", self.local_port, old_state.name, new_state.name) + def _get_server_state_manager(self) -> ServerStateManagerBase: + return self._server_state_manager + @property def is_reset(self): # We can't be sure of the server's state if there are still pending @@ -390,8 +399,7 @@ def _process_message(self, tag, fields): self.pool.on_write_failure(address=self.unresolved_address) raise except Neo4jError as e: - if self.pool and e._invalidates_all_connections(): - self.pool.mark_all_stale() + self.pool.on_neo4j_error(e, self.server_info.address) raise else: raise BoltProtocolError("Unexpected response message with signature %02X" % summary_signature, address=self.unresolved_address) diff --git a/src/neo4j/_sync/io/_bolt4.py b/src/neo4j/_sync/io/_bolt4.py index 741fb195..c374822d 100644 --- a/src/neo4j/_sync/io/_bolt4.py +++ b/src/neo4j/_sync/io/_bolt4.py @@ -19,7 +19,6 @@ from logging import getLogger from ssl import SSLSocket -from ..._async_compat.util import Util from ..._exceptions import BoltProtocolError from ...api import ( READ_ACCESS, @@ -34,7 +33,10 @@ NotALeader, ServiceUnavailable, ) -from ._bolt import Bolt +from ._bolt import ( + Bolt, + ServerStateManagerBase, +) from ._bolt3 import ( ServerStateManager, ServerStates, @@ -74,6 +76,9 @@ def _on_server_state_change(self, old_state, new_state): log.debug("[#%04X] _: state: %s > %s", self.local_port, old_state.name, new_state.name) + def _get_server_state_manager(self) -> ServerStateManagerBase: + return self._server_state_manager + @property def is_reset(self): # We can't be sure of the server's state if there are still pending @@ -343,8 +348,7 @@ def _process_message(self, tag, fields): self.pool.on_write_failure(address=self.unresolved_address) raise except Neo4jError as e: - if self.pool and e._invalidates_all_connections(): - self.pool.mark_all_stale() + self.pool.on_neo4j_error(e, self.server_info.address) raise else: raise BoltProtocolError("Unexpected response message with signature " diff --git a/src/neo4j/_sync/io/_bolt5.py b/src/neo4j/_sync/io/_bolt5.py index e9ba82ad..0dfc922d 100644 --- a/src/neo4j/_sync/io/_bolt5.py +++ b/src/neo4j/_sync/io/_bolt5.py @@ -34,7 +34,10 @@ NotALeader, ServiceUnavailable, ) -from ._bolt import Bolt +from ._bolt import ( + Bolt, + ServerStateManagerBase, +) from ._bolt3 import ( ServerStateManager, ServerStates, @@ -77,6 +80,9 @@ def _on_server_state_change(self, old_state, new_state): log.debug("[#%04X] _: state: %s > %s", self.local_port, old_state.name, new_state.name) + def _get_server_state_manager(self) -> ServerStateManagerBase: + return self._server_state_manager + @property def is_reset(self): # We can't be sure of the server's state if there are still pending @@ -340,8 +346,7 @@ def _process_message(self, tag, fields): self.pool.on_write_failure(address=self.unresolved_address) raise except Neo4jError as e: - if self.pool and e._invalidates_all_connections(): - self.pool.mark_all_stale() + self.pool.on_neo4j_error(e, self.server_info.address) raise else: raise BoltProtocolError( @@ -391,6 +396,10 @@ class ServerStateManager5x1(ServerStateManager): } + def failed(self): + return self.state == ServerStates5x1.FAILED + + class Bolt5x1(Bolt5x0): """Protocol handler for Bolt 5.1.""" diff --git a/src/neo4j/_sync/io/_common.py b/src/neo4j/_sync/io/_common.py index 85cff520..a0250071 100644 --- a/src/neo4j/_sync/io/_common.py +++ b/src/neo4j/_sync/io/_common.py @@ -273,7 +273,6 @@ def on_failure(self, metadata): class CommitResponse(Response): - pass diff --git a/src/neo4j/_sync/io/_pool.py b/src/neo4j/_sync/io/_pool.py index 8020f7b6..ab382d2f 100644 --- a/src/neo4j/_sync/io/_pool.py +++ b/src/neo4j/_sync/io/_pool.py @@ -87,8 +87,7 @@ def __init__(self, opener, pool_config, workspace_config): self.lock = CooperativeRLock() self.cond = Condition(self.lock) self.refreshing_auth = False - self.refreshing_auth_lock = CooperativeLock() - self.initializing_auth_lock = Lock() + self.auth_condition = Condition() self.last_auth: t.Optional[RenewableAuth] = None def __enter__(self): @@ -98,38 +97,49 @@ def __exit__(self, exc_type, exc_value, traceback): self.close() def get_auth(self): - if self._initialize_auth(): - return self.last_auth.auth + with self.auth_condition: + auth_missing = self.last_auth is None + needs_refresh = auth_missing or self.last_auth.expired + if not needs_refresh: + return self.last_auth.auth - with self.refreshing_auth_lock: if self.refreshing_auth: - # someone else is already getting a new auth - return self.last_auth.auth - if self.last_auth is None or self.last_auth.expired: - with self.refreshing_auth_lock: + # someone else is already getting new auth info + if not auth_missing: + # there is old auth info we can use in the meantime + return self.last_auth.auth + else: + while self.last_auth is None: + self.auth_condition.wait() + return self.last_auth.auth + else: + # we need to get new auth info self.refreshing_auth = True - try: - new_auth = self._get_new_auth() - with self.refreshing_auth_lock: - self.last_auth = new_auth - self.refreshing_auth = False - except: - with self.refreshing_auth_lock: - self.refreshing_auth = False - raise - return self.last_auth.auth - - def _initialize_auth(self): - if self.last_auth is not None: - return False - with self.initializing_auth_lock: - if self.last_auth is not None: - # someone else initialized the auth - return True - self.last_auth = self._get_new_auth() - return True + + auth = self._get_new_auth() + with self.auth_condition: + self.last_auth = auth + self.refreshing_auth = False + self.auth_condition.notify_all() + return self.last_auth.auth + + def force_new_auth(self): + with self.auth_condition: + log.debug("[#0000] _: force new auth info") + if self.refreshing_auth: + return + self.last_auth = None + self.refreshing_auth = True + + auth = self._get_new_auth() + with self.auth_condition: + self.last_auth = auth + self.refreshing_auth = False + self.auth_condition.notify_all() def _get_new_auth(self): + log.debug("[#0000] _: requesting new auth info from %r", + self.pool_config.auth) new_auth = Util.callback(self.pool_config.auth) if not isinstance(new_auth, RenewableAuth): return RenewableAuth(new_auth) @@ -429,6 +439,19 @@ def on_write_failure(self, address): "No write service available for pool {}".format(self) ) + def on_neo4j_error(self, error, address): + assert isinstance(error, Neo4jError) + if error._unauthenticates_all_connections(): + log.debug( + "[#0000] _: mark all connections to %r as " + "unauthenticated", address + ) + with self.lock: + for connection in self.connections.get(address, ()): + connection.auth_dict = {} + if error._requires_new_credentials(): + self.force_new_auth() + def close(self): """ Close all connections and empty the pool. This method is thread safe. diff --git a/src/neo4j/exceptions.py b/src/neo4j/exceptions.py index a0ce9ed0..22f770ea 100644 --- a/src/neo4j/exceptions.py +++ b/src/neo4j/exceptions.py @@ -230,15 +230,18 @@ def is_retryable(self) -> bool: """ return False - def _invalidates_all_connections(self) -> bool: + def _unauthenticates_all_connections(self) -> bool: return self.code == "Neo.ClientError.Security.AuthorizationExpired" + def _requires_new_credentials(self) -> bool: + return self.code == "Neo.ClientError.Security.TokenExpired" + # TODO: 6.0 - Remove this alias invalidates_all_connections = deprecated( "Neo4jError.invalidates_all_connections is deprecated and will be " "removed in a future version. It is an internal method and not meant " "for external use." - )(_invalidates_all_connections) + )(_unauthenticates_all_connections) def _is_fatal_during_discovery(self) -> bool: # checks if the code is an error that is caused by the client. In this diff --git a/testkitbackend/_async/requests.py b/testkitbackend/_async/requests.py index 09f5c0d4..23764a1d 100644 --- a/testkitbackend/_async/requests.py +++ b/testkitbackend/_async/requests.py @@ -124,6 +124,8 @@ async def NewDriver(backend, data): auth = _convert_auth_token(data, "authorizationToken") if auth is None and data.get("authTokenProviderId") is not None: auth = backend.auth_token_providers[data["authTokenProviderId"]] + else: + data.mark_item_as_read_if_equals("authTokenProviderId", None) kwargs = {} if data["resolverRegistered"] or data["domainNameResolverRegistered"]: kwargs["resolver"] = resolution_func( diff --git a/testkitbackend/_sync/requests.py b/testkitbackend/_sync/requests.py index 43470bdd..38e11576 100644 --- a/testkitbackend/_sync/requests.py +++ b/testkitbackend/_sync/requests.py @@ -124,6 +124,8 @@ def NewDriver(backend, data): auth = _convert_auth_token(data, "authorizationToken") if auth is None and data.get("authTokenProviderId") is not None: auth = backend.auth_token_providers[data["authTokenProviderId"]] + else: + data.mark_item_as_read_if_equals("authTokenProviderId", None) kwargs = {} if data["resolverRegistered"] or data["domainNameResolverRegistered"]: kwargs["resolver"] = resolution_func( diff --git a/testkitbackend/test_config.json b/testkitbackend/test_config.json index 9ae7e7d7..635e757b 100644 --- a/testkitbackend/test_config.json +++ b/testkitbackend/test_config.json @@ -60,6 +60,7 @@ "ConfHint:connection.recv_timeout_seconds": true, + "Backend:MockTime": true, "Backend:RTFetch": true, "Backend:RTForceUpdate": true } From 5ca2a5158213dc1b6eaa0e331736299dbc38b586 Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Thu, 12 Jan 2023 17:10:13 +0100 Subject: [PATCH 05/23] Refactoring + unit tests --- src/neo4j/_async/io/_bolt.py | 6 +- src/neo4j/_async/io/_pool.py | 4 +- src/neo4j/_sync/io/_bolt.py | 6 +- src/neo4j/_sync/io/_pool.py | 4 +- tests/unit/async_/io/test_class_bolt5x1.py | 3 +- tests/unit/async_/io/test_neo4j_pool.py | 59 ++++++++++ tests/unit/mixed/io/_common.py | 124 +++++++++++++++++++++ tests/unit/mixed/io/test_direct.py | 107 +----------------- tests/unit/mixed/io/test_pool_async.py | 72 ++++++++++++ tests/unit/mixed/io/test_pool_sync.py | 79 +++++++++++++ tests/unit/sync/io/test_class_bolt5x1.py | 3 +- tests/unit/sync/io/test_neo4j_pool.py | 59 ++++++++++ 12 files changed, 414 insertions(+), 112 deletions(-) create mode 100644 tests/unit/mixed/io/_common.py create mode 100644 tests/unit/mixed/io/test_pool_async.py create mode 100644 tests/unit/mixed/io/test_pool_sync.py diff --git a/src/neo4j/_async/io/_bolt.py b/src/neo4j/_async/io/_bolt.py index e9b46185..89facf97 100644 --- a/src/neo4j/_async/io/_bolt.py +++ b/src/neo4j/_async/io/_bolt.py @@ -207,7 +207,7 @@ def supports_multiple_databases(self): @property @abc.abstractmethod def supports_re_auth(self): - """TODO""" + """Whether the connection version supports re-authentication.""" pass def assert_re_auth_support(self): @@ -459,6 +459,10 @@ def logoff(self, dehydration_hooks=None, hydration_hooks=None): """Append a LOGOFF message to the outgoing queue.""" pass + def mark_unauthenticated(self): + """Mark the connection as unauthenticated.""" + self.auth_dict = {} + async def re_auth( self, auth, dehydration_hooks=None, hydration_hooks=None ): diff --git a/src/neo4j/_async/io/_pool.py b/src/neo4j/_async/io/_pool.py index dcd459f2..20b32656 100644 --- a/src/neo4j/_async/io/_pool.py +++ b/src/neo4j/_async/io/_pool.py @@ -31,9 +31,7 @@ from ..._async_compat.concurrency import ( AsyncCondition, - AsyncCooperativeLock, AsyncCooperativeRLock, - AsyncLock, AsyncRLock, ) from ..._async_compat.network import AsyncNetworkUtil @@ -448,7 +446,7 @@ async def on_neo4j_error(self, error, address): ) with self.lock: for connection in self.connections.get(address, ()): - connection.auth_dict = {} + connection.mark_unauthenticated() if error._requires_new_credentials(): await self.force_new_auth() diff --git a/src/neo4j/_sync/io/_bolt.py b/src/neo4j/_sync/io/_bolt.py index cb860e06..6b9b8c6f 100644 --- a/src/neo4j/_sync/io/_bolt.py +++ b/src/neo4j/_sync/io/_bolt.py @@ -207,7 +207,7 @@ def supports_multiple_databases(self): @property @abc.abstractmethod def supports_re_auth(self): - """TODO""" + """Whether the connection version supports re-authentication.""" pass def assert_re_auth_support(self): @@ -459,6 +459,10 @@ def logoff(self, dehydration_hooks=None, hydration_hooks=None): """Append a LOGOFF message to the outgoing queue.""" pass + def mark_unauthenticated(self): + """Mark the connection as unauthenticated.""" + self.auth_dict = {} + def re_auth( self, auth, dehydration_hooks=None, hydration_hooks=None ): diff --git a/src/neo4j/_sync/io/_pool.py b/src/neo4j/_sync/io/_pool.py index ab382d2f..6495a337 100644 --- a/src/neo4j/_sync/io/_pool.py +++ b/src/neo4j/_sync/io/_pool.py @@ -31,9 +31,7 @@ from ..._async_compat.concurrency import ( Condition, - CooperativeLock, CooperativeRLock, - Lock, RLock, ) from ..._async_compat.network import NetworkUtil @@ -448,7 +446,7 @@ def on_neo4j_error(self, error, address): ) with self.lock: for connection in self.connections.get(address, ()): - connection.auth_dict = {} + connection.mark_unauthenticated() if error._requires_new_credentials(): self.force_new_auth() diff --git a/tests/unit/async_/io/test_class_bolt5x1.py b/tests/unit/async_/io/test_class_bolt5x1.py index a2fdc280..696a5348 100644 --- a/tests/unit/async_/io/test_class_bolt5x1.py +++ b/tests/unit/async_/io/test_class_bolt5x1.py @@ -285,7 +285,7 @@ async def test_logon(fake_socket_pair): @mark_async_test -async def test_re_auth(fake_socket_pair): +async def test_re_auth(fake_socket_pair, mocker): auth = neo4j.Auth("basic", "alice123", "supersecret123") address = ("127.0.0.1", 7687) sockets = fake_socket_pair(address, @@ -297,6 +297,7 @@ async def test_re_auth(fake_socket_pair): ) connection = AsyncBolt5x1(address, sockets.client, PoolConfig.max_connection_lifetime) + connection.pool = mocker.AsyncMock() with pytest.raises(neo4j.exceptions.Neo4jError): await connection.re_auth(auth) tag, fields = await sockets.server.pop_message() diff --git a/tests/unit/async_/io/test_neo4j_pool.py b/tests/unit/async_/io/test_neo4j_pool.py index cb47c159..35253b58 100644 --- a/tests/unit/async_/io/test_neo4j_pool.py +++ b/tests/unit/async_/io/test_neo4j_pool.py @@ -602,3 +602,62 @@ async def test_fast_failing_discovery(routing_failure_opener, error): # reader # failed router assert len(opener.connections) == 3 + + + +@pytest.mark.parametrize( + ("error", "marks_unauthenticated", "fetches_new"), + ( + (Neo4jError.hydrate("message", args[0]), *args[1:]) + for args in ( + ("Neo.ClientError.Database.DatabaseNotFound", False, False), + ("Neo.ClientError.Statement.TypeError", False, False), + ("Neo.ClientError.Statement.ArgumentError", False, False), + ("Neo.ClientError.Request.Invalid", False, False), + ("Neo.ClientError.Security.AuthenticationRateLimit", False, False), + ("Neo.ClientError.Security.CredentialsExpired", False, False), + ("Neo.ClientError.Security.Forbidden", False, False), + ("Neo.ClientError.Security.Unauthorized", False, False), + ("Neo.ClientError.Security.MadeUpError", False, False), + ("Neo.ClientError.Security.TokenExpired", False, True), + ("Neo.ClientError.Security.AuthorizationExpired", True, False), + ) + ) +) +@mark_async_test +async def test_connection_error_callback( + opener, error, marks_unauthenticated, fetches_new, mocker +): + pool = AsyncNeo4jPool( + opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS + ) + force_new_auth_mock = mocker.patch.object( + pool, "force_new_auth", autospec=True + ) + cxs_read = [ + await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + for _ in range(5) + ] + cxs_write = [ + await pool.acquire(WRITE_ACCESS, 30, "test_db", None, None, None) + for _ in range(5) + ] + + force_new_auth_mock.assert_not_called() + for cx in cxs_read + cxs_write: + cx.mark_unauthenticated.assert_not_called() + + await pool.on_neo4j_error(error, cxs_read[0].addr) + + if fetches_new: + force_new_auth_mock.assert_awaited_once() + else: + force_new_auth_mock.assert_not_called() + + for cx in cxs_read: + if marks_unauthenticated: + cx.mark_unauthenticated.assert_called_once() + else: + cx.mark_unauthenticated.assert_not_called() + for cx in cxs_write: + cx.mark_unauthenticated.assert_not_called() diff --git a/tests/unit/mixed/io/_common.py b/tests/unit/mixed/io/_common.py new file mode 100644 index 00000000..7cba228a --- /dev/null +++ b/tests/unit/mixed/io/_common.py @@ -0,0 +1,124 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import asyncio +import time +from asyncio import ( + Condition as AsyncCondition, + Lock as AsyncLock, +) +from threading import ( + Condition, + Lock, +) + +from neo4j._async_compat.shims import wait_for + + +class MultiEvent: + # Adopted from threading.Event + + def __init__(self): + super().__init__() + self._cond = Condition(Lock()) + self._counter = 0 + + def _reset_internal_locks(self): + # private! called by Thread._reset_internal_locks by _after_fork() + self._cond.__init__(Lock()) + + def counter(self): + return self._counter + + def increment(self): + with self._cond: + self._counter += 1 + self._cond.notify_all() + + def decrement(self): + with self._cond: + self._counter -= 1 + self._cond.notify_all() + + def clear(self): + with self._cond: + self._counter = 0 + self._cond.notify_all() + + def wait(self, value=0, timeout=None): + with self._cond: + t_start = time.time() + while True: + if value == self._counter: + return True + if timeout is None: + time_left = None + else: + time_left = timeout - (time.time() - t_start) + if time_left <= 0: + return False + if not self._cond.wait(time_left): + return False + + +class AsyncMultiEvent: + # Adopted from threading.Event + + def __init__(self): + super().__init__() + self._cond = AsyncCondition() + self._counter = 0 + + def _reset_internal_locks(self): + # private! called by Thread._reset_internal_locks by _after_fork() + self._cond.__init__(AsyncLock()) + + def counter(self): + return self._counter + + async def increment(self): + async with self._cond: + self._counter += 1 + self._cond.notify_all() + + async def decrement(self): + async with self._cond: + self._counter -= 1 + self._cond.notify_all() + + async def clear(self): + async with self._cond: + self._counter = 0 + self._cond.notify_all() + + async def wait(self, value=0, timeout=None): + async with self._cond: + t_start = time.time() + while True: + if value == self._counter: + return True + if timeout is None: + time_left = None + else: + time_left = timeout - (time.time() - t_start) + if time_left <= 0: + return False + try: + await wait_for(self._cond.wait(), time_left) + except asyncio.TimeoutError: + return False diff --git a/tests/unit/mixed/io/test_direct.py b/tests/unit/mixed/io/test_direct.py index 2eac6aa6..63fba1e7 100644 --- a/tests/unit/mixed/io/test_direct.py +++ b/tests/unit/mixed/io/test_direct.py @@ -17,14 +17,8 @@ import asyncio -import time -from asyncio import ( - Condition as AsyncCondition, - Event as AsyncEvent, - Lock as AsyncLock, -) +from asyncio import Event as AsyncEvent from threading import ( - Condition, Event, Lock, Thread, @@ -32,105 +26,14 @@ import pytest -from neo4j._async_compat.shims import wait_for from neo4j._deadline import Deadline from ...async_.io.test_direct import AsyncFakeBoltPool from ...sync.io.test_direct import FakeBoltPool - - -class MultiEvent: - # Adopted from threading.Event - - def __init__(self): - super().__init__() - self._cond = Condition(Lock()) - self._counter = 0 - - def _reset_internal_locks(self): - # private! called by Thread._reset_internal_locks by _after_fork() - self._cond.__init__(Lock()) - - def counter(self): - return self._counter - - def increment(self): - with self._cond: - self._counter += 1 - self._cond.notify_all() - - def decrement(self): - with self._cond: - self._counter -= 1 - self._cond.notify_all() - - def clear(self): - with self._cond: - self._counter = 0 - self._cond.notify_all() - - def wait(self, value=0, timeout=None): - with self._cond: - t_start = time.time() - while True: - if value == self._counter: - return True - if timeout is None: - time_left = None - else: - time_left = timeout - (time.time() - t_start) - if time_left <= 0: - return False - if not self._cond.wait(time_left): - return False - - -class AsyncMultiEvent: - # Adopted from threading.Event - - def __init__(self): - super().__init__() - self._cond = AsyncCondition() - self._counter = 0 - - def _reset_internal_locks(self): - # private! called by Thread._reset_internal_locks by _after_fork() - self._cond.__init__(AsyncLock()) - - def counter(self): - return self._counter - - async def increment(self): - async with self._cond: - self._counter += 1 - self._cond.notify_all() - - async def decrement(self): - async with self._cond: - self._counter -= 1 - self._cond.notify_all() - - async def clear(self): - async with self._cond: - self._counter = 0 - self._cond.notify_all() - - async def wait(self, value=0, timeout=None): - async with self._cond: - t_start = time.time() - while True: - if value == self._counter: - return True - if timeout is None: - time_left = None - else: - time_left = timeout - (time.time() - t_start) - if time_left <= 0: - return False - try: - await wait_for(self._cond.wait(), time_left) - except asyncio.TimeoutError: - return False +from ._common import ( + AsyncMultiEvent, + MultiEvent, +) class TestMixedConnectionPoolTestCase: diff --git a/tests/unit/mixed/io/test_pool_async.py b/tests/unit/mixed/io/test_pool_async.py new file mode 100644 index 00000000..58f17616 --- /dev/null +++ b/tests/unit/mixed/io/test_pool_async.py @@ -0,0 +1,72 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import asyncio +from asyncio import Condition + +from ...async_.fixtures import * # fixtures necessary for pytest +from ...async_.io.test_neo4j_pool import * +from ._common import AsyncMultiEvent + + +@pytest.mark.asyncio +async def test_force_new_auth_blocks(opener): + count = 0 + done = False + condition = Condition() + event = AsyncMultiEvent() + + async def auth_provider(): + nonlocal done, count + count += 1 + if count == 1: + return "user1", "pass1" + await event.increment() + async with condition: + await event.wait(2) + await condition.wait() + await asyncio.sleep(0.1) # block + done = True + return "user", "password" + + config = PoolConfig() + config.auth = auth_provider + pool = AsyncNeo4jPool( + opener, config, WorkspaceConfig(), ROUTER1_ADDRESS + ) + + assert count == 0 + cx = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + await pool.release(cx) + assert count == 1 + + async def task1(): + assert count == 1 + await pool.force_new_auth() + assert count == 2 + + async def task2(): + await event.increment() + await event.wait(2) + async with condition: + condition.notify() + cx = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + assert done # assert waited for blocking auth provider + await pool.release(cx) + + await asyncio.gather(task1(), task2()) diff --git a/tests/unit/mixed/io/test_pool_sync.py b/tests/unit/mixed/io/test_pool_sync.py new file mode 100644 index 00000000..58d98d8c --- /dev/null +++ b/tests/unit/mixed/io/test_pool_sync.py @@ -0,0 +1,79 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from threading import ( + Condition, + Thread, +) +from time import sleep + +from ...sync.fixtures import * # fixtures necessary for pytest +from ...sync.io.test_neo4j_pool import * +from ._common import MultiEvent + + +def test_force_new_auth_blocks(opener): + count = 0 + done = False + condition = Condition() + event = MultiEvent() + + def auth_provider(): + nonlocal done, condition, count + count += 1 + if count == 1: + return "user1", "pass1" + event.wait(1) + with condition: + event.increment() + condition.wait() + sleep(0.1) # block + done = True + return "user", "password" + + config = PoolConfig() + config.auth = auth_provider + pool = Neo4jPool( + opener, config, WorkspaceConfig(), ROUTER1_ADDRESS + ) + + assert count == 0 + cx = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + pool.release(cx) + assert count == 1 + + def task1(): + assert count == 1 + pool.force_new_auth() + assert count == 2 + + def task2(): + event.increment() + event.wait(2) + with condition: + condition.notify() + cx = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + assert done # assert waited for blocking auth provider + pool.release(cx) + + t1 = Thread(target=task1) + t2 = Thread(target=task2) + t1.start() + t2.start() + t1.join() + t2.join() diff --git a/tests/unit/sync/io/test_class_bolt5x1.py b/tests/unit/sync/io/test_class_bolt5x1.py index 011a4a35..dedffb97 100644 --- a/tests/unit/sync/io/test_class_bolt5x1.py +++ b/tests/unit/sync/io/test_class_bolt5x1.py @@ -285,7 +285,7 @@ def test_logon(fake_socket_pair): @mark_sync_test -def test_re_auth(fake_socket_pair): +def test_re_auth(fake_socket_pair, mocker): auth = neo4j.Auth("basic", "alice123", "supersecret123") address = ("127.0.0.1", 7687) sockets = fake_socket_pair(address, @@ -297,6 +297,7 @@ def test_re_auth(fake_socket_pair): ) connection = Bolt5x1(address, sockets.client, PoolConfig.max_connection_lifetime) + connection.pool = mocker.Mock() with pytest.raises(neo4j.exceptions.Neo4jError): connection.re_auth(auth) tag, fields = sockets.server.pop_message() diff --git a/tests/unit/sync/io/test_neo4j_pool.py b/tests/unit/sync/io/test_neo4j_pool.py index 424228fe..1a060ea1 100644 --- a/tests/unit/sync/io/test_neo4j_pool.py +++ b/tests/unit/sync/io/test_neo4j_pool.py @@ -602,3 +602,62 @@ def test_fast_failing_discovery(routing_failure_opener, error): # reader # failed router assert len(opener.connections) == 3 + + + +@pytest.mark.parametrize( + ("error", "marks_unauthenticated", "fetches_new"), + ( + (Neo4jError.hydrate("message", args[0]), *args[1:]) + for args in ( + ("Neo.ClientError.Database.DatabaseNotFound", False, False), + ("Neo.ClientError.Statement.TypeError", False, False), + ("Neo.ClientError.Statement.ArgumentError", False, False), + ("Neo.ClientError.Request.Invalid", False, False), + ("Neo.ClientError.Security.AuthenticationRateLimit", False, False), + ("Neo.ClientError.Security.CredentialsExpired", False, False), + ("Neo.ClientError.Security.Forbidden", False, False), + ("Neo.ClientError.Security.Unauthorized", False, False), + ("Neo.ClientError.Security.MadeUpError", False, False), + ("Neo.ClientError.Security.TokenExpired", False, True), + ("Neo.ClientError.Security.AuthorizationExpired", True, False), + ) + ) +) +@mark_sync_test +def test_connection_error_callback( + opener, error, marks_unauthenticated, fetches_new, mocker +): + pool = Neo4jPool( + opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS + ) + force_new_auth_mock = mocker.patch.object( + pool, "force_new_auth", autospec=True + ) + cxs_read = [ + pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + for _ in range(5) + ] + cxs_write = [ + pool.acquire(WRITE_ACCESS, 30, "test_db", None, None, None) + for _ in range(5) + ] + + force_new_auth_mock.assert_not_called() + for cx in cxs_read + cxs_write: + cx.mark_unauthenticated.assert_not_called() + + pool.on_neo4j_error(error, cxs_read[0].addr) + + if fetches_new: + force_new_auth_mock.assert_called_once() + else: + force_new_auth_mock.assert_not_called() + + for cx in cxs_read: + if marks_unauthenticated: + cx.mark_unauthenticated.assert_called_once() + else: + cx.mark_unauthenticated.assert_not_called() + for cx in cxs_write: + cx.mark_unauthenticated.assert_not_called() From dab4e51304754fec9d9de3bd9aa84dffc4681071 Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Fri, 13 Jan 2023 11:57:28 +0100 Subject: [PATCH 06/23] Fully pipelined re-auth --- src/neo4j/_async/io/_bolt.py | 6 ++--- src/neo4j/_async/io/_pool.py | 5 +--- src/neo4j/_sync/io/_bolt.py | 4 +-- src/neo4j/_sync/io/_pool.py | 5 +--- testkitbackend/__main__.py | 2 ++ tests/unit/async_/fixtures/fake_connection.py | 2 +- tests/unit/async_/io/test_class_bolt3.py | 4 +-- tests/unit/async_/io/test_class_bolt4x0.py | 4 +-- tests/unit/async_/io/test_class_bolt4x1.py | 4 +-- tests/unit/async_/io/test_class_bolt4x2.py | 4 +-- tests/unit/async_/io/test_class_bolt4x3.py | 4 +-- tests/unit/async_/io/test_class_bolt4x4.py | 4 +-- tests/unit/async_/io/test_class_bolt5x0.py | 4 +-- tests/unit/async_/io/test_class_bolt5x1.py | 4 ++- tests/unit/async_/io/test_direct.py | 2 +- tests/unit/async_/io/test_neo4j_pool.py | 27 ------------------- tests/unit/sync/io/test_class_bolt5x1.py | 4 ++- tests/unit/sync/io/test_neo4j_pool.py | 27 ------------------- 18 files changed, 29 insertions(+), 87 deletions(-) diff --git a/src/neo4j/_async/io/_bolt.py b/src/neo4j/_async/io/_bolt.py index 89facf97..f07e5020 100644 --- a/src/neo4j/_async/io/_bolt.py +++ b/src/neo4j/_async/io/_bolt.py @@ -463,10 +463,10 @@ def mark_unauthenticated(self): """Mark the connection as unauthenticated.""" self.auth_dict = {} - async def re_auth( + def re_auth( self, auth, dehydration_hooks=None, hydration_hooks=None ): - """Append LOGON, LOGOFF to the outgoing queue, flush, then receive. + """Append LOGON, LOGOFF to the outgoing queue. If auth is the same as the current auth, this method does nothing. @@ -480,8 +480,6 @@ async def re_auth( self.auth_dict = new_auth_dict self.logon(dehydration_hooks=dehydration_hooks, hydration_hooks=hydration_hooks) - await self.send_all() - await self.fetch_all() return True diff --git a/src/neo4j/_async/io/_pool.py b/src/neo4j/_async/io/_pool.py index 20b32656..b2166155 100644 --- a/src/neo4j/_async/io/_pool.py +++ b/src/neo4j/_async/io/_pool.py @@ -246,10 +246,7 @@ async def health_check(connection_, deadline_): or connection_.stale()): return False try: - if await connection_.re_auth(auth or await self.get_auth()): - # no need for an extra liveness check if the connection was - # successfully re-authenticated - return True + connection_.re_auth(auth or await self.get_auth()) except ConfigurationError: # protocol does not support re-authentication if auth: diff --git a/src/neo4j/_sync/io/_bolt.py b/src/neo4j/_sync/io/_bolt.py index 6b9b8c6f..2c97bbd2 100644 --- a/src/neo4j/_sync/io/_bolt.py +++ b/src/neo4j/_sync/io/_bolt.py @@ -466,7 +466,7 @@ def mark_unauthenticated(self): def re_auth( self, auth, dehydration_hooks=None, hydration_hooks=None ): - """Append LOGON, LOGOFF to the outgoing queue, flush, then receive. + """Append LOGON, LOGOFF to the outgoing queue. If auth is the same as the current auth, this method does nothing. @@ -480,8 +480,6 @@ def re_auth( self.auth_dict = new_auth_dict self.logon(dehydration_hooks=dehydration_hooks, hydration_hooks=hydration_hooks) - self.send_all() - self.fetch_all() return True diff --git a/src/neo4j/_sync/io/_pool.py b/src/neo4j/_sync/io/_pool.py index 6495a337..8b9c4598 100644 --- a/src/neo4j/_sync/io/_pool.py +++ b/src/neo4j/_sync/io/_pool.py @@ -246,10 +246,7 @@ def health_check(connection_, deadline_): or connection_.stale()): return False try: - if connection_.re_auth(auth or self.get_auth()): - # no need for an extra liveness check if the connection was - # successfully re-authenticated - return True + connection_.re_auth(auth or self.get_auth()) except ConfigurationError: # protocol does not support re-authentication if auth: diff --git a/testkitbackend/__main__.py b/testkitbackend/__main__.py index 91a4a6e1..94739b13 100644 --- a/testkitbackend/__main__.py +++ b/testkitbackend/__main__.py @@ -30,6 +30,7 @@ def sync_main(): server = Server(("0.0.0.0", 9876)) + print("Start serving") while True: server.handle_request() @@ -39,6 +40,7 @@ async def main(): server = AsyncServer(("0.0.0.0", 9876)) await server.start() try: + print("Start serving") await server.serve_forever() finally: server.stop() diff --git a/tests/unit/async_/fixtures/fake_connection.py b/tests/unit/async_/fixtures/fake_connection.py index ae4b2dc9..48e357d8 100644 --- a/tests/unit/async_/fixtures/fake_connection.py +++ b/tests/unit/async_/fixtures/fake_connection.py @@ -50,7 +50,7 @@ def __init__(self, *args, **kwargs): self.attach_mock(mock.Mock(return_value=False), "stale") self.attach_mock(mock.Mock(return_value=False), "closed") self.attach_mock(mock.Mock(return_value=False), "socket") - self.attach_mock(mock.AsyncMock(return_value=False), "re_auth") + self.attach_mock(mock.Mock(return_value=False), "re_auth") self.attach_mock(mock.Mock(), "unresolved_address") def close_side_effect(): diff --git a/tests/unit/async_/io/test_class_bolt3.py b/tests/unit/async_/io/test_class_bolt3.py index 8f3bda27..a1500f44 100644 --- a/tests/unit/async_/io/test_class_bolt3.py +++ b/tests/unit/async_/io/test_class_bolt3.py @@ -176,7 +176,7 @@ async def test_re_auth_noop(auth, fake_socket, mocker): PoolConfig.max_connection_lifetime, auth=auth) logon_spy = mocker.spy(connection, "logon") logoff_spy = mocker.spy(connection, "logoff") - res = await connection.re_auth(auth) + res = connection.re_auth(auth) assert res is False logon_spy.assert_not_called() @@ -201,4 +201,4 @@ async def test_re_auth(auth1, auth2, fake_socket): PoolConfig.max_connection_lifetime, auth=auth1) with pytest.raises(ConfigurationError, match="Session level authentication is not supported"): - await connection.re_auth(auth2) + connection.re_auth(auth2) diff --git a/tests/unit/async_/io/test_class_bolt4x0.py b/tests/unit/async_/io/test_class_bolt4x0.py index 659db673..f0df6c40 100644 --- a/tests/unit/async_/io/test_class_bolt4x0.py +++ b/tests/unit/async_/io/test_class_bolt4x0.py @@ -272,7 +272,7 @@ async def test_re_auth_noop(auth, fake_socket, mocker): PoolConfig.max_connection_lifetime, auth=auth) logon_spy = mocker.spy(connection, "logon") logoff_spy = mocker.spy(connection, "logoff") - res = await connection.re_auth(auth) + res = connection.re_auth(auth) assert res is False logon_spy.assert_not_called() @@ -297,4 +297,4 @@ async def test_re_auth(auth1, auth2, fake_socket): PoolConfig.max_connection_lifetime, auth=auth1) with pytest.raises(ConfigurationError, match="Session level authentication is not supported"): - await connection.re_auth(auth2) + connection.re_auth(auth2) diff --git a/tests/unit/async_/io/test_class_bolt4x1.py b/tests/unit/async_/io/test_class_bolt4x1.py index d236cac4..1509ddc7 100644 --- a/tests/unit/async_/io/test_class_bolt4x1.py +++ b/tests/unit/async_/io/test_class_bolt4x1.py @@ -291,7 +291,7 @@ async def test_re_auth_noop(auth, fake_socket, mocker): PoolConfig.max_connection_lifetime, auth=auth) logon_spy = mocker.spy(connection, "logon") logoff_spy = mocker.spy(connection, "logoff") - res = await connection.re_auth(auth) + res = connection.re_auth(auth) assert res is False logon_spy.assert_not_called() @@ -316,4 +316,4 @@ async def test_re_auth(auth1, auth2, fake_socket): PoolConfig.max_connection_lifetime, auth=auth1) with pytest.raises(ConfigurationError, match="Session level authentication is not supported"): - await connection.re_auth(auth2) + connection.re_auth(auth2) diff --git a/tests/unit/async_/io/test_class_bolt4x2.py b/tests/unit/async_/io/test_class_bolt4x2.py index 92f4617e..ad1dec8a 100644 --- a/tests/unit/async_/io/test_class_bolt4x2.py +++ b/tests/unit/async_/io/test_class_bolt4x2.py @@ -292,7 +292,7 @@ async def test_re_auth_noop(auth, fake_socket, mocker): PoolConfig.max_connection_lifetime, auth=auth) logon_spy = mocker.spy(connection, "logon") logoff_spy = mocker.spy(connection, "logoff") - res = await connection.re_auth(auth) + res = connection.re_auth(auth) assert res is False logon_spy.assert_not_called() @@ -317,4 +317,4 @@ async def test_re_auth(auth1, auth2, fake_socket): PoolConfig.max_connection_lifetime, auth=auth1) with pytest.raises(ConfigurationError, match="Session level authentication is not supported"): - await connection.re_auth(auth2) + connection.re_auth(auth2) diff --git a/tests/unit/async_/io/test_class_bolt4x3.py b/tests/unit/async_/io/test_class_bolt4x3.py index 47d10ee9..436f2baf 100644 --- a/tests/unit/async_/io/test_class_bolt4x3.py +++ b/tests/unit/async_/io/test_class_bolt4x3.py @@ -319,7 +319,7 @@ async def test_re_auth_noop(auth, fake_socket, mocker): PoolConfig.max_connection_lifetime, auth=auth) logon_spy = mocker.spy(connection, "logon") logoff_spy = mocker.spy(connection, "logoff") - res = await connection.re_auth(auth) + res = connection.re_auth(auth) assert res is False logon_spy.assert_not_called() @@ -344,4 +344,4 @@ async def test_re_auth(auth1, auth2, fake_socket): PoolConfig.max_connection_lifetime, auth=auth1) with pytest.raises(ConfigurationError, match="Session level authentication is not supported"): - await connection.re_auth(auth2) + connection.re_auth(auth2) diff --git a/tests/unit/async_/io/test_class_bolt4x4.py b/tests/unit/async_/io/test_class_bolt4x4.py index 918b1498..a6b0eae3 100644 --- a/tests/unit/async_/io/test_class_bolt4x4.py +++ b/tests/unit/async_/io/test_class_bolt4x4.py @@ -332,7 +332,7 @@ async def test_re_auth_noop(auth, fake_socket, mocker): PoolConfig.max_connection_lifetime, auth=auth) logon_spy = mocker.spy(connection, "logon") logoff_spy = mocker.spy(connection, "logoff") - res = await connection.re_auth(auth) + res = connection.re_auth(auth) assert res is False logon_spy.assert_not_called() @@ -357,4 +357,4 @@ async def test_re_auth(auth1, auth2, fake_socket): PoolConfig.max_connection_lifetime, auth=auth1) with pytest.raises(ConfigurationError, match="Session level authentication is not supported"): - await connection.re_auth(auth2) + connection.re_auth(auth2) diff --git a/tests/unit/async_/io/test_class_bolt5x0.py b/tests/unit/async_/io/test_class_bolt5x0.py index f8c0c61e..f1f549d5 100644 --- a/tests/unit/async_/io/test_class_bolt5x0.py +++ b/tests/unit/async_/io/test_class_bolt5x0.py @@ -330,7 +330,7 @@ async def test_re_auth_noop(auth, fake_socket, mocker): PoolConfig.max_connection_lifetime, auth=auth) logon_spy = mocker.spy(connection, "logon") logoff_spy = mocker.spy(connection, "logoff") - res = await connection.re_auth(auth) + res = connection.re_auth(auth) assert res is False logon_spy.assert_not_called() @@ -355,4 +355,4 @@ async def test_re_auth(auth1, auth2, fake_socket): PoolConfig.max_connection_lifetime, auth=auth1) with pytest.raises(ConfigurationError, match="Session level authentication is not supported"): - await connection.re_auth(auth2) + connection.re_auth(auth2) diff --git a/tests/unit/async_/io/test_class_bolt5x1.py b/tests/unit/async_/io/test_class_bolt5x1.py index 696a5348..4ba9b47a 100644 --- a/tests/unit/async_/io/test_class_bolt5x1.py +++ b/tests/unit/async_/io/test_class_bolt5x1.py @@ -298,8 +298,10 @@ async def test_re_auth(fake_socket_pair, mocker): connection = AsyncBolt5x1(address, sockets.client, PoolConfig.max_connection_lifetime) connection.pool = mocker.AsyncMock() + connection.re_auth(auth) + await connection.send_all() with pytest.raises(neo4j.exceptions.Neo4jError): - await connection.re_auth(auth) + await connection.fetch_all() tag, fields = await sockets.server.pop_message() assert tag == b"\x6B" # LOGOFF assert len(fields) == 0 diff --git a/tests/unit/async_/io/test_direct.py b/tests/unit/async_/io/test_direct.py index 15e3c709..f6c57ccb 100644 --- a/tests/unit/async_/io/test_direct.py +++ b/tests/unit/async_/io/test_direct.py @@ -65,7 +65,7 @@ def stale(self): async def reset(self): pass - async def re_auth(self, auth): + def re_auth(self, auth): return False def close(self): diff --git a/tests/unit/async_/io/test_neo4j_pool.py b/tests/unit/async_/io/test_neo4j_pool.py index 35253b58..a417423f 100644 --- a/tests/unit/async_/io/test_neo4j_pool.py +++ b/tests/unit/async_/io/test_neo4j_pool.py @@ -425,33 +425,6 @@ def liveness_side_effect(*args, **kwargs): assert cx3 in pool.connections[cx1.addr] -@mark_async_test -async def test_acquire_accepts_re_auth_as_liveness_check(opener): - pool = AsyncNeo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS - ) - # populate the pool with a connection - cx1 = await pool._acquire(READER_ADDRESS, None, Deadline(30), 1) - # make sure we assume the right state - assert cx1.addr == READER_ADDRESS - cx1.is_idle_for.assert_not_called() - cx1.reset.assert_not_called() - - # simulate connections successfully re-authenticating - cx1.re_auth.return_value = True - - # release the connection - await pool.release(cx1) - cx1.reset.assert_not_called() - - # then acquire it again and assert the liveness check was performed - cx2 = await pool._acquire(READER_ADDRESS, None, Deadline(30), 1) - assert cx2 is cx1 - cx1.is_idle_for.assert_not_called() - cx1.reset.assert_not_called() - assert cx1 in pool.connections[cx1.addr] - - @mark_async_test async def test_multiple_broken_connections_on_close(opener, mocker): def mock_connection_breaks_on_close(cx): diff --git a/tests/unit/sync/io/test_class_bolt5x1.py b/tests/unit/sync/io/test_class_bolt5x1.py index dedffb97..8e9b6f11 100644 --- a/tests/unit/sync/io/test_class_bolt5x1.py +++ b/tests/unit/sync/io/test_class_bolt5x1.py @@ -298,8 +298,10 @@ def test_re_auth(fake_socket_pair, mocker): connection = Bolt5x1(address, sockets.client, PoolConfig.max_connection_lifetime) connection.pool = mocker.Mock() + connection.re_auth(auth) + connection.send_all() with pytest.raises(neo4j.exceptions.Neo4jError): - connection.re_auth(auth) + connection.fetch_all() tag, fields = sockets.server.pop_message() assert tag == b"\x6B" # LOGOFF assert len(fields) == 0 diff --git a/tests/unit/sync/io/test_neo4j_pool.py b/tests/unit/sync/io/test_neo4j_pool.py index 1a060ea1..24957c72 100644 --- a/tests/unit/sync/io/test_neo4j_pool.py +++ b/tests/unit/sync/io/test_neo4j_pool.py @@ -425,33 +425,6 @@ def liveness_side_effect(*args, **kwargs): assert cx3 in pool.connections[cx1.addr] -@mark_sync_test -def test_acquire_accepts_re_auth_as_liveness_check(opener): - pool = Neo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS - ) - # populate the pool with a connection - cx1 = pool._acquire(READER_ADDRESS, None, Deadline(30), 1) - # make sure we assume the right state - assert cx1.addr == READER_ADDRESS - cx1.is_idle_for.assert_not_called() - cx1.reset.assert_not_called() - - # simulate connections successfully re-authenticating - cx1.re_auth.return_value = True - - # release the connection - pool.release(cx1) - cx1.reset.assert_not_called() - - # then acquire it again and assert the liveness check was performed - cx2 = pool._acquire(READER_ADDRESS, None, Deadline(30), 1) - assert cx2 is cx1 - cx1.is_idle_for.assert_not_called() - cx1.reset.assert_not_called() - assert cx1 in pool.connections[cx1.addr] - - @mark_sync_test def test_multiple_broken_connections_on_close(opener, mocker): def mock_connection_breaks_on_close(cx): From 9a05ca9fc5a3bdef6498029592e1972153cd1de1 Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Fri, 13 Jan 2023 13:10:22 +0100 Subject: [PATCH 07/23] clean-up --- src/neo4j/scratch_11.py | 28 ---------------------------- 1 file changed, 28 deletions(-) delete mode 100644 src/neo4j/scratch_11.py diff --git a/src/neo4j/scratch_11.py b/src/neo4j/scratch_11.py deleted file mode 100644 index 6a80b32c..00000000 --- a/src/neo4j/scratch_11.py +++ /dev/null @@ -1,28 +0,0 @@ -from time import time - -import neo4j -from neo4j.debug import watch - - -# watch("neo4j") - - -URI = "neo4j://localhost:7687" -USER = "neo4j" -PASSWORD = "pass" - - - -def main(): - start = time() - with neo4j.GraphDatabase.driver(URI, auth=(USER, PASSWORD)) as driver: - for _ in range(10000): - with driver.session() as session: - value = list(range(100)) - list(session.run("RETURN $value", value=value)) - end = time() - print(f"Time taken: {end - start}s") - - -if __name__ == '__main__': - main() From e7078fa7e7e6d6a57aa09134dd43892847217ffb Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Mon, 30 Jan 2023 12:37:21 +0100 Subject: [PATCH 08/23] Implement driver.verify_authentication + refine docs --- docs/source/api.rst | 22 +++++- docs/source/async_api.rst | 8 ++- src/neo4j/_async/driver.py | 138 +++++++++++++++++++++++++++++++++---- src/neo4j/_sync/driver.py | 138 +++++++++++++++++++++++++++++++++---- src/neo4j/api.py | 50 ++++++++++---- src/neo4j/time/__init__.py | 4 +- 6 files changed, 317 insertions(+), 43 deletions(-) diff --git a/docs/source/api.rst b/docs/source/api.rst index 82f7241c..94183b37 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -107,10 +107,11 @@ Auth To authenticate with Neo4j the authentication details are supplied at driver creation. -The auth token is an object of the class :class:`neo4j.Auth` containing the details. +The auth token is an object of the class :class:`neo4j.Auth` containing static details or a a callable that returns a :class:`neo4j.RenewableAuth` object. .. autoclass:: neo4j.Auth +.. autoclass:: neo4j.RenewableAuth Example: @@ -154,7 +155,8 @@ Closing a driver will immediately shut down all connections in the pool. .. autoclass:: neo4j.Driver() :members: session, query_bookmark_manager, encrypted, close, - verify_connectivity, get_server_info + verify_connectivity, get_server_info, verify_authentication, + supports_session_auth, supports_multi_db .. method:: execute_query(query, parameters_=None,routing_=neo4j.RoutingControl.WRITERS, database_=None, impersonated_user_=None, bookmark_manager_=self.query_bookmark_manager, result_transformer_=Result.to_eager_result, **kwargs) @@ -726,6 +728,7 @@ Session .. automethod:: execute_write + Query ===== @@ -746,6 +749,7 @@ To construct a :class:`neo4j.Session` use the :meth:`neo4j.Driver.session` metho + :ref:`default-access-mode-ref` + :ref:`fetch-size-ref` + :ref:`bookmark-manager-ref` ++ :ref:`session-auth-ref` .. _bookmarks-ref: @@ -939,6 +943,20 @@ See :class:`.BookmarkManager` for more information. It might be changed or removed any time even without prior notice. +.. _session-auth-ref: + +``auth`` +-------- +Optional :class:`neo4j.Auth` or ``(user, password)``-tuple. Use this overwrite the +authentication information for the session. +This requires the server to support re-authentication on the protocol level. You can +check this by calling :meth:`.Driver.supports_session_auth` / :meth:`.AsyncDriver.supports_session_auth`. + +:Type: :data:`None`, :class:`.Auth` or ``(user, password)``-tuple +:Default: :data:`None` - use the authentication information provided during driver creation. + +.. versionadded:: 5.x + *********** diff --git a/docs/source/async_api.rst b/docs/source/async_api.rst index 08e0b3d2..c2b8e413 100644 --- a/docs/source/async_api.rst +++ b/docs/source/async_api.rst @@ -136,7 +136,8 @@ Closing a driver will immediately shut down all connections in the pool. .. autoclass:: neo4j.AsyncDriver() :members: session, query_bookmark_manager, encrypted, close, - verify_connectivity, get_server_info + verify_connectivity, get_server_info, verify_authentication, + supports_session_auth, supports_multi_db .. method:: execute_query(query, parameters_=None, routing_=neo4j.RoutingControl.WRITERS, database_=None, impersonated_user_=None, bookmark_manager_=self.query_bookmark_manager, result_transformer_=AsyncResult.to_eager_result, **kwargs) :async: @@ -334,8 +335,9 @@ Async Driver Configuration ========================== :class:`neo4j.AsyncDriver` is configured exactly like :class:`neo4j.Driver` -(see :ref:`driver-configuration-ref`). The only difference is that the async -driver accepts an async custom resolver function: +(see :ref:`driver-configuration-ref`). The only differences are that the async +driver accepts an async custom resolver function (see :ref:`async-resolver-ref`) +as well as an async auth token provider (see :class:`neo4j.RenewableAuth`). .. _async-resolver-ref: diff --git a/src/neo4j/_async/driver.py b/src/neo4j/_async/driver.py index 5f65d1b3..0345ce61 100644 --- a/src/neo4j/_async/driver.py +++ b/src/neo4j/_async/driver.py @@ -49,7 +49,6 @@ from .._work import EagerResult from ..addressing import Address from ..api import ( - _TAuthTokenProvider, AsyncBookmarkManager, Auth, BookmarkManager, @@ -59,6 +58,7 @@ parse_neo4j_uri, parse_routing_context, READ_ACCESS, + RenewableAuth, SECURITY_TYPE_SECURE, SECURITY_TYPE_SELF_SIGNED_CERTIFICATE, ServerInfo, @@ -71,6 +71,7 @@ URI_SCHEME_NEO4J_SECURE, URI_SCHEME_NEO4J_SELF_SIGNED_CERTIFICATE, ) +from ..exceptions import Neo4jError from .bookmark_manager import ( AsyncNeo4jBookmarkManager, TBmConsumer as _TBmConsumer, @@ -103,6 +104,12 @@ class _DefaultEnum(Enum): _T = t.TypeVar("_T") +_TAuthTokenProvider = t.Callable[[], t.Union[ + RenewableAuth, Auth, t.Tuple[t.Any, t.Any], None, + t.Awaitable[t.Union[RenewableAuth, Auth, t.Tuple[t.Any, t.Any], None]] +]] + + class AsyncGraphDatabase: """Accessor for :class:`neo4j.AsyncDriver` construction. """ @@ -858,7 +865,6 @@ async def verify_connectivity( else: - # TODO: 6.0 - remove config argument async def verify_connectivity(self, **config) -> None: """Verify that the driver can establish a connection to the server. @@ -877,7 +883,7 @@ async def verify_connectivity(self, **config) -> None: They might be changed or removed in any future version without prior notice. - :raises DriverError: if the driver cannot connect to the remote. + :raises Exception: if the driver cannot connect to the remote. Use the exception to further understand the cause of the connectivity problem. @@ -893,8 +899,7 @@ async def verify_connectivity(self, **config) -> None: "changed or removed in any future version without prior " "notice." ) - async with self.session(**config) as session: - await session._get_server_info() + await self._get_server_info() if t.TYPE_CHECKING: @@ -945,7 +950,7 @@ async def get_server_info(self, **config) -> ServerInfo: They might be changed or removed in any future version without prior notice. - :raises DriverError: if the driver cannot connect to the remote. + :raises Exception: if the driver cannot connect to the remote. Use the exception to further understand the cause of the connectivity problem. @@ -954,14 +959,12 @@ async def get_server_info(self, **config) -> ServerInfo: if config: experimental_warn( "All configuration key-word arguments to " - "verify_connectivity() are experimental. They might be " + "get_server_info() are experimental. They might be " "changed or removed in any future version without prior " "notice." ) - async with self.session(**config) as session: - return await session._get_server_info() + return await self._get_server_info() - @experimental("Feature support query, based on Bolt protocol version and Neo4j server version will change in the future.") async def supports_multi_db(self) -> bool: """ Check if the server or cluster supports multi-databases. @@ -969,14 +972,125 @@ async def supports_multi_db(self) -> bool: supports multi-databases, otherwise false. .. note:: - Feature support query, based on Bolt Protocol Version and Neo4j - server version will change in the future. + Feature support query based solely on the Bolt protocol version. + The feature might still be disabled on the server side even if this + function return :const:`True`. It just guarantees that the driver + won't throw a :exc:`ConfigurationError` when trying to use this + driver feature. """ async with self.session() as session: await session._connect(READ_ACCESS) assert session._connection return session._connection.supports_multiple_databases + if t.TYPE_CHECKING: + + async def verify_authentication( + self, + auth: t.Union[Auth, t.Tuple[t.Any, t.Any]], + # all other arguments are experimental + # they may be change or removed any time without prior notice + session_connection_timeout: float = ..., + connection_acquisition_timeout: float = ..., + max_transaction_retry_time: float = ..., + database: t.Optional[str] = ..., + fetch_size: int = ..., + impersonated_user: t.Optional[str] = ..., + bookmarks: t.Union[t.Iterable[str], Bookmarks, None] = ..., + default_access_mode: str = ..., + bookmark_manager: t.Union[ + AsyncBookmarkManager, BookmarkManager, None + ] = ..., + + # undocumented/unsupported options + initial_retry_delay: float = ..., + retry_delay_multiplier: float = ..., + retry_delay_jitter_factor: float = ... + ) -> bool: + ... + + else: + + async def verify_authentication( + self, auth: t.Union[Auth, t.Tuple[t.Any, t.Any]], **config + ) -> bool: + """Verify that the authentication information is valid. + + Like :meth:`.verify_connectivity`, but for checking authentication. + + Try to establish a working read connection to the remote server or + a member of a cluster and exchange some data. In a cluster, there + is no guarantee about which server will be contacted. If the data + exchange is successful, the authentication information is valid and + :const:`True` is returned. Otherwise, the error will be matched + against a list of known authentication errors. If the error is on + that list, :const:`False` is returned indicating that the + authentication information is invalid. Otherwise, the error is + re-raised. + + :param auth: authentication information to verify. + Same as the session config :ref:`auth-ref`. + :param config: accepts the same configuration key-word arguments as + :meth:`session`. + + .. warning:: + All configuration key-word arguments (except ``auth``) are + experimental. They might be changed or removed in any + future version without prior notice. + + :raises Exception: if the driver cannot connect to the remote. + Use the exception to further understand the cause of the + connectivity problem. + + .. versionadded:: 5.x + """ + if config: + experimental_warn( + "All configuration key-word arguments but auth to " + "verify_authentication() are experimental. They might be " + "changed or removed in any future version without prior " + "notice." + ) + config["auth"] = auth + try: + await self._get_server_info(**config) + except Neo4jError as exc: + if exc.code in ( + "Neo.ClientError.Security.CredentialsExpired", + "Neo.ClientError.Security.Forbidden", + "Neo.ClientError.Security.TokenExpired", + "Neo.ClientError.Security.Unauthorized", + ): + return False + raise + return True + + + async def supports_session_auth(self) -> bool: + """Check if the remote supports connection re-authentication. + + :returns: Returns true if the server or cluster the driver connects to + supports re-authentication of existing connections, otherwise + false. + + .. note:: + Feature support query based solely on the Bolt protocol version. + The feature might still be disabled on the server side even if this + function return :const:`True`. It just guarantees that the driver + won't throw a :exc:`ConfigurationError` when trying to use this + driver feature. + + .. versionadded:: 5.x + """ + async with self.session() as session: + await session._connect(READ_ACCESS) + assert session._connection + return session._connection.supports_re_auth + + async def _get_server_info(self, **config) -> ServerInfo: + async with self.session(**config) as session: + return await session._get_server_info() + async def _work( tx: AsyncManagedTransaction, diff --git a/src/neo4j/_sync/driver.py b/src/neo4j/_sync/driver.py index 6fc9b2aa..731668fe 100644 --- a/src/neo4j/_sync/driver.py +++ b/src/neo4j/_sync/driver.py @@ -48,7 +48,6 @@ from .._work import EagerResult from ..addressing import Address from ..api import ( - _TAuthTokenProvider, Auth, BookmarkManager, Bookmarks, @@ -57,6 +56,7 @@ parse_neo4j_uri, parse_routing_context, READ_ACCESS, + RenewableAuth, SECURITY_TYPE_SECURE, SECURITY_TYPE_SELF_SIGNED_CERTIFICATE, ServerInfo, @@ -69,6 +69,7 @@ URI_SCHEME_NEO4J_SECURE, URI_SCHEME_NEO4J_SELF_SIGNED_CERTIFICATE, ) +from ..exceptions import Neo4jError from .bookmark_manager import ( Neo4jBookmarkManager, TBmConsumer as _TBmConsumer, @@ -101,6 +102,12 @@ class _DefaultEnum(Enum): _T = t.TypeVar("_T") +_TAuthTokenProvider = t.Callable[[], t.Union[ + RenewableAuth, Auth, t.Tuple[t.Any, t.Any], None, + t.Union[t.Union[RenewableAuth, Auth, t.Tuple[t.Any, t.Any], None]] +]] + + class GraphDatabase: """Accessor for :class:`neo4j.Driver` construction. """ @@ -856,7 +863,6 @@ def verify_connectivity( else: - # TODO: 6.0 - remove config argument def verify_connectivity(self, **config) -> None: """Verify that the driver can establish a connection to the server. @@ -875,7 +881,7 @@ def verify_connectivity(self, **config) -> None: They might be changed or removed in any future version without prior notice. - :raises DriverError: if the driver cannot connect to the remote. + :raises Exception: if the driver cannot connect to the remote. Use the exception to further understand the cause of the connectivity problem. @@ -891,8 +897,7 @@ def verify_connectivity(self, **config) -> None: "changed or removed in any future version without prior " "notice." ) - with self.session(**config) as session: - session._get_server_info() + self._get_server_info() if t.TYPE_CHECKING: @@ -943,7 +948,7 @@ def get_server_info(self, **config) -> ServerInfo: They might be changed or removed in any future version without prior notice. - :raises DriverError: if the driver cannot connect to the remote. + :raises Exception: if the driver cannot connect to the remote. Use the exception to further understand the cause of the connectivity problem. @@ -952,14 +957,12 @@ def get_server_info(self, **config) -> ServerInfo: if config: experimental_warn( "All configuration key-word arguments to " - "verify_connectivity() are experimental. They might be " + "get_server_info() are experimental. They might be " "changed or removed in any future version without prior " "notice." ) - with self.session(**config) as session: - return session._get_server_info() + return self._get_server_info() - @experimental("Feature support query, based on Bolt protocol version and Neo4j server version will change in the future.") def supports_multi_db(self) -> bool: """ Check if the server or cluster supports multi-databases. @@ -967,14 +970,125 @@ def supports_multi_db(self) -> bool: supports multi-databases, otherwise false. .. note:: - Feature support query, based on Bolt Protocol Version and Neo4j - server version will change in the future. + Feature support query based solely on the Bolt protocol version. + The feature might still be disabled on the server side even if this + function return :const:`True`. It just guarantees that the driver + won't throw a :exc:`ConfigurationError` when trying to use this + driver feature. """ with self.session() as session: session._connect(READ_ACCESS) assert session._connection return session._connection.supports_multiple_databases + if t.TYPE_CHECKING: + + def verify_authentication( + self, + auth: t.Union[Auth, t.Tuple[t.Any, t.Any]], + # all other arguments are experimental + # they may be change or removed any time without prior notice + session_connection_timeout: float = ..., + connection_acquisition_timeout: float = ..., + max_transaction_retry_time: float = ..., + database: t.Optional[str] = ..., + fetch_size: int = ..., + impersonated_user: t.Optional[str] = ..., + bookmarks: t.Union[t.Iterable[str], Bookmarks, None] = ..., + default_access_mode: str = ..., + bookmark_manager: t.Union[ + BookmarkManager, BookmarkManager, None + ] = ..., + + # undocumented/unsupported options + initial_retry_delay: float = ..., + retry_delay_multiplier: float = ..., + retry_delay_jitter_factor: float = ... + ) -> bool: + ... + + else: + + def verify_authentication( + self, auth: t.Union[Auth, t.Tuple[t.Any, t.Any]], **config + ) -> bool: + """Verify that the authentication information is valid. + + Like :meth:`.verify_connectivity`, but for checking authentication. + + Try to establish a working read connection to the remote server or + a member of a cluster and exchange some data. In a cluster, there + is no guarantee about which server will be contacted. If the data + exchange is successful, the authentication information is valid and + :const:`True` is returned. Otherwise, the error will be matched + against a list of known authentication errors. If the error is on + that list, :const:`False` is returned indicating that the + authentication information is invalid. Otherwise, the error is + re-raised. + + :param auth: authentication information to verify. + Same as the session config :ref:`auth-ref`. + :param config: accepts the same configuration key-word arguments as + :meth:`session`. + + .. warning:: + All configuration key-word arguments (except ``auth``) are + experimental. They might be changed or removed in any + future version without prior notice. + + :raises Exception: if the driver cannot connect to the remote. + Use the exception to further understand the cause of the + connectivity problem. + + .. versionadded:: 5.x + """ + if config: + experimental_warn( + "All configuration key-word arguments but auth to " + "verify_authentication() are experimental. They might be " + "changed or removed in any future version without prior " + "notice." + ) + config["auth"] = auth + try: + self._get_server_info(**config) + except Neo4jError as exc: + if exc.code in ( + "Neo.ClientError.Security.CredentialsExpired", + "Neo.ClientError.Security.Forbidden", + "Neo.ClientError.Security.TokenExpired", + "Neo.ClientError.Security.Unauthorized", + ): + return False + raise + return True + + + def supports_session_auth(self) -> bool: + """Check if the remote supports connection re-authentication. + + :returns: Returns true if the server or cluster the driver connects to + supports re-authentication of existing connections, otherwise + false. + + .. note:: + Feature support query based solely on the Bolt protocol version. + The feature might still be disabled on the server side even if this + function return :const:`True`. It just guarantees that the driver + won't throw a :exc:`ConfigurationError` when trying to use this + driver feature. + + .. versionadded:: 5.x + """ + with self.session() as session: + session._connect(READ_ACCESS) + assert session._connection + return session._connection.supports_re_auth + + def _get_server_info(self, **config) -> ServerInfo: + with self.session(**config) as session: + return session._get_server_info() + def _work( tx: ManagedTransaction, diff --git a/src/neo4j/api.py b/src/neo4j/api.py index a502e5bf..b4809b4e 100644 --- a/src/neo4j/api.py +++ b/src/neo4j/api.py @@ -182,21 +182,53 @@ def custom_auth( class RenewableAuth: - """Container for details which potentially expire. + """Container for authentication details which potentially expire. - This is meant to be used with auth token provider which is a callable - that returns... + This is meant to be used as a return value for a callable auth token + provider to accommodate for expiring authentication information. + + For Example:: + + import neo4j + + + def auth_provider(): + sso_token = get_sso_token() # some way to getting a fresh token + expires_in = 60 # we know our tokens expire every 60 seconds + + return neo4j.RenewableAuth( + neo4j.bearer_auth(sso_token), + # The driver will continue to use the old token until a new one + # has been fetched. So we want the auth provider to be called + # a little before the token expires. + expires_in - 10 + ) + + + with neo4j.GraphDatabase.driver( + "neo4j://example.com:7687", + auth=auth_provider + ) as driver: + ... # do stuff .. warning:: - This function **must not** interact with the driver in any way as this - can cause a deadlock or undefined behaviour. + The auth provider **must not** interact with the driver in any way as + this can cause deadlocks and undefined behaviour. + + The driver will call the auth provider when either the last token it + provided has expired (see :attr:`expires_in`) or when the driver has + received an authentication error from the server that indicates new + authentication information is required. - :param auth: The auth token. :param expires_in: The expected expiry time + :param auth: The auth token. + :param expires_in: The expected expiry time of the auth token in seconds from now. It is recommended to set this a little before the actual expiry time to give the driver time to renew the auth token before connections start to fail. If set to :data:`None`, the token is assumed to never expire. + + .. versionadded:: 5.x """ def __init__( @@ -223,12 +255,6 @@ def expired(self): and self.expires_at < time.monotonic()) -_TAuthTokenProvider = t.Callable[[], t.Union[ - RenewableAuth, Auth, t.Tuple[t.Any, t.Any], None, - t.Awaitable[t.Union[RenewableAuth, Auth, t.Tuple[t.Any, t.Any], None]] -]] - - # TODO: 6.0 - remove this class class Bookmark: """A Bookmark object contains an immutable list of bookmark string values. diff --git a/src/neo4j/time/__init__.py b/src/neo4j/time/__init__.py index faf8061c..c6b70cf2 100644 --- a/src/neo4j/time/__init__.py +++ b/src/neo4j/time/__init__.py @@ -1391,7 +1391,7 @@ class Time(time_base_class, metaclass=TimeType): :raises ValueError: if one of the parameters is out of range. - ..versionchanged:: 5.0 + .. versionchanged:: 5.0 The parameter ``second`` no longer accepts :class:`float` values. """ @@ -1511,7 +1511,7 @@ def from_ticks(cls, ticks: int, tz: t.Optional[_tzinfo] = None) -> Time: :raises ValueError: if ticks is out of bounds (0 <= ticks < 86400000000000) - ..versionchanged:: 5.0 + .. versionchanged:: 5.0 The parameter ``ticks`` no longer accepts :class:`float` values but only :class:`int`. It's now nanoseconds since midnight instead of seconds. From b265476a36807922d14c4f76bf6e4c6ed1dd79b0 Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Thu, 2 Feb 2023 10:17:16 +0100 Subject: [PATCH 09/23] Update Driver.verify_authentication Update according ADR change + add TestKit backend support and unit tests --- src/neo4j/_async/driver.py | 11 ++++++++--- src/neo4j/_async/io/_bolt.py | 4 ++-- src/neo4j/_async/io/_bolt5.py | 3 ++- src/neo4j/_async/io/_pool.py | 24 ++++++++++++++++-------- src/neo4j/_async/work/session.py | 9 +++++++-- src/neo4j/_sync/driver.py | 11 ++++++++--- src/neo4j/_sync/io/_bolt.py | 4 ++-- src/neo4j/_sync/io/_bolt5.py | 3 ++- src/neo4j/_sync/io/_pool.py | 24 ++++++++++++++++-------- src/neo4j/_sync/work/session.py | 9 +++++++-- testkitbackend/_async/requests.py | 21 +++++++++++++++++++++ testkitbackend/_sync/requests.py | 21 +++++++++++++++++++++ testkitbackend/test_config.json | 1 + tests/unit/async_/io/test_direct.py | 6 +++--- tests/unit/async_/test_driver.py | 15 +++++++++++++++ tests/unit/mixed/io/test_direct.py | 10 ++++++---- tests/unit/sync/io/test_direct.py | 6 +++--- tests/unit/sync/test_driver.py | 15 +++++++++++++++ 18 files changed, 155 insertions(+), 42 deletions(-) diff --git a/src/neo4j/_async/driver.py b/src/neo4j/_async/driver.py index 0345ce61..b1a8b940 100644 --- a/src/neo4j/_async/driver.py +++ b/src/neo4j/_async/driver.py @@ -987,7 +987,7 @@ async def supports_multi_db(self) -> bool: async def verify_authentication( self, - auth: t.Union[Auth, t.Tuple[t.Any, t.Any]], + auth: t.Union[Auth, t.Tuple[t.Any, t.Any], None] = None, # all other arguments are experimental # they may be change or removed any time without prior notice session_connection_timeout: float = ..., @@ -1012,7 +1012,9 @@ async def verify_authentication( else: async def verify_authentication( - self, auth: t.Union[Auth, t.Tuple[t.Any, t.Any]], **config + self, + auth: t.Union[Auth, t.Tuple[t.Any, t.Any], None] = None, + **config ) -> bool: """Verify that the authentication information is valid. @@ -1052,8 +1054,11 @@ async def verify_authentication( "notice." ) config["auth"] = auth + if "database" not in config: + config["database"] = "system" try: - await self._get_server_info(**config) + async with self.session(**config) as session: + await session._verify_authentication() except Neo4jError as exc: if exc.code in ( "Neo.ClientError.Security.CredentialsExpired", diff --git a/src/neo4j/_async/io/_bolt.py b/src/neo4j/_async/io/_bolt.py index f07e5020..93ba696d 100644 --- a/src/neo4j/_async/io/_bolt.py +++ b/src/neo4j/_async/io/_bolt.py @@ -464,7 +464,7 @@ def mark_unauthenticated(self): self.auth_dict = {} def re_auth( - self, auth, dehydration_hooks=None, hydration_hooks=None + self, auth, dehydration_hooks=None, hydration_hooks=None, force=False ): """Append LOGON, LOGOFF to the outgoing queue. @@ -473,7 +473,7 @@ def re_auth( :returns: whether the auth was changed """ new_auth_dict = self._to_auth_dict(auth) - if new_auth_dict == self.auth_dict: + if not force and new_auth_dict == self.auth_dict: return False self.logoff(dehydration_hooks=dehydration_hooks, hydration_hooks=hydration_hooks) diff --git a/src/neo4j/_async/io/_bolt5.py b/src/neo4j/_async/io/_bolt5.py index 53b2f41c..d1df4c0c 100644 --- a/src/neo4j/_async/io/_bolt5.py +++ b/src/neo4j/_async/io/_bolt5.py @@ -346,7 +346,8 @@ async def _process_message(self, tag, fields): self.pool.on_write_failure(address=self.unresolved_address) raise except Neo4jError as e: - await self.pool.on_neo4j_error(e, self.server_info.address) + if self.pool: + await self.pool.on_neo4j_error(e, self.server_info.address) raise else: raise BoltProtocolError( diff --git a/src/neo4j/_async/io/_pool.py b/src/neo4j/_async/io/_pool.py index b2166155..6e4343a3 100644 --- a/src/neo4j/_async/io/_pool.py +++ b/src/neo4j/_async/io/_pool.py @@ -233,7 +233,9 @@ async def connection_creator(): return connection_creator return None - async def _acquire(self, address, auth, deadline, liveness_check_timeout): + async def _acquire( + self, address, auth, deadline, liveness_check_timeout, force_auth=False + ): """ Acquire a connection to a given address from the pool. The address supplied should always be an IP address, not a host name. @@ -246,7 +248,8 @@ async def health_check(connection_, deadline_): or connection_.stale()): return False try: - connection_.re_auth(auth or await self.get_auth()) + connection_.re_auth(auth or await self.get_auth(), + force=force_auth) except ConfigurationError: # protocol does not support re-authentication if auth: @@ -255,6 +258,9 @@ async def health_check(connection_, deadline_): # expiring tokens supported by flushing the pool # => give up this connection return False + if force_auth: + await connection_.send_all() + await connection_.fetch_all() if liveness_check_timeout is not None: if connection_.is_idle_for(liveness_check_timeout): with connection_deadline(connection_, deadline_): @@ -302,7 +308,7 @@ async def health_check(connection_, deadline_): @abc.abstractmethod async def acquire( self, access_mode, timeout, database, bookmarks, auth, - liveness_check_timeout + liveness_check_timeout, force_re_auth=False ): """ Acquire a connection to a server that can satisfy a set of parameters. @@ -313,6 +319,7 @@ async def acquire( :param bookmarks: :param auth: :param liveness_check_timeout: + :param force_re_auth: """ ... @@ -495,7 +502,7 @@ def __repr__(self): async def acquire( self, access_mode, timeout, database, bookmarks, auth, - liveness_check_timeout + liveness_check_timeout, force_re_auth=False ): # The access_mode and database is not needed for a direct connection, # it's just there for consistency. @@ -503,7 +510,7 @@ async def acquire( "access_mode=%r, database=%r", access_mode, database) deadline = Deadline.from_timeout_or_deadline(timeout) return await self._acquire( - self.address, auth, deadline, liveness_check_timeout + self.address, auth, deadline, liveness_check_timeout, force_re_auth ) @@ -857,7 +864,7 @@ async def _select_address(self, *, access_mode, database): async def acquire( self, access_mode, timeout, database, bookmarks, auth, - liveness_check_timeout + liveness_check_timeout, force_re_auth=False ): if access_mode not in (WRITE_ACCESS, READ_ACCESS): raise ClientError("Non valid 'access_mode'; {}".format(access_mode)) @@ -877,7 +884,7 @@ async def acquire( "access_mode=%r, database=%r", access_mode, database) await self.ensure_routing_table_is_fresh( access_mode=access_mode, database=database, - imp_user=None, bookmarks=bookmarks, + imp_user=None, bookmarks=bookmarks, auth=auth, acquisition_timeout=timeout ) @@ -896,7 +903,8 @@ async def acquire( deadline = Deadline.from_timeout_or_deadline(timeout) # should always be a resolved address connection = await self._acquire( - address, auth, deadline, liveness_check_timeout + address, auth, deadline, liveness_check_timeout, + force_re_auth ) except (ServiceUnavailable, SessionExpired): await self.deactivate(address=address) diff --git a/src/neo4j/_async/work/session.py b/src/neo4j/_async/work/session.py index a0d3ceb9..58f4164b 100644 --- a/src/neo4j/_async/work/session.py +++ b/src/neo4j/_async/work/session.py @@ -118,8 +118,9 @@ async def _connect(self, access_mode, **access_kwargs): if access_mode is None: access_mode = self._config.default_access_mode try: - await super()._connect(access_mode, auth=self._config.auth, - **access_kwargs) + await super()._connect( + access_mode, auth=self._config.auth, **access_kwargs + ) except asyncio.CancelledError: self._handle_cancellation(message="_connect") raise @@ -164,6 +165,10 @@ async def _get_server_info(self): await self._disconnect() return server_info + async def _verify_authentication(self): + assert not self._connection + await self._connect(READ_ACCESS, force_re_auth=True) + async def close(self) -> None: """Close the session. diff --git a/src/neo4j/_sync/driver.py b/src/neo4j/_sync/driver.py index 731668fe..88cdd4dd 100644 --- a/src/neo4j/_sync/driver.py +++ b/src/neo4j/_sync/driver.py @@ -985,7 +985,7 @@ def supports_multi_db(self) -> bool: def verify_authentication( self, - auth: t.Union[Auth, t.Tuple[t.Any, t.Any]], + auth: t.Union[Auth, t.Tuple[t.Any, t.Any], None] = None, # all other arguments are experimental # they may be change or removed any time without prior notice session_connection_timeout: float = ..., @@ -1010,7 +1010,9 @@ def verify_authentication( else: def verify_authentication( - self, auth: t.Union[Auth, t.Tuple[t.Any, t.Any]], **config + self, + auth: t.Union[Auth, t.Tuple[t.Any, t.Any], None] = None, + **config ) -> bool: """Verify that the authentication information is valid. @@ -1050,8 +1052,11 @@ def verify_authentication( "notice." ) config["auth"] = auth + if "database" not in config: + config["database"] = "system" try: - self._get_server_info(**config) + with self.session(**config) as session: + session._verify_authentication() except Neo4jError as exc: if exc.code in ( "Neo.ClientError.Security.CredentialsExpired", diff --git a/src/neo4j/_sync/io/_bolt.py b/src/neo4j/_sync/io/_bolt.py index 2c97bbd2..b2c7e27f 100644 --- a/src/neo4j/_sync/io/_bolt.py +++ b/src/neo4j/_sync/io/_bolt.py @@ -464,7 +464,7 @@ def mark_unauthenticated(self): self.auth_dict = {} def re_auth( - self, auth, dehydration_hooks=None, hydration_hooks=None + self, auth, dehydration_hooks=None, hydration_hooks=None, force=False ): """Append LOGON, LOGOFF to the outgoing queue. @@ -473,7 +473,7 @@ def re_auth( :returns: whether the auth was changed """ new_auth_dict = self._to_auth_dict(auth) - if new_auth_dict == self.auth_dict: + if not force and new_auth_dict == self.auth_dict: return False self.logoff(dehydration_hooks=dehydration_hooks, hydration_hooks=hydration_hooks) diff --git a/src/neo4j/_sync/io/_bolt5.py b/src/neo4j/_sync/io/_bolt5.py index 0dfc922d..fcb68491 100644 --- a/src/neo4j/_sync/io/_bolt5.py +++ b/src/neo4j/_sync/io/_bolt5.py @@ -346,7 +346,8 @@ def _process_message(self, tag, fields): self.pool.on_write_failure(address=self.unresolved_address) raise except Neo4jError as e: - self.pool.on_neo4j_error(e, self.server_info.address) + if self.pool: + self.pool.on_neo4j_error(e, self.server_info.address) raise else: raise BoltProtocolError( diff --git a/src/neo4j/_sync/io/_pool.py b/src/neo4j/_sync/io/_pool.py index 8b9c4598..f1fb23b9 100644 --- a/src/neo4j/_sync/io/_pool.py +++ b/src/neo4j/_sync/io/_pool.py @@ -233,7 +233,9 @@ def connection_creator(): return connection_creator return None - def _acquire(self, address, auth, deadline, liveness_check_timeout): + def _acquire( + self, address, auth, deadline, liveness_check_timeout, force_auth=False + ): """ Acquire a connection to a given address from the pool. The address supplied should always be an IP address, not a host name. @@ -246,7 +248,8 @@ def health_check(connection_, deadline_): or connection_.stale()): return False try: - connection_.re_auth(auth or self.get_auth()) + connection_.re_auth(auth or self.get_auth(), + force=force_auth) except ConfigurationError: # protocol does not support re-authentication if auth: @@ -255,6 +258,9 @@ def health_check(connection_, deadline_): # expiring tokens supported by flushing the pool # => give up this connection return False + if force_auth: + connection_.send_all() + connection_.fetch_all() if liveness_check_timeout is not None: if connection_.is_idle_for(liveness_check_timeout): with connection_deadline(connection_, deadline_): @@ -302,7 +308,7 @@ def health_check(connection_, deadline_): @abc.abstractmethod def acquire( self, access_mode, timeout, database, bookmarks, auth, - liveness_check_timeout + liveness_check_timeout, force_re_auth=False ): """ Acquire a connection to a server that can satisfy a set of parameters. @@ -313,6 +319,7 @@ def acquire( :param bookmarks: :param auth: :param liveness_check_timeout: + :param force_re_auth: """ ... @@ -495,7 +502,7 @@ def __repr__(self): def acquire( self, access_mode, timeout, database, bookmarks, auth, - liveness_check_timeout + liveness_check_timeout, force_re_auth=False ): # The access_mode and database is not needed for a direct connection, # it's just there for consistency. @@ -503,7 +510,7 @@ def acquire( "access_mode=%r, database=%r", access_mode, database) deadline = Deadline.from_timeout_or_deadline(timeout) return self._acquire( - self.address, auth, deadline, liveness_check_timeout + self.address, auth, deadline, liveness_check_timeout, force_re_auth ) @@ -857,7 +864,7 @@ def _select_address(self, *, access_mode, database): def acquire( self, access_mode, timeout, database, bookmarks, auth, - liveness_check_timeout + liveness_check_timeout, force_re_auth=False ): if access_mode not in (WRITE_ACCESS, READ_ACCESS): raise ClientError("Non valid 'access_mode'; {}".format(access_mode)) @@ -877,7 +884,7 @@ def acquire( "access_mode=%r, database=%r", access_mode, database) self.ensure_routing_table_is_fresh( access_mode=access_mode, database=database, - imp_user=None, bookmarks=bookmarks, + imp_user=None, bookmarks=bookmarks, auth=auth, acquisition_timeout=timeout ) @@ -896,7 +903,8 @@ def acquire( deadline = Deadline.from_timeout_or_deadline(timeout) # should always be a resolved address connection = self._acquire( - address, auth, deadline, liveness_check_timeout + address, auth, deadline, liveness_check_timeout, + force_re_auth ) except (ServiceUnavailable, SessionExpired): self.deactivate(address=address) diff --git a/src/neo4j/_sync/work/session.py b/src/neo4j/_sync/work/session.py index 2ee8c852..4a5dcefb 100644 --- a/src/neo4j/_sync/work/session.py +++ b/src/neo4j/_sync/work/session.py @@ -118,8 +118,9 @@ def _connect(self, access_mode, **access_kwargs): if access_mode is None: access_mode = self._config.default_access_mode try: - super()._connect(access_mode, auth=self._config.auth, - **access_kwargs) + super()._connect( + access_mode, auth=self._config.auth, **access_kwargs + ) except asyncio.CancelledError: self._handle_cancellation(message="_connect") raise @@ -164,6 +165,10 @@ def _get_server_info(self): self._disconnect() return server_info + def _verify_authentication(self): + assert not self._connection + self._connect(READ_ACCESS, force_re_auth=True) + def close(self) -> None: """Close the session. diff --git a/testkitbackend/_async/requests.py b/testkitbackend/_async/requests.py index ab4ee475..aec611c3 100644 --- a/testkitbackend/_async/requests.py +++ b/testkitbackend/_async/requests.py @@ -247,6 +247,27 @@ async def CheckMultiDBSupport(backend, data): }) +async def VerifyAuthentication(backend, data): + driver_id = data["driverId"] + driver = backend.drivers[driver_id] + auth = None + if data.get("auth_token"): + auth = _convert_auth_token(data, "auth_token") + authenticated = await driver.verify_authentication(auth=auth) + await backend.send_response("DriverIsAuthenticated", { + "id": backend.next_key(), "authenticated": authenticated + }) + + +async def CheckSessionAuthSupport(backend, data): + driver_id = data["driverId"] + driver = backend.drivers[driver_id] + available = await driver.supports_session_auth() + await backend.send_response("SessionAuthSupport", { + "id": backend.next_key(), "available": available + }) + + async def ExecuteQuery(backend, data): driver = backend.drivers[data["driverId"]] cypher, params = fromtestkit.to_cypher_and_params(data) diff --git a/testkitbackend/_sync/requests.py b/testkitbackend/_sync/requests.py index 930b58f6..2307d541 100644 --- a/testkitbackend/_sync/requests.py +++ b/testkitbackend/_sync/requests.py @@ -247,6 +247,27 @@ def CheckMultiDBSupport(backend, data): }) +def VerifyAuthentication(backend, data): + driver_id = data["driverId"] + driver = backend.drivers[driver_id] + auth = None + if data.get("auth_token"): + auth = _convert_auth_token(data, "auth_token") + authenticated = driver.verify_authentication(auth=auth) + backend.send_response("DriverIsAuthenticated", { + "id": backend.next_key(), "authenticated": authenticated + }) + + +def CheckSessionAuthSupport(backend, data): + driver_id = data["driverId"] + driver = backend.drivers[driver_id] + available = driver.supports_session_auth() + backend.send_response("SessionAuthSupport", { + "id": backend.next_key(), "available": available + }) + + def ExecuteQuery(backend, data): driver = backend.drivers[data["driverId"]] cypher, params = fromtestkit.to_cypher_and_params(data) diff --git a/testkitbackend/test_config.json b/testkitbackend/test_config.json index 2ac18848..fd29479b 100644 --- a/testkitbackend/test_config.json +++ b/testkitbackend/test_config.json @@ -21,6 +21,7 @@ "Feature:API:Driver.ExecuteQuery": true, "Feature:API:Driver:GetServerInfo": true, "Feature:API:Driver.IsEncrypted": true, + "Feature:API:Driver.VerifyAuthentication": true, "Feature:API:Driver.VerifyConnectivity": true, "Feature:API:Driver.SupportsSessionAuth": true, "Feature:API:Liveness.Check": false, diff --git a/tests/unit/async_/io/test_direct.py b/tests/unit/async_/io/test_direct.py index f6c57ccb..28b5773d 100644 --- a/tests/unit/async_/io/test_direct.py +++ b/tests/unit/async_/io/test_direct.py @@ -65,7 +65,7 @@ def stale(self): async def reset(self): pass - def re_auth(self, auth): + def re_auth(self, auth, force=False): return False def close(self): @@ -96,10 +96,10 @@ async def opener(addr, auth, timeout): async def acquire( self, access_mode, timeout, database, bookmarks, auth, - liveness_check_timeout + liveness_check_timeout, force_re_auth=False ): return await self._acquire( - self.address, auth, timeout, liveness_check_timeout + self.address, auth, timeout, liveness_check_timeout, force_re_auth ) diff --git a/tests/unit/async_/test_driver.py b/tests/unit/async_/test_driver.py index fbca7ce6..340a30ac 100644 --- a/tests/unit/async_/test_driver.py +++ b/tests/unit/async_/test_driver.py @@ -787,3 +787,18 @@ async def test_execute_query_result_transformer( _work, mocker.ANY, mocker.ANY, result_transformer ) assert res is session_executor_mock.return_value + + +@mark_async_test +async def test_supports_session_auth(mocker) -> None: + driver = AsyncGraphDatabase.driver("bolt://localhost") + session_cls_mock = mocker.patch("neo4j._async.driver.AsyncSession", + autospec=True) + async with driver as driver: + res = await driver.supports_session_auth() + + session_cls_mock.assert_called_once() + session_cls_mock.return_value.__aenter__.assert_awaited_once() + session_mock = session_cls_mock.return_value.__aenter__.return_value + connection_mock = session_mock._connection + assert res is connection_mock.supports_re_auth diff --git a/tests/unit/mixed/io/test_direct.py b/tests/unit/mixed/io/test_direct.py index 63fba1e7..3446e87a 100644 --- a/tests/unit/mixed/io/test_direct.py +++ b/tests/unit/mixed/io/test_direct.py @@ -59,7 +59,7 @@ def test_multithread(self, pre_populated): def acquire_release_conn(pool_, address_, acquired_counter_, release_event_): nonlocal connections, connections_lock - conn_ = pool_._acquire(address_, None, Deadline(3), None) + conn_ = pool_._acquire(address_, None, Deadline(3), None, False) with connections_lock: if connections is not None: connections.append(conn_) @@ -74,7 +74,7 @@ def acquire_release_conn(pool_, address_, acquired_counter_, # pre-populate the pool with connections for _ in range(pre_populated): - conn = pool._acquire(address, None, Deadline(3), None) + conn = pool._acquire(address, None, Deadline(3), None, False) pre_populated_connections.append(conn) for conn in pre_populated_connections: pool.release(conn) @@ -120,7 +120,8 @@ async def test_multi_coroutine(self, pre_populated): async def acquire_release_conn(pool_, address_, acquired_counter_, release_event_): nonlocal connections - conn_ = await pool_._acquire(address_, None, Deadline(3), None) + conn_ = await pool_._acquire(address_, None, Deadline(3), None, + False) if connections is not None: connections.append(conn_) await acquired_counter_.increment() @@ -154,7 +155,8 @@ async def waiter(pool_, acquired_counter_, release_event_): # pre-populate the pool with connections for _ in range(pre_populated): - conn = await pool._acquire(address, None, Deadline(3), None) + conn = await pool._acquire(address, None, Deadline(3), None, + False) pre_populated_connections.append(conn) for conn in pre_populated_connections: await pool.release(conn) diff --git a/tests/unit/sync/io/test_direct.py b/tests/unit/sync/io/test_direct.py index e61d1e49..ef732942 100644 --- a/tests/unit/sync/io/test_direct.py +++ b/tests/unit/sync/io/test_direct.py @@ -65,7 +65,7 @@ def stale(self): def reset(self): pass - def re_auth(self, auth): + def re_auth(self, auth, force=False): return False def close(self): @@ -96,10 +96,10 @@ def opener(addr, auth, timeout): def acquire( self, access_mode, timeout, database, bookmarks, auth, - liveness_check_timeout + liveness_check_timeout, force_re_auth=False ): return self._acquire( - self.address, auth, timeout, liveness_check_timeout + self.address, auth, timeout, liveness_check_timeout, force_re_auth ) diff --git a/tests/unit/sync/test_driver.py b/tests/unit/sync/test_driver.py index 25b3cf74..9a4d254a 100644 --- a/tests/unit/sync/test_driver.py +++ b/tests/unit/sync/test_driver.py @@ -786,3 +786,18 @@ def test_execute_query_result_transformer( _work, mocker.ANY, mocker.ANY, result_transformer ) assert res is session_executor_mock.return_value + + +@mark_sync_test +def test_supports_session_auth(mocker) -> None: + driver = GraphDatabase.driver("bolt://localhost") + session_cls_mock = mocker.patch("neo4j._sync.driver.Session", + autospec=True) + with driver as driver: + res = driver.supports_session_auth() + + session_cls_mock.assert_called_once() + session_cls_mock.return_value.__enter__.assert_called_once() + session_mock = session_cls_mock.return_value.__enter__.return_value + connection_mock = session_mock._connection + assert res is connection_mock.supports_re_auth From 3710e245c842695b167af35375d94669467ce8d0 Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Thu, 2 Feb 2023 13:36:37 +0100 Subject: [PATCH 10/23] API docs: limitation of session auth: now downgrade --- docs/source/api.rst | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/source/api.rst b/docs/source/api.rst index 94183b37..32e821be 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -952,6 +952,10 @@ authentication information for the session. This requires the server to support re-authentication on the protocol level. You can check this by calling :meth:`.Driver.supports_session_auth` / :meth:`.AsyncDriver.supports_session_auth`. +It is not possible to overwrite the authentication information for the session with no authentication, +i.e., downgrade the authentication at session level. +Instead, you should create a driver with no authentication and upgrade the authentication at session level as needed. + :Type: :data:`None`, :class:`.Auth` or ``(user, password)``-tuple :Default: :data:`None` - use the authentication information provided during driver creation. From db2ca7217c26345d2ba98dddbb7e57ab447e7890 Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Wed, 8 Feb 2023 18:01:34 +0100 Subject: [PATCH 11/23] Implemented ADR updated * move backwards compat option to driver level config * adjust `verify_authentication` and its backwards compat mode * more tests * Extended TestKit support * Fixed bugs found through TestKit --- src/neo4j/_async/driver.py | 1 - src/neo4j/_async/io/__init__.py | 2 + src/neo4j/_async/io/_bolt.py | 3 +- src/neo4j/_async/io/_bolt4.py | 3 +- src/neo4j/_async/io/_common.py | 29 ++-- src/neo4j/_async/io/_pool.py | 198 +++++++++++++++++++++------- src/neo4j/_async/work/session.py | 13 +- src/neo4j/_async/work/workspace.py | 11 +- src/neo4j/_conf.py | 7 + src/neo4j/_sync/io/__init__.py | 2 + src/neo4j/_sync/io/_bolt.py | 3 +- src/neo4j/_sync/io/_bolt4.py | 3 +- src/neo4j/_sync/io/_common.py | 29 ++-- src/neo4j/_sync/io/_pool.py | 198 +++++++++++++++++++++------- src/neo4j/_sync/work/session.py | 13 +- src/neo4j/_sync/work/workspace.py | 11 +- testkitbackend/_async/requests.py | 29 ++-- testkitbackend/_sync/requests.py | 29 ++-- tests/unit/async_/io/test_direct.py | 4 +- tests/unit/sync/io/test_direct.py | 4 +- 20 files changed, 431 insertions(+), 161 deletions(-) diff --git a/src/neo4j/_async/driver.py b/src/neo4j/_async/driver.py index b1a8b940..c75bcee6 100644 --- a/src/neo4j/_async/driver.py +++ b/src/neo4j/_async/driver.py @@ -28,7 +28,6 @@ import ssl - from .._api import RoutingControl from .._async_compat.util import AsyncUtil from .._conf import ( diff --git a/src/neo4j/_async/io/__init__.py b/src/neo4j/_async/io/__init__.py index 3b28553e..f3f628f1 100644 --- a/src/neo4j/_async/io/__init__.py +++ b/src/neo4j/_async/io/__init__.py @@ -24,6 +24,7 @@ __all__ = [ + "AcquireAuth", "AsyncBolt", "AsyncBoltPool", "AsyncNeo4jPool", @@ -38,6 +39,7 @@ ConnectionErrorHandler, ) from ._pool import ( + AcquireAuth, AsyncBoltPool, AsyncNeo4jPool, ) diff --git a/src/neo4j/_async/io/_bolt.py b/src/neo4j/_async/io/_bolt.py index 93ba696d..66fc124b 100644 --- a/src/neo4j/_async/io/_bolt.py +++ b/src/neo4j/_async/io/_bolt.py @@ -20,7 +20,6 @@ import abc import asyncio -import typing as t from collections import deque from logging import getLogger from time import perf_counter @@ -99,6 +98,8 @@ class AsyncBolt: # The socket in_use = False + throwaway = False + # When the connection was last put back into the pool idle_since = float("-inf") diff --git a/src/neo4j/_async/io/_bolt4.py b/src/neo4j/_async/io/_bolt4.py index cd99bca4..c96fc91d 100644 --- a/src/neo4j/_async/io/_bolt4.py +++ b/src/neo4j/_async/io/_bolt4.py @@ -348,7 +348,8 @@ async def _process_message(self, tag, fields): self.pool.on_write_failure(address=self.unresolved_address) raise except Neo4jError as e: - await self.pool.on_neo4j_error(e, self.server_info.address) + if self.pool: + await self.pool.on_neo4j_error(e, self.server_info.address) raise else: raise BoltProtocolError("Unexpected response message with signature " diff --git a/src/neo4j/_async/io/_common.py b/src/neo4j/_async/io/_common.py index 278a41b7..a83ab09c 100644 --- a/src/neo4j/_async/io/_common.py +++ b/src/neo4j/_async/io/_common.py @@ -257,18 +257,29 @@ async def on_ignored(self, metadata=None): class InitResponse(Response): async def on_failure(self, metadata): - code = metadata.get("code") - if code == "Neo.ClientError.Security.Unauthorized": - # this branch is only needed as long as we support Bolt 5.0 - raise Neo4jError.hydrate(**metadata) - else: - raise ServiceUnavailable( - metadata.get("message", "Connection initialisation failed") - ) + # No sense in resetting the connection, + # the server will have closed it already. + self.connection.kill() + handler = self.handlers.get("on_failure") + await AsyncUtil.callback(handler, metadata) + handler = self.handlers.get("on_summary") + await AsyncUtil.callback(handler) + metadata["message"] = metadata.get( + "message", + "Connection initialisation failed due to an unknown error" + ) + raise Neo4jError.hydrate(**metadata) -class LogonResponse(Response): +class LogonResponse(InitResponse): async def on_failure(self, metadata): + # No sense in resetting the connection, + # the server will have closed it already. + self.connection.kill() + handler = self.handlers.get("on_failure") + await AsyncUtil.callback(handler, metadata) + handler = self.handlers.get("on_summary") + await AsyncUtil.callback(handler) raise Neo4jError.hydrate(**metadata) diff --git a/src/neo4j/_async/io/_pool.py b/src/neo4j/_async/io/_pool.py index 6e4343a3..65079b09 100644 --- a/src/neo4j/_async/io/_pool.py +++ b/src/neo4j/_async/io/_pool.py @@ -26,6 +26,8 @@ defaultdict, deque, ) +from copy import copy +from dataclasses import dataclass from logging import getLogger from random import choice @@ -47,6 +49,7 @@ from ..._exceptions import BoltError from ..._routing import RoutingTable from ...api import ( + Auth, READ_ACCESS, RenewableAuth, WRITE_ACCESS, @@ -68,6 +71,13 @@ log = getLogger("neo4j") +@dataclass +class AcquireAuth: + auth: t.Optional[Auth] + backwards_compatible: bool = False + force_auth: bool = False + + class AsyncIOPool(abc.ABC): """ A collection of connections to one or more server addresses. """ @@ -153,6 +163,22 @@ async def _acquire_from_pool(self, address): return connection return None # no free connection available + def _remove_connection(self, connection): + address = connection.unresolved_address + with self.lock: + log.debug( + "[#%04X] _: remove connection from pool %r %s", + connection.local_port, address, connection.connection_id + ) + try: + self.connections.get(address, []).remove(connection) + except ValueError: + # If closure fails (e.g. because the server went + # down), all connections to the same address will + # be removed. Therefore, we silently ignore if the + # connection isn't in the pool anymore. + pass + async def _acquire_from_pool_checked( self, address, health_check, deadline ): @@ -167,27 +193,20 @@ async def _acquire_from_pool_checked( # `stale` but still alive. if log.isEnabledFor(logging.DEBUG): log.debug( - "[#%04X] _: removing old connection %s " + "[#%04X] _: found unhealthy connection %s " "(closed=%s, defunct=%s, stale=%s, in_use=%s)", connection.local_port, connection.connection_id, connection.closed(), connection.defunct(), connection.stale(), connection.in_use ) await connection.close() - with self.lock: - try: - self.connections.get(address, []).remove(connection) - except ValueError: - # If closure fails (e.g. because the server went - # down), all connections to the same address will - # be removed. Therefore, we silently ignore if the - # connection isn't in the pool anymore. - pass + self._remove_connection(connection) continue # try again with a new connection else: return connection - def _acquire_new_later(self, address, auth, deadline): + def _acquire_new_later(self, address, auth, deadline, + backwards_compatible_auth): async def connection_creator(): released_reservation = False try: @@ -206,8 +225,14 @@ async def connection_creator(): try: connection.assert_re_auth_support() except ConfigurationError: - await connection.close() - raise + if not backwards_compatible_auth: + log.debug("[#%04X] _: no re-auth support", + connection.local_port) + await connection.close() + raise + log.debug("[#%04X] _: is throwaway connection", + connection.local_port) + connection.throwaway = True connection.pool = self connection.in_use = True with self.lock: @@ -233,8 +258,26 @@ async def connection_creator(): return connection_creator return None + async def _re_auth_connection(self, connection, auth, force): + new_auth = auth or await self.get_auth() + log_auth = "******" if auth else "None" + try: + updated = connection.re_auth(new_auth, force=force) + log.debug("[#%04X] _: checked re_auth auth=%s updated=%s " + "force=%s", + connection.local_port, log_auth, updated, force) + except Exception as exc: + log.debug("[#%04X] _: check re_auth failed %r auth=%s " + "force=%s", + connection.local_port, exc, log_auth, force) + raise + assert not force or updated # force=True implies updated=True + if force: + await connection.send_all() + await connection.fetch_all() + async def _acquire( - self, address, auth, deadline, liveness_check_timeout, force_auth=False + self, address, auth, deadline, liveness_check_timeout ): """ Acquire a connection to a given address from the pool. The address supplied should always be an IP address, not @@ -242,25 +285,17 @@ async def _acquire( This method is thread safe. """ + if auth is None: + auth = AcquireAuth(None) + force_auth = auth.force_auth + backwards_compatible_auth = auth.backwards_compatible + auth = auth.auth + async def health_check(connection_, deadline_): if (connection_.closed() or connection_.defunct() or connection_.stale()): return False - try: - connection_.re_auth(auth or await self.get_auth(), - force=force_auth) - except ConfigurationError: - # protocol does not support re-authentication - if auth: - # session-level not supported - raise - # expiring tokens supported by flushing the pool - # => give up this connection - return False - if force_auth: - await connection_.send_all() - await connection_.fetch_all() if liveness_check_timeout is not None: if connection_.is_idle_for(liveness_check_timeout): with connection_deadline(connection_, deadline_): @@ -278,14 +313,54 @@ async def health_check(connection_, deadline_): address, health_check, deadline ) if connection: - log.debug("[#%04X] _: handing out existing connection " - "%s", connection.local_port, - connection.connection_id) + log.debug("[#%04X] _: picked existing connection %s", + connection.local_port, connection.connection_id) + try: + await self._re_auth_connection( + connection, auth, force_auth + ) + except ConfigurationError: + if not auth: + # expiring tokens supported by flushing the pool + # => give up this connection + log.debug("[#%04X] _: backwards compatible " + "auth token refresh: purge connection", + connection.local_port) + await connection.close() + await self.release(connection) + continue + if not backwards_compatible_auth: + raise + # backwards compatibility mode: + # create new throwaway connection, + connection_creator = self._acquire_new_later( + address, auth, deadline, + backwards_compatible_auth=True + ) + if connection_creator: + await self.release(connection) + if not connection_creator: + # pool is full => kill the picked connection + log.debug("[#%04X] _: backwards compatible " + "session auth making room by purge", + connection.local_port) + await connection.close() + with self.lock: + self._remove_connection(connection) + connection_creator = self._acquire_new_later( + address, auth, deadline, + backwards_compatible_auth=True + ) + assert connection_creator is not None + break + log.debug("[#%04X] _: handing out existing connection", + connection.local_port) return connection # all connections in pool are in-use with self.lock: connection_creator = self._acquire_new_later( - address, auth, deadline + address, auth, deadline, + backwards_compatible_auth=backwards_compatible_auth ) if connection_creator: break @@ -307,8 +382,8 @@ async def health_check(connection_, deadline_): @abc.abstractmethod async def acquire( - self, access_mode, timeout, database, bookmarks, auth, - liveness_check_timeout, force_re_auth=False + self, access_mode, timeout, database, bookmarks, auth: AcquireAuth, + liveness_check_timeout ): """ Acquire a connection to a server that can satisfy a set of parameters. @@ -319,7 +394,6 @@ async def acquire( :param bookmarks: :param auth: :param liveness_check_timeout: - :param force_re_auth: """ ... @@ -341,6 +415,29 @@ def kill_and_release(self, *connections): connection.in_use = False self.cond.notify_all() + @staticmethod + async def _close_throwaway_connection(connection, cancelled): + if connection.throwaway: + if cancelled is not None: + log.debug( + "[#%04X] _: kill throwaway connection %s", + connection.local_port, connection.connection_id + ) + connection.kill() + else: + try: + log.debug( + "[#%04X] _: close throwaway connection %s", + connection.local_port, connection.connection_id + ) + await connection.close() + except asyncio.CancelledError as exc: + log.debug("[#%04X] _: cancelled close of " + "throwaway connection: %r", + connection.local_port, exc) + cancelled = exc + return cancelled + async def release(self, *connections): """ Release connections back into the pool. @@ -348,29 +445,32 @@ async def release(self, *connections): """ cancelled = None for connection in connections: + cancelled = await self._close_throwaway_connection( + connection, cancelled + ) if not (connection.defunct() or connection.closed() or connection.is_reset): if cancelled is not None: log.debug( - "[#%04X] _: released unclean connection %s", + "[#%04X] _: kill unclean connection %s", connection.local_port, connection.connection_id ) connection.kill() continue try: log.debug( - "[#%04X] _: released unclean connection %s", + "[#%04X] _: release unclean connection %s", connection.local_port, connection.connection_id ) await connection.reset() - except (Neo4jError, DriverError, BoltError) as e: + except (Neo4jError, DriverError, BoltError) as exc: log.debug("[#%04X] _: failed to reset connection " - "on release: %r", connection.local_port, e) - except asyncio.CancelledError as e: + "on release: %r", connection.local_port, exc) + except asyncio.CancelledError as exc: log.debug("[#%04X] _: cancelled reset connection " - "on release: %r", connection.local_port, e) - cancelled = e + "on release: %r", connection.local_port, exc) + cancelled = exc connection.kill() with self.lock: for connection in connections: @@ -501,8 +601,8 @@ def __repr__(self): self.address) async def acquire( - self, access_mode, timeout, database, bookmarks, auth, - liveness_check_timeout, force_re_auth=False + self, access_mode, timeout, database, bookmarks, auth: AcquireAuth, + liveness_check_timeout ): # The access_mode and database is not needed for a direct connection, # it's just there for consistency. @@ -510,7 +610,7 @@ async def acquire( "access_mode=%r, database=%r", access_mode, database) deadline = Deadline.from_timeout_or_deadline(timeout) return await self._acquire( - self.address, auth, deadline, liveness_check_timeout, force_re_auth + self.address, auth, deadline, liveness_check_timeout ) @@ -602,6 +702,9 @@ async def fetch_routing_info( deadline = Deadline.from_timeout_or_deadline(acquisition_timeout) log.debug("[#0000] _: _acquire router connection, " "database=%r, address=%r", database, address) + if auth: + auth = copy(auth) + auth.force_auth = False cx = await self._acquire(address, auth, deadline, None) try: routing_table = await cx.route( @@ -863,8 +966,8 @@ async def _select_address(self, *, access_mode, database): return choice(addresses_by_usage[min(addresses_by_usage)]) async def acquire( - self, access_mode, timeout, database, bookmarks, auth, - liveness_check_timeout, force_re_auth=False + self, access_mode, timeout, database, bookmarks, auth: AcquireAuth, + liveness_check_timeout ): if access_mode not in (WRITE_ACCESS, READ_ACCESS): raise ClientError("Non valid 'access_mode'; {}".format(access_mode)) @@ -904,7 +1007,6 @@ async def acquire( # should always be a resolved address connection = await self._acquire( address, auth, deadline, liveness_check_timeout, - force_re_auth ) except (ServiceUnavailable, SessionExpired): await self.deactivate(address=address) diff --git a/src/neo4j/_async/work/session.py b/src/neo4j/_async/work/session.py index 58f4164b..c5e76ed7 100644 --- a/src/neo4j/_async/work/session.py +++ b/src/neo4j/_async/work/session.py @@ -114,12 +114,12 @@ async def __aexit__(self, exception_type, exception_value, traceback): self._state_failed = True await self.close() - async def _connect(self, access_mode, **access_kwargs): + async def _connect(self, access_mode, **acquire_kwargs): if access_mode is None: access_mode = self._config.default_access_mode try: await super()._connect( - access_mode, auth=self._config.auth, **access_kwargs + access_mode, auth=self._config.auth, **acquire_kwargs ) except asyncio.CancelledError: self._handle_cancellation(message="_connect") @@ -167,7 +167,14 @@ async def _get_server_info(self): async def _verify_authentication(self): assert not self._connection - await self._connect(READ_ACCESS, force_re_auth=True) + await self._connect(READ_ACCESS, force_auth=True) + if not self._config.backwards_compatible_auth: + # Even without backwards compatibility and an old server, the + # _connect call above can succeed if the connection is a new one. + # Hence, we enforce support explicitly to always let the user know + # that this is not supported. + self._connection.assert_re_auth_support() + await self._disconnect() async def close(self) -> None: """Close the session. diff --git a/src/neo4j/_async/work/workspace.py b/src/neo4j/_async/work/workspace.py index 75dc1d76..3c77096e 100644 --- a/src/neo4j/_async/work/workspace.py +++ b/src/neo4j/_async/work/workspace.py @@ -33,7 +33,10 @@ SessionError, SessionExpired, ) -from ..io import AsyncNeo4jPool +from ..io import ( + AcquireAuth, + AsyncNeo4jPool, +) log = logging.getLogger("neo4j") @@ -132,6 +135,12 @@ async def _update_bookmark(self, bookmark): async def _connect(self, access_mode, auth=None, **acquire_kwargs): acquisition_timeout = self._config.connection_acquisition_timeout + auth = AcquireAuth( + auth, + backwards_compatible=self._config.backwards_compatible_auth, + force_auth=acquire_kwargs.pop("force_auth", False), + ) + if self._connection: # TODO: Investigate this # log.warning("FIXME: should always disconnect before connect") diff --git a/src/neo4j/_conf.py b/src/neo4j/_conf.py index 9707334c..81242cc8 100644 --- a/src/neo4j/_conf.py +++ b/src/neo4j/_conf.py @@ -495,6 +495,13 @@ class WorkspaceConfig(Config): bookmark_manager = ExperimentalOption(None) # Specify the bookmark manager to be used for sessions by default. + #: Session Auth Backward Compatibility Layer + backwards_compatible_auth = False + # Enable session level authentication (user-switching) on session level + # even over Bolt 5.0 and earlier. This is done using a very costly + # backwards compatible authentication layer in the driver utilizing + # throwaway connections. + class SessionConfig(WorkspaceConfig): """ Session configuration. diff --git a/src/neo4j/_sync/io/__init__.py b/src/neo4j/_sync/io/__init__.py index 18aa1149..f5cddef2 100644 --- a/src/neo4j/_sync/io/__init__.py +++ b/src/neo4j/_sync/io/__init__.py @@ -24,6 +24,7 @@ __all__ = [ + "AcquireAuth", "Bolt", "BoltPool", "Neo4jPool", @@ -38,6 +39,7 @@ ConnectionErrorHandler, ) from ._pool import ( + AcquireAuth, BoltPool, Neo4jPool, ) diff --git a/src/neo4j/_sync/io/_bolt.py b/src/neo4j/_sync/io/_bolt.py index b2c7e27f..47b3144c 100644 --- a/src/neo4j/_sync/io/_bolt.py +++ b/src/neo4j/_sync/io/_bolt.py @@ -20,7 +20,6 @@ import abc import asyncio -import typing as t from collections import deque from logging import getLogger from time import perf_counter @@ -99,6 +98,8 @@ class Bolt: # The socket in_use = False + throwaway = False + # When the connection was last put back into the pool idle_since = float("-inf") diff --git a/src/neo4j/_sync/io/_bolt4.py b/src/neo4j/_sync/io/_bolt4.py index c374822d..5d373005 100644 --- a/src/neo4j/_sync/io/_bolt4.py +++ b/src/neo4j/_sync/io/_bolt4.py @@ -348,7 +348,8 @@ def _process_message(self, tag, fields): self.pool.on_write_failure(address=self.unresolved_address) raise except Neo4jError as e: - self.pool.on_neo4j_error(e, self.server_info.address) + if self.pool: + self.pool.on_neo4j_error(e, self.server_info.address) raise else: raise BoltProtocolError("Unexpected response message with signature " diff --git a/src/neo4j/_sync/io/_common.py b/src/neo4j/_sync/io/_common.py index a0250071..c4e29546 100644 --- a/src/neo4j/_sync/io/_common.py +++ b/src/neo4j/_sync/io/_common.py @@ -257,18 +257,29 @@ def on_ignored(self, metadata=None): class InitResponse(Response): def on_failure(self, metadata): - code = metadata.get("code") - if code == "Neo.ClientError.Security.Unauthorized": - # this branch is only needed as long as we support Bolt 5.0 - raise Neo4jError.hydrate(**metadata) - else: - raise ServiceUnavailable( - metadata.get("message", "Connection initialisation failed") - ) + # No sense in resetting the connection, + # the server will have closed it already. + self.connection.kill() + handler = self.handlers.get("on_failure") + Util.callback(handler, metadata) + handler = self.handlers.get("on_summary") + Util.callback(handler) + metadata["message"] = metadata.get( + "message", + "Connection initialisation failed due to an unknown error" + ) + raise Neo4jError.hydrate(**metadata) -class LogonResponse(Response): +class LogonResponse(InitResponse): def on_failure(self, metadata): + # No sense in resetting the connection, + # the server will have closed it already. + self.connection.kill() + handler = self.handlers.get("on_failure") + Util.callback(handler, metadata) + handler = self.handlers.get("on_summary") + Util.callback(handler) raise Neo4jError.hydrate(**metadata) diff --git a/src/neo4j/_sync/io/_pool.py b/src/neo4j/_sync/io/_pool.py index f1fb23b9..d806cc97 100644 --- a/src/neo4j/_sync/io/_pool.py +++ b/src/neo4j/_sync/io/_pool.py @@ -26,6 +26,8 @@ defaultdict, deque, ) +from copy import copy +from dataclasses import dataclass from logging import getLogger from random import choice @@ -47,6 +49,7 @@ from ..._exceptions import BoltError from ..._routing import RoutingTable from ...api import ( + Auth, READ_ACCESS, RenewableAuth, WRITE_ACCESS, @@ -68,6 +71,13 @@ log = getLogger("neo4j") +@dataclass +class AcquireAuth: + auth: t.Optional[Auth] + backwards_compatible: bool = False + force_auth: bool = False + + class IOPool(abc.ABC): """ A collection of connections to one or more server addresses. """ @@ -153,6 +163,22 @@ def _acquire_from_pool(self, address): return connection return None # no free connection available + def _remove_connection(self, connection): + address = connection.unresolved_address + with self.lock: + log.debug( + "[#%04X] _: remove connection from pool %r %s", + connection.local_port, address, connection.connection_id + ) + try: + self.connections.get(address, []).remove(connection) + except ValueError: + # If closure fails (e.g. because the server went + # down), all connections to the same address will + # be removed. Therefore, we silently ignore if the + # connection isn't in the pool anymore. + pass + def _acquire_from_pool_checked( self, address, health_check, deadline ): @@ -167,27 +193,20 @@ def _acquire_from_pool_checked( # `stale` but still alive. if log.isEnabledFor(logging.DEBUG): log.debug( - "[#%04X] _: removing old connection %s " + "[#%04X] _: found unhealthy connection %s " "(closed=%s, defunct=%s, stale=%s, in_use=%s)", connection.local_port, connection.connection_id, connection.closed(), connection.defunct(), connection.stale(), connection.in_use ) connection.close() - with self.lock: - try: - self.connections.get(address, []).remove(connection) - except ValueError: - # If closure fails (e.g. because the server went - # down), all connections to the same address will - # be removed. Therefore, we silently ignore if the - # connection isn't in the pool anymore. - pass + self._remove_connection(connection) continue # try again with a new connection else: return connection - def _acquire_new_later(self, address, auth, deadline): + def _acquire_new_later(self, address, auth, deadline, + backwards_compatible_auth): def connection_creator(): released_reservation = False try: @@ -206,8 +225,14 @@ def connection_creator(): try: connection.assert_re_auth_support() except ConfigurationError: - connection.close() - raise + if not backwards_compatible_auth: + log.debug("[#%04X] _: no re-auth support", + connection.local_port) + connection.close() + raise + log.debug("[#%04X] _: is throwaway connection", + connection.local_port) + connection.throwaway = True connection.pool = self connection.in_use = True with self.lock: @@ -233,8 +258,26 @@ def connection_creator(): return connection_creator return None + def _re_auth_connection(self, connection, auth, force): + new_auth = auth or self.get_auth() + log_auth = "******" if auth else "None" + try: + updated = connection.re_auth(new_auth, force=force) + log.debug("[#%04X] _: checked re_auth auth=%s updated=%s " + "force=%s", + connection.local_port, log_auth, updated, force) + except Exception as exc: + log.debug("[#%04X] _: check re_auth failed %r auth=%s " + "force=%s", + connection.local_port, exc, log_auth, force) + raise + assert not force or updated # force=True implies updated=True + if force: + connection.send_all() + connection.fetch_all() + def _acquire( - self, address, auth, deadline, liveness_check_timeout, force_auth=False + self, address, auth, deadline, liveness_check_timeout ): """ Acquire a connection to a given address from the pool. The address supplied should always be an IP address, not @@ -242,25 +285,17 @@ def _acquire( This method is thread safe. """ + if auth is None: + auth = AcquireAuth(None) + force_auth = auth.force_auth + backwards_compatible_auth = auth.backwards_compatible + auth = auth.auth + def health_check(connection_, deadline_): if (connection_.closed() or connection_.defunct() or connection_.stale()): return False - try: - connection_.re_auth(auth or self.get_auth(), - force=force_auth) - except ConfigurationError: - # protocol does not support re-authentication - if auth: - # session-level not supported - raise - # expiring tokens supported by flushing the pool - # => give up this connection - return False - if force_auth: - connection_.send_all() - connection_.fetch_all() if liveness_check_timeout is not None: if connection_.is_idle_for(liveness_check_timeout): with connection_deadline(connection_, deadline_): @@ -278,14 +313,54 @@ def health_check(connection_, deadline_): address, health_check, deadline ) if connection: - log.debug("[#%04X] _: handing out existing connection " - "%s", connection.local_port, - connection.connection_id) + log.debug("[#%04X] _: picked existing connection %s", + connection.local_port, connection.connection_id) + try: + self._re_auth_connection( + connection, auth, force_auth + ) + except ConfigurationError: + if not auth: + # expiring tokens supported by flushing the pool + # => give up this connection + log.debug("[#%04X] _: backwards compatible " + "auth token refresh: purge connection", + connection.local_port) + connection.close() + self.release(connection) + continue + if not backwards_compatible_auth: + raise + # backwards compatibility mode: + # create new throwaway connection, + connection_creator = self._acquire_new_later( + address, auth, deadline, + backwards_compatible_auth=True + ) + if connection_creator: + self.release(connection) + if not connection_creator: + # pool is full => kill the picked connection + log.debug("[#%04X] _: backwards compatible " + "session auth making room by purge", + connection.local_port) + connection.close() + with self.lock: + self._remove_connection(connection) + connection_creator = self._acquire_new_later( + address, auth, deadline, + backwards_compatible_auth=True + ) + assert connection_creator is not None + break + log.debug("[#%04X] _: handing out existing connection", + connection.local_port) return connection # all connections in pool are in-use with self.lock: connection_creator = self._acquire_new_later( - address, auth, deadline + address, auth, deadline, + backwards_compatible_auth=backwards_compatible_auth ) if connection_creator: break @@ -307,8 +382,8 @@ def health_check(connection_, deadline_): @abc.abstractmethod def acquire( - self, access_mode, timeout, database, bookmarks, auth, - liveness_check_timeout, force_re_auth=False + self, access_mode, timeout, database, bookmarks, auth: AcquireAuth, + liveness_check_timeout ): """ Acquire a connection to a server that can satisfy a set of parameters. @@ -319,7 +394,6 @@ def acquire( :param bookmarks: :param auth: :param liveness_check_timeout: - :param force_re_auth: """ ... @@ -341,6 +415,29 @@ def kill_and_release(self, *connections): connection.in_use = False self.cond.notify_all() + @staticmethod + def _close_throwaway_connection(connection, cancelled): + if connection.throwaway: + if cancelled is not None: + log.debug( + "[#%04X] _: kill throwaway connection %s", + connection.local_port, connection.connection_id + ) + connection.kill() + else: + try: + log.debug( + "[#%04X] _: close throwaway connection %s", + connection.local_port, connection.connection_id + ) + connection.close() + except asyncio.CancelledError as exc: + log.debug("[#%04X] _: cancelled close of " + "throwaway connection: %r", + connection.local_port, exc) + cancelled = exc + return cancelled + def release(self, *connections): """ Release connections back into the pool. @@ -348,29 +445,32 @@ def release(self, *connections): """ cancelled = None for connection in connections: + cancelled = self._close_throwaway_connection( + connection, cancelled + ) if not (connection.defunct() or connection.closed() or connection.is_reset): if cancelled is not None: log.debug( - "[#%04X] _: released unclean connection %s", + "[#%04X] _: kill unclean connection %s", connection.local_port, connection.connection_id ) connection.kill() continue try: log.debug( - "[#%04X] _: released unclean connection %s", + "[#%04X] _: release unclean connection %s", connection.local_port, connection.connection_id ) connection.reset() - except (Neo4jError, DriverError, BoltError) as e: + except (Neo4jError, DriverError, BoltError) as exc: log.debug("[#%04X] _: failed to reset connection " - "on release: %r", connection.local_port, e) - except asyncio.CancelledError as e: + "on release: %r", connection.local_port, exc) + except asyncio.CancelledError as exc: log.debug("[#%04X] _: cancelled reset connection " - "on release: %r", connection.local_port, e) - cancelled = e + "on release: %r", connection.local_port, exc) + cancelled = exc connection.kill() with self.lock: for connection in connections: @@ -501,8 +601,8 @@ def __repr__(self): self.address) def acquire( - self, access_mode, timeout, database, bookmarks, auth, - liveness_check_timeout, force_re_auth=False + self, access_mode, timeout, database, bookmarks, auth: AcquireAuth, + liveness_check_timeout ): # The access_mode and database is not needed for a direct connection, # it's just there for consistency. @@ -510,7 +610,7 @@ def acquire( "access_mode=%r, database=%r", access_mode, database) deadline = Deadline.from_timeout_or_deadline(timeout) return self._acquire( - self.address, auth, deadline, liveness_check_timeout, force_re_auth + self.address, auth, deadline, liveness_check_timeout ) @@ -602,6 +702,9 @@ def fetch_routing_info( deadline = Deadline.from_timeout_or_deadline(acquisition_timeout) log.debug("[#0000] _: _acquire router connection, " "database=%r, address=%r", database, address) + if auth: + auth = copy(auth) + auth.force_auth = False cx = self._acquire(address, auth, deadline, None) try: routing_table = cx.route( @@ -863,8 +966,8 @@ def _select_address(self, *, access_mode, database): return choice(addresses_by_usage[min(addresses_by_usage)]) def acquire( - self, access_mode, timeout, database, bookmarks, auth, - liveness_check_timeout, force_re_auth=False + self, access_mode, timeout, database, bookmarks, auth: AcquireAuth, + liveness_check_timeout ): if access_mode not in (WRITE_ACCESS, READ_ACCESS): raise ClientError("Non valid 'access_mode'; {}".format(access_mode)) @@ -904,7 +1007,6 @@ def acquire( # should always be a resolved address connection = self._acquire( address, auth, deadline, liveness_check_timeout, - force_re_auth ) except (ServiceUnavailable, SessionExpired): self.deactivate(address=address) diff --git a/src/neo4j/_sync/work/session.py b/src/neo4j/_sync/work/session.py index 4a5dcefb..bf0eac66 100644 --- a/src/neo4j/_sync/work/session.py +++ b/src/neo4j/_sync/work/session.py @@ -114,12 +114,12 @@ def __exit__(self, exception_type, exception_value, traceback): self._state_failed = True self.close() - def _connect(self, access_mode, **access_kwargs): + def _connect(self, access_mode, **acquire_kwargs): if access_mode is None: access_mode = self._config.default_access_mode try: super()._connect( - access_mode, auth=self._config.auth, **access_kwargs + access_mode, auth=self._config.auth, **acquire_kwargs ) except asyncio.CancelledError: self._handle_cancellation(message="_connect") @@ -167,7 +167,14 @@ def _get_server_info(self): def _verify_authentication(self): assert not self._connection - self._connect(READ_ACCESS, force_re_auth=True) + self._connect(READ_ACCESS, force_auth=True) + if not self._config.backwards_compatible_auth: + # Even without backwards compatibility and an old server, the + # _connect call above can succeed if the connection is a new one. + # Hence, we enforce support explicitly to always let the user know + # that this is not supported. + self._connection.assert_re_auth_support() + self._disconnect() def close(self) -> None: """Close the session. diff --git a/src/neo4j/_sync/work/workspace.py b/src/neo4j/_sync/work/workspace.py index 6fd37028..ff6b96e4 100644 --- a/src/neo4j/_sync/work/workspace.py +++ b/src/neo4j/_sync/work/workspace.py @@ -33,7 +33,10 @@ SessionError, SessionExpired, ) -from ..io import Neo4jPool +from ..io import ( + AcquireAuth, + Neo4jPool, +) log = logging.getLogger("neo4j") @@ -132,6 +135,12 @@ def _update_bookmark(self, bookmark): def _connect(self, access_mode, auth=None, **acquire_kwargs): acquisition_timeout = self._config.connection_acquisition_timeout + auth = AcquireAuth( + auth, + backwards_compatible=self._config.backwards_compatible_auth, + force_auth=acquire_kwargs.pop("force_auth", False), + ) + if self._connection: # TODO: Investigate this # log.warning("FIXME: should always disconnect before connect") diff --git a/testkitbackend/_async/requests.py b/testkitbackend/_async/requests.py index aec611c3..eb06046e 100644 --- a/testkitbackend/_async/requests.py +++ b/testkitbackend/_async/requests.py @@ -142,12 +142,18 @@ async def NewDriver(backend, data): for k in ("sessionConnectionTimeoutMs", "updateRoutingTableTimeoutMs"): if k in data: data.mark_item_as_read_if_equals(k, None) - if data.get("maxConnectionPoolSize"): - kwargs["max_connection_pool_size"] = data["maxConnectionPoolSize"] - if data.get("fetchSize"): - kwargs["fetch_size"] = data["fetchSize"] - if "encrypted" in data: - kwargs["encrypted"] = data["encrypted"] + for (conf_name, data_name) in ( + ("max_connection_pool_size", "maxConnectionPoolSize"), + ("fetch_size", "fetchSize"), + ): + if data.get(data_name): + kwargs[conf_name] = data[data_name] + for (conf_name, data_name) in ( + ("encrypted", "encrypted"), + ("backwards_compatible_auth", "backwardsCompatibleAuth"), + ): + if data_name in data: + kwargs[conf_name] = data[data_name] if "trustedCertificates" in data: if data["trustedCertificates"] is None: kwargs["trusted_certificates"] = neo4j.TrustSystemCAs() @@ -236,12 +242,7 @@ async def GetServerInfo(backend, data): async def CheckMultiDBSupport(backend, data): driver_id = data["driverId"] driver = backend.drivers[driver_id] - with warning_check( - neo4j.ExperimentalWarning, - "Feature support query, based on Bolt protocol version and Neo4j " - "server version will change in the future." - ): - available = await driver.supports_multi_db() + available = await driver.supports_multi_db() await backend.send_response("MultiDBSupport", { "id": backend.next_key(), "available": available }) @@ -250,9 +251,7 @@ async def CheckMultiDBSupport(backend, data): async def VerifyAuthentication(backend, data): driver_id = data["driverId"] driver = backend.drivers[driver_id] - auth = None - if data.get("auth_token"): - auth = _convert_auth_token(data, "auth_token") + auth = _convert_auth_token(data, "auth_token") authenticated = await driver.verify_authentication(auth=auth) await backend.send_response("DriverIsAuthenticated", { "id": backend.next_key(), "authenticated": authenticated diff --git a/testkitbackend/_sync/requests.py b/testkitbackend/_sync/requests.py index 2307d541..a12df510 100644 --- a/testkitbackend/_sync/requests.py +++ b/testkitbackend/_sync/requests.py @@ -142,12 +142,18 @@ def NewDriver(backend, data): for k in ("sessionConnectionTimeoutMs", "updateRoutingTableTimeoutMs"): if k in data: data.mark_item_as_read_if_equals(k, None) - if data.get("maxConnectionPoolSize"): - kwargs["max_connection_pool_size"] = data["maxConnectionPoolSize"] - if data.get("fetchSize"): - kwargs["fetch_size"] = data["fetchSize"] - if "encrypted" in data: - kwargs["encrypted"] = data["encrypted"] + for (conf_name, data_name) in ( + ("max_connection_pool_size", "maxConnectionPoolSize"), + ("fetch_size", "fetchSize"), + ): + if data.get(data_name): + kwargs[conf_name] = data[data_name] + for (conf_name, data_name) in ( + ("encrypted", "encrypted"), + ("backwards_compatible_auth", "backwardsCompatibleAuth"), + ): + if data_name in data: + kwargs[conf_name] = data[data_name] if "trustedCertificates" in data: if data["trustedCertificates"] is None: kwargs["trusted_certificates"] = neo4j.TrustSystemCAs() @@ -236,12 +242,7 @@ def GetServerInfo(backend, data): def CheckMultiDBSupport(backend, data): driver_id = data["driverId"] driver = backend.drivers[driver_id] - with warning_check( - neo4j.ExperimentalWarning, - "Feature support query, based on Bolt protocol version and Neo4j " - "server version will change in the future." - ): - available = driver.supports_multi_db() + available = driver.supports_multi_db() backend.send_response("MultiDBSupport", { "id": backend.next_key(), "available": available }) @@ -250,9 +251,7 @@ def CheckMultiDBSupport(backend, data): def VerifyAuthentication(backend, data): driver_id = data["driverId"] driver = backend.drivers[driver_id] - auth = None - if data.get("auth_token"): - auth = _convert_auth_token(data, "auth_token") + auth = _convert_auth_token(data, "auth_token") authenticated = driver.verify_authentication(auth=auth) backend.send_response("DriverIsAuthenticated", { "id": backend.next_key(), "authenticated": authenticated diff --git a/tests/unit/async_/io/test_direct.py b/tests/unit/async_/io/test_direct.py index 28b5773d..07ebb477 100644 --- a/tests/unit/async_/io/test_direct.py +++ b/tests/unit/async_/io/test_direct.py @@ -96,10 +96,10 @@ async def opener(addr, auth, timeout): async def acquire( self, access_mode, timeout, database, bookmarks, auth, - liveness_check_timeout, force_re_auth=False + liveness_check_timeout ): return await self._acquire( - self.address, auth, timeout, liveness_check_timeout, force_re_auth + self.address, auth, timeout, liveness_check_timeout ) diff --git a/tests/unit/sync/io/test_direct.py b/tests/unit/sync/io/test_direct.py index ef732942..2d13c669 100644 --- a/tests/unit/sync/io/test_direct.py +++ b/tests/unit/sync/io/test_direct.py @@ -96,10 +96,10 @@ def opener(addr, auth, timeout): def acquire( self, access_mode, timeout, database, bookmarks, auth, - liveness_check_timeout, force_re_auth=False + liveness_check_timeout ): return self._acquire( - self.address, auth, timeout, liveness_check_timeout, force_re_auth + self.address, auth, timeout, liveness_check_timeout ) From 304b471bf4b9a3ddce47b6a7fcefa173a9a46948 Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Tue, 7 Mar 2023 13:51:16 +0100 Subject: [PATCH 12/23] Implement AuthTokenManager Implement change to ADR https://github.com/neo-technology/drivers-adr/pull/60 The auth token provider function has been replaced by a more general-purpose auth token manager interface. --- docs/source/api.rst | 14 +- docs/source/async_api.rst | 24 ++- src/neo4j/__init__.py | 1 - src/neo4j/_async/auth_management.py | 190 ++++++++++++++++++ src/neo4j/_async/driver.py | 35 ++-- src/neo4j/_async/io/_bolt.py | 53 +++-- src/neo4j/_async/io/_bolt3.py | 2 +- src/neo4j/_async/io/_bolt4.py | 2 +- src/neo4j/_async/io/_bolt5.py | 2 +- src/neo4j/_async/io/_pool.py | 102 ++++------ src/neo4j/_async/work/result.py | 2 +- src/neo4j/_async/work/session.py | 3 + src/neo4j/_auth_management.py | 117 +++++++++++ src/neo4j/_sync/auth_management.py | 190 ++++++++++++++++++ src/neo4j/_sync/driver.py | 34 ++-- src/neo4j/_sync/io/_bolt.py | 53 +++-- src/neo4j/_sync/io/_bolt3.py | 2 +- src/neo4j/_sync/io/_bolt4.py | 2 +- src/neo4j/_sync/io/_bolt5.py | 2 +- src/neo4j/_sync/io/_pool.py | 99 +++------ src/neo4j/_sync/work/result.py | 2 +- src/neo4j/_sync/work/session.py | 3 + src/neo4j/api.py | 81 +------- src/neo4j/auth_management.py | 34 ++++ src/neo4j/time/__init__.py | 2 +- testkitbackend/_async/backend.py | 6 +- testkitbackend/_async/requests.py | 160 +++++++++------ testkitbackend/_sync/backend.py | 6 +- testkitbackend/_sync/requests.py | 160 +++++++++------ testkitbackend/fromtestkit.py | 26 +++ testkitbackend/totestkit.py | 4 + tests/unit/async_/fixtures/fake_connection.py | 6 +- tests/unit/async_/io/conftest.py | 3 + tests/unit/async_/io/test_class_bolt.py | 63 +++++- tests/unit/async_/io/test_class_bolt3.py | 28 +-- tests/unit/async_/io/test_class_bolt4x0.py | 36 ++-- tests/unit/async_/io/test_class_bolt4x1.py | 38 ++-- tests/unit/async_/io/test_class_bolt4x2.py | 38 ++-- tests/unit/async_/io/test_class_bolt4x3.py | 38 ++-- tests/unit/async_/io/test_class_bolt4x4.py | 38 ++-- tests/unit/async_/io/test_class_bolt5x0.py | 38 ++-- tests/unit/async_/io/test_class_bolt5x1.py | 44 ++-- tests/unit/async_/io/test_direct.py | 18 +- tests/unit/async_/io/test_neo4j_pool.py | 166 +++++++-------- tests/unit/async_/test_addressing.py | 5 +- tests/unit/async_/test_auth_manager.py | 116 +++++++++++ tests/unit/common/test_conf.py | 1 + tests/unit/mixed/io/test_direct.py | 12 +- tests/unit/mixed/io/test_pool_async.py | 72 ------- tests/unit/mixed/io/test_pool_sync.py | 79 -------- tests/unit/sync/fixtures/fake_connection.py | 6 +- tests/unit/sync/io/conftest.py | 3 + tests/unit/sync/io/test_class_bolt.py | 63 +++++- tests/unit/sync/io/test_class_bolt3.py | 28 +-- tests/unit/sync/io/test_class_bolt4x0.py | 36 ++-- tests/unit/sync/io/test_class_bolt4x1.py | 38 ++-- tests/unit/sync/io/test_class_bolt4x2.py | 38 ++-- tests/unit/sync/io/test_class_bolt4x3.py | 38 ++-- tests/unit/sync/io/test_class_bolt4x4.py | 38 ++-- tests/unit/sync/io/test_class_bolt5x0.py | 38 ++-- tests/unit/sync/io/test_class_bolt5x1.py | 44 ++-- tests/unit/sync/io/test_direct.py | 18 +- tests/unit/sync/io/test_neo4j_pool.py | 166 +++++++-------- tests/unit/sync/test_addressing.py | 5 +- tests/unit/sync/test_auth_manager.py | 116 +++++++++++ 65 files changed, 1859 insertions(+), 1068 deletions(-) create mode 100644 src/neo4j/_async/auth_management.py create mode 100644 src/neo4j/_auth_management.py create mode 100644 src/neo4j/_sync/auth_management.py create mode 100644 src/neo4j/auth_management.py create mode 100644 tests/unit/async_/test_auth_manager.py delete mode 100644 tests/unit/mixed/io/test_pool_async.py delete mode 100644 tests/unit/mixed/io/test_pool_sync.py create mode 100644 tests/unit/sync/test_auth_manager.py diff --git a/docs/source/api.rst b/docs/source/api.rst index 32e821be..0c28b413 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -107,11 +107,17 @@ Auth To authenticate with Neo4j the authentication details are supplied at driver creation. -The auth token is an object of the class :class:`neo4j.Auth` containing static details or a a callable that returns a :class:`neo4j.RenewableAuth` object. +The auth token is an object of the class :class:`neo4j.Auth` containing static details or :class:`neo4j.auth_management.AuthManager` object. .. autoclass:: neo4j.Auth -.. autoclass:: neo4j.RenewableAuth +.. autoclass:: neo4j.auth_management.AuthManager + :members: + +.. autoclass:: neo4j.auth_management.AuthManagers + :members: + +.. autoclass:: neo4j.auth_management.TemporalAuth Example: @@ -327,7 +333,7 @@ Closing a driver will immediately shut down all connections in the pool. Defaults to the driver's :attr:`.query_bookmark_manager`. - Pass :const:`None` to disable causal consistency. + Pass :data:`None` to disable causal consistency. :type bookmark_manager_: typing.Union[neo4j.BookmarkManager, neo4j.BookmarkManager, None] @@ -760,7 +766,7 @@ Optional :class:`neo4j.Bookmarks`. Use this to causally chain sessions. See :meth:`Session.last_bookmarks` or :meth:`AsyncSession.last_bookmarks` for more information. -:Default: ``None`` +:Default: :data:`None` .. deprecated:: 5.0 Alternatively, an iterable of strings can be passed. This usage is diff --git a/docs/source/async_api.rst b/docs/source/async_api.rst index c2b8e413..d130c33b 100644 --- a/docs/source/async_api.rst +++ b/docs/source/async_api.rst @@ -40,7 +40,7 @@ The :class:`neo4j.AsyncDriver` construction is done via a ``classmethod`` on the await driver.close() # close the driver object - asyncio.run(main()) + asyncio.run(main()) For basic authentication, ``auth`` can be a simple tuple, for example: @@ -67,7 +67,7 @@ The :class:`neo4j.AsyncDriver` construction is done via a ``classmethod`` on the async with AsyncGraphDatabase.driver(uri, auth=auth) as driver: ... # use the driver - asyncio.run(main()) + asyncio.run(main()) @@ -118,6 +118,20 @@ Each supported scheme maps to a particular :class:`neo4j.AsyncDriver` subclass t See https://neo4j.com/docs/operations-manual/current/configuration/ports/ for Neo4j ports. +.. _async-auth-ref: + +Async Auth +========== + +Authentication mostly works the same as in the synchronous driver. +However, there are async equivalents of the synchronous constructs. + +.. autoclass:: neo4j.auth_management.AsyncAuthManager + :members: + +.. autoclass:: neo4j.auth_management.AsyncAuthManagers + :members: + *********** AsyncDriver @@ -336,8 +350,10 @@ Async Driver Configuration :class:`neo4j.AsyncDriver` is configured exactly like :class:`neo4j.Driver` (see :ref:`driver-configuration-ref`). The only differences are that the async -driver accepts an async custom resolver function (see :ref:`async-resolver-ref`) -as well as an async auth token provider (see :class:`neo4j.RenewableAuth`). +driver accepts + + * a sync as well as an async custom resolver function (see :ref:`async-resolver-ref`) + * as sync as well as an async auth token manager (see :class:`neo4j.AsyncAuthManager`). .. _async-resolver-ref: diff --git a/src/neo4j/__init__.py b/src/neo4j/__init__.py index bc7e186a..2583462c 100644 --- a/src/neo4j/__init__.py +++ b/src/neo4j/__init__.py @@ -79,7 +79,6 @@ DEFAULT_DATABASE, kerberos_auth, READ_ACCESS, - RenewableAuth, ServerInfo, SYSTEM_DATABASE, TRUST_ALL_CERTIFICATES, diff --git a/src/neo4j/_async/auth_management.py b/src/neo4j/_async/auth_management.py new file mode 100644 index 00000000..04064f38 --- /dev/null +++ b/src/neo4j/_async/auth_management.py @@ -0,0 +1,190 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# from __future__ import annotations +# work around for https://github.com/sphinx-doc/sphinx/pull/10880 +# make sure TAuth is resolved in the docs, else they're pretty useless + +import time +import typing as t +from logging import getLogger + +from .._async_compat.concurrency import AsyncLock +from .._auth_management import ( + AsyncAuthManager, + TemporalAuth, +) + +# work around for https://github.com/sphinx-doc/sphinx/pull/10880 +# make sure TAuth is resolved in the docs, else they're pretty useless +# if t.TYPE_CHECKING: +from ..api import _TAuth + + +log = getLogger("neo4j") + + +class _AsyncStaticAuthManager(AsyncAuthManager): + _auth: _TAuth + + def __init__(self, auth: _TAuth) -> None: + self._auth = auth + + async def get_auth(self) -> _TAuth: + return self._auth + + async def on_auth_expired(self, auth: _TAuth) -> None: + pass + + +class _TemporalAuthHolder: + def __init__(self, auth: TemporalAuth) -> None: + self._auth = auth + self._expiry = None + if auth.expires_in is not None: + self._expiry = time.monotonic() + auth.expires_in + + @property + def auth(self) -> _TAuth: + return self._auth.auth + + def expired(self) -> bool: + if self._expiry is None: + return False + return time.monotonic() > self._expiry + +class _AsyncTemporalAuthManager(AsyncAuthManager): + _current_auth: t.Optional[_TemporalAuthHolder] + _provider: t.Callable[[], t.Awaitable[TemporalAuth]] + _lock: AsyncLock + + + def __init__( + self, + provider: t.Callable[[], t.Awaitable[TemporalAuth]] + ) -> None: + self._provider = provider + self._current_auth = None + self._lock = AsyncLock() + + async def _refresh_auth(self): + self._current_auth = _TemporalAuthHolder(await self._provider()) + + async def get_auth(self) -> _TAuth: + async with self._lock: + auth = self._current_auth + if auth is not None and not auth.expired(): + return auth.auth + log.debug("[ ] _: refreshing (time out)") + await self._refresh_auth() + assert self._current_auth is not None + return self._current_auth.auth + + async def on_auth_expired(self, auth: _TAuth) -> None: + async with self._lock: + cur_auth = self._current_auth + if cur_auth is not None and cur_auth.auth == auth: + log.debug("[ ] _: refreshing (error)") + await self._refresh_auth() + + +class AsyncAuthManagers: + """A collection of :class:`.AsyncAuthManager` factories.""" + + @staticmethod + def static(auth: _TAuth) -> AsyncAuthManager: + """Create a static auth manager. + + Example:: + + # NOTE: this example is for illustration purposes only. + # The driver will automatically wrap static auth info in a + # static auth manager. + + import neo4j + from neo4j.auth_management import AsyncAuthManagers + + + auth = neo4j.basic_auth("neo4j", "password") + + with neo4j.GraphDatabase.driver( + "neo4j://example.com:7687", + auth=AsyncAuthManagers.static(auth) + # auth=auth # this is equivalent + ) as driver: + ... # do stuff + + :param auth: The auth to return. + + :returns: + An instance of an implementation of :class:`.AsyncAuthManager` that + always returns the same auth. + """ + return _AsyncStaticAuthManager(auth) + + @staticmethod + def temporal( + provider: t.Callable[[], t.Awaitable[TemporalAuth]] + ) -> AsyncAuthManager: + """Create a auth manager for potentially expiring auth info. + + .. warning:: + + The provider function **must not** interact with the driver in any + way as this can cause deadlocks and undefined behaviour. + + Example:: + + import neo4j + from neo4j.auth_management import ( + AsyncAuthManagers, + TemporalAuth, + ) + + + async def auth_provider(): + # some way to getting a token + sso_token = await get_sso_token() + # assume we know our tokens expire every 60 seconds + expires_in = 60 + + return TemporalAuth( + auth=neo4j.bearer_auth(sso_token), + # Include a little buffer so that we fetch a new token + # *before* the old one expires + expires_in=expires_in - 10 + ) + + + with neo4j.GraphDatabase.driver( + "neo4j://example.com:7687", + auth=AsyncAuthManagers.temporal(auth_provider) + ) as driver: + ... # do stuff + + :param provider: + A callable that provides a :class:`.TemporalAuth` instance. + + :returns: + An instance of an implementation of :class:`.AsyncAuthManager` that + returns auth info from the given provider and refreshes it, calling + the provider again, when the auth info expires (either because it's + reached its expiry time or because the server flagged it as + expired). + """ + return _AsyncTemporalAuthManager(provider) diff --git a/src/neo4j/_async/driver.py b/src/neo4j/_async/driver.py index c75bcee6..3022fb8c 100644 --- a/src/neo4j/_async/driver.py +++ b/src/neo4j/_async/driver.py @@ -57,7 +57,6 @@ parse_neo4j_uri, parse_routing_context, READ_ACCESS, - RenewableAuth, SECURITY_TYPE_SECURE, SECURITY_TYPE_SELF_SIGNED_CERTIFICATE, ServerInfo, @@ -70,6 +69,11 @@ URI_SCHEME_NEO4J_SECURE, URI_SCHEME_NEO4J_SELF_SIGNED_CERTIFICATE, ) +from ..auth_management import ( + AsyncAuthManager, + AsyncAuthManagers, + AuthManager, +) from ..exceptions import Neo4jError from .bookmark_manager import ( AsyncNeo4jBookmarkManager, @@ -90,6 +94,7 @@ import typing_extensions as te from .._api import T_RoutingControl + from ..api import _TAuth class _DefaultEnum(Enum): @@ -103,12 +108,6 @@ class _DefaultEnum(Enum): _T = t.TypeVar("_T") -_TAuthTokenProvider = t.Callable[[], t.Union[ - RenewableAuth, Auth, t.Tuple[t.Any, t.Any], None, - t.Awaitable[t.Union[RenewableAuth, Auth, t.Tuple[t.Any, t.Any], None]] -]] - - class AsyncGraphDatabase: """Accessor for :class:`neo4j.AsyncDriver` construction. """ @@ -121,7 +120,12 @@ def driver( uri: str, *, auth: t.Union[ - t.Tuple[t.Any, t.Any], Auth, _TAuthTokenProvider, None + # work around https://github.com/sphinx-doc/sphinx/pull/10880 + # make sure TAuth is resolved in the docs + # TAuth, + t.Union[t.Tuple[t.Any, t.Any], Auth, None], + AsyncAuthManager, + AuthManager ] = ..., max_connection_lifetime: float = ..., max_connection_pool_size: int = ..., @@ -161,7 +165,12 @@ def driver( def driver( cls, uri: str, *, auth: t.Union[ - t.Tuple[t.Any, t.Any], Auth, _TAuthTokenProvider, None + # work around https://github.com/sphinx-doc/sphinx/pull/10880 + # make sure TAuth is resolved in the docs + # TAuth, + t.Union[t.Tuple[t.Any, t.Any], Auth, None], + AsyncAuthManager, + AuthManager ] = None, **config ) -> AsyncDriver: @@ -178,7 +187,9 @@ def driver( driver_type, security_type, parsed = parse_neo4j_uri(uri) - config["auth"] = auth if callable(auth) else lambda: auth + if not isinstance(auth, (AsyncAuthManager, AuthManager)): + auth = AsyncAuthManagers.static(auth) + config["auth"] = auth # TODO: 6.0 - remove "trust" config option if "trust" in config.keys(): @@ -496,7 +507,7 @@ def session( default_access_mode: str = ..., bookmark_manager: t.Union[AsyncBookmarkManager, BookmarkManager, None] = ..., - auth: t.Union[Auth, t.Tuple[t.Any, t.Any]] = ..., + auth: _TAuth = ..., # undocumented/unsupported options # they may be change or removed any time without prior notice @@ -752,7 +763,7 @@ async def example(driver: neo4j.AsyncDriver) -> neo4j.Record:: Defaults to the driver's :attr:`.query_bookmark_manager`. - Pass :const:`None` to disable causal consistency. + Pass :data:`None` to disable causal consistency. :type bookmark_manager_: typing.Union[neo4j.AsyncBookmarkManager, neo4j.BookmarkManager, None] diff --git a/src/neo4j/_async/io/_bolt.py b/src/neo4j/_async/io/_bolt.py index 66fc124b..863e817b 100644 --- a/src/neo4j/_async/io/_bolt.py +++ b/src/neo4j/_async/io/_bolt.py @@ -25,6 +25,7 @@ from time import perf_counter from ..._async_compat.network import AsyncBoltSocket +from ..._async_compat.util import AsyncUtil from ..._codec.hydration import v1 as hydration_v1 from ..._codec.packstream import v1 as packstream_v1 from ..._conf import PoolConfig @@ -33,7 +34,7 @@ BoltHandshakeError, ) from ..._meta import get_user_agent -from ...addressing import Address +from ...addressing import ResolvedAddress from ...api import ( ServerInfo, Version, @@ -119,12 +120,16 @@ class AsyncBolt: most_recent_qid = None def __init__(self, unresolved_address, sock, max_connection_lifetime, *, - auth=None, user_agent=None, routing_context=None): + auth=None, auth_manager=None, user_agent=None, + routing_context=None): self.unresolved_address = unresolved_address self.socket = sock self.local_port = self.socket.getsockname()[1] - self.server_info = ServerInfo(Address(sock.getpeername()), - self.PROTOCOL_VERSION) + self.server_info = ServerInfo( + ResolvedAddress(sock.getpeername(), + host_name=unresolved_address.host), + self.PROTOCOL_VERSION + ) # so far `connection.recv_timeout_seconds` is the only available # configuration hint that exists. Therefore, all hints can be stored at # connection level. This might change in the future. @@ -151,16 +156,9 @@ def __init__(self, unresolved_address, sock, max_connection_lifetime, *, else: self.user_agent = get_user_agent() + self.auth = auth self.auth_dict = self._to_auth_dict(auth) - - # Check for missing password - try: - credentials = self.auth_dict["credentials"] - except KeyError: - pass - else: - if credentials is None: - raise AuthError("Password cannot be None") + self.auth_manager = auth_manager def __del__(self): if not asyncio.iscoroutinefunction(self.close): @@ -328,13 +326,13 @@ async def ping(cls, address, *, timeout=None, pool_config=None): @classmethod async def open( - cls, address, *, auth=None, timeout=None, routing_context=None, + cls, address, *, auth_manager=None, timeout=None, routing_context=None, pool_config=None ): """Open a new Bolt connection to a given server address. :param address: - :param auth: + :param auth_manager: :param timeout: the connection timeout in seconds :param routing_context: dict containing routing context :param pool_config: @@ -400,7 +398,7 @@ def time_remaining(): from ._bolt5 import AsyncBolt5x1 bolt_cls = AsyncBolt5x1 else: - log.debug("[#%04X] S: ", s.getsockname()[1]) + log.debug("[#%04X] C: ", s.getsockname()[1]) await AsyncBoltSocket.close_socket(s) supported_versions = cls.protocol_handlers().keys() @@ -411,9 +409,23 @@ def time_remaining(): address=address, request_data=handshake, response_data=data ) + try: + auth = await AsyncUtil.callback(auth_manager.get_auth) + except asyncio.CancelledError as e: + log.debug("[#%04X] C: open auth manager failed: %r", + s.getsockname()[1], e) + s.kill() + raise + except Exception as e: + log.debug("[#%04X] C: open auth manager failed: %r", + s.getsockname()[1], e) + await s.close() + raise + connection = bolt_cls( address, s, pool_config.max_connection_lifetime, auth=auth, - user_agent=pool_config.user_agent, routing_context=routing_context + auth_manager=auth_manager, user_agent=pool_config.user_agent, + routing_context=routing_context ) try: @@ -465,7 +477,8 @@ def mark_unauthenticated(self): self.auth_dict = {} def re_auth( - self, auth, dehydration_hooks=None, hydration_hooks=None, force=False + self, auth, auth_manager, force=False, + dehydration_hooks=None, hydration_hooks=None, ): """Append LOGON, LOGOFF to the outgoing queue. @@ -475,10 +488,14 @@ def re_auth( """ new_auth_dict = self._to_auth_dict(auth) if not force and new_auth_dict == self.auth_dict: + self.auth_manager = auth_manager + self.auth = auth return False self.logoff(dehydration_hooks=dehydration_hooks, hydration_hooks=hydration_hooks) self.auth_dict = new_auth_dict + self.auth_manager = auth_manager + self.auth = auth self.logon(dehydration_hooks=dehydration_hooks, hydration_hooks=hydration_hooks) return True diff --git a/src/neo4j/_async/io/_bolt3.py b/src/neo4j/_async/io/_bolt3.py index e481de31..7cfca437 100644 --- a/src/neo4j/_async/io/_bolt3.py +++ b/src/neo4j/_async/io/_bolt3.py @@ -399,7 +399,7 @@ async def _process_message(self, tag, fields): self.pool.on_write_failure(address=self.unresolved_address) raise except Neo4jError as e: - await self.pool.on_neo4j_error(e, self.server_info.address) + await self.pool.on_neo4j_error(e, self) raise else: raise BoltProtocolError("Unexpected response message with signature %02X" % summary_signature, address=self.unresolved_address) diff --git a/src/neo4j/_async/io/_bolt4.py b/src/neo4j/_async/io/_bolt4.py index c96fc91d..e5ffa9fd 100644 --- a/src/neo4j/_async/io/_bolt4.py +++ b/src/neo4j/_async/io/_bolt4.py @@ -349,7 +349,7 @@ async def _process_message(self, tag, fields): raise except Neo4jError as e: if self.pool: - await self.pool.on_neo4j_error(e, self.server_info.address) + await self.pool.on_neo4j_error(e, self) raise else: raise BoltProtocolError("Unexpected response message with signature " diff --git a/src/neo4j/_async/io/_bolt5.py b/src/neo4j/_async/io/_bolt5.py index d1df4c0c..4f81c79b 100644 --- a/src/neo4j/_async/io/_bolt5.py +++ b/src/neo4j/_async/io/_bolt5.py @@ -347,7 +347,7 @@ async def _process_message(self, tag, fields): raise except Neo4jError as e: if self.pool: - await self.pool.on_neo4j_error(e, self.server_info.address) + await self.pool.on_neo4j_error(e, self) raise else: raise BoltProtocolError( diff --git a/src/neo4j/_async/io/_pool.py b/src/neo4j/_async/io/_pool.py index 65079b09..b76078b4 100644 --- a/src/neo4j/_async/io/_pool.py +++ b/src/neo4j/_async/io/_pool.py @@ -49,11 +49,13 @@ from ..._exceptions import BoltError from ..._routing import RoutingTable from ...api import ( - Auth, READ_ACCESS, - RenewableAuth, WRITE_ACCESS, ) +from ...auth_management import ( + AsyncAuthManager, + AuthManager, +) from ...exceptions import ( ClientError, ConfigurationError, @@ -73,7 +75,7 @@ @dataclass class AcquireAuth: - auth: t.Optional[Auth] + auth: t.Union[AsyncAuthManager, AuthManager, None] backwards_compatible: bool = False force_auth: bool = False @@ -94,9 +96,6 @@ def __init__(self, opener, pool_config, workspace_config): self.connections_reservations = defaultdict(lambda: 0) self.lock = AsyncCooperativeRLock() self.cond = AsyncCondition(self.lock) - self.refreshing_auth = False - self.auth_condition = AsyncCondition() - self.last_auth: t.Optional[RenewableAuth] = None async def __aenter__(self): return self @@ -105,53 +104,7 @@ async def __aexit__(self, exc_type, exc_value, traceback): await self.close() async def get_auth(self): - async with self.auth_condition: - auth_missing = self.last_auth is None - needs_refresh = auth_missing or self.last_auth.expired - if not needs_refresh: - return self.last_auth.auth - - if self.refreshing_auth: - # someone else is already getting new auth info - if not auth_missing: - # there is old auth info we can use in the meantime - return self.last_auth.auth - else: - while self.last_auth is None: - await self.auth_condition.wait() - return self.last_auth.auth - else: - # we need to get new auth info - self.refreshing_auth = True - - auth = await self._get_new_auth() - async with self.auth_condition: - self.last_auth = auth - self.refreshing_auth = False - self.auth_condition.notify_all() - return self.last_auth.auth - - async def force_new_auth(self): - async with self.auth_condition: - log.debug("[#0000] _: force new auth info") - if self.refreshing_auth: - return - self.last_auth = None - self.refreshing_auth = True - - auth = await self._get_new_auth() - async with self.auth_condition: - self.last_auth = auth - self.refreshing_auth = False - self.auth_condition.notify_all() - - async def _get_new_auth(self): - log.debug("[#0000] _: requesting new auth info from %r", - self.pool_config.auth) - new_auth = await AsyncUtil.callback(self.pool_config.auth) - if not isinstance(new_auth, RenewableAuth): - return RenewableAuth(new_auth) - return new_auth + return await AsyncUtil.callback(self.pool_config.auth.get_auth) async def _acquire_from_pool(self, address): with self.lock: @@ -212,7 +165,7 @@ async def connection_creator(): try: try: connection = await self.opener( - address, auth or await self.get_auth(), + address, auth or self.pool_config.auth, deadline.to_timeout() ) except ServiceUnavailable: @@ -258,11 +211,22 @@ async def connection_creator(): return connection_creator return None - async def _re_auth_connection(self, connection, auth, force): - new_auth = auth or await self.get_auth() + async def _re_auth_connection( + self, connection, auth, force, backwards_compatible_auth + ): + if auth and not backwards_compatible_auth: + # Assert session auth is supported by the protocol. + # The Bolt implementation will try as hard as it can to make the + # re-auth work. So if the session auth token is identical to the + # driver auth token, the connection doesn't have to do anything so + # it won't fail regardless of the protocol version. + connection.assert_re_auth_support() + new_auth_manager = auth or self.pool_config.auth log_auth = "******" if auth else "None" + new_auth = await AsyncUtil.callback(new_auth_manager.get_auth) try: - updated = connection.re_auth(new_auth, force=force) + updated = connection.re_auth(new_auth, new_auth_manager, + force=force) log.debug("[#%04X] _: checked re_auth auth=%s updated=%s " "force=%s", connection.local_port, log_auth, updated, force) @@ -317,7 +281,7 @@ async def health_check(connection_, deadline_): connection.local_port, connection.connection_id) try: await self._re_auth_connection( - connection, auth, force_auth + connection, auth, force_auth, backwards_compatible_auth ) except ConfigurationError: if not auth: @@ -541,7 +505,8 @@ def on_write_failure(self, address): "No write service available for pool {}".format(self) ) - async def on_neo4j_error(self, error, address): + async def on_neo4j_error(self, error, connection): + address = connection.unresolved_address assert isinstance(error, Neo4jError) if error._unauthenticates_all_connections(): log.debug( @@ -552,7 +517,10 @@ async def on_neo4j_error(self, error, address): for connection in self.connections.get(address, ()): connection.mark_unauthenticated() if error._requires_new_credentials(): - await self.force_new_auth() + await AsyncUtil.callback( + connection.auth_manager.on_auth_expired, + connection.auth + ) async def close(self): """ Close all connections and empty the pool. @@ -582,10 +550,10 @@ def open(cls, address, *, pool_config, workspace_config): :returns: BoltPool """ - async def opener(addr, auth, timeout): + async def opener(addr, auth_manager, timeout): return await AsyncBolt.open( - addr, auth=auth, timeout=timeout, routing_context=None, - pool_config=pool_config + addr, auth_manager=auth_manager, timeout=timeout, + routing_context=None, pool_config=pool_config ) pool = cls(opener, pool_config, workspace_config, address) @@ -637,9 +605,9 @@ def open(cls, *addresses, pool_config, workspace_config, raise ConfigurationError("The key 'address' is reserved for routing context.") routing_context["address"] = str(address) - async def opener(addr, auth, timeout): + async def opener(addr, auth_manager, timeout): return await AsyncBolt.open( - addr, auth=auth, timeout=timeout, + addr, auth_manager=auth_manager, timeout=timeout, routing_context=routing_context, pool_config=pool_config ) @@ -966,8 +934,8 @@ async def _select_address(self, *, access_mode, database): return choice(addresses_by_usage[min(addresses_by_usage)]) async def acquire( - self, access_mode, timeout, database, bookmarks, auth: AcquireAuth, - liveness_check_timeout + self, access_mode, timeout, database, bookmarks, + auth: t.Optional[AcquireAuth], liveness_check_timeout ): if access_mode not in (WRITE_ACCESS, READ_ACCESS): raise ClientError("Non valid 'access_mode'; {}".format(access_mode)) diff --git a/src/neo4j/_async/work/result.py b/src/neo4j/_async/work/result.py index 5d17dd1a..351c45bf 100644 --- a/src/neo4j/_async/work/result.py +++ b/src/neo4j/_async/work/result.py @@ -278,7 +278,7 @@ async def _buffer(self, n=None): Might end up with more records in the buffer if the fetch size makes it overshoot. - Might ent up with fewer records in the buffer if there are not enough + Might end up with fewer records in the buffer if there are not enough records available. """ if self._out_of_scope: diff --git a/src/neo4j/_async/work/session.py b/src/neo4j/_async/work/session.py index c5e76ed7..2a3c0bab 100644 --- a/src/neo4j/_async/work/session.py +++ b/src/neo4j/_async/work/session.py @@ -42,6 +42,7 @@ TransactionError, ) from ...work import Query +from ..auth_management import AsyncAuthManagers from .result import AsyncResult from .transaction import ( AsyncManagedTransaction, @@ -97,6 +98,8 @@ class AsyncSession(AsyncWorkspace): def __init__(self, pool, session_config): assert isinstance(session_config, SessionConfig) + if session_config.auth is not None: + session_config.auth = AsyncAuthManagers.static(session_config.auth) super().__init__(pool, session_config) self._config = session_config self._initialize_bookmarks(session_config.bookmarks) diff --git a/src/neo4j/_auth_management.py b/src/neo4j/_auth_management.py new file mode 100644 index 00000000..d99468c8 --- /dev/null +++ b/src/neo4j/_auth_management.py @@ -0,0 +1,117 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# from __future__ import annotations +# work around for https://github.com/sphinx-doc/sphinx/pull/10880 +# make sure TAuth is resolved in the docs, else they're pretty useless + +import abc +import typing as t +from dataclasses import dataclass + +from .api import _TAuth + + +@dataclass +class TemporalAuth: + """Represents potentially expiring authentication information. + + This class is used with :meth:`.AuthManagers.temporal` and + :meth:`.AsyncAuthManagers.temporal`. + + :param auth: The authentication information. + :param expires_in: The number of seconds until the authentication + information expires. If :data:`None`, the authentication information + is considered to not expire until the server explicitly indicates so. + + .. seealso:: + :meth:`.AuthManagers.temporal`, :meth:`.AsyncAuthManagers.temporal` + + .. versionadded:: 5.x + """ + auth: _TAuth + expires_in: t.Optional[float] = None + + +class AuthManager(metaclass=abc.ABCMeta): + """Baseclass for authentication information managers. + + The driver provides some default implementations of this class in + :class:`.AuthManagers` for convenience. + + Custom implementations of this class can be used to provide more complex + authentication refresh functionality. + + .. warning:: + + The manager **must not** interact with the driver in any way as this + can cause deadlocks and undefined behaviour. + + Furthermore, the manager is expected to be thread-safe. + + .. seealso:: :class:`.AuthManagers` + + .. versionadded:: 5.x + """ + + @abc.abstractmethod + def get_auth(self) -> _TAuth: + """Return the current authentication information. + + The driver will call this method very frequently. It is recommended + to implement some form of caching to avoid unnecessary overhead. + """ + ... + + @abc.abstractmethod + def on_auth_expired(self, auth: _TAuth) -> None: + """Handle the server indicating expired authentication information. + + The driver will call this method when the server indicates that the + provided authentication information is no longer valid. + + :param auth: + The authentication information that the server flagged as no longer + valid. + """ + ... + + +class AsyncAuthManager(metaclass=abc.ABCMeta): + """Async version of :class:`.AuthManager`. + + .. seealso:: :class:`.AuthManager` + + .. versionadded:: 5.x + """ + + @abc.abstractmethod + async def get_auth(self) -> _TAuth: + """Async version of :meth:`.AuthManager.get_auth`. + + .. seealso:: :meth:`.AuthManager.get_auth` + """ + ... + + @abc.abstractmethod + async def on_auth_expired(self, auth: _TAuth) -> None: + """Async version of :meth:`.AuthManager.on_auth_expired`. + + .. seealso:: :meth:`.AuthManager.on_auth_expired` + """ + ... diff --git a/src/neo4j/_sync/auth_management.py b/src/neo4j/_sync/auth_management.py new file mode 100644 index 00000000..8a31a0af --- /dev/null +++ b/src/neo4j/_sync/auth_management.py @@ -0,0 +1,190 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# from __future__ import annotations +# work around for https://github.com/sphinx-doc/sphinx/pull/10880 +# make sure TAuth is resolved in the docs, else they're pretty useless + +import time +import typing as t +from logging import getLogger + +from .._async_compat.concurrency import Lock +from .._auth_management import ( + AuthManager, + TemporalAuth, +) + +# work around for https://github.com/sphinx-doc/sphinx/pull/10880 +# make sure TAuth is resolved in the docs, else they're pretty useless +# if t.TYPE_CHECKING: +from ..api import _TAuth + + +log = getLogger("neo4j") + + +class _StaticAuthManager(AuthManager): + _auth: _TAuth + + def __init__(self, auth: _TAuth) -> None: + self._auth = auth + + def get_auth(self) -> _TAuth: + return self._auth + + def on_auth_expired(self, auth: _TAuth) -> None: + pass + + +class _TemporalAuthHolder: + def __init__(self, auth: TemporalAuth) -> None: + self._auth = auth + self._expiry = None + if auth.expires_in is not None: + self._expiry = time.monotonic() + auth.expires_in + + @property + def auth(self) -> _TAuth: + return self._auth.auth + + def expired(self) -> bool: + if self._expiry is None: + return False + return time.monotonic() > self._expiry + +class _TemporalAuthManager(AuthManager): + _current_auth: t.Optional[_TemporalAuthHolder] + _provider: t.Callable[[], t.Union[TemporalAuth]] + _lock: Lock + + + def __init__( + self, + provider: t.Callable[[], t.Union[TemporalAuth]] + ) -> None: + self._provider = provider + self._current_auth = None + self._lock = Lock() + + def _refresh_auth(self): + self._current_auth = _TemporalAuthHolder(self._provider()) + + def get_auth(self) -> _TAuth: + with self._lock: + auth = self._current_auth + if auth is not None and not auth.expired(): + return auth.auth + log.debug("[ ] _: refreshing (time out)") + self._refresh_auth() + assert self._current_auth is not None + return self._current_auth.auth + + def on_auth_expired(self, auth: _TAuth) -> None: + with self._lock: + cur_auth = self._current_auth + if cur_auth is not None and cur_auth.auth == auth: + log.debug("[ ] _: refreshing (error)") + self._refresh_auth() + + +class AuthManagers: + """A collection of :class:`.AuthManager` factories.""" + + @staticmethod + def static(auth: _TAuth) -> AuthManager: + """Create a static auth manager. + + Example:: + + # NOTE: this example is for illustration purposes only. + # The driver will automatically wrap static auth info in a + # static auth manager. + + import neo4j + from neo4j.auth_management import AuthManagers + + + auth = neo4j.basic_auth("neo4j", "password") + + with neo4j.GraphDatabase.driver( + "neo4j://example.com:7687", + auth=AuthManagers.static(auth) + # auth=auth # this is equivalent + ) as driver: + ... # do stuff + + :param auth: The auth to return. + + :returns: + An instance of an implementation of :class:`.AuthManager` that + always returns the same auth. + """ + return _StaticAuthManager(auth) + + @staticmethod + def temporal( + provider: t.Callable[[], t.Union[TemporalAuth]] + ) -> AuthManager: + """Create a auth manager for potentially expiring auth info. + + .. warning:: + + The provider function **must not** interact with the driver in any + way as this can cause deadlocks and undefined behaviour. + + Example:: + + import neo4j + from neo4j.auth_management import ( + AuthManagers, + TemporalAuth, + ) + + + def auth_provider(): + # some way to getting a token + sso_token = get_sso_token() + # assume we know our tokens expire every 60 seconds + expires_in = 60 + + return TemporalAuth( + auth=neo4j.bearer_auth(sso_token), + # Include a little buffer so that we fetch a new token + # *before* the old one expires + expires_in=expires_in - 10 + ) + + + with neo4j.GraphDatabase.driver( + "neo4j://example.com:7687", + auth=AuthManagers.temporal(auth_provider) + ) as driver: + ... # do stuff + + :param provider: + A callable that provides a :class:`.TemporalAuth` instance. + + :returns: + An instance of an implementation of :class:`.AuthManager` that + returns auth info from the given provider and refreshes it, calling + the provider again, when the auth info expires (either because it's + reached its expiry time or because the server flagged it as + expired). + """ + return _TemporalAuthManager(provider) diff --git a/src/neo4j/_sync/driver.py b/src/neo4j/_sync/driver.py index 88cdd4dd..14568f5d 100644 --- a/src/neo4j/_sync/driver.py +++ b/src/neo4j/_sync/driver.py @@ -56,7 +56,6 @@ parse_neo4j_uri, parse_routing_context, READ_ACCESS, - RenewableAuth, SECURITY_TYPE_SECURE, SECURITY_TYPE_SELF_SIGNED_CERTIFICATE, ServerInfo, @@ -69,6 +68,10 @@ URI_SCHEME_NEO4J_SECURE, URI_SCHEME_NEO4J_SELF_SIGNED_CERTIFICATE, ) +from ..auth_management import ( + AuthManager, + AuthManagers, +) from ..exceptions import Neo4jError from .bookmark_manager import ( Neo4jBookmarkManager, @@ -89,6 +92,7 @@ import typing_extensions as te from .._api import T_RoutingControl + from ..api import _TAuth class _DefaultEnum(Enum): @@ -102,12 +106,6 @@ class _DefaultEnum(Enum): _T = t.TypeVar("_T") -_TAuthTokenProvider = t.Callable[[], t.Union[ - RenewableAuth, Auth, t.Tuple[t.Any, t.Any], None, - t.Union[t.Union[RenewableAuth, Auth, t.Tuple[t.Any, t.Any], None]] -]] - - class GraphDatabase: """Accessor for :class:`neo4j.Driver` construction. """ @@ -120,7 +118,12 @@ def driver( uri: str, *, auth: t.Union[ - t.Tuple[t.Any, t.Any], Auth, _TAuthTokenProvider, None + # work around https://github.com/sphinx-doc/sphinx/pull/10880 + # make sure TAuth is resolved in the docs + # TAuth, + t.Union[t.Tuple[t.Any, t.Any], Auth, None], + AuthManager, + AuthManager ] = ..., max_connection_lifetime: float = ..., max_connection_pool_size: int = ..., @@ -160,7 +163,12 @@ def driver( def driver( cls, uri: str, *, auth: t.Union[ - t.Tuple[t.Any, t.Any], Auth, _TAuthTokenProvider, None + # work around https://github.com/sphinx-doc/sphinx/pull/10880 + # make sure TAuth is resolved in the docs + # TAuth, + t.Union[t.Tuple[t.Any, t.Any], Auth, None], + AuthManager, + AuthManager ] = None, **config ) -> Driver: @@ -177,7 +185,9 @@ def driver( driver_type, security_type, parsed = parse_neo4j_uri(uri) - config["auth"] = auth if callable(auth) else lambda: auth + if not isinstance(auth, (AuthManager, AuthManager)): + auth = AuthManagers.static(auth) + config["auth"] = auth # TODO: 6.0 - remove "trust" config option if "trust" in config.keys(): @@ -495,7 +505,7 @@ def session( default_access_mode: str = ..., bookmark_manager: t.Union[BookmarkManager, BookmarkManager, None] = ..., - auth: t.Union[Auth, t.Tuple[t.Any, t.Any]] = ..., + auth: _TAuth = ..., # undocumented/unsupported options # they may be change or removed any time without prior notice @@ -751,7 +761,7 @@ def example(driver: neo4j.Driver) -> neo4j.Record:: Defaults to the driver's :attr:`.query_bookmark_manager`. - Pass :const:`None` to disable causal consistency. + Pass :data:`None` to disable causal consistency. :type bookmark_manager_: typing.Union[neo4j.BookmarkManager, neo4j.BookmarkManager, None] diff --git a/src/neo4j/_sync/io/_bolt.py b/src/neo4j/_sync/io/_bolt.py index 47b3144c..9047eeb1 100644 --- a/src/neo4j/_sync/io/_bolt.py +++ b/src/neo4j/_sync/io/_bolt.py @@ -25,6 +25,7 @@ from time import perf_counter from ..._async_compat.network import BoltSocket +from ..._async_compat.util import Util from ..._codec.hydration import v1 as hydration_v1 from ..._codec.packstream import v1 as packstream_v1 from ..._conf import PoolConfig @@ -33,7 +34,7 @@ BoltHandshakeError, ) from ..._meta import get_user_agent -from ...addressing import Address +from ...addressing import ResolvedAddress from ...api import ( ServerInfo, Version, @@ -119,12 +120,16 @@ class Bolt: most_recent_qid = None def __init__(self, unresolved_address, sock, max_connection_lifetime, *, - auth=None, user_agent=None, routing_context=None): + auth=None, auth_manager=None, user_agent=None, + routing_context=None): self.unresolved_address = unresolved_address self.socket = sock self.local_port = self.socket.getsockname()[1] - self.server_info = ServerInfo(Address(sock.getpeername()), - self.PROTOCOL_VERSION) + self.server_info = ServerInfo( + ResolvedAddress(sock.getpeername(), + host_name=unresolved_address.host), + self.PROTOCOL_VERSION + ) # so far `connection.recv_timeout_seconds` is the only available # configuration hint that exists. Therefore, all hints can be stored at # connection level. This might change in the future. @@ -151,16 +156,9 @@ def __init__(self, unresolved_address, sock, max_connection_lifetime, *, else: self.user_agent = get_user_agent() + self.auth = auth self.auth_dict = self._to_auth_dict(auth) - - # Check for missing password - try: - credentials = self.auth_dict["credentials"] - except KeyError: - pass - else: - if credentials is None: - raise AuthError("Password cannot be None") + self.auth_manager = auth_manager def __del__(self): if not asyncio.iscoroutinefunction(self.close): @@ -328,13 +326,13 @@ def ping(cls, address, *, timeout=None, pool_config=None): @classmethod def open( - cls, address, *, auth=None, timeout=None, routing_context=None, + cls, address, *, auth_manager=None, timeout=None, routing_context=None, pool_config=None ): """Open a new Bolt connection to a given server address. :param address: - :param auth: + :param auth_manager: :param timeout: the connection timeout in seconds :param routing_context: dict containing routing context :param pool_config: @@ -400,7 +398,7 @@ def time_remaining(): from ._bolt5 import Bolt5x1 bolt_cls = Bolt5x1 else: - log.debug("[#%04X] S: ", s.getsockname()[1]) + log.debug("[#%04X] C: ", s.getsockname()[1]) BoltSocket.close_socket(s) supported_versions = cls.protocol_handlers().keys() @@ -411,9 +409,23 @@ def time_remaining(): address=address, request_data=handshake, response_data=data ) + try: + auth = Util.callback(auth_manager.get_auth) + except asyncio.CancelledError as e: + log.debug("[#%04X] C: open auth manager failed: %r", + s.getsockname()[1], e) + s.kill() + raise + except Exception as e: + log.debug("[#%04X] C: open auth manager failed: %r", + s.getsockname()[1], e) + s.close() + raise + connection = bolt_cls( address, s, pool_config.max_connection_lifetime, auth=auth, - user_agent=pool_config.user_agent, routing_context=routing_context + auth_manager=auth_manager, user_agent=pool_config.user_agent, + routing_context=routing_context ) try: @@ -465,7 +477,8 @@ def mark_unauthenticated(self): self.auth_dict = {} def re_auth( - self, auth, dehydration_hooks=None, hydration_hooks=None, force=False + self, auth, auth_manager, force=False, + dehydration_hooks=None, hydration_hooks=None, ): """Append LOGON, LOGOFF to the outgoing queue. @@ -475,10 +488,14 @@ def re_auth( """ new_auth_dict = self._to_auth_dict(auth) if not force and new_auth_dict == self.auth_dict: + self.auth_manager = auth_manager + self.auth = auth return False self.logoff(dehydration_hooks=dehydration_hooks, hydration_hooks=hydration_hooks) self.auth_dict = new_auth_dict + self.auth_manager = auth_manager + self.auth = auth self.logon(dehydration_hooks=dehydration_hooks, hydration_hooks=hydration_hooks) return True diff --git a/src/neo4j/_sync/io/_bolt3.py b/src/neo4j/_sync/io/_bolt3.py index 7229fc99..25466e47 100644 --- a/src/neo4j/_sync/io/_bolt3.py +++ b/src/neo4j/_sync/io/_bolt3.py @@ -399,7 +399,7 @@ def _process_message(self, tag, fields): self.pool.on_write_failure(address=self.unresolved_address) raise except Neo4jError as e: - self.pool.on_neo4j_error(e, self.server_info.address) + self.pool.on_neo4j_error(e, self) raise else: raise BoltProtocolError("Unexpected response message with signature %02X" % summary_signature, address=self.unresolved_address) diff --git a/src/neo4j/_sync/io/_bolt4.py b/src/neo4j/_sync/io/_bolt4.py index 5d373005..bf37a717 100644 --- a/src/neo4j/_sync/io/_bolt4.py +++ b/src/neo4j/_sync/io/_bolt4.py @@ -349,7 +349,7 @@ def _process_message(self, tag, fields): raise except Neo4jError as e: if self.pool: - self.pool.on_neo4j_error(e, self.server_info.address) + self.pool.on_neo4j_error(e, self) raise else: raise BoltProtocolError("Unexpected response message with signature " diff --git a/src/neo4j/_sync/io/_bolt5.py b/src/neo4j/_sync/io/_bolt5.py index fcb68491..17e6f951 100644 --- a/src/neo4j/_sync/io/_bolt5.py +++ b/src/neo4j/_sync/io/_bolt5.py @@ -347,7 +347,7 @@ def _process_message(self, tag, fields): raise except Neo4jError as e: if self.pool: - self.pool.on_neo4j_error(e, self.server_info.address) + self.pool.on_neo4j_error(e, self) raise else: raise BoltProtocolError( diff --git a/src/neo4j/_sync/io/_pool.py b/src/neo4j/_sync/io/_pool.py index d806cc97..99964e02 100644 --- a/src/neo4j/_sync/io/_pool.py +++ b/src/neo4j/_sync/io/_pool.py @@ -49,11 +49,10 @@ from ..._exceptions import BoltError from ..._routing import RoutingTable from ...api import ( - Auth, READ_ACCESS, - RenewableAuth, WRITE_ACCESS, ) +from ...auth_management import AuthManager from ...exceptions import ( ClientError, ConfigurationError, @@ -73,7 +72,7 @@ @dataclass class AcquireAuth: - auth: t.Optional[Auth] + auth: t.Union[AuthManager, AuthManager, None] backwards_compatible: bool = False force_auth: bool = False @@ -94,9 +93,6 @@ def __init__(self, opener, pool_config, workspace_config): self.connections_reservations = defaultdict(lambda: 0) self.lock = CooperativeRLock() self.cond = Condition(self.lock) - self.refreshing_auth = False - self.auth_condition = Condition() - self.last_auth: t.Optional[RenewableAuth] = None def __enter__(self): return self @@ -105,53 +101,7 @@ def __exit__(self, exc_type, exc_value, traceback): self.close() def get_auth(self): - with self.auth_condition: - auth_missing = self.last_auth is None - needs_refresh = auth_missing or self.last_auth.expired - if not needs_refresh: - return self.last_auth.auth - - if self.refreshing_auth: - # someone else is already getting new auth info - if not auth_missing: - # there is old auth info we can use in the meantime - return self.last_auth.auth - else: - while self.last_auth is None: - self.auth_condition.wait() - return self.last_auth.auth - else: - # we need to get new auth info - self.refreshing_auth = True - - auth = self._get_new_auth() - with self.auth_condition: - self.last_auth = auth - self.refreshing_auth = False - self.auth_condition.notify_all() - return self.last_auth.auth - - def force_new_auth(self): - with self.auth_condition: - log.debug("[#0000] _: force new auth info") - if self.refreshing_auth: - return - self.last_auth = None - self.refreshing_auth = True - - auth = self._get_new_auth() - with self.auth_condition: - self.last_auth = auth - self.refreshing_auth = False - self.auth_condition.notify_all() - - def _get_new_auth(self): - log.debug("[#0000] _: requesting new auth info from %r", - self.pool_config.auth) - new_auth = Util.callback(self.pool_config.auth) - if not isinstance(new_auth, RenewableAuth): - return RenewableAuth(new_auth) - return new_auth + return Util.callback(self.pool_config.auth.get_auth) def _acquire_from_pool(self, address): with self.lock: @@ -212,7 +162,7 @@ def connection_creator(): try: try: connection = self.opener( - address, auth or self.get_auth(), + address, auth or self.pool_config.auth, deadline.to_timeout() ) except ServiceUnavailable: @@ -258,11 +208,22 @@ def connection_creator(): return connection_creator return None - def _re_auth_connection(self, connection, auth, force): - new_auth = auth or self.get_auth() + def _re_auth_connection( + self, connection, auth, force, backwards_compatible_auth + ): + if auth and not backwards_compatible_auth: + # Assert session auth is supported by the protocol. + # The Bolt implementation will try as hard as it can to make the + # re-auth work. So if the session auth token is identical to the + # driver auth token, the connection doesn't have to do anything so + # it won't fail regardless of the protocol version. + connection.assert_re_auth_support() + new_auth_manager = auth or self.pool_config.auth log_auth = "******" if auth else "None" + new_auth = Util.callback(new_auth_manager.get_auth) try: - updated = connection.re_auth(new_auth, force=force) + updated = connection.re_auth(new_auth, new_auth_manager, + force=force) log.debug("[#%04X] _: checked re_auth auth=%s updated=%s " "force=%s", connection.local_port, log_auth, updated, force) @@ -317,7 +278,7 @@ def health_check(connection_, deadline_): connection.local_port, connection.connection_id) try: self._re_auth_connection( - connection, auth, force_auth + connection, auth, force_auth, backwards_compatible_auth ) except ConfigurationError: if not auth: @@ -541,7 +502,8 @@ def on_write_failure(self, address): "No write service available for pool {}".format(self) ) - def on_neo4j_error(self, error, address): + def on_neo4j_error(self, error, connection): + address = connection.unresolved_address assert isinstance(error, Neo4jError) if error._unauthenticates_all_connections(): log.debug( @@ -552,7 +514,10 @@ def on_neo4j_error(self, error, address): for connection in self.connections.get(address, ()): connection.mark_unauthenticated() if error._requires_new_credentials(): - self.force_new_auth() + Util.callback( + connection.auth_manager.on_auth_expired, + connection.auth + ) def close(self): """ Close all connections and empty the pool. @@ -582,10 +547,10 @@ def open(cls, address, *, pool_config, workspace_config): :returns: BoltPool """ - def opener(addr, auth, timeout): + def opener(addr, auth_manager, timeout): return Bolt.open( - addr, auth=auth, timeout=timeout, routing_context=None, - pool_config=pool_config + addr, auth_manager=auth_manager, timeout=timeout, + routing_context=None, pool_config=pool_config ) pool = cls(opener, pool_config, workspace_config, address) @@ -637,9 +602,9 @@ def open(cls, *addresses, pool_config, workspace_config, raise ConfigurationError("The key 'address' is reserved for routing context.") routing_context["address"] = str(address) - def opener(addr, auth, timeout): + def opener(addr, auth_manager, timeout): return Bolt.open( - addr, auth=auth, timeout=timeout, + addr, auth_manager=auth_manager, timeout=timeout, routing_context=routing_context, pool_config=pool_config ) @@ -966,8 +931,8 @@ def _select_address(self, *, access_mode, database): return choice(addresses_by_usage[min(addresses_by_usage)]) def acquire( - self, access_mode, timeout, database, bookmarks, auth: AcquireAuth, - liveness_check_timeout + self, access_mode, timeout, database, bookmarks, + auth: t.Optional[AcquireAuth], liveness_check_timeout ): if access_mode not in (WRITE_ACCESS, READ_ACCESS): raise ClientError("Non valid 'access_mode'; {}".format(access_mode)) diff --git a/src/neo4j/_sync/work/result.py b/src/neo4j/_sync/work/result.py index a8a2a2f8..c5d6a84a 100644 --- a/src/neo4j/_sync/work/result.py +++ b/src/neo4j/_sync/work/result.py @@ -278,7 +278,7 @@ def _buffer(self, n=None): Might end up with more records in the buffer if the fetch size makes it overshoot. - Might ent up with fewer records in the buffer if there are not enough + Might end up with fewer records in the buffer if there are not enough records available. """ if self._out_of_scope: diff --git a/src/neo4j/_sync/work/session.py b/src/neo4j/_sync/work/session.py index bf0eac66..f4f81b29 100644 --- a/src/neo4j/_sync/work/session.py +++ b/src/neo4j/_sync/work/session.py @@ -42,6 +42,7 @@ TransactionError, ) from ...work import Query +from ..auth_management import AuthManagers from .result import Result from .transaction import ( ManagedTransaction, @@ -97,6 +98,8 @@ class Session(Workspace): def __init__(self, pool, session_config): assert isinstance(session_config, SessionConfig) + if session_config.auth is not None: + session_config.auth = AuthManagers.static(session_config.auth) super().__init__(pool, session_config) self._config = session_config self._initialize_bookmarks(session_config.bookmarks) diff --git a/src/neo4j/api.py b/src/neo4j/api.py index b4809b4e..c5d94a33 100644 --- a/src/neo4j/api.py +++ b/src/neo4j/api.py @@ -21,7 +21,6 @@ from __future__ import annotations import abc -import time import typing as t from urllib.parse import ( parse_qs, @@ -113,6 +112,12 @@ def __eq__(self, other): # For backwards compatibility AuthToken = Auth +# if t.TYPE_CHECKING: +# commented out as work around for +# https://github.com/sphinx-doc/sphinx/pull/10880 +# make sure TAuth is resolved in the docs, else they're pretty useless +_TAuth = t.Union[t.Tuple[t.Any, t.Any], Auth, None] + def basic_auth( user: str, password: str, realm: t.Optional[str] = None @@ -181,80 +186,6 @@ def custom_auth( return Auth(scheme, principal, credentials, realm, **parameters) -class RenewableAuth: - """Container for authentication details which potentially expire. - - This is meant to be used as a return value for a callable auth token - provider to accommodate for expiring authentication information. - - For Example:: - - import neo4j - - - def auth_provider(): - sso_token = get_sso_token() # some way to getting a fresh token - expires_in = 60 # we know our tokens expire every 60 seconds - - return neo4j.RenewableAuth( - neo4j.bearer_auth(sso_token), - # The driver will continue to use the old token until a new one - # has been fetched. So we want the auth provider to be called - # a little before the token expires. - expires_in - 10 - ) - - - with neo4j.GraphDatabase.driver( - "neo4j://example.com:7687", - auth=auth_provider - ) as driver: - ... # do stuff - - .. warning:: - - The auth provider **must not** interact with the driver in any way as - this can cause deadlocks and undefined behaviour. - - The driver will call the auth provider when either the last token it - provided has expired (see :attr:`expires_in`) or when the driver has - received an authentication error from the server that indicates new - authentication information is required. - - :param auth: The auth token. - :param expires_in: The expected expiry time - of the auth token in seconds from now. It is recommended to set this - a little before the actual expiry time to give the driver time to - renew the auth token before connections start to fail. If set to - :data:`None`, the token is assumed to never expire. - - .. versionadded:: 5.x - """ - - def __init__( - self, - auth: t.Union[Auth, t.Tuple[t.Any, t.Any], None], - expires_in: t.Optional[float] = None, - ) -> None: - self.auth = auth - self.created_at: t.Optional[float] - self.expires_in: t.Optional[float] - self.expires_at: t.Optional[float] - if expires_in is not None: - self.expires_in = expires_in - self.created_at = time.monotonic() - self.expires_at = self.created_at + expires_in - else: - self.expires_in = None - self.created_at = None - self.expires_at = None - - @property - def expired(self): - return (self.expires_at is not None - and self.expires_at < time.monotonic()) - - # TODO: 6.0 - remove this class class Bookmark: """A Bookmark object contains an immutable list of bookmark string values. diff --git a/src/neo4j/auth_management.py b/src/neo4j/auth_management.py new file mode 100644 index 00000000..84b2a386 --- /dev/null +++ b/src/neo4j/auth_management.py @@ -0,0 +1,34 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from ._async.auth_management import AsyncAuthManagers +from ._auth_management import ( + AsyncAuthManager, + AuthManager, + TemporalAuth, +) +from ._sync.auth_management import AuthManagers + + +__all__ = [ + "AsyncAuthManager", + "AsyncAuthManagers", + "AuthManager", + "AuthManagers", + "TemporalAuth", +] diff --git a/src/neo4j/time/__init__.py b/src/neo4j/time/__init__.py index c6b70cf2..da696eb9 100644 --- a/src/neo4j/time/__init__.py +++ b/src/neo4j/time/__init__.py @@ -2544,7 +2544,7 @@ def as_timezone(self, tz: _tzinfo) -> DateTime: :param tz: the new timezone - :returns: the same object if ``tz`` is :const:``None``. + :returns: the same object if ``tz`` is :data:``None``. Else, a new :class:`.DateTime` that's the same point in time but in a different timezone. """ diff --git a/testkitbackend/_async/backend.py b/testkitbackend/_async/backend.py index 17724a54..5dcac273 100644 --- a/testkitbackend/_async/backend.py +++ b/testkitbackend/_async/backend.py @@ -55,8 +55,10 @@ def __init__(self, rd, wr): self.drivers = {} self.custom_resolutions = {} self.dns_resolutions = {} - self.auth_token_providers = {} - self.renewable_auth_token_supplies = {} + self.auth_token_managers = {} + self.auth_token_supplies = {} + self.auth_token_on_expiration_supplies = {} + self.temporal_auth_token_supplies = {} self.bookmark_managers = {} self.bookmarks_consumptions = {} self.bookmarks_supplies = {} diff --git a/testkitbackend/_async/requests.py b/testkitbackend/_async/requests.py index eb06046e..ab81c41c 100644 --- a/testkitbackend/_async/requests.py +++ b/testkitbackend/_async/requests.py @@ -15,6 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. + import datetime import json import re @@ -25,7 +26,13 @@ import neo4j import neo4j.api +import neo4j.auth_management from neo4j._async_compat.util import AsyncUtil +from neo4j.auth_management import ( + AsyncAuthManager, + AsyncAuthManagers, + TemporalAuth, +) from .. import ( fromtestkit, @@ -95,37 +102,12 @@ async def GetFeatures(backend, data): await backend.send_response("FeatureList", {"features": FEATURES}) -def _convert_auth_token(data, key): - if data[key] is None: - return None - auth_token = data[key]["data"] - data[key].mark_item_as_read_if_equals("name", "AuthorizationToken") - scheme = auth_token["scheme"] - if scheme == "basic": - auth = neo4j.basic_auth( - auth_token["principal"], auth_token["credentials"], - realm=auth_token.get("realm", None) - ) - elif scheme == "kerberos": - auth = neo4j.kerberos_auth(auth_token["credentials"]) - elif scheme == "bearer": - auth = neo4j.bearer_auth(auth_token["credentials"]) - else: - auth = neo4j.custom_auth( - auth_token["principal"], auth_token["credentials"], - auth_token["realm"], auth_token["scheme"], - **auth_token.get("parameters", {}) - ) - auth_token.mark_item_as_read("parameters", recursive=True) - return auth - - async def NewDriver(backend, data): - auth = _convert_auth_token(data, "authorizationToken") - if auth is None and data.get("authTokenProviderId") is not None: - auth = backend.auth_token_providers[data["authTokenProviderId"]] + auth = fromtestkit.to_auth_token(data, "authorizationToken") + if auth is None and data.get("authTokenManagerId") is not None: + auth = backend.auth_token_managers[data["authTokenManagerId"]] else: - data.mark_item_as_read_if_equals("authTokenProviderId", None) + data.mark_item_as_read_if_equals("authTokenManagerId", None) kwargs = {} if data["resolverRegistered"] or data["domainNameResolverRegistered"]: kwargs["resolver"] = resolution_func( @@ -173,52 +155,106 @@ async def NewDriver(backend, data): await backend.send_response("Driver", {"id": key}) -async def NewAuthTokenProvider(backend, data): - auth_token_provider_id = backend.next_key() +async def NewAuthTokenManager(backend, data): + auth_token_manager_id = backend.next_key() + + class TestKitAuthManager(AsyncAuthManager): + async def get_auth(self): + key = backend.next_key() + await backend.send_response("AuthTokenManagerGetAuthRequest", { + "id": key, + "authTokenManagerId": auth_token_manager_id, + }) + if not await backend.process_request(): + # connection was closed before end of next message + return None + if key not in backend.auth_token_supplies: + raise RuntimeError( + "Backend did not receive expected " + f"AuthTokenManagerGetAuthCompleted message for id {key}" + ) + return backend.auth_token_supplies.pop(key) + + async def on_auth_expired(self, auth): + key = backend.next_key() + await backend.send_response( + "AuthTokenManagerOnAuthExpiredRequest", { + "id": key, + "authTokenManagerId": auth_token_manager_id, + "auth": totestkit.auth_token(auth), + } + ) + if not await backend.process_request(): + # connection was closed before end of next message + return None + if key not in backend.auth_token_on_expiration_supplies: + raise RuntimeError( + "Backend did not receive expected " + "AuthTokenManagerOnAuthExpiredCompleted message for id " + f"{key}" + ) + backend.auth_token_on_expiration_supplies.pop(key) + + auth_manager = TestKitAuthManager() + backend.auth_token_managers[auth_token_manager_id] = auth_manager + await backend.send_response("AuthTokenManager", + {"id": auth_token_manager_id}) + + +async def AuthTokenManagerGetAuthCompleted(backend, data): + auth_token = fromtestkit.to_auth_token(data, "auth") + + backend.auth_token_supplies[data["requestId"]] = auth_token + + +async def AuthTokenManagerOnAuthExpiredCompleted(backend, data): + backend.auth_token_on_expiration_supplies[data["requestId"]] = True + + +async def AuthTokenManagerClose(backend, data): + auth_token_manager_id = data["id"] + del backend.auth_token_managers[auth_token_manager_id] + await backend.send_response("AuthTokenManager", + {"id": auth_token_manager_id}) + + +async def NewTemporalAuthTokenManager(backend, data): + auth_token_manager_id = backend.next_key() async def auth_token_provider(): key = backend.next_key() - await backend.send_response("AuthTokenProviderRequest", { + await backend.send_response("TemporalAuthTokenProviderRequest", { "id": key, - "authTokenProviderId": auth_token_provider_id, + "temporalAuthTokenManagerId": auth_token_manager_id, }) if not await backend.process_request(): # connection was closed before end of next message - return None - if key not in backend.renewable_auth_token_supplies: + return neo4j.auth_management.TemporalAuth(None, None) + if key not in backend.temporal_auth_token_supplies: raise RuntimeError( "Backend did not receive expected " - f"AuthTokenProviderCompleted message for id {key}" + f"TemporalAuthTokenManagerCompleted message for id {key}" ) - return backend.renewable_auth_token_supplies.pop(key) - + return backend.temporal_auth_token_supplies.pop(key) - backend.auth_token_providers[auth_token_provider_id] = auth_token_provider - await backend.send_response("AuthTokenProvider", - {"id": auth_token_provider_id}) + auth_manager = AsyncAuthManagers.temporal(auth_token_provider) + backend.auth_token_managers[auth_token_manager_id] = auth_manager + await backend.send_response("TemporalAuthTokenManager", + {"id": auth_token_manager_id}) -async def AuthTokenProviderClose(backend, data): - auth_token_provider_id = data["id"] - del backend.auth_token_providers[auth_token_provider_id] - await backend.send_response("AuthTokenProvider", - {"id": auth_token_provider_id}) - - -async def AuthTokenProviderCompleted(backend, data): - backend.renewable_auth_token_supplies[data["requestId"]] = \ - parse_renewable_auth(data["auth"]) - - -def parse_renewable_auth(data): - data.mark_item_as_read_if_equals("name", "RenewableAuthToken") - data = data["data"] - auth_token = _convert_auth_token(data, "auth") - if data["expiresInMs"] is not None: - expires_in = data["expiresInMs"] / 1000 +async def TemporalAuthTokenProviderCompleted(backend, data): + temp_auth_data = data["auth"] + temp_auth_data.mark_item_as_read_if_equals("name", "TemporalAuthToken") + temp_auth_data = temp_auth_data["data"] + auth_token = fromtestkit.to_auth_token(temp_auth_data, "auth") + if temp_auth_data["expiresInMs"] is not None: + expires_in = temp_auth_data["expiresInMs"] / 1000 else: expires_in = None - return neo4j.RenewableAuth(auth_token, expires_in) + temporal_auth = TemporalAuth(auth_token, expires_in) + + backend.temporal_auth_token_supplies[data["requestId"]] = temporal_auth async def VerifyConnectivity(backend, data): @@ -251,7 +287,7 @@ async def CheckMultiDBSupport(backend, data): async def VerifyAuthentication(backend, data): driver_id = data["driverId"] driver = backend.drivers[driver_id] - auth = _convert_auth_token(data, "auth_token") + auth = fromtestkit.to_auth_token(data, "auth_token") authenticated = await driver.verify_authentication(auth=auth) await backend.send_response("DriverIsAuthenticated", { "id": backend.next_key(), "authenticated": authenticated @@ -499,7 +535,7 @@ async def NewSession(backend, data): if data_name in data: config[conf_name] = data[data_name] if data.get("authorizationToken"): - config["auth"] = _convert_auth_token(data, "authorizationToken") + config["auth"] = fromtestkit.to_auth_token(data, "authorizationToken") if "bookmark_manager" in config: with warning_check( neo4j.ExperimentalWarning, diff --git a/testkitbackend/_sync/backend.py b/testkitbackend/_sync/backend.py index 052bce25..2ce5bc69 100644 --- a/testkitbackend/_sync/backend.py +++ b/testkitbackend/_sync/backend.py @@ -55,8 +55,10 @@ def __init__(self, rd, wr): self.drivers = {} self.custom_resolutions = {} self.dns_resolutions = {} - self.auth_token_providers = {} - self.renewable_auth_token_supplies = {} + self.auth_token_managers = {} + self.auth_token_supplies = {} + self.auth_token_on_expiration_supplies = {} + self.temporal_auth_token_supplies = {} self.bookmark_managers = {} self.bookmarks_consumptions = {} self.bookmarks_supplies = {} diff --git a/testkitbackend/_sync/requests.py b/testkitbackend/_sync/requests.py index a12df510..ca037d0b 100644 --- a/testkitbackend/_sync/requests.py +++ b/testkitbackend/_sync/requests.py @@ -15,6 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. + import datetime import json import re @@ -25,7 +26,13 @@ import neo4j import neo4j.api +import neo4j.auth_management from neo4j._async_compat.util import Util +from neo4j.auth_management import ( + AuthManager, + AuthManagers, + TemporalAuth, +) from .. import ( fromtestkit, @@ -95,37 +102,12 @@ def GetFeatures(backend, data): backend.send_response("FeatureList", {"features": FEATURES}) -def _convert_auth_token(data, key): - if data[key] is None: - return None - auth_token = data[key]["data"] - data[key].mark_item_as_read_if_equals("name", "AuthorizationToken") - scheme = auth_token["scheme"] - if scheme == "basic": - auth = neo4j.basic_auth( - auth_token["principal"], auth_token["credentials"], - realm=auth_token.get("realm", None) - ) - elif scheme == "kerberos": - auth = neo4j.kerberos_auth(auth_token["credentials"]) - elif scheme == "bearer": - auth = neo4j.bearer_auth(auth_token["credentials"]) - else: - auth = neo4j.custom_auth( - auth_token["principal"], auth_token["credentials"], - auth_token["realm"], auth_token["scheme"], - **auth_token.get("parameters", {}) - ) - auth_token.mark_item_as_read("parameters", recursive=True) - return auth - - def NewDriver(backend, data): - auth = _convert_auth_token(data, "authorizationToken") - if auth is None and data.get("authTokenProviderId") is not None: - auth = backend.auth_token_providers[data["authTokenProviderId"]] + auth = fromtestkit.to_auth_token(data, "authorizationToken") + if auth is None and data.get("authTokenManagerId") is not None: + auth = backend.auth_token_managers[data["authTokenManagerId"]] else: - data.mark_item_as_read_if_equals("authTokenProviderId", None) + data.mark_item_as_read_if_equals("authTokenManagerId", None) kwargs = {} if data["resolverRegistered"] or data["domainNameResolverRegistered"]: kwargs["resolver"] = resolution_func( @@ -173,52 +155,106 @@ def NewDriver(backend, data): backend.send_response("Driver", {"id": key}) -def NewAuthTokenProvider(backend, data): - auth_token_provider_id = backend.next_key() +def NewAuthTokenManager(backend, data): + auth_token_manager_id = backend.next_key() + + class TestKitAuthManager(AuthManager): + def get_auth(self): + key = backend.next_key() + backend.send_response("AuthTokenManagerGetAuthRequest", { + "id": key, + "authTokenManagerId": auth_token_manager_id, + }) + if not backend.process_request(): + # connection was closed before end of next message + return None + if key not in backend.auth_token_supplies: + raise RuntimeError( + "Backend did not receive expected " + f"AuthTokenManagerGetAuthCompleted message for id {key}" + ) + return backend.auth_token_supplies.pop(key) + + def on_auth_expired(self, auth): + key = backend.next_key() + backend.send_response( + "AuthTokenManagerOnAuthExpiredRequest", { + "id": key, + "authTokenManagerId": auth_token_manager_id, + "auth": totestkit.auth_token(auth), + } + ) + if not backend.process_request(): + # connection was closed before end of next message + return None + if key not in backend.auth_token_on_expiration_supplies: + raise RuntimeError( + "Backend did not receive expected " + "AuthTokenManagerOnAuthExpiredCompleted message for id " + f"{key}" + ) + backend.auth_token_on_expiration_supplies.pop(key) + + auth_manager = TestKitAuthManager() + backend.auth_token_managers[auth_token_manager_id] = auth_manager + backend.send_response("AuthTokenManager", + {"id": auth_token_manager_id}) + + +def AuthTokenManagerGetAuthCompleted(backend, data): + auth_token = fromtestkit.to_auth_token(data, "auth") + + backend.auth_token_supplies[data["requestId"]] = auth_token + + +def AuthTokenManagerOnAuthExpiredCompleted(backend, data): + backend.auth_token_on_expiration_supplies[data["requestId"]] = True + + +def AuthTokenManagerClose(backend, data): + auth_token_manager_id = data["id"] + del backend.auth_token_managers[auth_token_manager_id] + backend.send_response("AuthTokenManager", + {"id": auth_token_manager_id}) + + +def NewTemporalAuthTokenManager(backend, data): + auth_token_manager_id = backend.next_key() def auth_token_provider(): key = backend.next_key() - backend.send_response("AuthTokenProviderRequest", { + backend.send_response("TemporalAuthTokenProviderRequest", { "id": key, - "authTokenProviderId": auth_token_provider_id, + "temporalAuthTokenManagerId": auth_token_manager_id, }) if not backend.process_request(): # connection was closed before end of next message - return None - if key not in backend.renewable_auth_token_supplies: + return neo4j.auth_management.TemporalAuth(None, None) + if key not in backend.temporal_auth_token_supplies: raise RuntimeError( "Backend did not receive expected " - f"AuthTokenProviderCompleted message for id {key}" + f"TemporalAuthTokenManagerCompleted message for id {key}" ) - return backend.renewable_auth_token_supplies.pop(key) - + return backend.temporal_auth_token_supplies.pop(key) - backend.auth_token_providers[auth_token_provider_id] = auth_token_provider - backend.send_response("AuthTokenProvider", - {"id": auth_token_provider_id}) + auth_manager = AuthManagers.temporal(auth_token_provider) + backend.auth_token_managers[auth_token_manager_id] = auth_manager + backend.send_response("TemporalAuthTokenManager", + {"id": auth_token_manager_id}) -def AuthTokenProviderClose(backend, data): - auth_token_provider_id = data["id"] - del backend.auth_token_providers[auth_token_provider_id] - backend.send_response("AuthTokenProvider", - {"id": auth_token_provider_id}) - - -def AuthTokenProviderCompleted(backend, data): - backend.renewable_auth_token_supplies[data["requestId"]] = \ - parse_renewable_auth(data["auth"]) - - -def parse_renewable_auth(data): - data.mark_item_as_read_if_equals("name", "RenewableAuthToken") - data = data["data"] - auth_token = _convert_auth_token(data, "auth") - if data["expiresInMs"] is not None: - expires_in = data["expiresInMs"] / 1000 +def TemporalAuthTokenProviderCompleted(backend, data): + temp_auth_data = data["auth"] + temp_auth_data.mark_item_as_read_if_equals("name", "TemporalAuthToken") + temp_auth_data = temp_auth_data["data"] + auth_token = fromtestkit.to_auth_token(temp_auth_data, "auth") + if temp_auth_data["expiresInMs"] is not None: + expires_in = temp_auth_data["expiresInMs"] / 1000 else: expires_in = None - return neo4j.RenewableAuth(auth_token, expires_in) + temporal_auth = TemporalAuth(auth_token, expires_in) + + backend.temporal_auth_token_supplies[data["requestId"]] = temporal_auth def VerifyConnectivity(backend, data): @@ -251,7 +287,7 @@ def CheckMultiDBSupport(backend, data): def VerifyAuthentication(backend, data): driver_id = data["driverId"] driver = backend.drivers[driver_id] - auth = _convert_auth_token(data, "auth_token") + auth = fromtestkit.to_auth_token(data, "auth_token") authenticated = driver.verify_authentication(auth=auth) backend.send_response("DriverIsAuthenticated", { "id": backend.next_key(), "authenticated": authenticated @@ -499,7 +535,7 @@ def NewSession(backend, data): if data_name in data: config[conf_name] = data[data_name] if data.get("authorizationToken"): - config["auth"] = _convert_auth_token(data, "authorizationToken") + config["auth"] = fromtestkit.to_auth_token(data, "authorizationToken") if "bookmark_manager" in config: with warning_check( neo4j.ExperimentalWarning, diff --git a/testkitbackend/fromtestkit.py b/testkitbackend/fromtestkit.py index 0c540590..8451ab49 100644 --- a/testkitbackend/fromtestkit.py +++ b/testkitbackend/fromtestkit.py @@ -20,6 +20,7 @@ import pytz +import neo4j from neo4j import Query from neo4j.spatial import ( CartesianPoint, @@ -151,3 +152,28 @@ def to_param(m): seconds=data["seconds"], nanoseconds=data["nanoseconds"] ) raise ValueError("Unknown param type " + name) + + +def to_auth_token(data, key): + if data[key] is None: + return None + auth_token = data[key]["data"] + data[key].mark_item_as_read_if_equals("name", "AuthorizationToken") + scheme = auth_token["scheme"] + if scheme == "basic": + auth = neo4j.basic_auth( + auth_token["principal"], auth_token["credentials"], + realm=auth_token.get("realm", None) + ) + elif scheme == "kerberos": + auth = neo4j.kerberos_auth(auth_token["credentials"]) + elif scheme == "bearer": + auth = neo4j.bearer_auth(auth_token["credentials"]) + else: + auth = neo4j.custom_auth( + auth_token["principal"], auth_token["credentials"], + auth_token["realm"], auth_token["scheme"], + **auth_token.get("parameters", {}) + ) + auth_token.mark_item_as_read("parameters", recursive=True) + return auth diff --git a/testkitbackend/totestkit.py b/testkitbackend/totestkit.py index fadc8b5f..f422d5de 100644 --- a/testkitbackend/totestkit.py +++ b/testkitbackend/totestkit.py @@ -230,3 +230,7 @@ def to(name, val): } raise ValueError("Unhandled type:" + str(type(v))) + + +def auth_token(auth): + return {"name": "AuthorizationToken", "data": vars(auth)} diff --git a/tests/unit/async_/fixtures/fake_connection.py b/tests/unit/async_/fixtures/fake_connection.py index 48e357d8..5bf1a7ee 100644 --- a/tests/unit/async_/fixtures/fake_connection.py +++ b/tests/unit/async_/fixtures/fake_connection.py @@ -23,6 +23,7 @@ from neo4j import ServerInfo from neo4j._async.io import AsyncBolt from neo4j._deadline import Deadline +from neo4j.auth_management import AsyncAuthManager __all__ = [ @@ -51,7 +52,10 @@ def __init__(self, *args, **kwargs): self.attach_mock(mock.Mock(return_value=False), "closed") self.attach_mock(mock.Mock(return_value=False), "socket") self.attach_mock(mock.Mock(return_value=False), "re_auth") - self.attach_mock(mock.Mock(), "unresolved_address") + self.attach_mock(mock.AsyncMock(spec=AsyncAuthManager), + "auth_manager") + self.unresolved_address = next(iter(args), "localhost") + self.throwaway = False def close_side_effect(): self.closed.return_value = True diff --git a/tests/unit/async_/io/conftest.py b/tests/unit/async_/io/conftest.py index 08b6e9c4..0d3afa85 100644 --- a/tests/unit/async_/io/conftest.py +++ b/tests/unit/async_/io/conftest.py @@ -103,6 +103,9 @@ async def sendall(self, data): def close(self): return + def kill(self): + return + def inject(self, data): self.recv_buffer += data diff --git a/tests/unit/async_/io/test_class_bolt.py b/tests/unit/async_/io/test_class_bolt.py index 60adfd28..657acb7d 100644 --- a/tests/unit/async_/io/test_class_bolt.py +++ b/tests/unit/async_/io/test_class_bolt.py @@ -20,6 +20,7 @@ import pytest +import neo4j.auth_management from neo4j._async.io import AsyncBolt from neo4j._async_compat.network import AsyncBoltSocket @@ -102,6 +103,66 @@ async def test_cancel_hello_in_open(mocker): bolt_mock.local_port = 1234 with pytest.raises(asyncio.CancelledError): - await AsyncBolt.open(address) + await AsyncBolt.open( + address, + auth_manager=neo4j.auth_management.AsyncAuthManagers.static(None) + ) bolt_mock.kill.assert_called_once_with() + + +@AsyncTestDecorators.mark_async_only_test +async def test_cancel_manager_in_open(mocker): + address = ("localhost", 7687) + socket_mock = mocker.AsyncMock(spec=AsyncBoltSocket) + + socket_cls_mock = mocker.patch("neo4j._async.io._bolt.AsyncBoltSocket", + autospec=True) + socket_cls_mock.connect.return_value = ( + socket_mock, (5, 0), None, None + ) + socket_mock.getpeername.return_value = address + bolt_cls_mock = mocker.patch("neo4j._async.io._bolt5.AsyncBolt5x0", + autospec=True) + bolt_mock = bolt_cls_mock.return_value + bolt_mock.socket = socket_mock + bolt_mock.local_port = 1234 + + auth_manager = mocker.AsyncMock( + spec=neo4j.auth_management.AsyncAuthManager + ) + auth_manager.get_auth.side_effect = asyncio.CancelledError() + + with pytest.raises(asyncio.CancelledError): + await AsyncBolt.open(address, auth_manager=auth_manager) + + socket_mock.kill.assert_called_once_with() + + +@AsyncTestDecorators.mark_async_only_test +async def test_fail_manager_in_open(mocker): + address = ("localhost", 7687) + socket_mock = mocker.AsyncMock(spec=AsyncBoltSocket) + + socket_cls_mock = mocker.patch("neo4j._async.io._bolt.AsyncBoltSocket", + autospec=True) + socket_cls_mock.connect.return_value = ( + socket_mock, (5, 0), None, None + ) + socket_mock.getpeername.return_value = address + bolt_cls_mock = mocker.patch("neo4j._async.io._bolt5.AsyncBolt5x0", + autospec=True) + bolt_mock = bolt_cls_mock.return_value + bolt_mock.socket = socket_mock + bolt_mock.local_port = 1234 + + auth_manager = mocker.AsyncMock( + spec=neo4j.auth_management.AsyncAuthManager + ) + auth_manager.get_auth.side_effect = RuntimeError("token fetching failed") + + with pytest.raises(RuntimeError) as exc: + await AsyncBolt.open(address, auth_manager=auth_manager) + assert exc.value is auth_manager.get_auth.side_effect + + socket_mock.close.assert_called_once_with() diff --git a/tests/unit/async_/io/test_class_bolt3.py b/tests/unit/async_/io/test_class_bolt3.py index a1500f44..f0a78975 100644 --- a/tests/unit/async_/io/test_class_bolt3.py +++ b/tests/unit/async_/io/test_class_bolt3.py @@ -31,7 +31,7 @@ @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_stale(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = 0 connection = AsyncBolt3(address, fake_socket(address), max_connection_lifetime) if set_stale: @@ -41,7 +41,7 @@ def test_conn_is_stale(fake_socket, set_stale): @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = -1 connection = AsyncBolt3(address, fake_socket(address), max_connection_lifetime) if set_stale: @@ -51,7 +51,7 @@ def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_not_stale(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = 999999999 connection = AsyncBolt3(address, fake_socket(address), max_connection_lifetime) if set_stale: @@ -60,14 +60,14 @@ def test_conn_is_not_stale(fake_socket, set_stale): def test_db_extra_not_supported_in_begin(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) connection = AsyncBolt3(address, fake_socket(address), PoolConfig.max_connection_lifetime) with pytest.raises(ConfigurationError): connection.begin(db="something") def test_db_extra_not_supported_in_run(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) connection = AsyncBolt3(address, fake_socket(address), PoolConfig.max_connection_lifetime) with pytest.raises(ConfigurationError): connection.run("", db="something") @@ -75,7 +75,7 @@ def test_db_extra_not_supported_in_run(fake_socket): @mark_async_test async def test_simple_discard(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt3.UNPACKER_CLS) connection = AsyncBolt3(address, socket, PoolConfig.max_connection_lifetime) connection.discard() @@ -87,7 +87,7 @@ async def test_simple_discard(fake_socket): @mark_async_test async def test_simple_pull(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt3.UNPACKER_CLS) connection = AsyncBolt3(address, socket, PoolConfig.max_connection_lifetime) connection.pull() @@ -102,7 +102,7 @@ async def test_simple_pull(fake_socket): async def test_hint_recv_timeout_seconds_gets_ignored( fake_socket_pair, recv_timeout, mocker ): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) sockets = fake_socket_pair( address, AsyncBolt3.PACKER_CLS, AsyncBolt3.UNPACKER_CLS ) @@ -133,7 +133,7 @@ async def test_hint_recv_timeout_seconds_gets_ignored( async def test_credentials_are_not_logged( auth, fake_socket_pair, mocker, caplog ): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) sockets = fake_socket_pair(address, packer_cls=AsyncBolt3.PACKER_CLS, unpacker_cls=AsyncBolt3.UNPACKER_CLS) @@ -156,7 +156,7 @@ async def test_credentials_are_not_logged( @pytest.mark.parametrize("message", ("logon", "logoff")) def test_auth_message_raises_configuration_error(message, fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) connection = AsyncBolt3(address, fake_socket(address), PoolConfig.max_connection_lifetime) with pytest.raises(ConfigurationError, @@ -171,12 +171,12 @@ def test_auth_message_raises_configuration_error(message, fake_socket): )) @mark_async_test async def test_re_auth_noop(auth, fake_socket, mocker): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) connection = AsyncBolt3(address, fake_socket(address), PoolConfig.max_connection_lifetime, auth=auth) logon_spy = mocker.spy(connection, "logon") logoff_spy = mocker.spy(connection, "logoff") - res = connection.re_auth(auth) + res = connection.re_auth(auth, None) assert res is False logon_spy.assert_not_called() @@ -196,9 +196,9 @@ async def test_re_auth_noop(auth, fake_socket, mocker): ) @mark_async_test async def test_re_auth(auth1, auth2, fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) connection = AsyncBolt3(address, fake_socket(address), PoolConfig.max_connection_lifetime, auth=auth1) with pytest.raises(ConfigurationError, match="Session level authentication is not supported"): - connection.re_auth(auth2) + connection.re_auth(auth2, None) diff --git a/tests/unit/async_/io/test_class_bolt4x0.py b/tests/unit/async_/io/test_class_bolt4x0.py index f0df6c40..34aa4961 100644 --- a/tests/unit/async_/io/test_class_bolt4x0.py +++ b/tests/unit/async_/io/test_class_bolt4x0.py @@ -31,7 +31,7 @@ @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_stale(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = 0 connection = AsyncBolt4x0(address, fake_socket(address), max_connection_lifetime) if set_stale: @@ -41,7 +41,7 @@ def test_conn_is_stale(fake_socket, set_stale): @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = -1 connection = AsyncBolt4x0(address, fake_socket(address), max_connection_lifetime) if set_stale: @@ -51,7 +51,7 @@ def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_not_stale(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = 999999999 connection = AsyncBolt4x0(address, fake_socket(address), max_connection_lifetime) if set_stale: @@ -61,7 +61,7 @@ def test_conn_is_not_stale(fake_socket, set_stale): @mark_async_test async def test_db_extra_in_begin(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x0.UNPACKER_CLS) connection = AsyncBolt4x0(address, socket, PoolConfig.max_connection_lifetime) connection.begin(db="something") @@ -74,7 +74,7 @@ async def test_db_extra_in_begin(fake_socket): @mark_async_test async def test_db_extra_in_run(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x0.UNPACKER_CLS) connection = AsyncBolt4x0(address, socket, PoolConfig.max_connection_lifetime) connection.run("", {}, db="something") @@ -89,7 +89,7 @@ async def test_db_extra_in_run(fake_socket): @mark_async_test async def test_n_extra_in_discard(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x0.UNPACKER_CLS) connection = AsyncBolt4x0(address, socket, PoolConfig.max_connection_lifetime) connection.discard(n=666) @@ -109,7 +109,7 @@ async def test_n_extra_in_discard(fake_socket): ) @mark_async_test async def test_qid_extra_in_discard(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x0.UNPACKER_CLS) connection = AsyncBolt4x0(address, socket, PoolConfig.max_connection_lifetime) connection.discard(qid=test_input) @@ -129,7 +129,7 @@ async def test_qid_extra_in_discard(fake_socket, test_input, expected): ) @mark_async_test async def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x0.UNPACKER_CLS) connection = AsyncBolt4x0(address, socket, PoolConfig.max_connection_lifetime) connection.discard(n=666, qid=test_input) @@ -149,7 +149,7 @@ async def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): ) @mark_async_test async def test_n_extra_in_pull(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x0.UNPACKER_CLS) connection = AsyncBolt4x0(address, socket, PoolConfig.max_connection_lifetime) connection.pull(n=test_input) @@ -169,7 +169,7 @@ async def test_n_extra_in_pull(fake_socket, test_input, expected): ) @mark_async_test async def test_qid_extra_in_pull(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x0.UNPACKER_CLS) connection = AsyncBolt4x0(address, socket, PoolConfig.max_connection_lifetime) connection.pull(qid=test_input) @@ -182,7 +182,7 @@ async def test_qid_extra_in_pull(fake_socket, test_input, expected): @mark_async_test async def test_n_and_qid_extras_in_pull(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x0.UNPACKER_CLS) connection = AsyncBolt4x0(address, socket, PoolConfig.max_connection_lifetime) connection.pull(n=666, qid=777) @@ -198,7 +198,7 @@ async def test_n_and_qid_extras_in_pull(fake_socket): async def test_hint_recv_timeout_seconds_gets_ignored( fake_socket_pair, recv_timeout, mocker ): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) sockets = fake_socket_pair(address, packer_cls=AsyncBolt4x0.PACKER_CLS, unpacker_cls=AsyncBolt4x0.UNPACKER_CLS) @@ -229,7 +229,7 @@ async def test_hint_recv_timeout_seconds_gets_ignored( async def test_credentials_are_not_logged( auth, fake_socket_pair, mocker, caplog ): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) sockets = fake_socket_pair(address, packer_cls=AsyncBolt4x0.PACKER_CLS, unpacker_cls=AsyncBolt4x0.UNPACKER_CLS) @@ -252,7 +252,7 @@ async def test_credentials_are_not_logged( @pytest.mark.parametrize("message", ("logon", "logoff")) def test_auth_message_raises_configuration_error(message, fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) connection = AsyncBolt4x0(address, fake_socket(address), PoolConfig.max_connection_lifetime) with pytest.raises(ConfigurationError, @@ -267,12 +267,12 @@ def test_auth_message_raises_configuration_error(message, fake_socket): )) @mark_async_test async def test_re_auth_noop(auth, fake_socket, mocker): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) connection = AsyncBolt4x0(address, fake_socket(address), PoolConfig.max_connection_lifetime, auth=auth) logon_spy = mocker.spy(connection, "logon") logoff_spy = mocker.spy(connection, "logoff") - res = connection.re_auth(auth) + res = connection.re_auth(auth, None) assert res is False logon_spy.assert_not_called() @@ -292,9 +292,9 @@ async def test_re_auth_noop(auth, fake_socket, mocker): ) @mark_async_test async def test_re_auth(auth1, auth2, fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) connection = AsyncBolt4x0(address, fake_socket(address), PoolConfig.max_connection_lifetime, auth=auth1) with pytest.raises(ConfigurationError, match="Session level authentication is not supported"): - connection.re_auth(auth2) + connection.re_auth(auth2, None) diff --git a/tests/unit/async_/io/test_class_bolt4x1.py b/tests/unit/async_/io/test_class_bolt4x1.py index 1509ddc7..e47c253c 100644 --- a/tests/unit/async_/io/test_class_bolt4x1.py +++ b/tests/unit/async_/io/test_class_bolt4x1.py @@ -31,7 +31,7 @@ @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_stale(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = 0 connection = AsyncBolt4x1(address, fake_socket(address), max_connection_lifetime) if set_stale: @@ -41,7 +41,7 @@ def test_conn_is_stale(fake_socket, set_stale): @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = -1 connection = AsyncBolt4x1(address, fake_socket(address), max_connection_lifetime) if set_stale: @@ -51,7 +51,7 @@ def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_not_stale(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = 999999999 connection = AsyncBolt4x1(address, fake_socket(address), max_connection_lifetime) if set_stale: @@ -61,7 +61,7 @@ def test_conn_is_not_stale(fake_socket, set_stale): @mark_async_test async def test_db_extra_in_begin(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x1.UNPACKER_CLS) connection = AsyncBolt4x1(address, socket, PoolConfig.max_connection_lifetime) connection.begin(db="something") @@ -74,7 +74,7 @@ async def test_db_extra_in_begin(fake_socket): @mark_async_test async def test_db_extra_in_run(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x1.UNPACKER_CLS) connection = AsyncBolt4x1(address, socket, PoolConfig.max_connection_lifetime) connection.run("", {}, db="something") @@ -89,7 +89,7 @@ async def test_db_extra_in_run(fake_socket): @mark_async_test async def test_n_extra_in_discard(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x1.UNPACKER_CLS) connection = AsyncBolt4x1(address, socket, PoolConfig.max_connection_lifetime) connection.discard(n=666) @@ -109,7 +109,7 @@ async def test_n_extra_in_discard(fake_socket): ) @mark_async_test async def test_qid_extra_in_discard(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x1.UNPACKER_CLS) connection = AsyncBolt4x1(address, socket, PoolConfig.max_connection_lifetime) connection.discard(qid=test_input) @@ -130,7 +130,7 @@ async def test_qid_extra_in_discard(fake_socket, test_input, expected): @mark_async_test async def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): # python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_n_and_qid_extras_in_discard - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x1.UNPACKER_CLS) connection = AsyncBolt4x1(address, socket, PoolConfig.max_connection_lifetime) connection.discard(n=666, qid=test_input) @@ -150,7 +150,7 @@ async def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): ) @mark_async_test async def test_n_extra_in_pull(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x1.UNPACKER_CLS) connection = AsyncBolt4x1(address, socket, PoolConfig.max_connection_lifetime) connection.pull(n=test_input) @@ -171,7 +171,7 @@ async def test_n_extra_in_pull(fake_socket, test_input, expected): @mark_async_test async def test_qid_extra_in_pull(fake_socket, test_input, expected): # python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_qid_extra_in_pull - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x1.UNPACKER_CLS) connection = AsyncBolt4x1(address, socket, PoolConfig.max_connection_lifetime) connection.pull(qid=test_input) @@ -184,7 +184,7 @@ async def test_qid_extra_in_pull(fake_socket, test_input, expected): @mark_async_test async def test_n_and_qid_extras_in_pull(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x1.UNPACKER_CLS) connection = AsyncBolt4x1(address, socket, PoolConfig.max_connection_lifetime) connection.pull(n=666, qid=777) @@ -197,7 +197,7 @@ async def test_n_and_qid_extras_in_pull(fake_socket): @mark_async_test async def test_hello_passes_routing_metadata(fake_socket_pair): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) sockets = fake_socket_pair(address, packer_cls=AsyncBolt4x1.PACKER_CLS, unpacker_cls=AsyncBolt4x1.UNPACKER_CLS) @@ -218,7 +218,7 @@ async def test_hello_passes_routing_metadata(fake_socket_pair): async def test_hint_recv_timeout_seconds_gets_ignored( fake_socket_pair, recv_timeout, mocker ): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) sockets = fake_socket_pair(address, packer_cls=AsyncBolt4x1.PACKER_CLS, unpacker_cls=AsyncBolt4x1.UNPACKER_CLS) @@ -248,7 +248,7 @@ async def test_hint_recv_timeout_seconds_gets_ignored( async def test_credentials_are_not_logged( auth, fake_socket_pair, mocker, caplog ): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) sockets = fake_socket_pair(address, packer_cls=AsyncBolt4x1.PACKER_CLS, unpacker_cls=AsyncBolt4x1.UNPACKER_CLS) @@ -271,7 +271,7 @@ async def test_credentials_are_not_logged( @pytest.mark.parametrize("message", ("logon", "logoff")) def test_auth_message_raises_configuration_error(message, fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) connection = AsyncBolt4x1(address, fake_socket(address), PoolConfig.max_connection_lifetime) with pytest.raises(ConfigurationError, @@ -286,12 +286,12 @@ def test_auth_message_raises_configuration_error(message, fake_socket): )) @mark_async_test async def test_re_auth_noop(auth, fake_socket, mocker): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) connection = AsyncBolt4x1(address, fake_socket(address), PoolConfig.max_connection_lifetime, auth=auth) logon_spy = mocker.spy(connection, "logon") logoff_spy = mocker.spy(connection, "logoff") - res = connection.re_auth(auth) + res = connection.re_auth(auth, None) assert res is False logon_spy.assert_not_called() @@ -311,9 +311,9 @@ async def test_re_auth_noop(auth, fake_socket, mocker): ) @mark_async_test async def test_re_auth(auth1, auth2, fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) connection = AsyncBolt4x1(address, fake_socket(address), PoolConfig.max_connection_lifetime, auth=auth1) with pytest.raises(ConfigurationError, match="Session level authentication is not supported"): - connection.re_auth(auth2) + connection.re_auth(auth2, None) diff --git a/tests/unit/async_/io/test_class_bolt4x2.py b/tests/unit/async_/io/test_class_bolt4x2.py index ad1dec8a..fc1a3d1e 100644 --- a/tests/unit/async_/io/test_class_bolt4x2.py +++ b/tests/unit/async_/io/test_class_bolt4x2.py @@ -31,7 +31,7 @@ @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_stale(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = 0 connection = AsyncBolt4x2(address, fake_socket(address), max_connection_lifetime) if set_stale: @@ -41,7 +41,7 @@ def test_conn_is_stale(fake_socket, set_stale): @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = -1 connection = AsyncBolt4x2(address, fake_socket(address), max_connection_lifetime) if set_stale: @@ -51,7 +51,7 @@ def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_not_stale(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = 999999999 connection = AsyncBolt4x2(address, fake_socket(address), max_connection_lifetime) if set_stale: @@ -61,7 +61,7 @@ def test_conn_is_not_stale(fake_socket, set_stale): @mark_async_test async def test_db_extra_in_begin(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x2.UNPACKER_CLS) connection = AsyncBolt4x2(address, socket, PoolConfig.max_connection_lifetime) connection.begin(db="something") @@ -74,7 +74,7 @@ async def test_db_extra_in_begin(fake_socket): @mark_async_test async def test_db_extra_in_run(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x2.UNPACKER_CLS) connection = AsyncBolt4x2(address, socket, PoolConfig.max_connection_lifetime) connection.run("", {}, db="something") @@ -89,7 +89,7 @@ async def test_db_extra_in_run(fake_socket): @mark_async_test async def test_n_extra_in_discard(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x2.UNPACKER_CLS) connection = AsyncBolt4x2(address, socket, PoolConfig.max_connection_lifetime) connection.discard(n=666) @@ -109,7 +109,7 @@ async def test_n_extra_in_discard(fake_socket): ) @mark_async_test async def test_qid_extra_in_discard(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x2.UNPACKER_CLS) connection = AsyncBolt4x2(address, socket, PoolConfig.max_connection_lifetime) connection.discard(qid=test_input) @@ -130,7 +130,7 @@ async def test_qid_extra_in_discard(fake_socket, test_input, expected): @mark_async_test async def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): # python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_n_and_qid_extras_in_discard - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x2.UNPACKER_CLS) connection = AsyncBolt4x2(address, socket, PoolConfig.max_connection_lifetime) connection.discard(n=666, qid=test_input) @@ -150,7 +150,7 @@ async def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): ) @mark_async_test async def test_n_extra_in_pull(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x2.UNPACKER_CLS) connection = AsyncBolt4x2(address, socket, PoolConfig.max_connection_lifetime) connection.pull(n=test_input) @@ -171,7 +171,7 @@ async def test_n_extra_in_pull(fake_socket, test_input, expected): @mark_async_test async def test_qid_extra_in_pull(fake_socket, test_input, expected): # python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_qid_extra_in_pull - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x2.UNPACKER_CLS) connection = AsyncBolt4x2(address, socket, PoolConfig.max_connection_lifetime) connection.pull(qid=test_input) @@ -184,7 +184,7 @@ async def test_qid_extra_in_pull(fake_socket, test_input, expected): @mark_async_test async def test_n_and_qid_extras_in_pull(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x2.UNPACKER_CLS) connection = AsyncBolt4x2(address, socket, PoolConfig.max_connection_lifetime) connection.pull(n=666, qid=777) @@ -197,7 +197,7 @@ async def test_n_and_qid_extras_in_pull(fake_socket): @mark_async_test async def test_hello_passes_routing_metadata(fake_socket_pair): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) sockets = fake_socket_pair(address, packer_cls=AsyncBolt4x2.PACKER_CLS, unpacker_cls=AsyncBolt4x2.UNPACKER_CLS) @@ -218,7 +218,7 @@ async def test_hello_passes_routing_metadata(fake_socket_pair): async def test_hint_recv_timeout_seconds_gets_ignored( fake_socket_pair, recv_timeout, mocker ): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) sockets = fake_socket_pair(address, packer_cls=AsyncBolt4x2.PACKER_CLS, unpacker_cls=AsyncBolt4x2.UNPACKER_CLS) @@ -249,7 +249,7 @@ async def test_hint_recv_timeout_seconds_gets_ignored( async def test_credentials_are_not_logged( auth, fake_socket_pair, mocker, caplog ): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) sockets = fake_socket_pair(address, packer_cls=AsyncBolt4x2.PACKER_CLS, unpacker_cls=AsyncBolt4x2.UNPACKER_CLS) @@ -272,7 +272,7 @@ async def test_credentials_are_not_logged( @pytest.mark.parametrize("message", ("logon", "logoff")) def test_auth_message_raises_configuration_error(message, fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) connection = AsyncBolt4x2(address, fake_socket(address), PoolConfig.max_connection_lifetime) with pytest.raises(ConfigurationError, @@ -287,12 +287,12 @@ def test_auth_message_raises_configuration_error(message, fake_socket): )) @mark_async_test async def test_re_auth_noop(auth, fake_socket, mocker): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) connection = AsyncBolt4x2(address, fake_socket(address), PoolConfig.max_connection_lifetime, auth=auth) logon_spy = mocker.spy(connection, "logon") logoff_spy = mocker.spy(connection, "logoff") - res = connection.re_auth(auth) + res = connection.re_auth(auth, None) assert res is False logon_spy.assert_not_called() @@ -312,9 +312,9 @@ async def test_re_auth_noop(auth, fake_socket, mocker): ) @mark_async_test async def test_re_auth(auth1, auth2, fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) connection = AsyncBolt4x2(address, fake_socket(address), PoolConfig.max_connection_lifetime, auth=auth1) with pytest.raises(ConfigurationError, match="Session level authentication is not supported"): - connection.re_auth(auth2) + connection.re_auth(auth2, None) diff --git a/tests/unit/async_/io/test_class_bolt4x3.py b/tests/unit/async_/io/test_class_bolt4x3.py index 436f2baf..82716329 100644 --- a/tests/unit/async_/io/test_class_bolt4x3.py +++ b/tests/unit/async_/io/test_class_bolt4x3.py @@ -31,7 +31,7 @@ @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_stale(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = 0 connection = AsyncBolt4x3(address, fake_socket(address), max_connection_lifetime) if set_stale: @@ -41,7 +41,7 @@ def test_conn_is_stale(fake_socket, set_stale): @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = -1 connection = AsyncBolt4x3(address, fake_socket(address), max_connection_lifetime) if set_stale: @@ -51,7 +51,7 @@ def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_not_stale(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = 999999999 connection = AsyncBolt4x3(address, fake_socket(address), max_connection_lifetime) if set_stale: @@ -61,7 +61,7 @@ def test_conn_is_not_stale(fake_socket, set_stale): @mark_async_test async def test_db_extra_in_begin(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x3.UNPACKER_CLS) connection = AsyncBolt4x3(address, socket, PoolConfig.max_connection_lifetime) connection.begin(db="something") @@ -74,7 +74,7 @@ async def test_db_extra_in_begin(fake_socket): @mark_async_test async def test_db_extra_in_run(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x3.UNPACKER_CLS) connection = AsyncBolt4x3(address, socket, PoolConfig.max_connection_lifetime) connection.run("", {}, db="something") @@ -89,7 +89,7 @@ async def test_db_extra_in_run(fake_socket): @mark_async_test async def test_n_extra_in_discard(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x3.UNPACKER_CLS) connection = AsyncBolt4x3(address, socket, PoolConfig.max_connection_lifetime) connection.discard(n=666) @@ -109,7 +109,7 @@ async def test_n_extra_in_discard(fake_socket): ) @mark_async_test async def test_qid_extra_in_discard(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x3.UNPACKER_CLS) connection = AsyncBolt4x3(address, socket, PoolConfig.max_connection_lifetime) connection.discard(qid=test_input) @@ -130,7 +130,7 @@ async def test_qid_extra_in_discard(fake_socket, test_input, expected): @mark_async_test async def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): # python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_n_and_qid_extras_in_discard - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x3.UNPACKER_CLS) connection = AsyncBolt4x3(address, socket, PoolConfig.max_connection_lifetime) connection.discard(n=666, qid=test_input) @@ -150,7 +150,7 @@ async def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): ) @mark_async_test async def test_n_extra_in_pull(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x3.UNPACKER_CLS) connection = AsyncBolt4x3(address, socket, PoolConfig.max_connection_lifetime) connection.pull(n=test_input) @@ -171,7 +171,7 @@ async def test_n_extra_in_pull(fake_socket, test_input, expected): @mark_async_test async def test_qid_extra_in_pull(fake_socket, test_input, expected): # python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_qid_extra_in_pull - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x3.UNPACKER_CLS) connection = AsyncBolt4x3(address, socket, PoolConfig.max_connection_lifetime) connection.pull(qid=test_input) @@ -184,7 +184,7 @@ async def test_qid_extra_in_pull(fake_socket, test_input, expected): @mark_async_test async def test_n_and_qid_extras_in_pull(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x3.UNPACKER_CLS) connection = AsyncBolt4x3(address, socket, PoolConfig.max_connection_lifetime) connection.pull(n=666, qid=777) @@ -197,7 +197,7 @@ async def test_n_and_qid_extras_in_pull(fake_socket): @mark_async_test async def test_hello_passes_routing_metadata(fake_socket_pair): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) sockets = fake_socket_pair(address, packer_cls=AsyncBolt4x3.PACKER_CLS, unpacker_cls=AsyncBolt4x3.UNPACKER_CLS) @@ -229,7 +229,7 @@ async def test_hello_passes_routing_metadata(fake_socket_pair): async def test_hint_recv_timeout_seconds( fake_socket_pair, hints, valid, caplog, mocker ): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) sockets = fake_socket_pair(address, packer_cls=AsyncBolt4x3.PACKER_CLS, unpacker_cls=AsyncBolt4x3.UNPACKER_CLS) @@ -275,7 +275,7 @@ async def test_hint_recv_timeout_seconds( async def test_credentials_are_not_logged( auth, fake_socket_pair, mocker, caplog ): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) sockets = fake_socket_pair(address, packer_cls=AsyncBolt4x3.PACKER_CLS, unpacker_cls=AsyncBolt4x3.UNPACKER_CLS) @@ -299,7 +299,7 @@ async def test_credentials_are_not_logged( @pytest.mark.parametrize("message", ("logon", "logoff")) def test_auth_message_raises_configuration_error(message, fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) connection = AsyncBolt4x3(address, fake_socket(address), PoolConfig.max_connection_lifetime) with pytest.raises(ConfigurationError, @@ -314,12 +314,12 @@ def test_auth_message_raises_configuration_error(message, fake_socket): )) @mark_async_test async def test_re_auth_noop(auth, fake_socket, mocker): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) connection = AsyncBolt4x3(address, fake_socket(address), PoolConfig.max_connection_lifetime, auth=auth) logon_spy = mocker.spy(connection, "logon") logoff_spy = mocker.spy(connection, "logoff") - res = connection.re_auth(auth) + res = connection.re_auth(auth, None) assert res is False logon_spy.assert_not_called() @@ -339,9 +339,9 @@ async def test_re_auth_noop(auth, fake_socket, mocker): ) @mark_async_test async def test_re_auth(auth1, auth2, fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) connection = AsyncBolt4x3(address, fake_socket(address), PoolConfig.max_connection_lifetime, auth=auth1) with pytest.raises(ConfigurationError, match="Session level authentication is not supported"): - connection.re_auth(auth2) + connection.re_auth(auth2, None) diff --git a/tests/unit/async_/io/test_class_bolt4x4.py b/tests/unit/async_/io/test_class_bolt4x4.py index a6b0eae3..4255fe2e 100644 --- a/tests/unit/async_/io/test_class_bolt4x4.py +++ b/tests/unit/async_/io/test_class_bolt4x4.py @@ -31,7 +31,7 @@ @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_stale(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = 0 connection = AsyncBolt4x4(address, fake_socket(address), max_connection_lifetime) if set_stale: @@ -41,7 +41,7 @@ def test_conn_is_stale(fake_socket, set_stale): @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = -1 connection = AsyncBolt4x4(address, fake_socket(address), max_connection_lifetime) if set_stale: @@ -51,7 +51,7 @@ def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_not_stale(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = 999999999 connection = AsyncBolt4x4(address, fake_socket(address), max_connection_lifetime) if set_stale: @@ -70,7 +70,7 @@ def test_conn_is_not_stale(fake_socket, set_stale): )) @mark_async_test async def test_extra_in_begin(fake_socket, args, kwargs, expected_fields): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x4.UNPACKER_CLS) connection = AsyncBolt4x4(address, socket, PoolConfig.max_connection_lifetime) connection.begin(*args, **kwargs) @@ -91,7 +91,7 @@ async def test_extra_in_begin(fake_socket, args, kwargs, expected_fields): )) @mark_async_test async def test_extra_in_run(fake_socket, args, kwargs, expected_fields): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x4.UNPACKER_CLS) connection = AsyncBolt4x4(address, socket, PoolConfig.max_connection_lifetime) connection.run(*args, **kwargs) @@ -103,7 +103,7 @@ async def test_extra_in_run(fake_socket, args, kwargs, expected_fields): @mark_async_test async def test_n_extra_in_discard(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x4.UNPACKER_CLS) connection = AsyncBolt4x4(address, socket, PoolConfig.max_connection_lifetime) connection.discard(n=666) @@ -123,7 +123,7 @@ async def test_n_extra_in_discard(fake_socket): ) @mark_async_test async def test_qid_extra_in_discard(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x4.UNPACKER_CLS) connection = AsyncBolt4x4(address, socket, PoolConfig.max_connection_lifetime) connection.discard(qid=test_input) @@ -144,7 +144,7 @@ async def test_qid_extra_in_discard(fake_socket, test_input, expected): @mark_async_test async def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): # python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_n_and_qid_extras_in_discard - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x4.UNPACKER_CLS) connection = AsyncBolt4x4(address, socket, PoolConfig.max_connection_lifetime) connection.discard(n=666, qid=test_input) @@ -164,7 +164,7 @@ async def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): ) @mark_async_test async def test_n_extra_in_pull(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x4.UNPACKER_CLS) connection = AsyncBolt4x4(address, socket, PoolConfig.max_connection_lifetime) connection.pull(n=test_input) @@ -185,7 +185,7 @@ async def test_n_extra_in_pull(fake_socket, test_input, expected): @mark_async_test async def test_qid_extra_in_pull(fake_socket, test_input, expected): # python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_qid_extra_in_pull - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x4.UNPACKER_CLS) connection = AsyncBolt4x4(address, socket, PoolConfig.max_connection_lifetime) connection.pull(qid=test_input) @@ -198,7 +198,7 @@ async def test_qid_extra_in_pull(fake_socket, test_input, expected): @mark_async_test async def test_n_and_qid_extras_in_pull(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt4x4.UNPACKER_CLS) connection = AsyncBolt4x4(address, socket, PoolConfig.max_connection_lifetime) connection.pull(n=666, qid=777) @@ -211,7 +211,7 @@ async def test_n_and_qid_extras_in_pull(fake_socket): @mark_async_test async def test_hello_passes_routing_metadata(fake_socket_pair): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) sockets = fake_socket_pair(address, packer_cls=AsyncBolt4x4.PACKER_CLS, unpacker_cls=AsyncBolt4x4.UNPACKER_CLS) @@ -243,7 +243,7 @@ async def test_hello_passes_routing_metadata(fake_socket_pair): async def test_hint_recv_timeout_seconds( fake_socket_pair, hints, valid, caplog, mocker ): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) sockets = fake_socket_pair(address, packer_cls=AsyncBolt4x4.PACKER_CLS, unpacker_cls=AsyncBolt4x4.UNPACKER_CLS) @@ -289,7 +289,7 @@ async def test_hint_recv_timeout_seconds( async def test_credentials_are_not_logged( auth, fake_socket_pair, mocker, caplog ): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) sockets = fake_socket_pair(address, packer_cls=AsyncBolt4x4.PACKER_CLS, unpacker_cls=AsyncBolt4x4.UNPACKER_CLS) @@ -312,7 +312,7 @@ async def test_credentials_are_not_logged( @pytest.mark.parametrize("message", ("logon", "logoff")) def test_auth_message_raises_configuration_error(message, fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) connection = AsyncBolt4x4(address, fake_socket(address), PoolConfig.max_connection_lifetime) with pytest.raises(ConfigurationError, @@ -327,12 +327,12 @@ def test_auth_message_raises_configuration_error(message, fake_socket): )) @mark_async_test async def test_re_auth_noop(auth, fake_socket, mocker): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) connection = AsyncBolt4x4(address, fake_socket(address), PoolConfig.max_connection_lifetime, auth=auth) logon_spy = mocker.spy(connection, "logon") logoff_spy = mocker.spy(connection, "logoff") - res = connection.re_auth(auth) + res = connection.re_auth(auth, None) assert res is False logon_spy.assert_not_called() @@ -352,9 +352,9 @@ async def test_re_auth_noop(auth, fake_socket, mocker): ) @mark_async_test async def test_re_auth(auth1, auth2, fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) connection = AsyncBolt4x4(address, fake_socket(address), PoolConfig.max_connection_lifetime, auth=auth1) with pytest.raises(ConfigurationError, match="Session level authentication is not supported"): - connection.re_auth(auth2) + connection.re_auth(auth2, None) diff --git a/tests/unit/async_/io/test_class_bolt5x0.py b/tests/unit/async_/io/test_class_bolt5x0.py index f1f549d5..1f9d9b8c 100644 --- a/tests/unit/async_/io/test_class_bolt5x0.py +++ b/tests/unit/async_/io/test_class_bolt5x0.py @@ -31,7 +31,7 @@ @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_stale(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = 0 connection = AsyncBolt5x0(address, fake_socket(address), max_connection_lifetime) if set_stale: @@ -41,7 +41,7 @@ def test_conn_is_stale(fake_socket, set_stale): @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = -1 connection = AsyncBolt5x0(address, fake_socket(address), max_connection_lifetime) if set_stale: @@ -51,7 +51,7 @@ def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_not_stale(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = 999999999 connection = AsyncBolt5x0(address, fake_socket(address), max_connection_lifetime) if set_stale: @@ -70,7 +70,7 @@ def test_conn_is_not_stale(fake_socket, set_stale): )) @mark_async_test async def test_extra_in_begin(fake_socket, args, kwargs, expected_fields): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt5x0.UNPACKER_CLS) connection = AsyncBolt5x0(address, socket, PoolConfig.max_connection_lifetime) connection.begin(*args, **kwargs) @@ -91,7 +91,7 @@ async def test_extra_in_begin(fake_socket, args, kwargs, expected_fields): )) @mark_async_test async def test_extra_in_run(fake_socket, args, kwargs, expected_fields): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt5x0.UNPACKER_CLS) connection = AsyncBolt5x0(address, socket, PoolConfig.max_connection_lifetime) connection.run(*args, **kwargs) @@ -103,7 +103,7 @@ async def test_extra_in_run(fake_socket, args, kwargs, expected_fields): @mark_async_test async def test_n_extra_in_discard(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt5x0.UNPACKER_CLS) connection = AsyncBolt5x0(address, socket, PoolConfig.max_connection_lifetime) connection.discard(n=666) @@ -123,7 +123,7 @@ async def test_n_extra_in_discard(fake_socket): ) @mark_async_test async def test_qid_extra_in_discard(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt5x0.UNPACKER_CLS) connection = AsyncBolt5x0(address, socket, PoolConfig.max_connection_lifetime) connection.discard(qid=test_input) @@ -143,7 +143,7 @@ async def test_qid_extra_in_discard(fake_socket, test_input, expected): ) @mark_async_test async def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt5x0.UNPACKER_CLS) connection = AsyncBolt5x0(address, socket, PoolConfig.max_connection_lifetime) connection.discard(n=666, qid=test_input) @@ -163,7 +163,7 @@ async def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): ) @mark_async_test async def test_n_extra_in_pull(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt5x0.UNPACKER_CLS) connection = AsyncBolt5x0(address, socket, PoolConfig.max_connection_lifetime) connection.pull(n=test_input) @@ -183,7 +183,7 @@ async def test_n_extra_in_pull(fake_socket, test_input, expected): ) @mark_async_test async def test_qid_extra_in_pull(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt5x0.UNPACKER_CLS) connection = AsyncBolt5x0(address, socket, PoolConfig.max_connection_lifetime) connection.pull(qid=test_input) @@ -196,7 +196,7 @@ async def test_qid_extra_in_pull(fake_socket, test_input, expected): @mark_async_test async def test_n_and_qid_extras_in_pull(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt5x0.UNPACKER_CLS) connection = AsyncBolt5x0(address, socket, PoolConfig.max_connection_lifetime) connection.pull(n=666, qid=777) @@ -209,7 +209,7 @@ async def test_n_and_qid_extras_in_pull(fake_socket): @mark_async_test async def test_hello_passes_routing_metadata(fake_socket_pair): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) sockets = fake_socket_pair(address, packer_cls=AsyncBolt5x0.PACKER_CLS, unpacker_cls=AsyncBolt5x0.UNPACKER_CLS) @@ -241,7 +241,7 @@ async def test_hello_passes_routing_metadata(fake_socket_pair): async def test_hint_recv_timeout_seconds( fake_socket_pair, hints, valid, caplog, mocker ): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) sockets = fake_socket_pair(address, packer_cls=AsyncBolt5x0.PACKER_CLS, unpacker_cls=AsyncBolt5x0.UNPACKER_CLS) @@ -287,7 +287,7 @@ async def test_hint_recv_timeout_seconds( async def test_credentials_are_not_logged( auth, fake_socket_pair, mocker, caplog ): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) sockets = fake_socket_pair(address, packer_cls=AsyncBolt5x0.PACKER_CLS, unpacker_cls=AsyncBolt5x0.UNPACKER_CLS) @@ -310,7 +310,7 @@ async def test_credentials_are_not_logged( @pytest.mark.parametrize("message", ("logon", "logoff")) def test_auth_message_raises_configuration_error(message, fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) connection = AsyncBolt5x0(address, fake_socket(address), PoolConfig.max_connection_lifetime) with pytest.raises(ConfigurationError, @@ -325,12 +325,12 @@ def test_auth_message_raises_configuration_error(message, fake_socket): )) @mark_async_test async def test_re_auth_noop(auth, fake_socket, mocker): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) connection = AsyncBolt5x0(address, fake_socket(address), PoolConfig.max_connection_lifetime, auth=auth) logon_spy = mocker.spy(connection, "logon") logoff_spy = mocker.spy(connection, "logoff") - res = connection.re_auth(auth) + res = connection.re_auth(auth, None) assert res is False logon_spy.assert_not_called() @@ -350,9 +350,9 @@ async def test_re_auth_noop(auth, fake_socket, mocker): ) @mark_async_test async def test_re_auth(auth1, auth2, fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) connection = AsyncBolt5x0(address, fake_socket(address), PoolConfig.max_connection_lifetime, auth=auth1) with pytest.raises(ConfigurationError, match="Session level authentication is not supported"): - connection.re_auth(auth2) + connection.re_auth(auth2, None) diff --git a/tests/unit/async_/io/test_class_bolt5x1.py b/tests/unit/async_/io/test_class_bolt5x1.py index 4ba9b47a..95debda2 100644 --- a/tests/unit/async_/io/test_class_bolt5x1.py +++ b/tests/unit/async_/io/test_class_bolt5x1.py @@ -24,13 +24,14 @@ import neo4j.exceptions from neo4j._async.io._bolt5 import AsyncBolt5x1 from neo4j._conf import PoolConfig +from neo4j.auth_management import AsyncAuthManagers from ...._async_compat import mark_async_test @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_stale(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = 0 connection = AsyncBolt5x1(address, fake_socket(address), max_connection_lifetime) @@ -41,7 +42,7 @@ def test_conn_is_stale(fake_socket, set_stale): @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = -1 connection = AsyncBolt5x1(address, fake_socket(address), max_connection_lifetime) @@ -52,7 +53,7 @@ def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_not_stale(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = 999999999 connection = AsyncBolt5x1(address, fake_socket(address), max_connection_lifetime) @@ -72,7 +73,7 @@ def test_conn_is_not_stale(fake_socket, set_stale): )) @mark_async_test async def test_extra_in_begin(fake_socket, args, kwargs, expected_fields): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt5x1.UNPACKER_CLS) connection = AsyncBolt5x1(address, socket, PoolConfig.max_connection_lifetime) @@ -94,7 +95,7 @@ async def test_extra_in_begin(fake_socket, args, kwargs, expected_fields): )) @mark_async_test async def test_extra_in_run(fake_socket, args, kwargs, expected_fields): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt5x1.UNPACKER_CLS) connection = AsyncBolt5x1(address, socket, PoolConfig.max_connection_lifetime) @@ -107,7 +108,7 @@ async def test_extra_in_run(fake_socket, args, kwargs, expected_fields): @mark_async_test async def test_n_extra_in_discard(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt5x1.UNPACKER_CLS) connection = AsyncBolt5x1(address, socket, PoolConfig.max_connection_lifetime) @@ -128,7 +129,7 @@ async def test_n_extra_in_discard(fake_socket): ) @mark_async_test async def test_qid_extra_in_discard(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt5x1.UNPACKER_CLS) connection = AsyncBolt5x1(address, socket, PoolConfig.max_connection_lifetime) @@ -149,7 +150,7 @@ async def test_qid_extra_in_discard(fake_socket, test_input, expected): ) @mark_async_test async def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt5x1.UNPACKER_CLS) connection = AsyncBolt5x1(address, socket, PoolConfig.max_connection_lifetime) @@ -170,7 +171,7 @@ async def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): ) @mark_async_test async def test_n_extra_in_pull(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt5x1.UNPACKER_CLS) connection = AsyncBolt5x1(address, socket, PoolConfig.max_connection_lifetime) @@ -191,7 +192,7 @@ async def test_n_extra_in_pull(fake_socket, test_input, expected): ) @mark_async_test async def test_qid_extra_in_pull(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt5x1.UNPACKER_CLS) connection = AsyncBolt5x1(address, socket, PoolConfig.max_connection_lifetime) @@ -205,7 +206,7 @@ async def test_qid_extra_in_pull(fake_socket, test_input, expected): @mark_async_test async def test_n_and_qid_extras_in_pull(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, AsyncBolt5x1.UNPACKER_CLS) connection = AsyncBolt5x1(address, socket, PoolConfig.max_connection_lifetime) @@ -219,7 +220,7 @@ async def test_n_and_qid_extras_in_pull(fake_socket): @mark_async_test async def test_hello_passes_routing_metadata(fake_socket_pair): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) sockets = fake_socket_pair(address, packer_cls=AsyncBolt5x1.PACKER_CLS, unpacker_cls=AsyncBolt5x1.UNPACKER_CLS) @@ -249,7 +250,7 @@ async def _assert_logon_message(sockets, auth): @mark_async_test async def test_hello_pipelines_logon(fake_socket_pair): auth = neo4j.Auth("basic", "alice123", "supersecret123") - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) sockets = fake_socket_pair(address, packer_cls=AsyncBolt5x1.PACKER_CLS, unpacker_cls=AsyncBolt5x1.UNPACKER_CLS) @@ -260,7 +261,7 @@ async def test_hello_pipelines_logon(fake_socket_pair): connection = AsyncBolt5x1( address, sockets.client, PoolConfig.max_connection_lifetime, auth=auth ) - with pytest.raises(neo4j.exceptions.ServiceUnavailable): + with pytest.raises(neo4j.exceptions.Neo4jError): await connection.hello() tag, fields = await sockets.server.pop_message() assert tag == b"\x01" # HELLO @@ -273,7 +274,7 @@ async def test_hello_pipelines_logon(fake_socket_pair): @mark_async_test async def test_logon(fake_socket_pair): auth = neo4j.Auth("basic", "alice123", "supersecret123") - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) sockets = fake_socket_pair(address, packer_cls=AsyncBolt5x1.PACKER_CLS, unpacker_cls=AsyncBolt5x1.UNPACKER_CLS) @@ -287,7 +288,8 @@ async def test_logon(fake_socket_pair): @mark_async_test async def test_re_auth(fake_socket_pair, mocker): auth = neo4j.Auth("basic", "alice123", "supersecret123") - address = ("127.0.0.1", 7687) + auth_manager = AsyncAuthManagers.static(auth) + address = neo4j.Address(("127.0.0.1", 7687)) sockets = fake_socket_pair(address, packer_cls=AsyncBolt5x1.PACKER_CLS, unpacker_cls=AsyncBolt5x1.UNPACKER_CLS) @@ -298,7 +300,7 @@ async def test_re_auth(fake_socket_pair, mocker): connection = AsyncBolt5x1(address, sockets.client, PoolConfig.max_connection_lifetime) connection.pool = mocker.AsyncMock() - connection.re_auth(auth) + connection.re_auth(auth, auth_manager) await connection.send_all() with pytest.raises(neo4j.exceptions.Neo4jError): await connection.fetch_all() @@ -306,11 +308,13 @@ async def test_re_auth(fake_socket_pair, mocker): assert tag == b"\x6B" # LOGOFF assert len(fields) == 0 await _assert_logon_message(sockets, auth) + assert connection.auth is auth + assert connection.auth_manager is auth_manager @mark_async_test async def test_logoff(fake_socket_pair): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) sockets = fake_socket_pair(address, packer_cls=AsyncBolt5x1.PACKER_CLS, unpacker_cls=AsyncBolt5x1.UNPACKER_CLS) @@ -342,7 +346,7 @@ async def test_logoff(fake_socket_pair): async def test_hint_recv_timeout_seconds( fake_socket_pair, hints, valid, caplog, mocker ): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) sockets = fake_socket_pair(address, packer_cls=AsyncBolt5x1.PACKER_CLS, unpacker_cls=AsyncBolt5x1.UNPACKER_CLS) @@ -387,7 +391,7 @@ async def test_hint_recv_timeout_seconds( )) @mark_async_test async def test_credentials_are_not_logged(auth, fake_socket_pair, caplog): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) sockets = fake_socket_pair(address, packer_cls=AsyncBolt5x1.PACKER_CLS, unpacker_cls=AsyncBolt5x1.UNPACKER_CLS) diff --git a/tests/unit/async_/io/test_direct.py b/tests/unit/async_/io/test_direct.py index 07ebb477..8397d995 100644 --- a/tests/unit/async_/io/test_direct.py +++ b/tests/unit/async_/io/test_direct.py @@ -26,6 +26,7 @@ WorkspaceConfig, ) from neo4j._deadline import Deadline +from neo4j.auth_management import AsyncAuthManagers from neo4j.exceptions import ( ClientError, ServiceUnavailable, @@ -59,13 +60,17 @@ def __init__(self, socket): def is_reset(self): return True + @property + def throwaway(self): + return False + def stale(self): return False async def reset(self): pass - def re_auth(self, auth, force=False): + def re_auth(self, auth, auth_manager, force=False): return False def close(self): @@ -82,8 +87,8 @@ def timedout(self): class AsyncFakeBoltPool(AsyncIOPool): - def __init__(self, address, *, auth=None, **config): + config["auth"] = AsyncAuthManagers.static(None) self.pool_config, self.workspace_config = Config.consume_chain(config, PoolConfig, WorkspaceConfig) if config: raise ValueError("Unexpected config keys: %s" % ", ".join(config.keys())) @@ -105,15 +110,18 @@ async def acquire( @mark_async_test async def test_bolt_connection_open(): + auth_manager = AsyncAuthManagers.static(("test", "test")) with pytest.raises(ServiceUnavailable): - await AsyncBolt.open(("localhost", 9999), auth=("test", "test")) + await AsyncBolt.open(("localhost", 9999), auth_manager=auth_manager) @mark_async_test async def test_bolt_connection_open_timeout(): + auth_manager = AsyncAuthManagers.static(("test", "test")) with pytest.raises(ServiceUnavailable): - await AsyncBolt.open(("localhost", 9999), auth=("test", "test"), - timeout=1) + await AsyncBolt.open( + ("localhost", 9999), auth_manager=auth_manager, timeout=1 + ) @mark_async_test diff --git a/tests/unit/async_/io/test_neo4j_pool.py b/tests/unit/async_/io/test_neo4j_pool.py index a417423f..98707139 100644 --- a/tests/unit/async_/io/test_neo4j_pool.py +++ b/tests/unit/async_/io/test_neo4j_pool.py @@ -33,6 +33,7 @@ ) from neo4j._deadline import Deadline from neo4j.addressing import ResolvedAddress +from neo4j.auth_management import AsyncAuthManagers from neo4j.exceptions import ( Neo4jError, ServiceUnavailable, @@ -71,7 +72,7 @@ def routing_side_effect(*args, **kwargs): async def open_(addr, auth, timeout): connection = async_fake_connection_generator() - connection.addr = addr + connection.unresolved_address = addr connection.timeout = timeout connection.auth = auth route_mock = mocker.AsyncMock() @@ -95,11 +96,21 @@ def opener(routing_failure_opener): return routing_failure_opener() +def _pool_config(): + pool_config = PoolConfig() + pool_config.auth = AsyncAuthManagers.static(("user", "pass")) + return pool_config + + +def _simple_pool(opener) -> AsyncNeo4jPool: + return AsyncNeo4jPool( + opener, _pool_config(), WorkspaceConfig(), ROUTER1_ADDRESS + ) + + @mark_async_test async def test_acquires_new_routing_table_if_deleted(opener): - pool = AsyncNeo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS - ) + pool = _simple_pool(opener) cx = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) await pool.release(cx) assert pool.routing_tables.get("test_db") @@ -113,9 +124,7 @@ async def test_acquires_new_routing_table_if_deleted(opener): @mark_async_test async def test_acquires_new_routing_table_if_stale(opener): - pool = AsyncNeo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS - ) + pool = _simple_pool(opener) cx = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) await pool.release(cx) assert pool.routing_tables.get("test_db") @@ -130,9 +139,7 @@ async def test_acquires_new_routing_table_if_stale(opener): @mark_async_test async def test_removes_old_routing_table(opener): - pool = AsyncNeo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS - ) + pool = _simple_pool(opener) cx = await pool.acquire(READ_ACCESS, 30, "test_db1", None, None, None) await pool.release(cx) assert pool.routing_tables.get("test_db1") @@ -154,25 +161,21 @@ async def test_removes_old_routing_table(opener): @pytest.mark.parametrize("type_", ("r", "w")) @mark_async_test async def test_chooses_right_connection_type(opener, type_): - pool = AsyncNeo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS - ) + pool = _simple_pool(opener) cx1 = await pool.acquire( READ_ACCESS if type_ == "r" else WRITE_ACCESS, 30, "test_db", None, None, None ) await pool.release(cx1) if type_ == "r": - assert cx1.addr == READER_ADDRESS + assert cx1.unresolved_address == READER_ADDRESS else: - assert cx1.addr == WRITER_ADDRESS + assert cx1.unresolved_address == WRITER_ADDRESS @mark_async_test async def test_reuses_connection(opener): - pool = AsyncNeo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS - ) + pool = _simple_pool(opener) cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) await pool.release(cx1) cx2 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) @@ -183,21 +186,19 @@ async def test_reuses_connection(opener): @mark_async_test async def test_closes_stale_connections(opener, break_on_close): async def break_connection(): - await pool.deactivate(cx1.addr) + await pool.deactivate(cx1.unresolved_address) if cx_close_mock_side_effect: res = cx_close_mock_side_effect() if inspect.isawaitable(res): return await res - pool = AsyncNeo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS - ) + pool = _simple_pool(opener) cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) await pool.release(cx1) - assert cx1 in pool.connections[cx1.addr] - # simulate connection going stale (e.g. exceeding) and then breaking when - # the pool tries to close the connection + assert cx1 in pool.connections[cx1.unresolved_address] + # simulate connection going stale (e.g. exceeding idle timeout) and then + # breaking when the pool tries to close the connection cx1.stale.return_value = True cx_close_mock = cx1.close if break_on_close: @@ -210,27 +211,26 @@ async def break_connection(): else: cx1.close.assert_called_once() assert cx2 is not cx1 - assert cx2.addr == cx1.addr - assert cx1 not in pool.connections[cx1.addr] - assert cx2 in pool.connections[cx2.addr] + assert cx2.unresolved_address == cx1.unresolved_address + assert cx1 not in pool.connections[cx1.unresolved_address] + assert cx2 in pool.connections[cx2.unresolved_address] @mark_async_test async def test_does_not_close_stale_connections_in_use(opener): - pool = AsyncNeo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS - ) + pool = _simple_pool(opener) cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) - assert cx1 in pool.connections[cx1.addr] - # simulate connection going stale (e.g. exceeding) while being in use + assert cx1 in pool.connections[cx1.unresolved_address] + # simulate connection going stale (e.g. exceeding idle timeout) while being + # in use cx1.stale.return_value = True cx2 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) await pool.release(cx2) cx1.close.assert_not_called() assert cx2 is not cx1 - assert cx2.addr == cx1.addr - assert cx1 in pool.connections[cx1.addr] - assert cx2 in pool.connections[cx2.addr] + assert cx2.unresolved_address == cx1.unresolved_address + assert cx1 in pool.connections[cx1.unresolved_address] + assert cx2 in pool.connections[cx2.unresolved_address] await pool.release(cx1) # now that cx1 is back in the pool and still stale, @@ -241,16 +241,14 @@ async def test_does_not_close_stale_connections_in_use(opener): await pool.release(cx3) cx1.close.assert_called_once() assert cx2 is cx3 - assert cx3.addr == cx1.addr - assert cx1 not in pool.connections[cx1.addr] - assert cx3 in pool.connections[cx2.addr] + assert cx3.unresolved_address == cx1.unresolved_address + assert cx1 not in pool.connections[cx1.unresolved_address] + assert cx3 in pool.connections[cx2.unresolved_address] @mark_async_test async def test_release_resets_connections(opener): - pool = AsyncNeo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS - ) + pool = _simple_pool(opener) cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) cx1.is_reset_mock.return_value = False cx1.is_reset_mock.reset_mock() @@ -261,9 +259,7 @@ async def test_release_resets_connections(opener): @mark_async_test async def test_release_does_not_resets_closed_connections(opener): - pool = AsyncNeo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS - ) + pool = _simple_pool(opener) cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) cx1.closed.return_value = True cx1.closed.reset_mock() @@ -276,9 +272,7 @@ async def test_release_does_not_resets_closed_connections(opener): @mark_async_test async def test_release_does_not_resets_defunct_connections(opener): - pool = AsyncNeo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS - ) + pool = _simple_pool(opener) cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) cx1.defunct.return_value = True cx1.defunct.reset_mock() @@ -294,12 +288,10 @@ async def test_release_does_not_resets_defunct_connections(opener): async def test_acquire_performs_no_liveness_check_on_fresh_connection( opener, liveness_timeout ): - pool = AsyncNeo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS - ) + pool = _simple_pool(opener) cx1 = await pool._acquire(READER_ADDRESS, None, Deadline(30), liveness_timeout) - assert cx1.addr == READER_ADDRESS + assert cx1.unresolved_address == READER_ADDRESS cx1.reset.assert_not_called() @@ -308,15 +300,13 @@ async def test_acquire_performs_no_liveness_check_on_fresh_connection( async def test_acquire_performs_liveness_check_on_existing_connection( opener, liveness_timeout ): - pool = AsyncNeo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS - ) + pool = _simple_pool(opener) # populate the pool with a connection cx1 = await pool._acquire(READER_ADDRESS, None, Deadline(30), liveness_timeout) # make sure we assume the right state - assert cx1.addr == READER_ADDRESS + assert cx1.unresolved_address == READER_ADDRESS cx1.is_idle_for.assert_not_called() cx1.reset.assert_not_called() @@ -344,15 +334,13 @@ def liveness_side_effect(*args, **kwargs): raise liveness_error("liveness check failed") liveness_timeout = 1 - pool = AsyncNeo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS - ) + pool = _simple_pool(opener) # populate the pool with a connection cx1 = await pool._acquire(READER_ADDRESS, None, Deadline(30), liveness_timeout) # make sure we assume the right state - assert cx1.addr == READER_ADDRESS + assert cx1.unresolved_address == READER_ADDRESS cx1.is_idle_for.assert_not_called() cx1.reset.assert_not_called() @@ -368,11 +356,11 @@ def liveness_side_effect(*args, **kwargs): cx2 = await pool._acquire(READER_ADDRESS, None, Deadline(30), liveness_timeout) assert cx1 is not cx2 - assert cx1.addr == cx2.addr + assert cx1.unresolved_address == cx2.unresolved_address cx1.is_idle_for.assert_called_once_with(liveness_timeout) cx2.reset.assert_not_called() - assert cx1 not in pool.connections[cx1.addr] - assert cx2 in pool.connections[cx1.addr] + assert cx1 not in pool.connections[cx1.unresolved_address] + assert cx2 in pool.connections[cx1.unresolved_address] @pytest.mark.parametrize("liveness_error", @@ -385,9 +373,7 @@ def liveness_side_effect(*args, **kwargs): raise liveness_error("liveness check failed") liveness_timeout = 1 - pool = AsyncNeo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS - ) + pool = _simple_pool(opener) # populate the pool with a connection cx1 = await pool._acquire(READER_ADDRESS, None, Deadline(30), liveness_timeout) @@ -395,8 +381,8 @@ def liveness_side_effect(*args, **kwargs): liveness_timeout) # make sure we assume the right state - assert cx1.addr == READER_ADDRESS - assert cx2.addr == READER_ADDRESS + assert cx1.unresolved_address == READER_ADDRESS + assert cx2.unresolved_address == READER_ADDRESS assert cx1 is not cx2 cx1.is_idle_for.assert_not_called() cx2.is_idle_for.assert_not_called() @@ -421,8 +407,8 @@ def liveness_side_effect(*args, **kwargs): cx1.reset.assert_awaited_once() cx3.is_idle_for.assert_called_once_with(liveness_timeout) cx3.reset.assert_awaited_once() - assert cx1 not in pool.connections[cx1.addr] - assert cx3 in pool.connections[cx1.addr] + assert cx1 not in pool.connections[cx1.unresolved_address] + assert cx3 in pool.connections[cx1.unresolved_address] @mark_async_test @@ -437,9 +423,7 @@ async def close_side_effect(): "close") # create pool with 2 idle connections - pool = AsyncNeo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS - ) + pool = _simple_pool(opener) cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) cx2 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) await pool.release(cx1) @@ -461,9 +445,7 @@ async def close_side_effect(): @mark_async_test async def test_failing_opener_leaves_connections_in_use_alone(opener): - pool = AsyncNeo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS - ) + pool = _simple_pool(opener) cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) opener.side_effect = ServiceUnavailable("Server overloaded") @@ -474,13 +456,13 @@ async def test_failing_opener_leaves_connections_in_use_alone(opener): @mark_async_test async def test__acquire_new_later_with_room(opener): - config = PoolConfig() + config = _pool_config() config.max_connection_pool_size = 1 pool = AsyncNeo4jPool( opener, config, WorkspaceConfig(), ROUTER1_ADDRESS ) assert pool.connections_reservations[READER_ADDRESS] == 0 - creator = pool._acquire_new_later(READER_ADDRESS, None, Deadline(1)) + creator = pool._acquire_new_later(READER_ADDRESS, None, Deadline(1), False) assert pool.connections_reservations[READER_ADDRESS] == 1 assert callable(creator) if AsyncUtil.is_async_code: @@ -489,7 +471,7 @@ async def test__acquire_new_later_with_room(opener): @mark_async_test async def test__acquire_new_later_without_room(opener): - config = PoolConfig() + config = _pool_config() config.max_connection_pool_size = 1 pool = AsyncNeo4jPool( opener, config, WorkspaceConfig(), ROUTER1_ADDRESS @@ -497,7 +479,7 @@ async def test__acquire_new_later_without_room(opener): _ = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) # pool is full now assert pool.connections_reservations[READER_ADDRESS] == 0 - creator = pool._acquire_new_later(READER_ADDRESS, None, Deadline(1)) + creator = pool._acquire_new_later(READER_ADDRESS, None, Deadline(1), False) assert pool.connections_reservations[READER_ADDRESS] == 0 assert creator is None @@ -515,7 +497,7 @@ async def test_discovery_is_retried(routing_failure_opener, error): error, # will be retried ]) pool = AsyncNeo4jPool( - opener, PoolConfig(), WorkspaceConfig(), + opener, _pool_config(), WorkspaceConfig(), ResolvedAddress(("1.2.3.1", 9999), host_name="host") ) cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) @@ -559,7 +541,7 @@ async def test_fast_failing_discovery(routing_failure_opener, error): error, # will be retried ]) pool = AsyncNeo4jPool( - opener, PoolConfig(), WorkspaceConfig(), + opener, _pool_config(), WorkspaceConfig(), ResolvedAddress(("1.2.3.1", 9999), host_name="host") ) cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) @@ -601,11 +583,13 @@ async def test_fast_failing_discovery(routing_failure_opener, error): async def test_connection_error_callback( opener, error, marks_unauthenticated, fetches_new, mocker ): + config = _pool_config() + auth_manager = AsyncAuthManagers.static(("user", "auth")) + on_auth_expired_mock = mocker.patch.object(auth_manager, "on_auth_expired", + autospec=True) + config.auth = auth_manager pool = AsyncNeo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS - ) - force_new_auth_mock = mocker.patch.object( - pool, "force_new_auth", autospec=True + opener, config, WorkspaceConfig(), ROUTER1_ADDRESS ) cxs_read = [ await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) @@ -616,16 +600,18 @@ async def test_connection_error_callback( for _ in range(5) ] - force_new_auth_mock.assert_not_called() + on_auth_expired_mock.assert_not_called() for cx in cxs_read + cxs_write: cx.mark_unauthenticated.assert_not_called() - await pool.on_neo4j_error(error, cxs_read[0].addr) + await pool.on_neo4j_error(error, cxs_read[0]) if fetches_new: - force_new_auth_mock.assert_awaited_once() + cxs_read[0].auth_manager.on_auth_expired.assert_awaited_once() else: - force_new_auth_mock.assert_not_called() + on_auth_expired_mock.assert_not_called() + for cx in cxs_read: + cx.auth_manager.on_auth_expired.assert_not_called() for cx in cxs_read: if marks_unauthenticated: diff --git a/tests/unit/async_/test_addressing.py b/tests/unit/async_/test_addressing.py index ca5eb0f0..54113cb2 100644 --- a/tests/unit/async_/test_addressing.py +++ b/tests/unit/async_/test_addressing.py @@ -16,10 +16,7 @@ # limitations under the License. -from socket import ( - AF_INET, - AF_INET6, -) +from socket import AF_INET import pytest diff --git a/tests/unit/async_/test_auth_manager.py b/tests/unit/async_/test_auth_manager.py new file mode 100644 index 00000000..f201cbeb --- /dev/null +++ b/tests/unit/async_/test_auth_manager.py @@ -0,0 +1,116 @@ +import itertools +import typing as t + +import pytest +from freezegun import freeze_time +from freezegun.api import FrozenDateTimeFactory + +from neo4j import ( + Auth, + basic_auth, +) +from neo4j.auth_management import ( + AsyncAuthManager, + AsyncAuthManagers, + TemporalAuth, +) + +from ..._async_compat import mark_async_test + + +SAMPLE_AUTHS = ( + None, + ("user", "password"), + basic_auth("foo", "bar"), + basic_auth("foo", "bar", "baz"), + Auth("scheme", "principal", "credentials", "realm", para="meter"), +) + + +@mark_async_test +@pytest.mark.parametrize("auth", SAMPLE_AUTHS) +async def test_static_manager( + auth +) -> None: + manager: AsyncAuthManager = AsyncAuthManagers.static(auth) + assert await manager.get_auth() is auth + + await manager.on_auth_expired(("something", "else")) + assert await manager.get_auth() is auth + + await manager.on_auth_expired(auth) + assert await manager.get_auth() is auth + + +@mark_async_test +@pytest.mark.parametrize(("auth1", "auth2"), + itertools.product(SAMPLE_AUTHS, repeat=2)) +@pytest.mark.parametrize("expires_in", (None, -1, 1., 1, 1000.)) +async def test_temporal_manager_manual_expiry( + auth1: t.Union[t.Tuple[str, str], Auth, None], + auth2: t.Union[t.Tuple[str, str], Auth, None], + expires_in: t.Union[float, int], + mocker +) -> None: + if expires_in is None or expires_in >= 0: + temporal_auth = TemporalAuth(auth1, expires_in) + else: + temporal_auth = TemporalAuth(auth1) + provider = mocker.AsyncMock(return_value=temporal_auth) + manager: AsyncAuthManager = AsyncAuthManagers.temporal(provider) + + provider.assert_not_called() + assert await manager.get_auth() is auth1 + provider.assert_awaited_once() + provider.reset_mock() + + provider.return_value = TemporalAuth(auth2) + + await manager.on_auth_expired(("something", "else")) + assert await manager.get_auth() is auth1 + provider.assert_not_called() + + await manager.on_auth_expired(auth1) + provider.assert_awaited_once() + provider.reset_mock() + assert await manager.get_auth() is auth2 + provider.assert_not_called() + + +@mark_async_test +@pytest.mark.parametrize(("auth1", "auth2"), + itertools.product(SAMPLE_AUTHS, repeat=2)) +@pytest.mark.parametrize("expires_in", (None, -1, 1., 1, 1000.)) +async def test_temporal_manager_time_expiry( + auth1: t.Union[t.Tuple[str, str], Auth, None], + auth2: t.Union[t.Tuple[str, str], Auth, None], + expires_in: t.Union[float, int, None], + mocker +) -> None: + with freeze_time() as frozen_time: + assert isinstance(frozen_time, FrozenDateTimeFactory) + if expires_in is None or expires_in >= 0: + temporal_auth = TemporalAuth(auth1, expires_in) + else: + temporal_auth = TemporalAuth(auth1) + provider = mocker.AsyncMock(return_value=temporal_auth) + manager: AsyncAuthManager = AsyncAuthManagers.temporal(provider) + + provider.assert_not_called() + assert await manager.get_auth() is auth1 + provider.assert_awaited_once() + provider.reset_mock() + + provider.return_value = TemporalAuth(auth2) + + if expires_in is None or expires_in < 0: + frozen_time.tick(1_000_000) + assert await manager.get_auth() is auth1 + provider.assert_not_called() + else: + frozen_time.tick(expires_in - 0.000001) + assert await manager.get_auth() is auth1 + provider.assert_not_called() + frozen_time.tick(0.000002) + assert await manager.get_auth() is auth2 + provider.assert_awaited_once() diff --git a/tests/unit/common/test_conf.py b/tests/unit/common/test_conf.py index 88e67197..41f0647a 100644 --- a/tests/unit/common/test_conf.py +++ b/tests/unit/common/test_conf.py @@ -72,6 +72,7 @@ "fetch_size": 100, "bookmark_manager": object(), "auth": None, + "backwards_compatible_auth": False, } diff --git a/tests/unit/mixed/io/test_direct.py b/tests/unit/mixed/io/test_direct.py index 3446e87a..ec1e9bb3 100644 --- a/tests/unit/mixed/io/test_direct.py +++ b/tests/unit/mixed/io/test_direct.py @@ -26,7 +26,9 @@ import pytest +from neo4j._async.io._pool import AcquireAuth as AsyncAcquireAuth from neo4j._deadline import Deadline +from neo4j._sync.io._pool import AcquireAuth from ...async_.io.test_direct import AsyncFakeBoltPool from ...sync.io.test_direct import FakeBoltPool @@ -59,7 +61,7 @@ def test_multithread(self, pre_populated): def acquire_release_conn(pool_, address_, acquired_counter_, release_event_): nonlocal connections, connections_lock - conn_ = pool_._acquire(address_, None, Deadline(3), None, False) + conn_ = pool_._acquire(address_, None, Deadline(3), None) with connections_lock: if connections is not None: connections.append(conn_) @@ -74,7 +76,7 @@ def acquire_release_conn(pool_, address_, acquired_counter_, # pre-populate the pool with connections for _ in range(pre_populated): - conn = pool._acquire(address, None, Deadline(3), None, False) + conn = pool._acquire(address, None, Deadline(3), None) pre_populated_connections.append(conn) for conn in pre_populated_connections: pool.release(conn) @@ -120,8 +122,7 @@ async def test_multi_coroutine(self, pre_populated): async def acquire_release_conn(pool_, address_, acquired_counter_, release_event_): nonlocal connections - conn_ = await pool_._acquire(address_, None, Deadline(3), None, - False) + conn_ = await pool_._acquire(address_, None, Deadline(3), None) if connections is not None: connections.append(conn_) await acquired_counter_.increment() @@ -155,8 +156,7 @@ async def waiter(pool_, acquired_counter_, release_event_): # pre-populate the pool with connections for _ in range(pre_populated): - conn = await pool._acquire(address, None, Deadline(3), None, - False) + conn = await pool._acquire(address, None, Deadline(3), None) pre_populated_connections.append(conn) for conn in pre_populated_connections: await pool.release(conn) diff --git a/tests/unit/mixed/io/test_pool_async.py b/tests/unit/mixed/io/test_pool_async.py deleted file mode 100644 index 58f17616..00000000 --- a/tests/unit/mixed/io/test_pool_async.py +++ /dev/null @@ -1,72 +0,0 @@ -# Copyright (c) "Neo4j" -# Neo4j Sweden AB [https://neo4j.com] -# -# This file is part of Neo4j. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import asyncio -from asyncio import Condition - -from ...async_.fixtures import * # fixtures necessary for pytest -from ...async_.io.test_neo4j_pool import * -from ._common import AsyncMultiEvent - - -@pytest.mark.asyncio -async def test_force_new_auth_blocks(opener): - count = 0 - done = False - condition = Condition() - event = AsyncMultiEvent() - - async def auth_provider(): - nonlocal done, count - count += 1 - if count == 1: - return "user1", "pass1" - await event.increment() - async with condition: - await event.wait(2) - await condition.wait() - await asyncio.sleep(0.1) # block - done = True - return "user", "password" - - config = PoolConfig() - config.auth = auth_provider - pool = AsyncNeo4jPool( - opener, config, WorkspaceConfig(), ROUTER1_ADDRESS - ) - - assert count == 0 - cx = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) - await pool.release(cx) - assert count == 1 - - async def task1(): - assert count == 1 - await pool.force_new_auth() - assert count == 2 - - async def task2(): - await event.increment() - await event.wait(2) - async with condition: - condition.notify() - cx = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) - assert done # assert waited for blocking auth provider - await pool.release(cx) - - await asyncio.gather(task1(), task2()) diff --git a/tests/unit/mixed/io/test_pool_sync.py b/tests/unit/mixed/io/test_pool_sync.py deleted file mode 100644 index 58d98d8c..00000000 --- a/tests/unit/mixed/io/test_pool_sync.py +++ /dev/null @@ -1,79 +0,0 @@ -# Copyright (c) "Neo4j" -# Neo4j Sweden AB [https://neo4j.com] -# -# This file is part of Neo4j. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from threading import ( - Condition, - Thread, -) -from time import sleep - -from ...sync.fixtures import * # fixtures necessary for pytest -from ...sync.io.test_neo4j_pool import * -from ._common import MultiEvent - - -def test_force_new_auth_blocks(opener): - count = 0 - done = False - condition = Condition() - event = MultiEvent() - - def auth_provider(): - nonlocal done, condition, count - count += 1 - if count == 1: - return "user1", "pass1" - event.wait(1) - with condition: - event.increment() - condition.wait() - sleep(0.1) # block - done = True - return "user", "password" - - config = PoolConfig() - config.auth = auth_provider - pool = Neo4jPool( - opener, config, WorkspaceConfig(), ROUTER1_ADDRESS - ) - - assert count == 0 - cx = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) - pool.release(cx) - assert count == 1 - - def task1(): - assert count == 1 - pool.force_new_auth() - assert count == 2 - - def task2(): - event.increment() - event.wait(2) - with condition: - condition.notify() - cx = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) - assert done # assert waited for blocking auth provider - pool.release(cx) - - t1 = Thread(target=task1) - t2 = Thread(target=task2) - t1.start() - t2.start() - t1.join() - t2.join() diff --git a/tests/unit/sync/fixtures/fake_connection.py b/tests/unit/sync/fixtures/fake_connection.py index fdf5ad90..5d66c9d9 100644 --- a/tests/unit/sync/fixtures/fake_connection.py +++ b/tests/unit/sync/fixtures/fake_connection.py @@ -23,6 +23,7 @@ from neo4j import ServerInfo from neo4j._deadline import Deadline from neo4j._sync.io import Bolt +from neo4j.auth_management import AuthManager __all__ = [ @@ -51,7 +52,10 @@ def __init__(self, *args, **kwargs): self.attach_mock(mock.Mock(return_value=False), "closed") self.attach_mock(mock.Mock(return_value=False), "socket") self.attach_mock(mock.Mock(return_value=False), "re_auth") - self.attach_mock(mock.Mock(), "unresolved_address") + self.attach_mock(mock.Mock(spec=AuthManager), + "auth_manager") + self.unresolved_address = next(iter(args), "localhost") + self.throwaway = False def close_side_effect(): self.closed.return_value = True diff --git a/tests/unit/sync/io/conftest.py b/tests/unit/sync/io/conftest.py index a06cf3c3..beba1e27 100644 --- a/tests/unit/sync/io/conftest.py +++ b/tests/unit/sync/io/conftest.py @@ -103,6 +103,9 @@ def sendall(self, data): def close(self): return + def kill(self): + return + def inject(self, data): self.recv_buffer += data diff --git a/tests/unit/sync/io/test_class_bolt.py b/tests/unit/sync/io/test_class_bolt.py index ed96bef5..86b963dc 100644 --- a/tests/unit/sync/io/test_class_bolt.py +++ b/tests/unit/sync/io/test_class_bolt.py @@ -20,6 +20,7 @@ import pytest +import neo4j.auth_management from neo4j._async_compat.network import BoltSocket from neo4j._sync.io import Bolt @@ -102,6 +103,66 @@ def test_cancel_hello_in_open(mocker): bolt_mock.local_port = 1234 with pytest.raises(asyncio.CancelledError): - Bolt.open(address) + Bolt.open( + address, + auth_manager=neo4j.auth_management.AuthManagers.static(None) + ) bolt_mock.kill.assert_called_once_with() + + +@TestDecorators.mark_async_only_test +def test_cancel_manager_in_open(mocker): + address = ("localhost", 7687) + socket_mock = mocker.Mock(spec=BoltSocket) + + socket_cls_mock = mocker.patch("neo4j._sync.io._bolt.BoltSocket", + autospec=True) + socket_cls_mock.connect.return_value = ( + socket_mock, (5, 0), None, None + ) + socket_mock.getpeername.return_value = address + bolt_cls_mock = mocker.patch("neo4j._sync.io._bolt5.Bolt5x0", + autospec=True) + bolt_mock = bolt_cls_mock.return_value + bolt_mock.socket = socket_mock + bolt_mock.local_port = 1234 + + auth_manager = mocker.Mock( + spec=neo4j.auth_management.AuthManager + ) + auth_manager.get_auth.side_effect = asyncio.CancelledError() + + with pytest.raises(asyncio.CancelledError): + Bolt.open(address, auth_manager=auth_manager) + + socket_mock.kill.assert_called_once_with() + + +@TestDecorators.mark_async_only_test +def test_fail_manager_in_open(mocker): + address = ("localhost", 7687) + socket_mock = mocker.Mock(spec=BoltSocket) + + socket_cls_mock = mocker.patch("neo4j._sync.io._bolt.BoltSocket", + autospec=True) + socket_cls_mock.connect.return_value = ( + socket_mock, (5, 0), None, None + ) + socket_mock.getpeername.return_value = address + bolt_cls_mock = mocker.patch("neo4j._sync.io._bolt5.Bolt5x0", + autospec=True) + bolt_mock = bolt_cls_mock.return_value + bolt_mock.socket = socket_mock + bolt_mock.local_port = 1234 + + auth_manager = mocker.Mock( + spec=neo4j.auth_management.AuthManager + ) + auth_manager.get_auth.side_effect = RuntimeError("token fetching failed") + + with pytest.raises(RuntimeError) as exc: + Bolt.open(address, auth_manager=auth_manager) + assert exc.value is auth_manager.get_auth.side_effect + + socket_mock.close.assert_called_once_with() diff --git a/tests/unit/sync/io/test_class_bolt3.py b/tests/unit/sync/io/test_class_bolt3.py index 0c39167d..f445b84e 100644 --- a/tests/unit/sync/io/test_class_bolt3.py +++ b/tests/unit/sync/io/test_class_bolt3.py @@ -31,7 +31,7 @@ @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_stale(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = 0 connection = Bolt3(address, fake_socket(address), max_connection_lifetime) if set_stale: @@ -41,7 +41,7 @@ def test_conn_is_stale(fake_socket, set_stale): @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = -1 connection = Bolt3(address, fake_socket(address), max_connection_lifetime) if set_stale: @@ -51,7 +51,7 @@ def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_not_stale(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = 999999999 connection = Bolt3(address, fake_socket(address), max_connection_lifetime) if set_stale: @@ -60,14 +60,14 @@ def test_conn_is_not_stale(fake_socket, set_stale): def test_db_extra_not_supported_in_begin(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) connection = Bolt3(address, fake_socket(address), PoolConfig.max_connection_lifetime) with pytest.raises(ConfigurationError): connection.begin(db="something") def test_db_extra_not_supported_in_run(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) connection = Bolt3(address, fake_socket(address), PoolConfig.max_connection_lifetime) with pytest.raises(ConfigurationError): connection.run("", db="something") @@ -75,7 +75,7 @@ def test_db_extra_not_supported_in_run(fake_socket): @mark_sync_test def test_simple_discard(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt3.UNPACKER_CLS) connection = Bolt3(address, socket, PoolConfig.max_connection_lifetime) connection.discard() @@ -87,7 +87,7 @@ def test_simple_discard(fake_socket): @mark_sync_test def test_simple_pull(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt3.UNPACKER_CLS) connection = Bolt3(address, socket, PoolConfig.max_connection_lifetime) connection.pull() @@ -102,7 +102,7 @@ def test_simple_pull(fake_socket): def test_hint_recv_timeout_seconds_gets_ignored( fake_socket_pair, recv_timeout, mocker ): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) sockets = fake_socket_pair( address, Bolt3.PACKER_CLS, Bolt3.UNPACKER_CLS ) @@ -133,7 +133,7 @@ def test_hint_recv_timeout_seconds_gets_ignored( def test_credentials_are_not_logged( auth, fake_socket_pair, mocker, caplog ): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) sockets = fake_socket_pair(address, packer_cls=Bolt3.PACKER_CLS, unpacker_cls=Bolt3.UNPACKER_CLS) @@ -156,7 +156,7 @@ def test_credentials_are_not_logged( @pytest.mark.parametrize("message", ("logon", "logoff")) def test_auth_message_raises_configuration_error(message, fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) connection = Bolt3(address, fake_socket(address), PoolConfig.max_connection_lifetime) with pytest.raises(ConfigurationError, @@ -171,12 +171,12 @@ def test_auth_message_raises_configuration_error(message, fake_socket): )) @mark_sync_test def test_re_auth_noop(auth, fake_socket, mocker): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) connection = Bolt3(address, fake_socket(address), PoolConfig.max_connection_lifetime, auth=auth) logon_spy = mocker.spy(connection, "logon") logoff_spy = mocker.spy(connection, "logoff") - res = connection.re_auth(auth) + res = connection.re_auth(auth, None) assert res is False logon_spy.assert_not_called() @@ -196,9 +196,9 @@ def test_re_auth_noop(auth, fake_socket, mocker): ) @mark_sync_test def test_re_auth(auth1, auth2, fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) connection = Bolt3(address, fake_socket(address), PoolConfig.max_connection_lifetime, auth=auth1) with pytest.raises(ConfigurationError, match="Session level authentication is not supported"): - connection.re_auth(auth2) + connection.re_auth(auth2, None) diff --git a/tests/unit/sync/io/test_class_bolt4x0.py b/tests/unit/sync/io/test_class_bolt4x0.py index ff3a595c..743fd05d 100644 --- a/tests/unit/sync/io/test_class_bolt4x0.py +++ b/tests/unit/sync/io/test_class_bolt4x0.py @@ -31,7 +31,7 @@ @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_stale(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = 0 connection = Bolt4x0(address, fake_socket(address), max_connection_lifetime) if set_stale: @@ -41,7 +41,7 @@ def test_conn_is_stale(fake_socket, set_stale): @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = -1 connection = Bolt4x0(address, fake_socket(address), max_connection_lifetime) if set_stale: @@ -51,7 +51,7 @@ def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_not_stale(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = 999999999 connection = Bolt4x0(address, fake_socket(address), max_connection_lifetime) if set_stale: @@ -61,7 +61,7 @@ def test_conn_is_not_stale(fake_socket, set_stale): @mark_sync_test def test_db_extra_in_begin(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt4x0.UNPACKER_CLS) connection = Bolt4x0(address, socket, PoolConfig.max_connection_lifetime) connection.begin(db="something") @@ -74,7 +74,7 @@ def test_db_extra_in_begin(fake_socket): @mark_sync_test def test_db_extra_in_run(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt4x0.UNPACKER_CLS) connection = Bolt4x0(address, socket, PoolConfig.max_connection_lifetime) connection.run("", {}, db="something") @@ -89,7 +89,7 @@ def test_db_extra_in_run(fake_socket): @mark_sync_test def test_n_extra_in_discard(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt4x0.UNPACKER_CLS) connection = Bolt4x0(address, socket, PoolConfig.max_connection_lifetime) connection.discard(n=666) @@ -109,7 +109,7 @@ def test_n_extra_in_discard(fake_socket): ) @mark_sync_test def test_qid_extra_in_discard(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt4x0.UNPACKER_CLS) connection = Bolt4x0(address, socket, PoolConfig.max_connection_lifetime) connection.discard(qid=test_input) @@ -129,7 +129,7 @@ def test_qid_extra_in_discard(fake_socket, test_input, expected): ) @mark_sync_test def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt4x0.UNPACKER_CLS) connection = Bolt4x0(address, socket, PoolConfig.max_connection_lifetime) connection.discard(n=666, qid=test_input) @@ -149,7 +149,7 @@ def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): ) @mark_sync_test def test_n_extra_in_pull(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt4x0.UNPACKER_CLS) connection = Bolt4x0(address, socket, PoolConfig.max_connection_lifetime) connection.pull(n=test_input) @@ -169,7 +169,7 @@ def test_n_extra_in_pull(fake_socket, test_input, expected): ) @mark_sync_test def test_qid_extra_in_pull(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt4x0.UNPACKER_CLS) connection = Bolt4x0(address, socket, PoolConfig.max_connection_lifetime) connection.pull(qid=test_input) @@ -182,7 +182,7 @@ def test_qid_extra_in_pull(fake_socket, test_input, expected): @mark_sync_test def test_n_and_qid_extras_in_pull(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt4x0.UNPACKER_CLS) connection = Bolt4x0(address, socket, PoolConfig.max_connection_lifetime) connection.pull(n=666, qid=777) @@ -198,7 +198,7 @@ def test_n_and_qid_extras_in_pull(fake_socket): def test_hint_recv_timeout_seconds_gets_ignored( fake_socket_pair, recv_timeout, mocker ): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) sockets = fake_socket_pair(address, packer_cls=Bolt4x0.PACKER_CLS, unpacker_cls=Bolt4x0.UNPACKER_CLS) @@ -229,7 +229,7 @@ def test_hint_recv_timeout_seconds_gets_ignored( def test_credentials_are_not_logged( auth, fake_socket_pair, mocker, caplog ): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) sockets = fake_socket_pair(address, packer_cls=Bolt4x0.PACKER_CLS, unpacker_cls=Bolt4x0.UNPACKER_CLS) @@ -252,7 +252,7 @@ def test_credentials_are_not_logged( @pytest.mark.parametrize("message", ("logon", "logoff")) def test_auth_message_raises_configuration_error(message, fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) connection = Bolt4x0(address, fake_socket(address), PoolConfig.max_connection_lifetime) with pytest.raises(ConfigurationError, @@ -267,12 +267,12 @@ def test_auth_message_raises_configuration_error(message, fake_socket): )) @mark_sync_test def test_re_auth_noop(auth, fake_socket, mocker): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) connection = Bolt4x0(address, fake_socket(address), PoolConfig.max_connection_lifetime, auth=auth) logon_spy = mocker.spy(connection, "logon") logoff_spy = mocker.spy(connection, "logoff") - res = connection.re_auth(auth) + res = connection.re_auth(auth, None) assert res is False logon_spy.assert_not_called() @@ -292,9 +292,9 @@ def test_re_auth_noop(auth, fake_socket, mocker): ) @mark_sync_test def test_re_auth(auth1, auth2, fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) connection = Bolt4x0(address, fake_socket(address), PoolConfig.max_connection_lifetime, auth=auth1) with pytest.raises(ConfigurationError, match="Session level authentication is not supported"): - connection.re_auth(auth2) + connection.re_auth(auth2, None) diff --git a/tests/unit/sync/io/test_class_bolt4x1.py b/tests/unit/sync/io/test_class_bolt4x1.py index 81b25336..b72e2fe2 100644 --- a/tests/unit/sync/io/test_class_bolt4x1.py +++ b/tests/unit/sync/io/test_class_bolt4x1.py @@ -31,7 +31,7 @@ @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_stale(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = 0 connection = Bolt4x1(address, fake_socket(address), max_connection_lifetime) if set_stale: @@ -41,7 +41,7 @@ def test_conn_is_stale(fake_socket, set_stale): @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = -1 connection = Bolt4x1(address, fake_socket(address), max_connection_lifetime) if set_stale: @@ -51,7 +51,7 @@ def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_not_stale(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = 999999999 connection = Bolt4x1(address, fake_socket(address), max_connection_lifetime) if set_stale: @@ -61,7 +61,7 @@ def test_conn_is_not_stale(fake_socket, set_stale): @mark_sync_test def test_db_extra_in_begin(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt4x1.UNPACKER_CLS) connection = Bolt4x1(address, socket, PoolConfig.max_connection_lifetime) connection.begin(db="something") @@ -74,7 +74,7 @@ def test_db_extra_in_begin(fake_socket): @mark_sync_test def test_db_extra_in_run(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt4x1.UNPACKER_CLS) connection = Bolt4x1(address, socket, PoolConfig.max_connection_lifetime) connection.run("", {}, db="something") @@ -89,7 +89,7 @@ def test_db_extra_in_run(fake_socket): @mark_sync_test def test_n_extra_in_discard(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt4x1.UNPACKER_CLS) connection = Bolt4x1(address, socket, PoolConfig.max_connection_lifetime) connection.discard(n=666) @@ -109,7 +109,7 @@ def test_n_extra_in_discard(fake_socket): ) @mark_sync_test def test_qid_extra_in_discard(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt4x1.UNPACKER_CLS) connection = Bolt4x1(address, socket, PoolConfig.max_connection_lifetime) connection.discard(qid=test_input) @@ -130,7 +130,7 @@ def test_qid_extra_in_discard(fake_socket, test_input, expected): @mark_sync_test def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): # python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_n_and_qid_extras_in_discard - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt4x1.UNPACKER_CLS) connection = Bolt4x1(address, socket, PoolConfig.max_connection_lifetime) connection.discard(n=666, qid=test_input) @@ -150,7 +150,7 @@ def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): ) @mark_sync_test def test_n_extra_in_pull(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt4x1.UNPACKER_CLS) connection = Bolt4x1(address, socket, PoolConfig.max_connection_lifetime) connection.pull(n=test_input) @@ -171,7 +171,7 @@ def test_n_extra_in_pull(fake_socket, test_input, expected): @mark_sync_test def test_qid_extra_in_pull(fake_socket, test_input, expected): # python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_qid_extra_in_pull - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt4x1.UNPACKER_CLS) connection = Bolt4x1(address, socket, PoolConfig.max_connection_lifetime) connection.pull(qid=test_input) @@ -184,7 +184,7 @@ def test_qid_extra_in_pull(fake_socket, test_input, expected): @mark_sync_test def test_n_and_qid_extras_in_pull(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt4x1.UNPACKER_CLS) connection = Bolt4x1(address, socket, PoolConfig.max_connection_lifetime) connection.pull(n=666, qid=777) @@ -197,7 +197,7 @@ def test_n_and_qid_extras_in_pull(fake_socket): @mark_sync_test def test_hello_passes_routing_metadata(fake_socket_pair): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) sockets = fake_socket_pair(address, packer_cls=Bolt4x1.PACKER_CLS, unpacker_cls=Bolt4x1.UNPACKER_CLS) @@ -218,7 +218,7 @@ def test_hello_passes_routing_metadata(fake_socket_pair): def test_hint_recv_timeout_seconds_gets_ignored( fake_socket_pair, recv_timeout, mocker ): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) sockets = fake_socket_pair(address, packer_cls=Bolt4x1.PACKER_CLS, unpacker_cls=Bolt4x1.UNPACKER_CLS) @@ -248,7 +248,7 @@ def test_hint_recv_timeout_seconds_gets_ignored( def test_credentials_are_not_logged( auth, fake_socket_pair, mocker, caplog ): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) sockets = fake_socket_pair(address, packer_cls=Bolt4x1.PACKER_CLS, unpacker_cls=Bolt4x1.UNPACKER_CLS) @@ -271,7 +271,7 @@ def test_credentials_are_not_logged( @pytest.mark.parametrize("message", ("logon", "logoff")) def test_auth_message_raises_configuration_error(message, fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) connection = Bolt4x1(address, fake_socket(address), PoolConfig.max_connection_lifetime) with pytest.raises(ConfigurationError, @@ -286,12 +286,12 @@ def test_auth_message_raises_configuration_error(message, fake_socket): )) @mark_sync_test def test_re_auth_noop(auth, fake_socket, mocker): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) connection = Bolt4x1(address, fake_socket(address), PoolConfig.max_connection_lifetime, auth=auth) logon_spy = mocker.spy(connection, "logon") logoff_spy = mocker.spy(connection, "logoff") - res = connection.re_auth(auth) + res = connection.re_auth(auth, None) assert res is False logon_spy.assert_not_called() @@ -311,9 +311,9 @@ def test_re_auth_noop(auth, fake_socket, mocker): ) @mark_sync_test def test_re_auth(auth1, auth2, fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) connection = Bolt4x1(address, fake_socket(address), PoolConfig.max_connection_lifetime, auth=auth1) with pytest.raises(ConfigurationError, match="Session level authentication is not supported"): - connection.re_auth(auth2) + connection.re_auth(auth2, None) diff --git a/tests/unit/sync/io/test_class_bolt4x2.py b/tests/unit/sync/io/test_class_bolt4x2.py index d5217c6d..d9c5e843 100644 --- a/tests/unit/sync/io/test_class_bolt4x2.py +++ b/tests/unit/sync/io/test_class_bolt4x2.py @@ -31,7 +31,7 @@ @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_stale(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = 0 connection = Bolt4x2(address, fake_socket(address), max_connection_lifetime) if set_stale: @@ -41,7 +41,7 @@ def test_conn_is_stale(fake_socket, set_stale): @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = -1 connection = Bolt4x2(address, fake_socket(address), max_connection_lifetime) if set_stale: @@ -51,7 +51,7 @@ def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_not_stale(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = 999999999 connection = Bolt4x2(address, fake_socket(address), max_connection_lifetime) if set_stale: @@ -61,7 +61,7 @@ def test_conn_is_not_stale(fake_socket, set_stale): @mark_sync_test def test_db_extra_in_begin(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt4x2.UNPACKER_CLS) connection = Bolt4x2(address, socket, PoolConfig.max_connection_lifetime) connection.begin(db="something") @@ -74,7 +74,7 @@ def test_db_extra_in_begin(fake_socket): @mark_sync_test def test_db_extra_in_run(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt4x2.UNPACKER_CLS) connection = Bolt4x2(address, socket, PoolConfig.max_connection_lifetime) connection.run("", {}, db="something") @@ -89,7 +89,7 @@ def test_db_extra_in_run(fake_socket): @mark_sync_test def test_n_extra_in_discard(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt4x2.UNPACKER_CLS) connection = Bolt4x2(address, socket, PoolConfig.max_connection_lifetime) connection.discard(n=666) @@ -109,7 +109,7 @@ def test_n_extra_in_discard(fake_socket): ) @mark_sync_test def test_qid_extra_in_discard(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt4x2.UNPACKER_CLS) connection = Bolt4x2(address, socket, PoolConfig.max_connection_lifetime) connection.discard(qid=test_input) @@ -130,7 +130,7 @@ def test_qid_extra_in_discard(fake_socket, test_input, expected): @mark_sync_test def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): # python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_n_and_qid_extras_in_discard - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt4x2.UNPACKER_CLS) connection = Bolt4x2(address, socket, PoolConfig.max_connection_lifetime) connection.discard(n=666, qid=test_input) @@ -150,7 +150,7 @@ def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): ) @mark_sync_test def test_n_extra_in_pull(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt4x2.UNPACKER_CLS) connection = Bolt4x2(address, socket, PoolConfig.max_connection_lifetime) connection.pull(n=test_input) @@ -171,7 +171,7 @@ def test_n_extra_in_pull(fake_socket, test_input, expected): @mark_sync_test def test_qid_extra_in_pull(fake_socket, test_input, expected): # python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_qid_extra_in_pull - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt4x2.UNPACKER_CLS) connection = Bolt4x2(address, socket, PoolConfig.max_connection_lifetime) connection.pull(qid=test_input) @@ -184,7 +184,7 @@ def test_qid_extra_in_pull(fake_socket, test_input, expected): @mark_sync_test def test_n_and_qid_extras_in_pull(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt4x2.UNPACKER_CLS) connection = Bolt4x2(address, socket, PoolConfig.max_connection_lifetime) connection.pull(n=666, qid=777) @@ -197,7 +197,7 @@ def test_n_and_qid_extras_in_pull(fake_socket): @mark_sync_test def test_hello_passes_routing_metadata(fake_socket_pair): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) sockets = fake_socket_pair(address, packer_cls=Bolt4x2.PACKER_CLS, unpacker_cls=Bolt4x2.UNPACKER_CLS) @@ -218,7 +218,7 @@ def test_hello_passes_routing_metadata(fake_socket_pair): def test_hint_recv_timeout_seconds_gets_ignored( fake_socket_pair, recv_timeout, mocker ): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) sockets = fake_socket_pair(address, packer_cls=Bolt4x2.PACKER_CLS, unpacker_cls=Bolt4x2.UNPACKER_CLS) @@ -249,7 +249,7 @@ def test_hint_recv_timeout_seconds_gets_ignored( def test_credentials_are_not_logged( auth, fake_socket_pair, mocker, caplog ): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) sockets = fake_socket_pair(address, packer_cls=Bolt4x2.PACKER_CLS, unpacker_cls=Bolt4x2.UNPACKER_CLS) @@ -272,7 +272,7 @@ def test_credentials_are_not_logged( @pytest.mark.parametrize("message", ("logon", "logoff")) def test_auth_message_raises_configuration_error(message, fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) connection = Bolt4x2(address, fake_socket(address), PoolConfig.max_connection_lifetime) with pytest.raises(ConfigurationError, @@ -287,12 +287,12 @@ def test_auth_message_raises_configuration_error(message, fake_socket): )) @mark_sync_test def test_re_auth_noop(auth, fake_socket, mocker): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) connection = Bolt4x2(address, fake_socket(address), PoolConfig.max_connection_lifetime, auth=auth) logon_spy = mocker.spy(connection, "logon") logoff_spy = mocker.spy(connection, "logoff") - res = connection.re_auth(auth) + res = connection.re_auth(auth, None) assert res is False logon_spy.assert_not_called() @@ -312,9 +312,9 @@ def test_re_auth_noop(auth, fake_socket, mocker): ) @mark_sync_test def test_re_auth(auth1, auth2, fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) connection = Bolt4x2(address, fake_socket(address), PoolConfig.max_connection_lifetime, auth=auth1) with pytest.raises(ConfigurationError, match="Session level authentication is not supported"): - connection.re_auth(auth2) + connection.re_auth(auth2, None) diff --git a/tests/unit/sync/io/test_class_bolt4x3.py b/tests/unit/sync/io/test_class_bolt4x3.py index 4de581ea..fd55cb29 100644 --- a/tests/unit/sync/io/test_class_bolt4x3.py +++ b/tests/unit/sync/io/test_class_bolt4x3.py @@ -31,7 +31,7 @@ @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_stale(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = 0 connection = Bolt4x3(address, fake_socket(address), max_connection_lifetime) if set_stale: @@ -41,7 +41,7 @@ def test_conn_is_stale(fake_socket, set_stale): @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = -1 connection = Bolt4x3(address, fake_socket(address), max_connection_lifetime) if set_stale: @@ -51,7 +51,7 @@ def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_not_stale(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = 999999999 connection = Bolt4x3(address, fake_socket(address), max_connection_lifetime) if set_stale: @@ -61,7 +61,7 @@ def test_conn_is_not_stale(fake_socket, set_stale): @mark_sync_test def test_db_extra_in_begin(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt4x3.UNPACKER_CLS) connection = Bolt4x3(address, socket, PoolConfig.max_connection_lifetime) connection.begin(db="something") @@ -74,7 +74,7 @@ def test_db_extra_in_begin(fake_socket): @mark_sync_test def test_db_extra_in_run(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt4x3.UNPACKER_CLS) connection = Bolt4x3(address, socket, PoolConfig.max_connection_lifetime) connection.run("", {}, db="something") @@ -89,7 +89,7 @@ def test_db_extra_in_run(fake_socket): @mark_sync_test def test_n_extra_in_discard(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt4x3.UNPACKER_CLS) connection = Bolt4x3(address, socket, PoolConfig.max_connection_lifetime) connection.discard(n=666) @@ -109,7 +109,7 @@ def test_n_extra_in_discard(fake_socket): ) @mark_sync_test def test_qid_extra_in_discard(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt4x3.UNPACKER_CLS) connection = Bolt4x3(address, socket, PoolConfig.max_connection_lifetime) connection.discard(qid=test_input) @@ -130,7 +130,7 @@ def test_qid_extra_in_discard(fake_socket, test_input, expected): @mark_sync_test def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): # python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_n_and_qid_extras_in_discard - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt4x3.UNPACKER_CLS) connection = Bolt4x3(address, socket, PoolConfig.max_connection_lifetime) connection.discard(n=666, qid=test_input) @@ -150,7 +150,7 @@ def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): ) @mark_sync_test def test_n_extra_in_pull(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt4x3.UNPACKER_CLS) connection = Bolt4x3(address, socket, PoolConfig.max_connection_lifetime) connection.pull(n=test_input) @@ -171,7 +171,7 @@ def test_n_extra_in_pull(fake_socket, test_input, expected): @mark_sync_test def test_qid_extra_in_pull(fake_socket, test_input, expected): # python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_qid_extra_in_pull - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt4x3.UNPACKER_CLS) connection = Bolt4x3(address, socket, PoolConfig.max_connection_lifetime) connection.pull(qid=test_input) @@ -184,7 +184,7 @@ def test_qid_extra_in_pull(fake_socket, test_input, expected): @mark_sync_test def test_n_and_qid_extras_in_pull(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt4x3.UNPACKER_CLS) connection = Bolt4x3(address, socket, PoolConfig.max_connection_lifetime) connection.pull(n=666, qid=777) @@ -197,7 +197,7 @@ def test_n_and_qid_extras_in_pull(fake_socket): @mark_sync_test def test_hello_passes_routing_metadata(fake_socket_pair): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) sockets = fake_socket_pair(address, packer_cls=Bolt4x3.PACKER_CLS, unpacker_cls=Bolt4x3.UNPACKER_CLS) @@ -229,7 +229,7 @@ def test_hello_passes_routing_metadata(fake_socket_pair): def test_hint_recv_timeout_seconds( fake_socket_pair, hints, valid, caplog, mocker ): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) sockets = fake_socket_pair(address, packer_cls=Bolt4x3.PACKER_CLS, unpacker_cls=Bolt4x3.UNPACKER_CLS) @@ -275,7 +275,7 @@ def test_hint_recv_timeout_seconds( def test_credentials_are_not_logged( auth, fake_socket_pair, mocker, caplog ): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) sockets = fake_socket_pair(address, packer_cls=Bolt4x3.PACKER_CLS, unpacker_cls=Bolt4x3.UNPACKER_CLS) @@ -299,7 +299,7 @@ def test_credentials_are_not_logged( @pytest.mark.parametrize("message", ("logon", "logoff")) def test_auth_message_raises_configuration_error(message, fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) connection = Bolt4x3(address, fake_socket(address), PoolConfig.max_connection_lifetime) with pytest.raises(ConfigurationError, @@ -314,12 +314,12 @@ def test_auth_message_raises_configuration_error(message, fake_socket): )) @mark_sync_test def test_re_auth_noop(auth, fake_socket, mocker): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) connection = Bolt4x3(address, fake_socket(address), PoolConfig.max_connection_lifetime, auth=auth) logon_spy = mocker.spy(connection, "logon") logoff_spy = mocker.spy(connection, "logoff") - res = connection.re_auth(auth) + res = connection.re_auth(auth, None) assert res is False logon_spy.assert_not_called() @@ -339,9 +339,9 @@ def test_re_auth_noop(auth, fake_socket, mocker): ) @mark_sync_test def test_re_auth(auth1, auth2, fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) connection = Bolt4x3(address, fake_socket(address), PoolConfig.max_connection_lifetime, auth=auth1) with pytest.raises(ConfigurationError, match="Session level authentication is not supported"): - connection.re_auth(auth2) + connection.re_auth(auth2, None) diff --git a/tests/unit/sync/io/test_class_bolt4x4.py b/tests/unit/sync/io/test_class_bolt4x4.py index 6305b95c..83e2f7bd 100644 --- a/tests/unit/sync/io/test_class_bolt4x4.py +++ b/tests/unit/sync/io/test_class_bolt4x4.py @@ -31,7 +31,7 @@ @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_stale(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = 0 connection = Bolt4x4(address, fake_socket(address), max_connection_lifetime) if set_stale: @@ -41,7 +41,7 @@ def test_conn_is_stale(fake_socket, set_stale): @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = -1 connection = Bolt4x4(address, fake_socket(address), max_connection_lifetime) if set_stale: @@ -51,7 +51,7 @@ def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_not_stale(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = 999999999 connection = Bolt4x4(address, fake_socket(address), max_connection_lifetime) if set_stale: @@ -70,7 +70,7 @@ def test_conn_is_not_stale(fake_socket, set_stale): )) @mark_sync_test def test_extra_in_begin(fake_socket, args, kwargs, expected_fields): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt4x4.UNPACKER_CLS) connection = Bolt4x4(address, socket, PoolConfig.max_connection_lifetime) connection.begin(*args, **kwargs) @@ -91,7 +91,7 @@ def test_extra_in_begin(fake_socket, args, kwargs, expected_fields): )) @mark_sync_test def test_extra_in_run(fake_socket, args, kwargs, expected_fields): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt4x4.UNPACKER_CLS) connection = Bolt4x4(address, socket, PoolConfig.max_connection_lifetime) connection.run(*args, **kwargs) @@ -103,7 +103,7 @@ def test_extra_in_run(fake_socket, args, kwargs, expected_fields): @mark_sync_test def test_n_extra_in_discard(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt4x4.UNPACKER_CLS) connection = Bolt4x4(address, socket, PoolConfig.max_connection_lifetime) connection.discard(n=666) @@ -123,7 +123,7 @@ def test_n_extra_in_discard(fake_socket): ) @mark_sync_test def test_qid_extra_in_discard(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt4x4.UNPACKER_CLS) connection = Bolt4x4(address, socket, PoolConfig.max_connection_lifetime) connection.discard(qid=test_input) @@ -144,7 +144,7 @@ def test_qid_extra_in_discard(fake_socket, test_input, expected): @mark_sync_test def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): # python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_n_and_qid_extras_in_discard - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt4x4.UNPACKER_CLS) connection = Bolt4x4(address, socket, PoolConfig.max_connection_lifetime) connection.discard(n=666, qid=test_input) @@ -164,7 +164,7 @@ def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): ) @mark_sync_test def test_n_extra_in_pull(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt4x4.UNPACKER_CLS) connection = Bolt4x4(address, socket, PoolConfig.max_connection_lifetime) connection.pull(n=test_input) @@ -185,7 +185,7 @@ def test_n_extra_in_pull(fake_socket, test_input, expected): @mark_sync_test def test_qid_extra_in_pull(fake_socket, test_input, expected): # python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_qid_extra_in_pull - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt4x4.UNPACKER_CLS) connection = Bolt4x4(address, socket, PoolConfig.max_connection_lifetime) connection.pull(qid=test_input) @@ -198,7 +198,7 @@ def test_qid_extra_in_pull(fake_socket, test_input, expected): @mark_sync_test def test_n_and_qid_extras_in_pull(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt4x4.UNPACKER_CLS) connection = Bolt4x4(address, socket, PoolConfig.max_connection_lifetime) connection.pull(n=666, qid=777) @@ -211,7 +211,7 @@ def test_n_and_qid_extras_in_pull(fake_socket): @mark_sync_test def test_hello_passes_routing_metadata(fake_socket_pair): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) sockets = fake_socket_pair(address, packer_cls=Bolt4x4.PACKER_CLS, unpacker_cls=Bolt4x4.UNPACKER_CLS) @@ -243,7 +243,7 @@ def test_hello_passes_routing_metadata(fake_socket_pair): def test_hint_recv_timeout_seconds( fake_socket_pair, hints, valid, caplog, mocker ): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) sockets = fake_socket_pair(address, packer_cls=Bolt4x4.PACKER_CLS, unpacker_cls=Bolt4x4.UNPACKER_CLS) @@ -289,7 +289,7 @@ def test_hint_recv_timeout_seconds( def test_credentials_are_not_logged( auth, fake_socket_pair, mocker, caplog ): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) sockets = fake_socket_pair(address, packer_cls=Bolt4x4.PACKER_CLS, unpacker_cls=Bolt4x4.UNPACKER_CLS) @@ -312,7 +312,7 @@ def test_credentials_are_not_logged( @pytest.mark.parametrize("message", ("logon", "logoff")) def test_auth_message_raises_configuration_error(message, fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) connection = Bolt4x4(address, fake_socket(address), PoolConfig.max_connection_lifetime) with pytest.raises(ConfigurationError, @@ -327,12 +327,12 @@ def test_auth_message_raises_configuration_error(message, fake_socket): )) @mark_sync_test def test_re_auth_noop(auth, fake_socket, mocker): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) connection = Bolt4x4(address, fake_socket(address), PoolConfig.max_connection_lifetime, auth=auth) logon_spy = mocker.spy(connection, "logon") logoff_spy = mocker.spy(connection, "logoff") - res = connection.re_auth(auth) + res = connection.re_auth(auth, None) assert res is False logon_spy.assert_not_called() @@ -352,9 +352,9 @@ def test_re_auth_noop(auth, fake_socket, mocker): ) @mark_sync_test def test_re_auth(auth1, auth2, fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) connection = Bolt4x4(address, fake_socket(address), PoolConfig.max_connection_lifetime, auth=auth1) with pytest.raises(ConfigurationError, match="Session level authentication is not supported"): - connection.re_auth(auth2) + connection.re_auth(auth2, None) diff --git a/tests/unit/sync/io/test_class_bolt5x0.py b/tests/unit/sync/io/test_class_bolt5x0.py index 9c7e4654..8e224603 100644 --- a/tests/unit/sync/io/test_class_bolt5x0.py +++ b/tests/unit/sync/io/test_class_bolt5x0.py @@ -31,7 +31,7 @@ @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_stale(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = 0 connection = Bolt5x0(address, fake_socket(address), max_connection_lifetime) if set_stale: @@ -41,7 +41,7 @@ def test_conn_is_stale(fake_socket, set_stale): @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = -1 connection = Bolt5x0(address, fake_socket(address), max_connection_lifetime) if set_stale: @@ -51,7 +51,7 @@ def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_not_stale(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = 999999999 connection = Bolt5x0(address, fake_socket(address), max_connection_lifetime) if set_stale: @@ -70,7 +70,7 @@ def test_conn_is_not_stale(fake_socket, set_stale): )) @mark_sync_test def test_extra_in_begin(fake_socket, args, kwargs, expected_fields): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt5x0.UNPACKER_CLS) connection = Bolt5x0(address, socket, PoolConfig.max_connection_lifetime) connection.begin(*args, **kwargs) @@ -91,7 +91,7 @@ def test_extra_in_begin(fake_socket, args, kwargs, expected_fields): )) @mark_sync_test def test_extra_in_run(fake_socket, args, kwargs, expected_fields): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt5x0.UNPACKER_CLS) connection = Bolt5x0(address, socket, PoolConfig.max_connection_lifetime) connection.run(*args, **kwargs) @@ -103,7 +103,7 @@ def test_extra_in_run(fake_socket, args, kwargs, expected_fields): @mark_sync_test def test_n_extra_in_discard(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt5x0.UNPACKER_CLS) connection = Bolt5x0(address, socket, PoolConfig.max_connection_lifetime) connection.discard(n=666) @@ -123,7 +123,7 @@ def test_n_extra_in_discard(fake_socket): ) @mark_sync_test def test_qid_extra_in_discard(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt5x0.UNPACKER_CLS) connection = Bolt5x0(address, socket, PoolConfig.max_connection_lifetime) connection.discard(qid=test_input) @@ -143,7 +143,7 @@ def test_qid_extra_in_discard(fake_socket, test_input, expected): ) @mark_sync_test def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt5x0.UNPACKER_CLS) connection = Bolt5x0(address, socket, PoolConfig.max_connection_lifetime) connection.discard(n=666, qid=test_input) @@ -163,7 +163,7 @@ def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): ) @mark_sync_test def test_n_extra_in_pull(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt5x0.UNPACKER_CLS) connection = Bolt5x0(address, socket, PoolConfig.max_connection_lifetime) connection.pull(n=test_input) @@ -183,7 +183,7 @@ def test_n_extra_in_pull(fake_socket, test_input, expected): ) @mark_sync_test def test_qid_extra_in_pull(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt5x0.UNPACKER_CLS) connection = Bolt5x0(address, socket, PoolConfig.max_connection_lifetime) connection.pull(qid=test_input) @@ -196,7 +196,7 @@ def test_qid_extra_in_pull(fake_socket, test_input, expected): @mark_sync_test def test_n_and_qid_extras_in_pull(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt5x0.UNPACKER_CLS) connection = Bolt5x0(address, socket, PoolConfig.max_connection_lifetime) connection.pull(n=666, qid=777) @@ -209,7 +209,7 @@ def test_n_and_qid_extras_in_pull(fake_socket): @mark_sync_test def test_hello_passes_routing_metadata(fake_socket_pair): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) sockets = fake_socket_pair(address, packer_cls=Bolt5x0.PACKER_CLS, unpacker_cls=Bolt5x0.UNPACKER_CLS) @@ -241,7 +241,7 @@ def test_hello_passes_routing_metadata(fake_socket_pair): def test_hint_recv_timeout_seconds( fake_socket_pair, hints, valid, caplog, mocker ): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) sockets = fake_socket_pair(address, packer_cls=Bolt5x0.PACKER_CLS, unpacker_cls=Bolt5x0.UNPACKER_CLS) @@ -287,7 +287,7 @@ def test_hint_recv_timeout_seconds( def test_credentials_are_not_logged( auth, fake_socket_pair, mocker, caplog ): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) sockets = fake_socket_pair(address, packer_cls=Bolt5x0.PACKER_CLS, unpacker_cls=Bolt5x0.UNPACKER_CLS) @@ -310,7 +310,7 @@ def test_credentials_are_not_logged( @pytest.mark.parametrize("message", ("logon", "logoff")) def test_auth_message_raises_configuration_error(message, fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) connection = Bolt5x0(address, fake_socket(address), PoolConfig.max_connection_lifetime) with pytest.raises(ConfigurationError, @@ -325,12 +325,12 @@ def test_auth_message_raises_configuration_error(message, fake_socket): )) @mark_sync_test def test_re_auth_noop(auth, fake_socket, mocker): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) connection = Bolt5x0(address, fake_socket(address), PoolConfig.max_connection_lifetime, auth=auth) logon_spy = mocker.spy(connection, "logon") logoff_spy = mocker.spy(connection, "logoff") - res = connection.re_auth(auth) + res = connection.re_auth(auth, None) assert res is False logon_spy.assert_not_called() @@ -350,9 +350,9 @@ def test_re_auth_noop(auth, fake_socket, mocker): ) @mark_sync_test def test_re_auth(auth1, auth2, fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) connection = Bolt5x0(address, fake_socket(address), PoolConfig.max_connection_lifetime, auth=auth1) with pytest.raises(ConfigurationError, match="Session level authentication is not supported"): - connection.re_auth(auth2) + connection.re_auth(auth2, None) diff --git a/tests/unit/sync/io/test_class_bolt5x1.py b/tests/unit/sync/io/test_class_bolt5x1.py index 8e9b6f11..3f736a4a 100644 --- a/tests/unit/sync/io/test_class_bolt5x1.py +++ b/tests/unit/sync/io/test_class_bolt5x1.py @@ -24,13 +24,14 @@ import neo4j.exceptions from neo4j._conf import PoolConfig from neo4j._sync.io._bolt5 import Bolt5x1 +from neo4j.auth_management import AuthManagers from ...._async_compat import mark_sync_test @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_stale(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = 0 connection = Bolt5x1(address, fake_socket(address), max_connection_lifetime) @@ -41,7 +42,7 @@ def test_conn_is_stale(fake_socket, set_stale): @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = -1 connection = Bolt5x1(address, fake_socket(address), max_connection_lifetime) @@ -52,7 +53,7 @@ def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): @pytest.mark.parametrize("set_stale", (True, False)) def test_conn_is_not_stale(fake_socket, set_stale): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) max_connection_lifetime = 999999999 connection = Bolt5x1(address, fake_socket(address), max_connection_lifetime) @@ -72,7 +73,7 @@ def test_conn_is_not_stale(fake_socket, set_stale): )) @mark_sync_test def test_extra_in_begin(fake_socket, args, kwargs, expected_fields): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt5x1.UNPACKER_CLS) connection = Bolt5x1(address, socket, PoolConfig.max_connection_lifetime) @@ -94,7 +95,7 @@ def test_extra_in_begin(fake_socket, args, kwargs, expected_fields): )) @mark_sync_test def test_extra_in_run(fake_socket, args, kwargs, expected_fields): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt5x1.UNPACKER_CLS) connection = Bolt5x1(address, socket, PoolConfig.max_connection_lifetime) @@ -107,7 +108,7 @@ def test_extra_in_run(fake_socket, args, kwargs, expected_fields): @mark_sync_test def test_n_extra_in_discard(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt5x1.UNPACKER_CLS) connection = Bolt5x1(address, socket, PoolConfig.max_connection_lifetime) @@ -128,7 +129,7 @@ def test_n_extra_in_discard(fake_socket): ) @mark_sync_test def test_qid_extra_in_discard(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt5x1.UNPACKER_CLS) connection = Bolt5x1(address, socket, PoolConfig.max_connection_lifetime) @@ -149,7 +150,7 @@ def test_qid_extra_in_discard(fake_socket, test_input, expected): ) @mark_sync_test def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt5x1.UNPACKER_CLS) connection = Bolt5x1(address, socket, PoolConfig.max_connection_lifetime) @@ -170,7 +171,7 @@ def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): ) @mark_sync_test def test_n_extra_in_pull(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt5x1.UNPACKER_CLS) connection = Bolt5x1(address, socket, PoolConfig.max_connection_lifetime) @@ -191,7 +192,7 @@ def test_n_extra_in_pull(fake_socket, test_input, expected): ) @mark_sync_test def test_qid_extra_in_pull(fake_socket, test_input, expected): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt5x1.UNPACKER_CLS) connection = Bolt5x1(address, socket, PoolConfig.max_connection_lifetime) @@ -205,7 +206,7 @@ def test_qid_extra_in_pull(fake_socket, test_input, expected): @mark_sync_test def test_n_and_qid_extras_in_pull(fake_socket): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) socket = fake_socket(address, Bolt5x1.UNPACKER_CLS) connection = Bolt5x1(address, socket, PoolConfig.max_connection_lifetime) @@ -219,7 +220,7 @@ def test_n_and_qid_extras_in_pull(fake_socket): @mark_sync_test def test_hello_passes_routing_metadata(fake_socket_pair): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) sockets = fake_socket_pair(address, packer_cls=Bolt5x1.PACKER_CLS, unpacker_cls=Bolt5x1.UNPACKER_CLS) @@ -249,7 +250,7 @@ def _assert_logon_message(sockets, auth): @mark_sync_test def test_hello_pipelines_logon(fake_socket_pair): auth = neo4j.Auth("basic", "alice123", "supersecret123") - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) sockets = fake_socket_pair(address, packer_cls=Bolt5x1.PACKER_CLS, unpacker_cls=Bolt5x1.UNPACKER_CLS) @@ -260,7 +261,7 @@ def test_hello_pipelines_logon(fake_socket_pair): connection = Bolt5x1( address, sockets.client, PoolConfig.max_connection_lifetime, auth=auth ) - with pytest.raises(neo4j.exceptions.ServiceUnavailable): + with pytest.raises(neo4j.exceptions.Neo4jError): connection.hello() tag, fields = sockets.server.pop_message() assert tag == b"\x01" # HELLO @@ -273,7 +274,7 @@ def test_hello_pipelines_logon(fake_socket_pair): @mark_sync_test def test_logon(fake_socket_pair): auth = neo4j.Auth("basic", "alice123", "supersecret123") - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) sockets = fake_socket_pair(address, packer_cls=Bolt5x1.PACKER_CLS, unpacker_cls=Bolt5x1.UNPACKER_CLS) @@ -287,7 +288,8 @@ def test_logon(fake_socket_pair): @mark_sync_test def test_re_auth(fake_socket_pair, mocker): auth = neo4j.Auth("basic", "alice123", "supersecret123") - address = ("127.0.0.1", 7687) + auth_manager = AuthManagers.static(auth) + address = neo4j.Address(("127.0.0.1", 7687)) sockets = fake_socket_pair(address, packer_cls=Bolt5x1.PACKER_CLS, unpacker_cls=Bolt5x1.UNPACKER_CLS) @@ -298,7 +300,7 @@ def test_re_auth(fake_socket_pair, mocker): connection = Bolt5x1(address, sockets.client, PoolConfig.max_connection_lifetime) connection.pool = mocker.Mock() - connection.re_auth(auth) + connection.re_auth(auth, auth_manager) connection.send_all() with pytest.raises(neo4j.exceptions.Neo4jError): connection.fetch_all() @@ -306,11 +308,13 @@ def test_re_auth(fake_socket_pair, mocker): assert tag == b"\x6B" # LOGOFF assert len(fields) == 0 _assert_logon_message(sockets, auth) + assert connection.auth is auth + assert connection.auth_manager is auth_manager @mark_sync_test def test_logoff(fake_socket_pair): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) sockets = fake_socket_pair(address, packer_cls=Bolt5x1.PACKER_CLS, unpacker_cls=Bolt5x1.UNPACKER_CLS) @@ -342,7 +346,7 @@ def test_logoff(fake_socket_pair): def test_hint_recv_timeout_seconds( fake_socket_pair, hints, valid, caplog, mocker ): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) sockets = fake_socket_pair(address, packer_cls=Bolt5x1.PACKER_CLS, unpacker_cls=Bolt5x1.UNPACKER_CLS) @@ -387,7 +391,7 @@ def test_hint_recv_timeout_seconds( )) @mark_sync_test def test_credentials_are_not_logged(auth, fake_socket_pair, caplog): - address = ("127.0.0.1", 7687) + address = neo4j.Address(("127.0.0.1", 7687)) sockets = fake_socket_pair(address, packer_cls=Bolt5x1.PACKER_CLS, unpacker_cls=Bolt5x1.UNPACKER_CLS) diff --git a/tests/unit/sync/io/test_direct.py b/tests/unit/sync/io/test_direct.py index 2d13c669..ea3ea83e 100644 --- a/tests/unit/sync/io/test_direct.py +++ b/tests/unit/sync/io/test_direct.py @@ -26,6 +26,7 @@ from neo4j._deadline import Deadline from neo4j._sync.io import Bolt from neo4j._sync.io._pool import IOPool +from neo4j.auth_management import AuthManagers from neo4j.exceptions import ( ClientError, ServiceUnavailable, @@ -59,13 +60,17 @@ def __init__(self, socket): def is_reset(self): return True + @property + def throwaway(self): + return False + def stale(self): return False def reset(self): pass - def re_auth(self, auth, force=False): + def re_auth(self, auth, auth_manager, force=False): return False def close(self): @@ -82,8 +87,8 @@ def timedout(self): class FakeBoltPool(IOPool): - def __init__(self, address, *, auth=None, **config): + config["auth"] = AuthManagers.static(None) self.pool_config, self.workspace_config = Config.consume_chain(config, PoolConfig, WorkspaceConfig) if config: raise ValueError("Unexpected config keys: %s" % ", ".join(config.keys())) @@ -105,15 +110,18 @@ def acquire( @mark_sync_test def test_bolt_connection_open(): + auth_manager = AuthManagers.static(("test", "test")) with pytest.raises(ServiceUnavailable): - Bolt.open(("localhost", 9999), auth=("test", "test")) + Bolt.open(("localhost", 9999), auth_manager=auth_manager) @mark_sync_test def test_bolt_connection_open_timeout(): + auth_manager = AuthManagers.static(("test", "test")) with pytest.raises(ServiceUnavailable): - Bolt.open(("localhost", 9999), auth=("test", "test"), - timeout=1) + Bolt.open( + ("localhost", 9999), auth_manager=auth_manager, timeout=1 + ) @mark_sync_test diff --git a/tests/unit/sync/io/test_neo4j_pool.py b/tests/unit/sync/io/test_neo4j_pool.py index 24957c72..52f44742 100644 --- a/tests/unit/sync/io/test_neo4j_pool.py +++ b/tests/unit/sync/io/test_neo4j_pool.py @@ -33,6 +33,7 @@ from neo4j._deadline import Deadline from neo4j._sync.io import Neo4jPool from neo4j.addressing import ResolvedAddress +from neo4j.auth_management import AuthManagers from neo4j.exceptions import ( Neo4jError, ServiceUnavailable, @@ -71,7 +72,7 @@ def routing_side_effect(*args, **kwargs): def open_(addr, auth, timeout): connection = fake_connection_generator() - connection.addr = addr + connection.unresolved_address = addr connection.timeout = timeout connection.auth = auth route_mock = mocker.Mock() @@ -95,11 +96,21 @@ def opener(routing_failure_opener): return routing_failure_opener() +def _pool_config(): + pool_config = PoolConfig() + pool_config.auth = AuthManagers.static(("user", "pass")) + return pool_config + + +def _simple_pool(opener) -> Neo4jPool: + return Neo4jPool( + opener, _pool_config(), WorkspaceConfig(), ROUTER1_ADDRESS + ) + + @mark_sync_test def test_acquires_new_routing_table_if_deleted(opener): - pool = Neo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS - ) + pool = _simple_pool(opener) cx = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) pool.release(cx) assert pool.routing_tables.get("test_db") @@ -113,9 +124,7 @@ def test_acquires_new_routing_table_if_deleted(opener): @mark_sync_test def test_acquires_new_routing_table_if_stale(opener): - pool = Neo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS - ) + pool = _simple_pool(opener) cx = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) pool.release(cx) assert pool.routing_tables.get("test_db") @@ -130,9 +139,7 @@ def test_acquires_new_routing_table_if_stale(opener): @mark_sync_test def test_removes_old_routing_table(opener): - pool = Neo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS - ) + pool = _simple_pool(opener) cx = pool.acquire(READ_ACCESS, 30, "test_db1", None, None, None) pool.release(cx) assert pool.routing_tables.get("test_db1") @@ -154,25 +161,21 @@ def test_removes_old_routing_table(opener): @pytest.mark.parametrize("type_", ("r", "w")) @mark_sync_test def test_chooses_right_connection_type(opener, type_): - pool = Neo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS - ) + pool = _simple_pool(opener) cx1 = pool.acquire( READ_ACCESS if type_ == "r" else WRITE_ACCESS, 30, "test_db", None, None, None ) pool.release(cx1) if type_ == "r": - assert cx1.addr == READER_ADDRESS + assert cx1.unresolved_address == READER_ADDRESS else: - assert cx1.addr == WRITER_ADDRESS + assert cx1.unresolved_address == WRITER_ADDRESS @mark_sync_test def test_reuses_connection(opener): - pool = Neo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS - ) + pool = _simple_pool(opener) cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) pool.release(cx1) cx2 = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) @@ -183,21 +186,19 @@ def test_reuses_connection(opener): @mark_sync_test def test_closes_stale_connections(opener, break_on_close): def break_connection(): - pool.deactivate(cx1.addr) + pool.deactivate(cx1.unresolved_address) if cx_close_mock_side_effect: res = cx_close_mock_side_effect() if inspect.isawaitable(res): return res - pool = Neo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS - ) + pool = _simple_pool(opener) cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) pool.release(cx1) - assert cx1 in pool.connections[cx1.addr] - # simulate connection going stale (e.g. exceeding) and then breaking when - # the pool tries to close the connection + assert cx1 in pool.connections[cx1.unresolved_address] + # simulate connection going stale (e.g. exceeding idle timeout) and then + # breaking when the pool tries to close the connection cx1.stale.return_value = True cx_close_mock = cx1.close if break_on_close: @@ -210,27 +211,26 @@ def break_connection(): else: cx1.close.assert_called_once() assert cx2 is not cx1 - assert cx2.addr == cx1.addr - assert cx1 not in pool.connections[cx1.addr] - assert cx2 in pool.connections[cx2.addr] + assert cx2.unresolved_address == cx1.unresolved_address + assert cx1 not in pool.connections[cx1.unresolved_address] + assert cx2 in pool.connections[cx2.unresolved_address] @mark_sync_test def test_does_not_close_stale_connections_in_use(opener): - pool = Neo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS - ) + pool = _simple_pool(opener) cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) - assert cx1 in pool.connections[cx1.addr] - # simulate connection going stale (e.g. exceeding) while being in use + assert cx1 in pool.connections[cx1.unresolved_address] + # simulate connection going stale (e.g. exceeding idle timeout) while being + # in use cx1.stale.return_value = True cx2 = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) pool.release(cx2) cx1.close.assert_not_called() assert cx2 is not cx1 - assert cx2.addr == cx1.addr - assert cx1 in pool.connections[cx1.addr] - assert cx2 in pool.connections[cx2.addr] + assert cx2.unresolved_address == cx1.unresolved_address + assert cx1 in pool.connections[cx1.unresolved_address] + assert cx2 in pool.connections[cx2.unresolved_address] pool.release(cx1) # now that cx1 is back in the pool and still stale, @@ -241,16 +241,14 @@ def test_does_not_close_stale_connections_in_use(opener): pool.release(cx3) cx1.close.assert_called_once() assert cx2 is cx3 - assert cx3.addr == cx1.addr - assert cx1 not in pool.connections[cx1.addr] - assert cx3 in pool.connections[cx2.addr] + assert cx3.unresolved_address == cx1.unresolved_address + assert cx1 not in pool.connections[cx1.unresolved_address] + assert cx3 in pool.connections[cx2.unresolved_address] @mark_sync_test def test_release_resets_connections(opener): - pool = Neo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS - ) + pool = _simple_pool(opener) cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) cx1.is_reset_mock.return_value = False cx1.is_reset_mock.reset_mock() @@ -261,9 +259,7 @@ def test_release_resets_connections(opener): @mark_sync_test def test_release_does_not_resets_closed_connections(opener): - pool = Neo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS - ) + pool = _simple_pool(opener) cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) cx1.closed.return_value = True cx1.closed.reset_mock() @@ -276,9 +272,7 @@ def test_release_does_not_resets_closed_connections(opener): @mark_sync_test def test_release_does_not_resets_defunct_connections(opener): - pool = Neo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS - ) + pool = _simple_pool(opener) cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) cx1.defunct.return_value = True cx1.defunct.reset_mock() @@ -294,12 +288,10 @@ def test_release_does_not_resets_defunct_connections(opener): def test_acquire_performs_no_liveness_check_on_fresh_connection( opener, liveness_timeout ): - pool = Neo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS - ) + pool = _simple_pool(opener) cx1 = pool._acquire(READER_ADDRESS, None, Deadline(30), liveness_timeout) - assert cx1.addr == READER_ADDRESS + assert cx1.unresolved_address == READER_ADDRESS cx1.reset.assert_not_called() @@ -308,15 +300,13 @@ def test_acquire_performs_no_liveness_check_on_fresh_connection( def test_acquire_performs_liveness_check_on_existing_connection( opener, liveness_timeout ): - pool = Neo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS - ) + pool = _simple_pool(opener) # populate the pool with a connection cx1 = pool._acquire(READER_ADDRESS, None, Deadline(30), liveness_timeout) # make sure we assume the right state - assert cx1.addr == READER_ADDRESS + assert cx1.unresolved_address == READER_ADDRESS cx1.is_idle_for.assert_not_called() cx1.reset.assert_not_called() @@ -344,15 +334,13 @@ def liveness_side_effect(*args, **kwargs): raise liveness_error("liveness check failed") liveness_timeout = 1 - pool = Neo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS - ) + pool = _simple_pool(opener) # populate the pool with a connection cx1 = pool._acquire(READER_ADDRESS, None, Deadline(30), liveness_timeout) # make sure we assume the right state - assert cx1.addr == READER_ADDRESS + assert cx1.unresolved_address == READER_ADDRESS cx1.is_idle_for.assert_not_called() cx1.reset.assert_not_called() @@ -368,11 +356,11 @@ def liveness_side_effect(*args, **kwargs): cx2 = pool._acquire(READER_ADDRESS, None, Deadline(30), liveness_timeout) assert cx1 is not cx2 - assert cx1.addr == cx2.addr + assert cx1.unresolved_address == cx2.unresolved_address cx1.is_idle_for.assert_called_once_with(liveness_timeout) cx2.reset.assert_not_called() - assert cx1 not in pool.connections[cx1.addr] - assert cx2 in pool.connections[cx1.addr] + assert cx1 not in pool.connections[cx1.unresolved_address] + assert cx2 in pool.connections[cx1.unresolved_address] @pytest.mark.parametrize("liveness_error", @@ -385,9 +373,7 @@ def liveness_side_effect(*args, **kwargs): raise liveness_error("liveness check failed") liveness_timeout = 1 - pool = Neo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS - ) + pool = _simple_pool(opener) # populate the pool with a connection cx1 = pool._acquire(READER_ADDRESS, None, Deadline(30), liveness_timeout) @@ -395,8 +381,8 @@ def liveness_side_effect(*args, **kwargs): liveness_timeout) # make sure we assume the right state - assert cx1.addr == READER_ADDRESS - assert cx2.addr == READER_ADDRESS + assert cx1.unresolved_address == READER_ADDRESS + assert cx2.unresolved_address == READER_ADDRESS assert cx1 is not cx2 cx1.is_idle_for.assert_not_called() cx2.is_idle_for.assert_not_called() @@ -421,8 +407,8 @@ def liveness_side_effect(*args, **kwargs): cx1.reset.assert_called_once() cx3.is_idle_for.assert_called_once_with(liveness_timeout) cx3.reset.assert_called_once() - assert cx1 not in pool.connections[cx1.addr] - assert cx3 in pool.connections[cx1.addr] + assert cx1 not in pool.connections[cx1.unresolved_address] + assert cx3 in pool.connections[cx1.unresolved_address] @mark_sync_test @@ -437,9 +423,7 @@ def close_side_effect(): "close") # create pool with 2 idle connections - pool = Neo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS - ) + pool = _simple_pool(opener) cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) cx2 = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) pool.release(cx1) @@ -461,9 +445,7 @@ def close_side_effect(): @mark_sync_test def test_failing_opener_leaves_connections_in_use_alone(opener): - pool = Neo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS - ) + pool = _simple_pool(opener) cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) opener.side_effect = ServiceUnavailable("Server overloaded") @@ -474,13 +456,13 @@ def test_failing_opener_leaves_connections_in_use_alone(opener): @mark_sync_test def test__acquire_new_later_with_room(opener): - config = PoolConfig() + config = _pool_config() config.max_connection_pool_size = 1 pool = Neo4jPool( opener, config, WorkspaceConfig(), ROUTER1_ADDRESS ) assert pool.connections_reservations[READER_ADDRESS] == 0 - creator = pool._acquire_new_later(READER_ADDRESS, None, Deadline(1)) + creator = pool._acquire_new_later(READER_ADDRESS, None, Deadline(1), False) assert pool.connections_reservations[READER_ADDRESS] == 1 assert callable(creator) if Util.is_async_code: @@ -489,7 +471,7 @@ def test__acquire_new_later_with_room(opener): @mark_sync_test def test__acquire_new_later_without_room(opener): - config = PoolConfig() + config = _pool_config() config.max_connection_pool_size = 1 pool = Neo4jPool( opener, config, WorkspaceConfig(), ROUTER1_ADDRESS @@ -497,7 +479,7 @@ def test__acquire_new_later_without_room(opener): _ = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) # pool is full now assert pool.connections_reservations[READER_ADDRESS] == 0 - creator = pool._acquire_new_later(READER_ADDRESS, None, Deadline(1)) + creator = pool._acquire_new_later(READER_ADDRESS, None, Deadline(1), False) assert pool.connections_reservations[READER_ADDRESS] == 0 assert creator is None @@ -515,7 +497,7 @@ def test_discovery_is_retried(routing_failure_opener, error): error, # will be retried ]) pool = Neo4jPool( - opener, PoolConfig(), WorkspaceConfig(), + opener, _pool_config(), WorkspaceConfig(), ResolvedAddress(("1.2.3.1", 9999), host_name="host") ) cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) @@ -559,7 +541,7 @@ def test_fast_failing_discovery(routing_failure_opener, error): error, # will be retried ]) pool = Neo4jPool( - opener, PoolConfig(), WorkspaceConfig(), + opener, _pool_config(), WorkspaceConfig(), ResolvedAddress(("1.2.3.1", 9999), host_name="host") ) cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) @@ -601,11 +583,13 @@ def test_fast_failing_discovery(routing_failure_opener, error): def test_connection_error_callback( opener, error, marks_unauthenticated, fetches_new, mocker ): + config = _pool_config() + auth_manager = AuthManagers.static(("user", "auth")) + on_auth_expired_mock = mocker.patch.object(auth_manager, "on_auth_expired", + autospec=True) + config.auth = auth_manager pool = Neo4jPool( - opener, PoolConfig(), WorkspaceConfig(), ROUTER1_ADDRESS - ) - force_new_auth_mock = mocker.patch.object( - pool, "force_new_auth", autospec=True + opener, config, WorkspaceConfig(), ROUTER1_ADDRESS ) cxs_read = [ pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) @@ -616,16 +600,18 @@ def test_connection_error_callback( for _ in range(5) ] - force_new_auth_mock.assert_not_called() + on_auth_expired_mock.assert_not_called() for cx in cxs_read + cxs_write: cx.mark_unauthenticated.assert_not_called() - pool.on_neo4j_error(error, cxs_read[0].addr) + pool.on_neo4j_error(error, cxs_read[0]) if fetches_new: - force_new_auth_mock.assert_called_once() + cxs_read[0].auth_manager.on_auth_expired.assert_called_once() else: - force_new_auth_mock.assert_not_called() + on_auth_expired_mock.assert_not_called() + for cx in cxs_read: + cx.auth_manager.on_auth_expired.assert_not_called() for cx in cxs_read: if marks_unauthenticated: diff --git a/tests/unit/sync/test_addressing.py b/tests/unit/sync/test_addressing.py index 190ac416..9274f449 100644 --- a/tests/unit/sync/test_addressing.py +++ b/tests/unit/sync/test_addressing.py @@ -16,10 +16,7 @@ # limitations under the License. -from socket import ( - AF_INET, - AF_INET6, -) +from socket import AF_INET import pytest diff --git a/tests/unit/sync/test_auth_manager.py b/tests/unit/sync/test_auth_manager.py new file mode 100644 index 00000000..6cb7817f --- /dev/null +++ b/tests/unit/sync/test_auth_manager.py @@ -0,0 +1,116 @@ +import itertools +import typing as t + +import pytest +from freezegun import freeze_time +from freezegun.api import FrozenDateTimeFactory + +from neo4j import ( + Auth, + basic_auth, +) +from neo4j.auth_management import ( + AuthManager, + AuthManagers, + TemporalAuth, +) + +from ..._async_compat import mark_sync_test + + +SAMPLE_AUTHS = ( + None, + ("user", "password"), + basic_auth("foo", "bar"), + basic_auth("foo", "bar", "baz"), + Auth("scheme", "principal", "credentials", "realm", para="meter"), +) + + +@mark_sync_test +@pytest.mark.parametrize("auth", SAMPLE_AUTHS) +def test_static_manager( + auth +) -> None: + manager: AuthManager = AuthManagers.static(auth) + assert manager.get_auth() is auth + + manager.on_auth_expired(("something", "else")) + assert manager.get_auth() is auth + + manager.on_auth_expired(auth) + assert manager.get_auth() is auth + + +@mark_sync_test +@pytest.mark.parametrize(("auth1", "auth2"), + itertools.product(SAMPLE_AUTHS, repeat=2)) +@pytest.mark.parametrize("expires_in", (None, -1, 1., 1, 1000.)) +def test_temporal_manager_manual_expiry( + auth1: t.Union[t.Tuple[str, str], Auth, None], + auth2: t.Union[t.Tuple[str, str], Auth, None], + expires_in: t.Union[float, int], + mocker +) -> None: + if expires_in is None or expires_in >= 0: + temporal_auth = TemporalAuth(auth1, expires_in) + else: + temporal_auth = TemporalAuth(auth1) + provider = mocker.Mock(return_value=temporal_auth) + manager: AuthManager = AuthManagers.temporal(provider) + + provider.assert_not_called() + assert manager.get_auth() is auth1 + provider.assert_called_once() + provider.reset_mock() + + provider.return_value = TemporalAuth(auth2) + + manager.on_auth_expired(("something", "else")) + assert manager.get_auth() is auth1 + provider.assert_not_called() + + manager.on_auth_expired(auth1) + provider.assert_called_once() + provider.reset_mock() + assert manager.get_auth() is auth2 + provider.assert_not_called() + + +@mark_sync_test +@pytest.mark.parametrize(("auth1", "auth2"), + itertools.product(SAMPLE_AUTHS, repeat=2)) +@pytest.mark.parametrize("expires_in", (None, -1, 1., 1, 1000.)) +def test_temporal_manager_time_expiry( + auth1: t.Union[t.Tuple[str, str], Auth, None], + auth2: t.Union[t.Tuple[str, str], Auth, None], + expires_in: t.Union[float, int, None], + mocker +) -> None: + with freeze_time() as frozen_time: + assert isinstance(frozen_time, FrozenDateTimeFactory) + if expires_in is None or expires_in >= 0: + temporal_auth = TemporalAuth(auth1, expires_in) + else: + temporal_auth = TemporalAuth(auth1) + provider = mocker.Mock(return_value=temporal_auth) + manager: AuthManager = AuthManagers.temporal(provider) + + provider.assert_not_called() + assert manager.get_auth() is auth1 + provider.assert_called_once() + provider.reset_mock() + + provider.return_value = TemporalAuth(auth2) + + if expires_in is None or expires_in < 0: + frozen_time.tick(1_000_000) + assert manager.get_auth() is auth1 + provider.assert_not_called() + else: + frozen_time.tick(expires_in - 0.000001) + assert manager.get_auth() is auth1 + provider.assert_not_called() + frozen_time.tick(0.000002) + assert manager.get_auth() is auth2 + provider.assert_called_once() From 660fa538df48c9b1dd7cf40ecc7857f57d37c9a8 Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Wed, 8 Mar 2023 13:43:05 +0100 Subject: [PATCH 13/23] Don't expose backwards compatible user-switching https://github.com/neo-technology/drivers-adr/pull/64 --- docs/source/api.rst | 11 +++++++++++ src/neo4j/_async/io/_pool.py | 1 + src/neo4j/_async/work/session.py | 13 +++++++------ src/neo4j/_async/work/workspace.py | 3 ++- src/neo4j/_conf.py | 13 +++++++------ src/neo4j/_sync/io/_pool.py | 1 + src/neo4j/_sync/work/session.py | 13 +++++++------ src/neo4j/_sync/work/workspace.py | 3 ++- testkitbackend/_async/requests.py | 6 +++++- testkitbackend/_sync/requests.py | 6 +++++- testkitbackend/test_config.json | 1 + tests/unit/common/test_conf.py | 3 ++- 12 files changed, 51 insertions(+), 23 deletions(-) diff --git a/docs/source/api.rst b/docs/source/api.rst index 0c28b413..35587a86 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -360,6 +360,9 @@ Driver Configuration Additional configuration can be provided via the :class:`neo4j.Driver` constructor. +.. TODO: wait for decision on backwards_compatible_auth + + :ref:`backwards-compatible-auth-ref` + + :ref:`connection-acquisition-timeout-ref` + :ref:`connection-timeout-ref` + :ref:`encrypted-ref` @@ -374,6 +377,14 @@ Additional configuration can be provided via the :class:`neo4j.Driver` construct + :ref:`user-agent-ref` +.. TODO: wait for decision on backwards_compatible_auth + .. _backwards-compatible-auth-ref: + + ``backwards_compatible_auth`` + ----------------------------- + ... + + .. _connection-acquisition-timeout-ref: ``connection_acquisition_timeout`` diff --git a/src/neo4j/_async/io/_pool.py b/src/neo4j/_async/io/_pool.py index b76078b4..efa79d0f 100644 --- a/src/neo4j/_async/io/_pool.py +++ b/src/neo4j/_async/io/_pool.py @@ -76,6 +76,7 @@ @dataclass class AcquireAuth: auth: t.Union[AsyncAuthManager, AuthManager, None] + # TODO: wait for decision on backwards_compatible_auth backwards_compatible: bool = False force_auth: bool = False diff --git a/src/neo4j/_async/work/session.py b/src/neo4j/_async/work/session.py index 2a3c0bab..dd60b5b8 100644 --- a/src/neo4j/_async/work/session.py +++ b/src/neo4j/_async/work/session.py @@ -171,12 +171,13 @@ async def _get_server_info(self): async def _verify_authentication(self): assert not self._connection await self._connect(READ_ACCESS, force_auth=True) - if not self._config.backwards_compatible_auth: - # Even without backwards compatibility and an old server, the - # _connect call above can succeed if the connection is a new one. - # Hence, we enforce support explicitly to always let the user know - # that this is not supported. - self._connection.assert_re_auth_support() + # TODO: wait for decision on backwards_compatible_auth + # if not self._config.backwards_compatible_auth: + # # Even without backwards compatibility and an old server, the + # # _connect call above can succeed if the connection is a new one. + # # Hence, we enforce support explicitly to always let the user know + # # that this is not supported. + # self._connection.assert_re_auth_support() await self._disconnect() async def close(self) -> None: diff --git a/src/neo4j/_async/work/workspace.py b/src/neo4j/_async/work/workspace.py index 3c77096e..ebbc950d 100644 --- a/src/neo4j/_async/work/workspace.py +++ b/src/neo4j/_async/work/workspace.py @@ -137,7 +137,8 @@ async def _connect(self, access_mode, auth=None, **acquire_kwargs): acquisition_timeout = self._config.connection_acquisition_timeout auth = AcquireAuth( auth, - backwards_compatible=self._config.backwards_compatible_auth, + # TODO: wait for decision on backwards_compatible_auth + # backwards_compatible=self._config.backwards_compatible_auth, force_auth=acquire_kwargs.pop("force_auth", False), ) diff --git a/src/neo4j/_conf.py b/src/neo4j/_conf.py index 81242cc8..97cd21f6 100644 --- a/src/neo4j/_conf.py +++ b/src/neo4j/_conf.py @@ -495,12 +495,13 @@ class WorkspaceConfig(Config): bookmark_manager = ExperimentalOption(None) # Specify the bookmark manager to be used for sessions by default. - #: Session Auth Backward Compatibility Layer - backwards_compatible_auth = False - # Enable session level authentication (user-switching) on session level - # even over Bolt 5.0 and earlier. This is done using a very costly - # backwards compatible authentication layer in the driver utilizing - # throwaway connections. + # TODO: wait for decision on backwards_compatible_auth + # #: Session Auth Backward Compatibility Layer + # backwards_compatible_auth = False + # # Enable session level authentication (user-switching) on session level + # # even over Bolt 5.0 and earlier. This is done using a very costly + # # backwards compatible authentication layer in the driver utilizing + # # throwaway connections. class SessionConfig(WorkspaceConfig): diff --git a/src/neo4j/_sync/io/_pool.py b/src/neo4j/_sync/io/_pool.py index 99964e02..b3bcdc68 100644 --- a/src/neo4j/_sync/io/_pool.py +++ b/src/neo4j/_sync/io/_pool.py @@ -73,6 +73,7 @@ @dataclass class AcquireAuth: auth: t.Union[AuthManager, AuthManager, None] + # TODO: wait for decision on backwards_compatible_auth backwards_compatible: bool = False force_auth: bool = False diff --git a/src/neo4j/_sync/work/session.py b/src/neo4j/_sync/work/session.py index f4f81b29..af220c1e 100644 --- a/src/neo4j/_sync/work/session.py +++ b/src/neo4j/_sync/work/session.py @@ -171,12 +171,13 @@ def _get_server_info(self): def _verify_authentication(self): assert not self._connection self._connect(READ_ACCESS, force_auth=True) - if not self._config.backwards_compatible_auth: - # Even without backwards compatibility and an old server, the - # _connect call above can succeed if the connection is a new one. - # Hence, we enforce support explicitly to always let the user know - # that this is not supported. - self._connection.assert_re_auth_support() + # TODO: wait for decision on backwards_compatible_auth + # if not self._config.backwards_compatible_auth: + # # Even without backwards compatibility and an old server, the + # # _connect call above can succeed if the connection is a new one. + # # Hence, we enforce support explicitly to always let the user know + # # that this is not supported. + # self._connection.assert_re_auth_support() self._disconnect() def close(self) -> None: diff --git a/src/neo4j/_sync/work/workspace.py b/src/neo4j/_sync/work/workspace.py index ff6b96e4..f0cfc563 100644 --- a/src/neo4j/_sync/work/workspace.py +++ b/src/neo4j/_sync/work/workspace.py @@ -137,7 +137,8 @@ def _connect(self, access_mode, auth=None, **acquire_kwargs): acquisition_timeout = self._config.connection_acquisition_timeout auth = AcquireAuth( auth, - backwards_compatible=self._config.backwards_compatible_auth, + # TODO: wait for decision on backwards_compatible_auth + # backwards_compatible=self._config.backwards_compatible_auth, force_auth=acquire_kwargs.pop("force_auth", False), ) diff --git a/testkitbackend/_async/requests.py b/testkitbackend/_async/requests.py index ab81c41c..7a36a7a6 100644 --- a/testkitbackend/_async/requests.py +++ b/testkitbackend/_async/requests.py @@ -132,10 +132,14 @@ async def NewDriver(backend, data): kwargs[conf_name] = data[data_name] for (conf_name, data_name) in ( ("encrypted", "encrypted"), - ("backwards_compatible_auth", "backwardsCompatibleAuth"), + # TODO: wait for decision on backwards_compatible_auth + # ("backwards_compatible_auth", "backwardsCompatibleAuth"), ): if data_name in data: kwargs[conf_name] = data[data_name] + # TODO: wait for decision on backwards_compatible_auth + if "backwardsCompatibleAuth" in data: + data.mark_item_as_read_if_equals("backwardsCompatibleAuth", False) if "trustedCertificates" in data: if data["trustedCertificates"] is None: kwargs["trusted_certificates"] = neo4j.TrustSystemCAs() diff --git a/testkitbackend/_sync/requests.py b/testkitbackend/_sync/requests.py index ca037d0b..757d0fd7 100644 --- a/testkitbackend/_sync/requests.py +++ b/testkitbackend/_sync/requests.py @@ -132,10 +132,14 @@ def NewDriver(backend, data): kwargs[conf_name] = data[data_name] for (conf_name, data_name) in ( ("encrypted", "encrypted"), - ("backwards_compatible_auth", "backwardsCompatibleAuth"), + # TODO: wait for decision on backwards_compatible_auth + # ("backwards_compatible_auth", "backwardsCompatibleAuth"), ): if data_name in data: kwargs[conf_name] = data[data_name] + # TODO: wait for decision on backwards_compatible_auth + if "backwardsCompatibleAuth" in data: + data.mark_item_as_read_if_equals("backwardsCompatibleAuth", False) if "trustedCertificates" in data: if data["trustedCertificates"] is None: kwargs["trusted_certificates"] = neo4j.TrustSystemCAs() diff --git a/testkitbackend/test_config.json b/testkitbackend/test_config.json index fd29479b..a13a2e82 100644 --- a/testkitbackend/test_config.json +++ b/testkitbackend/test_config.json @@ -37,6 +37,7 @@ "Feature:Auth:Bearer": true, "Feature:Auth:Custom": true, "Feature:Auth:Kerberos": true, + "Feature:Auth:Managed": true, "Feature:Bolt:3.0": true, "Feature:Bolt:4.1": true, "Feature:Bolt:4.2": true, diff --git a/tests/unit/common/test_conf.py b/tests/unit/common/test_conf.py index 41f0647a..52179568 100644 --- a/tests/unit/common/test_conf.py +++ b/tests/unit/common/test_conf.py @@ -72,7 +72,8 @@ "fetch_size": 100, "bookmark_manager": object(), "auth": None, - "backwards_compatible_auth": False, + # TODO: wait for decision on backwards_compatible_auth + # "backwards_compatible_auth": False, } From 25f91e55ca968343593c57c395c4ffe843c37214 Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Thu, 9 Mar 2023 11:26:26 +0100 Subject: [PATCH 14/23] Differentiate TokenExpired Token expired code should be marked retryable iff the driver was configured with a (non-static) auth token manager. https://github.com/neo-technology/drivers-adr/pull/64 --- docs/source/api.rst | 5 +++++ src/neo4j/_async/auth_management.py | 8 ++++---- src/neo4j/_async/io/_pool.py | 8 ++++++++ src/neo4j/_sync/auth_management.py | 8 ++++---- src/neo4j/_sync/io/_pool.py | 8 ++++++++ src/neo4j/exceptions.py | 18 +++++++++++++++++- 6 files changed, 46 insertions(+), 9 deletions(-) diff --git a/docs/source/api.rst b/docs/source/api.rst index 35587a86..c43a7346 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -1651,6 +1651,8 @@ Server-side errors * :class:`neo4j.exceptions.TokenExpired` + * :class:`neo4j.exceptions.TokenExpiredRetryable` + * :class:`neo4j.exceptions.Forbidden` * :class:`neo4j.exceptions.DatabaseError` @@ -1686,6 +1688,9 @@ Server-side errors .. autoexception:: neo4j.exceptions.TokenExpired() :show-inheritance: +.. autoexception:: neo4j.exceptions.TokenExpiredRetryable() + :show-inheritance: + .. autoexception:: neo4j.exceptions.Forbidden() :show-inheritance: diff --git a/src/neo4j/_async/auth_management.py b/src/neo4j/_async/auth_management.py index 04064f38..c9fb568b 100644 --- a/src/neo4j/_async/auth_management.py +++ b/src/neo4j/_async/auth_management.py @@ -39,7 +39,7 @@ log = getLogger("neo4j") -class _AsyncStaticAuthManager(AsyncAuthManager): +class AsyncStaticAuthManager(AsyncAuthManager): _auth: _TAuth def __init__(self, auth: _TAuth) -> None: @@ -68,7 +68,7 @@ def expired(self) -> bool: return False return time.monotonic() > self._expiry -class _AsyncTemporalAuthManager(AsyncAuthManager): +class AsyncTemporalAuthManager(AsyncAuthManager): _current_auth: t.Optional[_TemporalAuthHolder] _provider: t.Callable[[], t.Awaitable[TemporalAuth]] _lock: AsyncLock @@ -135,7 +135,7 @@ def static(auth: _TAuth) -> AsyncAuthManager: An instance of an implementation of :class:`.AsyncAuthManager` that always returns the same auth. """ - return _AsyncStaticAuthManager(auth) + return AsyncStaticAuthManager(auth) @staticmethod def temporal( @@ -187,4 +187,4 @@ async def auth_provider(): reached its expiry time or because the server flagged it as expired). """ - return _AsyncTemporalAuthManager(provider) + return AsyncTemporalAuthManager(provider) diff --git a/src/neo4j/_async/io/_pool.py b/src/neo4j/_async/io/_pool.py index efa79d0f..3e182db0 100644 --- a/src/neo4j/_async/io/_pool.py +++ b/src/neo4j/_async/io/_pool.py @@ -48,6 +48,7 @@ ) from ..._exceptions import BoltError from ..._routing import RoutingTable +from ..._sync.auth_management import StaticAuthManager from ...api import ( READ_ACCESS, WRITE_ACCESS, @@ -64,8 +65,11 @@ ReadServiceUnavailable, ServiceUnavailable, SessionExpired, + TokenExpired, + TokenExpiredRetryable, WriteServiceUnavailable, ) +from ..auth_management import AsyncStaticAuthManager from ._bolt import AsyncBolt @@ -522,6 +526,10 @@ async def on_neo4j_error(self, error, connection): connection.auth_manager.on_auth_expired, connection.auth ) + if (isinstance(error, TokenExpired) + and not isinstance(self.pool_config.auth, (AsyncStaticAuthManager, + StaticAuthManager))): + error.__class__ = TokenExpiredRetryable async def close(self): """ Close all connections and empty the pool. diff --git a/src/neo4j/_sync/auth_management.py b/src/neo4j/_sync/auth_management.py index 8a31a0af..579f811d 100644 --- a/src/neo4j/_sync/auth_management.py +++ b/src/neo4j/_sync/auth_management.py @@ -39,7 +39,7 @@ log = getLogger("neo4j") -class _StaticAuthManager(AuthManager): +class StaticAuthManager(AuthManager): _auth: _TAuth def __init__(self, auth: _TAuth) -> None: @@ -68,7 +68,7 @@ def expired(self) -> bool: return False return time.monotonic() > self._expiry -class _TemporalAuthManager(AuthManager): +class TemporalAuthManager(AuthManager): _current_auth: t.Optional[_TemporalAuthHolder] _provider: t.Callable[[], t.Union[TemporalAuth]] _lock: Lock @@ -135,7 +135,7 @@ def static(auth: _TAuth) -> AuthManager: An instance of an implementation of :class:`.AuthManager` that always returns the same auth. """ - return _StaticAuthManager(auth) + return StaticAuthManager(auth) @staticmethod def temporal( @@ -187,4 +187,4 @@ def auth_provider(): reached its expiry time or because the server flagged it as expired). """ - return _TemporalAuthManager(provider) + return TemporalAuthManager(provider) diff --git a/src/neo4j/_sync/io/_pool.py b/src/neo4j/_sync/io/_pool.py index b3bcdc68..7b25f050 100644 --- a/src/neo4j/_sync/io/_pool.py +++ b/src/neo4j/_sync/io/_pool.py @@ -48,6 +48,7 @@ ) from ..._exceptions import BoltError from ..._routing import RoutingTable +from ..._sync.auth_management import StaticAuthManager from ...api import ( READ_ACCESS, WRITE_ACCESS, @@ -61,8 +62,11 @@ ReadServiceUnavailable, ServiceUnavailable, SessionExpired, + TokenExpired, + TokenExpiredRetryable, WriteServiceUnavailable, ) +from ..auth_management import StaticAuthManager from ._bolt import Bolt @@ -519,6 +523,10 @@ def on_neo4j_error(self, error, connection): connection.auth_manager.on_auth_expired, connection.auth ) + if (isinstance(error, TokenExpired) + and not isinstance(self.pool_config.auth, (StaticAuthManager, + StaticAuthManager))): + error.__class__ = TokenExpiredRetryable def close(self): """ Close all connections and empty the pool. diff --git a/src/neo4j/exceptions.py b/src/neo4j/exceptions.py index 22f770ea..8360e451 100644 --- a/src/neo4j/exceptions.py +++ b/src/neo4j/exceptions.py @@ -310,10 +310,26 @@ class AuthError(ClientError): class TokenExpired(AuthError): """ Raised when the authentication token has expired. - A new driver instance with a fresh authentication token needs to be created. + A new driver instance with a fresh authentication token needs to be + created, unless the driver was configured using a non-static + :class:`.AuthManager`. In that case, the error will be + :exc:`.TokenExpiredRetryable` instead. """ +# Neo4jError > ClientError > AuthError > TokenExpired > TokenExpiredRetryable +class TokenExpiredRetryable(TokenExpired): + """Raised when the authentication token has expired but can be refreshed. + + This is the same server error as :exc:`.TokenExpired`, but raised when + the driver is configured to be able to refresh the token, hence making + the error retryable. + """ + + def is_retryable(self) -> bool: + return True + + # Neo4jError > ClientError > Forbidden class Forbidden(ClientError): """ From 22a7f28decfbf638f40a68cf7d569bea41f16142 Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Mon, 20 Mar 2023 11:37:56 +0100 Subject: [PATCH 15/23] minor refactoring --- src/neo4j/_async/io/_pool.py | 2 +- src/neo4j/_sync/io/_pool.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/neo4j/_async/io/_pool.py b/src/neo4j/_async/io/_pool.py index 3e182db0..c4857f87 100644 --- a/src/neo4j/_async/io/_pool.py +++ b/src/neo4j/_async/io/_pool.py @@ -511,9 +511,9 @@ def on_write_failure(self, address): ) async def on_neo4j_error(self, error, connection): - address = connection.unresolved_address assert isinstance(error, Neo4jError) if error._unauthenticates_all_connections(): + address = connection.unresolved_address log.debug( "[#0000] _: mark all connections to %r as " "unauthenticated", address diff --git a/src/neo4j/_sync/io/_pool.py b/src/neo4j/_sync/io/_pool.py index 7b25f050..b9221dd7 100644 --- a/src/neo4j/_sync/io/_pool.py +++ b/src/neo4j/_sync/io/_pool.py @@ -508,9 +508,9 @@ def on_write_failure(self, address): ) def on_neo4j_error(self, error, connection): - address = connection.unresolved_address assert isinstance(error, Neo4jError) if error._unauthenticates_all_connections(): + address = connection.unresolved_address log.debug( "[#0000] _: mark all connections to %r as " "unauthenticated", address From 22fb61a37b35aff8e488f8d963c76f0f741b5c39 Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Mon, 27 Mar 2023 09:33:08 +0200 Subject: [PATCH 16/23] Remove backwards compatible session auth (polyfill). --- docs/source/api.rst | 12 --- src/neo4j/_async/io/_bolt.py | 2 - src/neo4j/_async/io/_pool.py | 94 ++++--------------- src/neo4j/_async/work/session.py | 7 -- src/neo4j/_async/work/workspace.py | 2 - src/neo4j/_conf.py | 8 -- src/neo4j/_sync/io/_bolt.py | 2 - src/neo4j/_sync/io/_pool.py | 94 ++++--------------- src/neo4j/_sync/work/session.py | 7 -- src/neo4j/_sync/work/workspace.py | 2 - testkitbackend/_async/requests.py | 5 - testkitbackend/_sync/requests.py | 5 - tests/unit/async_/fixtures/fake_connection.py | 1 - tests/unit/async_/io/test_direct.py | 4 - tests/unit/async_/io/test_neo4j_pool.py | 4 +- tests/unit/common/test_conf.py | 2 - tests/unit/sync/fixtures/fake_connection.py | 1 - tests/unit/sync/io/test_direct.py | 4 - tests/unit/sync/io/test_neo4j_pool.py | 4 +- 19 files changed, 40 insertions(+), 220 deletions(-) diff --git a/docs/source/api.rst b/docs/source/api.rst index c43a7346..0abb5164 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -360,9 +360,6 @@ Driver Configuration Additional configuration can be provided via the :class:`neo4j.Driver` constructor. -.. TODO: wait for decision on backwards_compatible_auth - + :ref:`backwards-compatible-auth-ref` - + :ref:`connection-acquisition-timeout-ref` + :ref:`connection-timeout-ref` + :ref:`encrypted-ref` @@ -374,15 +371,6 @@ Additional configuration can be provided via the :class:`neo4j.Driver` construct + :ref:`trust-ref` + :ref:`ssl-context-ref` + :ref:`trusted-certificates-ref` -+ :ref:`user-agent-ref` - - -.. TODO: wait for decision on backwards_compatible_auth - .. _backwards-compatible-auth-ref: - - ``backwards_compatible_auth`` - ----------------------------- - ... .. _connection-acquisition-timeout-ref: diff --git a/src/neo4j/_async/io/_bolt.py b/src/neo4j/_async/io/_bolt.py index 863e817b..2b60d932 100644 --- a/src/neo4j/_async/io/_bolt.py +++ b/src/neo4j/_async/io/_bolt.py @@ -99,8 +99,6 @@ class AsyncBolt: # The socket in_use = False - throwaway = False - # When the connection was last put back into the pool idle_since = float("-inf") diff --git a/src/neo4j/_async/io/_pool.py b/src/neo4j/_async/io/_pool.py index c4857f87..0f5a8b67 100644 --- a/src/neo4j/_async/io/_pool.py +++ b/src/neo4j/_async/io/_pool.py @@ -80,8 +80,6 @@ @dataclass class AcquireAuth: auth: t.Union[AsyncAuthManager, AuthManager, None] - # TODO: wait for decision on backwards_compatible_auth - backwards_compatible: bool = False force_auth: bool = False @@ -163,8 +161,7 @@ async def _acquire_from_pool_checked( else: return connection - def _acquire_new_later(self, address, auth, deadline, - backwards_compatible_auth): + def _acquire_new_later(self, address, auth, deadline): async def connection_creator(): released_reservation = False try: @@ -183,14 +180,10 @@ async def connection_creator(): try: connection.assert_re_auth_support() except ConfigurationError: - if not backwards_compatible_auth: - log.debug("[#%04X] _: no re-auth support", - connection.local_port) - await connection.close() - raise - log.debug("[#%04X] _: is throwaway connection", + log.debug("[#%04X] _: no re-auth support", connection.local_port) - connection.throwaway = True + await connection.close() + raise connection.pool = self connection.in_use = True with self.lock: @@ -216,10 +209,8 @@ async def connection_creator(): return connection_creator return None - async def _re_auth_connection( - self, connection, auth, force, backwards_compatible_auth - ): - if auth and not backwards_compatible_auth: + async def _re_auth_connection(self, connection, auth, force): + if auth: # Assert session auth is supported by the protocol. # The Bolt implementation will try as hard as it can to make the # re-auth work. So if the session auth token is identical to the @@ -257,7 +248,6 @@ async def _acquire( if auth is None: auth = AcquireAuth(None) force_auth = auth.force_auth - backwards_compatible_auth = auth.backwards_compatible auth = auth.auth async def health_check(connection_, deadline_): @@ -286,42 +276,21 @@ async def health_check(connection_, deadline_): connection.local_port, connection.connection_id) try: await self._re_auth_connection( - connection, auth, force_auth, backwards_compatible_auth + connection, auth, force_auth ) except ConfigurationError: - if not auth: - # expiring tokens supported by flushing the pool - # => give up this connection - log.debug("[#%04X] _: backwards compatible " - "auth token refresh: purge connection", - connection.local_port) - await connection.close() - await self.release(connection) - continue - if not backwards_compatible_auth: + if auth: + # protocol version lacks support for re-auth + # => session auth token is not supported raise - # backwards compatibility mode: - # create new throwaway connection, - connection_creator = self._acquire_new_later( - address, auth, deadline, - backwards_compatible_auth=True - ) - if connection_creator: - await self.release(connection) - if not connection_creator: - # pool is full => kill the picked connection - log.debug("[#%04X] _: backwards compatible " - "session auth making room by purge", - connection.local_port) - await connection.close() - with self.lock: - self._remove_connection(connection) - connection_creator = self._acquire_new_later( - address, auth, deadline, - backwards_compatible_auth=True - ) - assert connection_creator is not None - break + # expiring tokens supported by flushing the pool + # => give up this connection + log.debug("[#%04X] _: backwards compatible " + "auth token refresh: purge connection", + connection.local_port) + await connection.close() + await self.release(connection) + continue log.debug("[#%04X] _: handing out existing connection", connection.local_port) return connection @@ -329,7 +298,6 @@ async def health_check(connection_, deadline_): with self.lock: connection_creator = self._acquire_new_later( address, auth, deadline, - backwards_compatible_auth=backwards_compatible_auth ) if connection_creator: break @@ -384,29 +352,6 @@ def kill_and_release(self, *connections): connection.in_use = False self.cond.notify_all() - @staticmethod - async def _close_throwaway_connection(connection, cancelled): - if connection.throwaway: - if cancelled is not None: - log.debug( - "[#%04X] _: kill throwaway connection %s", - connection.local_port, connection.connection_id - ) - connection.kill() - else: - try: - log.debug( - "[#%04X] _: close throwaway connection %s", - connection.local_port, connection.connection_id - ) - await connection.close() - except asyncio.CancelledError as exc: - log.debug("[#%04X] _: cancelled close of " - "throwaway connection: %r", - connection.local_port, exc) - cancelled = exc - return cancelled - async def release(self, *connections): """ Release connections back into the pool. @@ -414,9 +359,6 @@ async def release(self, *connections): """ cancelled = None for connection in connections: - cancelled = await self._close_throwaway_connection( - connection, cancelled - ) if not (connection.defunct() or connection.closed() or connection.is_reset): diff --git a/src/neo4j/_async/work/session.py b/src/neo4j/_async/work/session.py index dd60b5b8..2bbebc39 100644 --- a/src/neo4j/_async/work/session.py +++ b/src/neo4j/_async/work/session.py @@ -171,13 +171,6 @@ async def _get_server_info(self): async def _verify_authentication(self): assert not self._connection await self._connect(READ_ACCESS, force_auth=True) - # TODO: wait for decision on backwards_compatible_auth - # if not self._config.backwards_compatible_auth: - # # Even without backwards compatibility and an old server, the - # # _connect call above can succeed if the connection is a new one. - # # Hence, we enforce support explicitly to always let the user know - # # that this is not supported. - # self._connection.assert_re_auth_support() await self._disconnect() async def close(self) -> None: diff --git a/src/neo4j/_async/work/workspace.py b/src/neo4j/_async/work/workspace.py index ebbc950d..6a3d43d6 100644 --- a/src/neo4j/_async/work/workspace.py +++ b/src/neo4j/_async/work/workspace.py @@ -137,8 +137,6 @@ async def _connect(self, access_mode, auth=None, **acquire_kwargs): acquisition_timeout = self._config.connection_acquisition_timeout auth = AcquireAuth( auth, - # TODO: wait for decision on backwards_compatible_auth - # backwards_compatible=self._config.backwards_compatible_auth, force_auth=acquire_kwargs.pop("force_auth", False), ) diff --git a/src/neo4j/_conf.py b/src/neo4j/_conf.py index 97cd21f6..9707334c 100644 --- a/src/neo4j/_conf.py +++ b/src/neo4j/_conf.py @@ -495,14 +495,6 @@ class WorkspaceConfig(Config): bookmark_manager = ExperimentalOption(None) # Specify the bookmark manager to be used for sessions by default. - # TODO: wait for decision on backwards_compatible_auth - # #: Session Auth Backward Compatibility Layer - # backwards_compatible_auth = False - # # Enable session level authentication (user-switching) on session level - # # even over Bolt 5.0 and earlier. This is done using a very costly - # # backwards compatible authentication layer in the driver utilizing - # # throwaway connections. - class SessionConfig(WorkspaceConfig): """ Session configuration. diff --git a/src/neo4j/_sync/io/_bolt.py b/src/neo4j/_sync/io/_bolt.py index 9047eeb1..2676ffa8 100644 --- a/src/neo4j/_sync/io/_bolt.py +++ b/src/neo4j/_sync/io/_bolt.py @@ -99,8 +99,6 @@ class Bolt: # The socket in_use = False - throwaway = False - # When the connection was last put back into the pool idle_since = float("-inf") diff --git a/src/neo4j/_sync/io/_pool.py b/src/neo4j/_sync/io/_pool.py index b9221dd7..f31f3637 100644 --- a/src/neo4j/_sync/io/_pool.py +++ b/src/neo4j/_sync/io/_pool.py @@ -77,8 +77,6 @@ @dataclass class AcquireAuth: auth: t.Union[AuthManager, AuthManager, None] - # TODO: wait for decision on backwards_compatible_auth - backwards_compatible: bool = False force_auth: bool = False @@ -160,8 +158,7 @@ def _acquire_from_pool_checked( else: return connection - def _acquire_new_later(self, address, auth, deadline, - backwards_compatible_auth): + def _acquire_new_later(self, address, auth, deadline): def connection_creator(): released_reservation = False try: @@ -180,14 +177,10 @@ def connection_creator(): try: connection.assert_re_auth_support() except ConfigurationError: - if not backwards_compatible_auth: - log.debug("[#%04X] _: no re-auth support", - connection.local_port) - connection.close() - raise - log.debug("[#%04X] _: is throwaway connection", + log.debug("[#%04X] _: no re-auth support", connection.local_port) - connection.throwaway = True + connection.close() + raise connection.pool = self connection.in_use = True with self.lock: @@ -213,10 +206,8 @@ def connection_creator(): return connection_creator return None - def _re_auth_connection( - self, connection, auth, force, backwards_compatible_auth - ): - if auth and not backwards_compatible_auth: + def _re_auth_connection(self, connection, auth, force): + if auth: # Assert session auth is supported by the protocol. # The Bolt implementation will try as hard as it can to make the # re-auth work. So if the session auth token is identical to the @@ -254,7 +245,6 @@ def _acquire( if auth is None: auth = AcquireAuth(None) force_auth = auth.force_auth - backwards_compatible_auth = auth.backwards_compatible auth = auth.auth def health_check(connection_, deadline_): @@ -283,42 +273,21 @@ def health_check(connection_, deadline_): connection.local_port, connection.connection_id) try: self._re_auth_connection( - connection, auth, force_auth, backwards_compatible_auth + connection, auth, force_auth ) except ConfigurationError: - if not auth: - # expiring tokens supported by flushing the pool - # => give up this connection - log.debug("[#%04X] _: backwards compatible " - "auth token refresh: purge connection", - connection.local_port) - connection.close() - self.release(connection) - continue - if not backwards_compatible_auth: + if auth: + # protocol version lacks support for re-auth + # => session auth token is not supported raise - # backwards compatibility mode: - # create new throwaway connection, - connection_creator = self._acquire_new_later( - address, auth, deadline, - backwards_compatible_auth=True - ) - if connection_creator: - self.release(connection) - if not connection_creator: - # pool is full => kill the picked connection - log.debug("[#%04X] _: backwards compatible " - "session auth making room by purge", - connection.local_port) - connection.close() - with self.lock: - self._remove_connection(connection) - connection_creator = self._acquire_new_later( - address, auth, deadline, - backwards_compatible_auth=True - ) - assert connection_creator is not None - break + # expiring tokens supported by flushing the pool + # => give up this connection + log.debug("[#%04X] _: backwards compatible " + "auth token refresh: purge connection", + connection.local_port) + connection.close() + self.release(connection) + continue log.debug("[#%04X] _: handing out existing connection", connection.local_port) return connection @@ -326,7 +295,6 @@ def health_check(connection_, deadline_): with self.lock: connection_creator = self._acquire_new_later( address, auth, deadline, - backwards_compatible_auth=backwards_compatible_auth ) if connection_creator: break @@ -381,29 +349,6 @@ def kill_and_release(self, *connections): connection.in_use = False self.cond.notify_all() - @staticmethod - def _close_throwaway_connection(connection, cancelled): - if connection.throwaway: - if cancelled is not None: - log.debug( - "[#%04X] _: kill throwaway connection %s", - connection.local_port, connection.connection_id - ) - connection.kill() - else: - try: - log.debug( - "[#%04X] _: close throwaway connection %s", - connection.local_port, connection.connection_id - ) - connection.close() - except asyncio.CancelledError as exc: - log.debug("[#%04X] _: cancelled close of " - "throwaway connection: %r", - connection.local_port, exc) - cancelled = exc - return cancelled - def release(self, *connections): """ Release connections back into the pool. @@ -411,9 +356,6 @@ def release(self, *connections): """ cancelled = None for connection in connections: - cancelled = self._close_throwaway_connection( - connection, cancelled - ) if not (connection.defunct() or connection.closed() or connection.is_reset): diff --git a/src/neo4j/_sync/work/session.py b/src/neo4j/_sync/work/session.py index af220c1e..f32116fe 100644 --- a/src/neo4j/_sync/work/session.py +++ b/src/neo4j/_sync/work/session.py @@ -171,13 +171,6 @@ def _get_server_info(self): def _verify_authentication(self): assert not self._connection self._connect(READ_ACCESS, force_auth=True) - # TODO: wait for decision on backwards_compatible_auth - # if not self._config.backwards_compatible_auth: - # # Even without backwards compatibility and an old server, the - # # _connect call above can succeed if the connection is a new one. - # # Hence, we enforce support explicitly to always let the user know - # # that this is not supported. - # self._connection.assert_re_auth_support() self._disconnect() def close(self) -> None: diff --git a/src/neo4j/_sync/work/workspace.py b/src/neo4j/_sync/work/workspace.py index f0cfc563..b098c73d 100644 --- a/src/neo4j/_sync/work/workspace.py +++ b/src/neo4j/_sync/work/workspace.py @@ -137,8 +137,6 @@ def _connect(self, access_mode, auth=None, **acquire_kwargs): acquisition_timeout = self._config.connection_acquisition_timeout auth = AcquireAuth( auth, - # TODO: wait for decision on backwards_compatible_auth - # backwards_compatible=self._config.backwards_compatible_auth, force_auth=acquire_kwargs.pop("force_auth", False), ) diff --git a/testkitbackend/_async/requests.py b/testkitbackend/_async/requests.py index 7a36a7a6..471ed6d9 100644 --- a/testkitbackend/_async/requests.py +++ b/testkitbackend/_async/requests.py @@ -132,14 +132,9 @@ async def NewDriver(backend, data): kwargs[conf_name] = data[data_name] for (conf_name, data_name) in ( ("encrypted", "encrypted"), - # TODO: wait for decision on backwards_compatible_auth - # ("backwards_compatible_auth", "backwardsCompatibleAuth"), ): if data_name in data: kwargs[conf_name] = data[data_name] - # TODO: wait for decision on backwards_compatible_auth - if "backwardsCompatibleAuth" in data: - data.mark_item_as_read_if_equals("backwardsCompatibleAuth", False) if "trustedCertificates" in data: if data["trustedCertificates"] is None: kwargs["trusted_certificates"] = neo4j.TrustSystemCAs() diff --git a/testkitbackend/_sync/requests.py b/testkitbackend/_sync/requests.py index 757d0fd7..c31ddb6d 100644 --- a/testkitbackend/_sync/requests.py +++ b/testkitbackend/_sync/requests.py @@ -132,14 +132,9 @@ def NewDriver(backend, data): kwargs[conf_name] = data[data_name] for (conf_name, data_name) in ( ("encrypted", "encrypted"), - # TODO: wait for decision on backwards_compatible_auth - # ("backwards_compatible_auth", "backwardsCompatibleAuth"), ): if data_name in data: kwargs[conf_name] = data[data_name] - # TODO: wait for decision on backwards_compatible_auth - if "backwardsCompatibleAuth" in data: - data.mark_item_as_read_if_equals("backwardsCompatibleAuth", False) if "trustedCertificates" in data: if data["trustedCertificates"] is None: kwargs["trusted_certificates"] = neo4j.TrustSystemCAs() diff --git a/tests/unit/async_/fixtures/fake_connection.py b/tests/unit/async_/fixtures/fake_connection.py index 5bf1a7ee..9e3995cf 100644 --- a/tests/unit/async_/fixtures/fake_connection.py +++ b/tests/unit/async_/fixtures/fake_connection.py @@ -55,7 +55,6 @@ def __init__(self, *args, **kwargs): self.attach_mock(mock.AsyncMock(spec=AsyncAuthManager), "auth_manager") self.unresolved_address = next(iter(args), "localhost") - self.throwaway = False def close_side_effect(): self.closed.return_value = True diff --git a/tests/unit/async_/io/test_direct.py b/tests/unit/async_/io/test_direct.py index 8397d995..cdbaa802 100644 --- a/tests/unit/async_/io/test_direct.py +++ b/tests/unit/async_/io/test_direct.py @@ -60,10 +60,6 @@ def __init__(self, socket): def is_reset(self): return True - @property - def throwaway(self): - return False - def stale(self): return False diff --git a/tests/unit/async_/io/test_neo4j_pool.py b/tests/unit/async_/io/test_neo4j_pool.py index 98707139..cb365df7 100644 --- a/tests/unit/async_/io/test_neo4j_pool.py +++ b/tests/unit/async_/io/test_neo4j_pool.py @@ -462,7 +462,7 @@ async def test__acquire_new_later_with_room(opener): opener, config, WorkspaceConfig(), ROUTER1_ADDRESS ) assert pool.connections_reservations[READER_ADDRESS] == 0 - creator = pool._acquire_new_later(READER_ADDRESS, None, Deadline(1), False) + creator = pool._acquire_new_later(READER_ADDRESS, None, Deadline(1)) assert pool.connections_reservations[READER_ADDRESS] == 1 assert callable(creator) if AsyncUtil.is_async_code: @@ -479,7 +479,7 @@ async def test__acquire_new_later_without_room(opener): _ = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) # pool is full now assert pool.connections_reservations[READER_ADDRESS] == 0 - creator = pool._acquire_new_later(READER_ADDRESS, None, Deadline(1), False) + creator = pool._acquire_new_later(READER_ADDRESS, None, Deadline(1)) assert pool.connections_reservations[READER_ADDRESS] == 0 assert creator is None diff --git a/tests/unit/common/test_conf.py b/tests/unit/common/test_conf.py index 52179568..88e67197 100644 --- a/tests/unit/common/test_conf.py +++ b/tests/unit/common/test_conf.py @@ -72,8 +72,6 @@ "fetch_size": 100, "bookmark_manager": object(), "auth": None, - # TODO: wait for decision on backwards_compatible_auth - # "backwards_compatible_auth": False, } diff --git a/tests/unit/sync/fixtures/fake_connection.py b/tests/unit/sync/fixtures/fake_connection.py index 5d66c9d9..d2c7dded 100644 --- a/tests/unit/sync/fixtures/fake_connection.py +++ b/tests/unit/sync/fixtures/fake_connection.py @@ -55,7 +55,6 @@ def __init__(self, *args, **kwargs): self.attach_mock(mock.Mock(spec=AuthManager), "auth_manager") self.unresolved_address = next(iter(args), "localhost") - self.throwaway = False def close_side_effect(): self.closed.return_value = True diff --git a/tests/unit/sync/io/test_direct.py b/tests/unit/sync/io/test_direct.py index ea3ea83e..455a2de5 100644 --- a/tests/unit/sync/io/test_direct.py +++ b/tests/unit/sync/io/test_direct.py @@ -60,10 +60,6 @@ def __init__(self, socket): def is_reset(self): return True - @property - def throwaway(self): - return False - def stale(self): return False diff --git a/tests/unit/sync/io/test_neo4j_pool.py b/tests/unit/sync/io/test_neo4j_pool.py index 52f44742..a6a9db78 100644 --- a/tests/unit/sync/io/test_neo4j_pool.py +++ b/tests/unit/sync/io/test_neo4j_pool.py @@ -462,7 +462,7 @@ def test__acquire_new_later_with_room(opener): opener, config, WorkspaceConfig(), ROUTER1_ADDRESS ) assert pool.connections_reservations[READER_ADDRESS] == 0 - creator = pool._acquire_new_later(READER_ADDRESS, None, Deadline(1), False) + creator = pool._acquire_new_later(READER_ADDRESS, None, Deadline(1)) assert pool.connections_reservations[READER_ADDRESS] == 1 assert callable(creator) if Util.is_async_code: @@ -479,7 +479,7 @@ def test__acquire_new_later_without_room(opener): _ = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) # pool is full now assert pool.connections_reservations[READER_ADDRESS] == 0 - creator = pool._acquire_new_later(READER_ADDRESS, None, Deadline(1), False) + creator = pool._acquire_new_later(READER_ADDRESS, None, Deadline(1)) assert pool.connections_reservations[READER_ADDRESS] == 0 assert creator is None From b9377d19b4216209a1e2d7f002e1e100a38ee4ea Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Fri, 31 Mar 2023 09:16:58 +0200 Subject: [PATCH 17/23] AuthManager docs += must not switch identities https://github.com/neo-technology/drivers-adr/pull/67 --- src/neo4j/_async/auth_management.py | 6 +++++- src/neo4j/_auth_management.py | 9 +++++++++ src/neo4j/_sync/auth_management.py | 6 +++++- 3 files changed, 19 insertions(+), 2 deletions(-) diff --git a/src/neo4j/_async/auth_management.py b/src/neo4j/_async/auth_management.py index c9fb568b..ad9deeef 100644 --- a/src/neo4j/_async/auth_management.py +++ b/src/neo4j/_async/auth_management.py @@ -141,13 +141,17 @@ def static(auth: _TAuth) -> AsyncAuthManager: def temporal( provider: t.Callable[[], t.Awaitable[TemporalAuth]] ) -> AsyncAuthManager: - """Create a auth manager for potentially expiring auth info. + """Create an auth manager for potentially expiring auth info. .. warning:: The provider function **must not** interact with the driver in any way as this can cause deadlocks and undefined behaviour. + The provider function only ever return auth information belonging + to the same identity. + Switching identities is undefined behavior. + Example:: import neo4j diff --git a/src/neo4j/_auth_management.py b/src/neo4j/_auth_management.py index d99468c8..f8c28570 100644 --- a/src/neo4j/_auth_management.py +++ b/src/neo4j/_auth_management.py @@ -64,6 +64,9 @@ class AuthManager(metaclass=abc.ABCMeta): Furthermore, the manager is expected to be thread-safe. + The token returned must always belong to the same identity. + Switching identities using the `AuthManager` is undefined behavior. + .. seealso:: :class:`.AuthManagers` .. versionadded:: 5.x @@ -75,6 +78,12 @@ def get_auth(self) -> _TAuth: The driver will call this method very frequently. It is recommended to implement some form of caching to avoid unnecessary overhead. + + .. warning:: + + The method must only ever return auth information belonging to the + same identity. + Switching identities using the `AuthManager` is undefined behavior. """ ... diff --git a/src/neo4j/_sync/auth_management.py b/src/neo4j/_sync/auth_management.py index 579f811d..329a375b 100644 --- a/src/neo4j/_sync/auth_management.py +++ b/src/neo4j/_sync/auth_management.py @@ -141,13 +141,17 @@ def static(auth: _TAuth) -> AuthManager: def temporal( provider: t.Callable[[], t.Union[TemporalAuth]] ) -> AuthManager: - """Create a auth manager for potentially expiring auth info. + """Create an auth manager for potentially expiring auth info. .. warning:: The provider function **must not** interact with the driver in any way as this can cause deadlocks and undefined behaviour. + The provider function only ever return auth information belonging + to the same identity. + Switching identities is undefined behavior. + Example:: import neo4j From 8b506f9c7eb3327d01366ac75e78a24295f96017 Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Mon, 3 Apr 2023 13:06:34 +0200 Subject: [PATCH 18/23] Rename messages following ADR change AuthManagers.temporal -> AuthManagers.expiration_based TemporalAuth-> ExpiringAuth --- docs/source/api.rst | 2 +- src/neo4j/_async/auth_management.py | 28 +++++++-------- src/neo4j/_auth_management.py | 2 +- src/neo4j/_sync/auth_management.py | 28 +++++++-------- src/neo4j/auth_management.py | 4 +-- testkitbackend/_async/backend.py | 2 +- testkitbackend/_async/requests.py | 50 +++++++++++++++----------- testkitbackend/_sync/backend.py | 2 +- testkitbackend/_sync/requests.py | 50 +++++++++++++++----------- tests/unit/async_/test_auth_manager.py | 18 +++++----- tests/unit/sync/test_auth_manager.py | 18 +++++----- 11 files changed, 110 insertions(+), 94 deletions(-) diff --git a/docs/source/api.rst b/docs/source/api.rst index 28b653f2..cfa13af3 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -117,7 +117,7 @@ The auth token is an object of the class :class:`neo4j.Auth` containing static d .. autoclass:: neo4j.auth_management.AuthManagers :members: -.. autoclass:: neo4j.auth_management.TemporalAuth +.. autoclass:: neo4j.auth_management.ExpiringAuth Example: diff --git a/src/neo4j/_async/auth_management.py b/src/neo4j/_async/auth_management.py index ad9deeef..37b7037a 100644 --- a/src/neo4j/_async/auth_management.py +++ b/src/neo4j/_async/auth_management.py @@ -27,7 +27,7 @@ from .._async_compat.concurrency import AsyncLock from .._auth_management import ( AsyncAuthManager, - TemporalAuth, + ExpiringAuth, ) # work around for https://github.com/sphinx-doc/sphinx/pull/10880 @@ -52,8 +52,8 @@ async def on_auth_expired(self, auth: _TAuth) -> None: pass -class _TemporalAuthHolder: - def __init__(self, auth: TemporalAuth) -> None: +class _ExpiringAuthHolder: + def __init__(self, auth: ExpiringAuth) -> None: self._auth = auth self._expiry = None if auth.expires_in is not None: @@ -68,22 +68,22 @@ def expired(self) -> bool: return False return time.monotonic() > self._expiry -class AsyncTemporalAuthManager(AsyncAuthManager): - _current_auth: t.Optional[_TemporalAuthHolder] - _provider: t.Callable[[], t.Awaitable[TemporalAuth]] +class AsyncExpirationBasedAuthManager(AsyncAuthManager): + _current_auth: t.Optional[_ExpiringAuthHolder] + _provider: t.Callable[[], t.Awaitable[ExpiringAuth]] _lock: AsyncLock def __init__( self, - provider: t.Callable[[], t.Awaitable[TemporalAuth]] + provider: t.Callable[[], t.Awaitable[ExpiringAuth]] ) -> None: self._provider = provider self._current_auth = None self._lock = AsyncLock() async def _refresh_auth(self): - self._current_auth = _TemporalAuthHolder(await self._provider()) + self._current_auth = _ExpiringAuthHolder(await self._provider()) async def get_auth(self) -> _TAuth: async with self._lock: @@ -138,8 +138,8 @@ def static(auth: _TAuth) -> AsyncAuthManager: return AsyncStaticAuthManager(auth) @staticmethod - def temporal( - provider: t.Callable[[], t.Awaitable[TemporalAuth]] + def expiration_based( + provider: t.Callable[[], t.Awaitable[ExpiringAuth]] ) -> AsyncAuthManager: """Create an auth manager for potentially expiring auth info. @@ -157,7 +157,7 @@ def temporal( import neo4j from neo4j.auth_management import ( AsyncAuthManagers, - TemporalAuth, + ExpiringAuth, ) @@ -167,7 +167,7 @@ async def auth_provider(): # assume we know our tokens expire every 60 seconds expires_in = 60 - return TemporalAuth( + return ExpiringAuth( auth=neo4j.bearer_auth(sso_token), # Include a little buffer so that we fetch a new token # *before* the old one expires @@ -182,7 +182,7 @@ async def auth_provider(): ... # do stuff :param provider: - A callable that provides a :class:`.TemporalAuth` instance. + A callable that provides a :class:`.ExpiringAuth` instance. :returns: An instance of an implementation of :class:`.AsyncAuthManager` that @@ -191,4 +191,4 @@ async def auth_provider(): reached its expiry time or because the server flagged it as expired). """ - return AsyncTemporalAuthManager(provider) + return AsyncExpirationBasedAuthManager(provider) diff --git a/src/neo4j/_auth_management.py b/src/neo4j/_auth_management.py index f8c28570..5becce4d 100644 --- a/src/neo4j/_auth_management.py +++ b/src/neo4j/_auth_management.py @@ -28,7 +28,7 @@ @dataclass -class TemporalAuth: +class ExpiringAuth: """Represents potentially expiring authentication information. This class is used with :meth:`.AuthManagers.temporal` and diff --git a/src/neo4j/_sync/auth_management.py b/src/neo4j/_sync/auth_management.py index 329a375b..96d89973 100644 --- a/src/neo4j/_sync/auth_management.py +++ b/src/neo4j/_sync/auth_management.py @@ -27,7 +27,7 @@ from .._async_compat.concurrency import Lock from .._auth_management import ( AuthManager, - TemporalAuth, + ExpiringAuth, ) # work around for https://github.com/sphinx-doc/sphinx/pull/10880 @@ -52,8 +52,8 @@ def on_auth_expired(self, auth: _TAuth) -> None: pass -class _TemporalAuthHolder: - def __init__(self, auth: TemporalAuth) -> None: +class _ExpiringAuthHolder: + def __init__(self, auth: ExpiringAuth) -> None: self._auth = auth self._expiry = None if auth.expires_in is not None: @@ -68,22 +68,22 @@ def expired(self) -> bool: return False return time.monotonic() > self._expiry -class TemporalAuthManager(AuthManager): - _current_auth: t.Optional[_TemporalAuthHolder] - _provider: t.Callable[[], t.Union[TemporalAuth]] +class ExpirationBasedAuthManager(AuthManager): + _current_auth: t.Optional[_ExpiringAuthHolder] + _provider: t.Callable[[], t.Union[ExpiringAuth]] _lock: Lock def __init__( self, - provider: t.Callable[[], t.Union[TemporalAuth]] + provider: t.Callable[[], t.Union[ExpiringAuth]] ) -> None: self._provider = provider self._current_auth = None self._lock = Lock() def _refresh_auth(self): - self._current_auth = _TemporalAuthHolder(self._provider()) + self._current_auth = _ExpiringAuthHolder(self._provider()) def get_auth(self) -> _TAuth: with self._lock: @@ -138,8 +138,8 @@ def static(auth: _TAuth) -> AuthManager: return StaticAuthManager(auth) @staticmethod - def temporal( - provider: t.Callable[[], t.Union[TemporalAuth]] + def expiration_based( + provider: t.Callable[[], t.Union[ExpiringAuth]] ) -> AuthManager: """Create an auth manager for potentially expiring auth info. @@ -157,7 +157,7 @@ def temporal( import neo4j from neo4j.auth_management import ( AuthManagers, - TemporalAuth, + ExpiringAuth, ) @@ -167,7 +167,7 @@ def auth_provider(): # assume we know our tokens expire every 60 seconds expires_in = 60 - return TemporalAuth( + return ExpiringAuth( auth=neo4j.bearer_auth(sso_token), # Include a little buffer so that we fetch a new token # *before* the old one expires @@ -182,7 +182,7 @@ def auth_provider(): ... # do stuff :param provider: - A callable that provides a :class:`.TemporalAuth` instance. + A callable that provides a :class:`.ExpiringAuth` instance. :returns: An instance of an implementation of :class:`.AuthManager` that @@ -191,4 +191,4 @@ def auth_provider(): reached its expiry time or because the server flagged it as expired). """ - return TemporalAuthManager(provider) + return ExpirationBasedAuthManager(provider) diff --git a/src/neo4j/auth_management.py b/src/neo4j/auth_management.py index 84b2a386..8fdad964 100644 --- a/src/neo4j/auth_management.py +++ b/src/neo4j/auth_management.py @@ -20,7 +20,7 @@ from ._auth_management import ( AsyncAuthManager, AuthManager, - TemporalAuth, + ExpiringAuth, ) from ._sync.auth_management import AuthManagers @@ -30,5 +30,5 @@ "AsyncAuthManagers", "AuthManager", "AuthManagers", - "TemporalAuth", + "ExpiringAuth", ] diff --git a/testkitbackend/_async/backend.py b/testkitbackend/_async/backend.py index 2927381e..b3bdd885 100644 --- a/testkitbackend/_async/backend.py +++ b/testkitbackend/_async/backend.py @@ -58,7 +58,7 @@ def __init__(self, rd, wr): self.auth_token_managers = {} self.auth_token_supplies = {} self.auth_token_on_expiration_supplies = {} - self.temporal_auth_token_supplies = {} + self.expiring_auth_token_supplies = {} self.bookmark_managers = {} self.bookmarks_consumptions = {} self.bookmarks_supplies = {} diff --git a/testkitbackend/_async/requests.py b/testkitbackend/_async/requests.py index 76d60569..ec867870 100644 --- a/testkitbackend/_async/requests.py +++ b/testkitbackend/_async/requests.py @@ -31,7 +31,7 @@ from neo4j.auth_management import ( AsyncAuthManager, AsyncAuthManagers, - TemporalAuth, + ExpiringAuth, ) from .. import ( @@ -197,8 +197,9 @@ async def on_auth_expired(self, auth): auth_manager = TestKitAuthManager() backend.auth_token_managers[auth_token_manager_id] = auth_manager - await backend.send_response("AuthTokenManager", - {"id": auth_token_manager_id}) + await backend.send_response( + "AuthTokenManager", {"id": auth_token_manager_id} + ) async def AuthTokenManagerGetAuthCompleted(backend, data): @@ -214,47 +215,54 @@ async def AuthTokenManagerOnAuthExpiredCompleted(backend, data): async def AuthTokenManagerClose(backend, data): auth_token_manager_id = data["id"] del backend.auth_token_managers[auth_token_manager_id] - await backend.send_response("AuthTokenManager", - {"id": auth_token_manager_id}) + await backend.send_response( + "AuthTokenManager", {"id": auth_token_manager_id} + ) -async def NewTemporalAuthTokenManager(backend, data): +async def NewExpirationBasedAuthTokenManager(backend, data): auth_token_manager_id = backend.next_key() async def auth_token_provider(): key = backend.next_key() - await backend.send_response("TemporalAuthTokenProviderRequest", { - "id": key, - "temporalAuthTokenManagerId": auth_token_manager_id, - }) + await backend.send_response( + "ExpirationBasedAuthTokenProviderRequest", + { + "id": key, + "expirationBasedAuthTokenManagerId": auth_token_manager_id, + } + ) if not await backend.process_request(): # connection was closed before end of next message - return neo4j.auth_management.TemporalAuth(None, None) - if key not in backend.temporal_auth_token_supplies: + return neo4j.auth_management.ExpiringAuth(None, None) + if key not in backend.expiring_auth_token_supplies: raise RuntimeError( "Backend did not receive expected " - f"TemporalAuthTokenManagerCompleted message for id {key}" + "ExpirationBasedAuthTokenManagerCompleted message for id " + f"{key}" ) - return backend.temporal_auth_token_supplies.pop(key) + return backend.expiring_auth_token_supplies.pop(key) - auth_manager = AsyncAuthManagers.temporal(auth_token_provider) + auth_manager = AsyncAuthManagers.expiration_based(auth_token_provider) backend.auth_token_managers[auth_token_manager_id] = auth_manager - await backend.send_response("TemporalAuthTokenManager", - {"id": auth_token_manager_id}) + await backend.send_response( + "ExpirationBasedAuthTokenManager", {"id": auth_token_manager_id} + ) -async def TemporalAuthTokenProviderCompleted(backend, data): +async def ExpirationBasedAuthTokenProviderCompleted(backend, data): temp_auth_data = data["auth"] - temp_auth_data.mark_item_as_read_if_equals("name", "TemporalAuthToken") + temp_auth_data.mark_item_as_read_if_equals("name", + "AuthTokenAndExpiration") temp_auth_data = temp_auth_data["data"] auth_token = fromtestkit.to_auth_token(temp_auth_data, "auth") if temp_auth_data["expiresInMs"] is not None: expires_in = temp_auth_data["expiresInMs"] / 1000 else: expires_in = None - temporal_auth = TemporalAuth(auth_token, expires_in) + expiring_auth = ExpiringAuth(auth_token, expires_in) - backend.temporal_auth_token_supplies[data["requestId"]] = temporal_auth + backend.expiring_auth_token_supplies[data["requestId"]] = expiring_auth async def VerifyConnectivity(backend, data): diff --git a/testkitbackend/_sync/backend.py b/testkitbackend/_sync/backend.py index eb342c0f..b5ad66ac 100644 --- a/testkitbackend/_sync/backend.py +++ b/testkitbackend/_sync/backend.py @@ -58,7 +58,7 @@ def __init__(self, rd, wr): self.auth_token_managers = {} self.auth_token_supplies = {} self.auth_token_on_expiration_supplies = {} - self.temporal_auth_token_supplies = {} + self.expiring_auth_token_supplies = {} self.bookmark_managers = {} self.bookmarks_consumptions = {} self.bookmarks_supplies = {} diff --git a/testkitbackend/_sync/requests.py b/testkitbackend/_sync/requests.py index 95b44de4..a6fdf45c 100644 --- a/testkitbackend/_sync/requests.py +++ b/testkitbackend/_sync/requests.py @@ -31,7 +31,7 @@ from neo4j.auth_management import ( AuthManager, AuthManagers, - TemporalAuth, + ExpiringAuth, ) from .. import ( @@ -197,8 +197,9 @@ def on_auth_expired(self, auth): auth_manager = TestKitAuthManager() backend.auth_token_managers[auth_token_manager_id] = auth_manager - backend.send_response("AuthTokenManager", - {"id": auth_token_manager_id}) + backend.send_response( + "AuthTokenManager", {"id": auth_token_manager_id} + ) def AuthTokenManagerGetAuthCompleted(backend, data): @@ -214,47 +215,54 @@ def AuthTokenManagerOnAuthExpiredCompleted(backend, data): def AuthTokenManagerClose(backend, data): auth_token_manager_id = data["id"] del backend.auth_token_managers[auth_token_manager_id] - backend.send_response("AuthTokenManager", - {"id": auth_token_manager_id}) + backend.send_response( + "AuthTokenManager", {"id": auth_token_manager_id} + ) -def NewTemporalAuthTokenManager(backend, data): +def NewExpirationBasedAuthTokenManager(backend, data): auth_token_manager_id = backend.next_key() def auth_token_provider(): key = backend.next_key() - backend.send_response("TemporalAuthTokenProviderRequest", { - "id": key, - "temporalAuthTokenManagerId": auth_token_manager_id, - }) + backend.send_response( + "ExpirationBasedAuthTokenProviderRequest", + { + "id": key, + "expirationBasedAuthTokenManagerId": auth_token_manager_id, + } + ) if not backend.process_request(): # connection was closed before end of next message - return neo4j.auth_management.TemporalAuth(None, None) - if key not in backend.temporal_auth_token_supplies: + return neo4j.auth_management.ExpiringAuth(None, None) + if key not in backend.expiring_auth_token_supplies: raise RuntimeError( "Backend did not receive expected " - f"TemporalAuthTokenManagerCompleted message for id {key}" + "ExpirationBasedAuthTokenManagerCompleted message for id " + f"{key}" ) - return backend.temporal_auth_token_supplies.pop(key) + return backend.expiring_auth_token_supplies.pop(key) - auth_manager = AuthManagers.temporal(auth_token_provider) + auth_manager = AuthManagers.expiration_based(auth_token_provider) backend.auth_token_managers[auth_token_manager_id] = auth_manager - backend.send_response("TemporalAuthTokenManager", - {"id": auth_token_manager_id}) + backend.send_response( + "ExpirationBasedAuthTokenManager", {"id": auth_token_manager_id} + ) -def TemporalAuthTokenProviderCompleted(backend, data): +def ExpirationBasedAuthTokenProviderCompleted(backend, data): temp_auth_data = data["auth"] - temp_auth_data.mark_item_as_read_if_equals("name", "TemporalAuthToken") + temp_auth_data.mark_item_as_read_if_equals("name", + "AuthTokenAndExpiration") temp_auth_data = temp_auth_data["data"] auth_token = fromtestkit.to_auth_token(temp_auth_data, "auth") if temp_auth_data["expiresInMs"] is not None: expires_in = temp_auth_data["expiresInMs"] / 1000 else: expires_in = None - temporal_auth = TemporalAuth(auth_token, expires_in) + expiring_auth = ExpiringAuth(auth_token, expires_in) - backend.temporal_auth_token_supplies[data["requestId"]] = temporal_auth + backend.expiring_auth_token_supplies[data["requestId"]] = expiring_auth def VerifyConnectivity(backend, data): diff --git a/tests/unit/async_/test_auth_manager.py b/tests/unit/async_/test_auth_manager.py index f201cbeb..9312824a 100644 --- a/tests/unit/async_/test_auth_manager.py +++ b/tests/unit/async_/test_auth_manager.py @@ -12,7 +12,7 @@ from neo4j.auth_management import ( AsyncAuthManager, AsyncAuthManagers, - TemporalAuth, + ExpiringAuth, ) from ..._async_compat import mark_async_test @@ -53,18 +53,18 @@ async def test_temporal_manager_manual_expiry( mocker ) -> None: if expires_in is None or expires_in >= 0: - temporal_auth = TemporalAuth(auth1, expires_in) + temporal_auth = ExpiringAuth(auth1, expires_in) else: - temporal_auth = TemporalAuth(auth1) + temporal_auth = ExpiringAuth(auth1) provider = mocker.AsyncMock(return_value=temporal_auth) - manager: AsyncAuthManager = AsyncAuthManagers.temporal(provider) + manager: AsyncAuthManager = AsyncAuthManagers.expiration_based(provider) provider.assert_not_called() assert await manager.get_auth() is auth1 provider.assert_awaited_once() provider.reset_mock() - provider.return_value = TemporalAuth(auth2) + provider.return_value = ExpiringAuth(auth2) await manager.on_auth_expired(("something", "else")) assert await manager.get_auth() is auth1 @@ -90,18 +90,18 @@ async def test_temporal_manager_time_expiry( with freeze_time() as frozen_time: assert isinstance(frozen_time, FrozenDateTimeFactory) if expires_in is None or expires_in >= 0: - temporal_auth = TemporalAuth(auth1, expires_in) + temporal_auth = ExpiringAuth(auth1, expires_in) else: - temporal_auth = TemporalAuth(auth1) + temporal_auth = ExpiringAuth(auth1) provider = mocker.AsyncMock(return_value=temporal_auth) - manager: AsyncAuthManager = AsyncAuthManagers.temporal(provider) + manager: AsyncAuthManager = AsyncAuthManagers.expiration_based(provider) provider.assert_not_called() assert await manager.get_auth() is auth1 provider.assert_awaited_once() provider.reset_mock() - provider.return_value = TemporalAuth(auth2) + provider.return_value = ExpiringAuth(auth2) if expires_in is None or expires_in < 0: frozen_time.tick(1_000_000) diff --git a/tests/unit/sync/test_auth_manager.py b/tests/unit/sync/test_auth_manager.py index 9d18f789..0b91714f 100644 --- a/tests/unit/sync/test_auth_manager.py +++ b/tests/unit/sync/test_auth_manager.py @@ -12,7 +12,7 @@ from neo4j.auth_management import ( AuthManager, AuthManagers, - TemporalAuth, + ExpiringAuth, ) from ..._async_compat import mark_sync_test @@ -53,18 +53,18 @@ def test_temporal_manager_manual_expiry( mocker ) -> None: if expires_in is None or expires_in >= 0: - temporal_auth = TemporalAuth(auth1, expires_in) + temporal_auth = ExpiringAuth(auth1, expires_in) else: - temporal_auth = TemporalAuth(auth1) + temporal_auth = ExpiringAuth(auth1) provider = mocker.MagicMock(return_value=temporal_auth) - manager: AuthManager = AuthManagers.temporal(provider) + manager: AuthManager = AuthManagers.expiration_based(provider) provider.assert_not_called() assert manager.get_auth() is auth1 provider.assert_called_once() provider.reset_mock() - provider.return_value = TemporalAuth(auth2) + provider.return_value = ExpiringAuth(auth2) manager.on_auth_expired(("something", "else")) assert manager.get_auth() is auth1 @@ -90,18 +90,18 @@ def test_temporal_manager_time_expiry( with freeze_time() as frozen_time: assert isinstance(frozen_time, FrozenDateTimeFactory) if expires_in is None or expires_in >= 0: - temporal_auth = TemporalAuth(auth1, expires_in) + temporal_auth = ExpiringAuth(auth1, expires_in) else: - temporal_auth = TemporalAuth(auth1) + temporal_auth = ExpiringAuth(auth1) provider = mocker.MagicMock(return_value=temporal_auth) - manager: AuthManager = AuthManagers.temporal(provider) + manager: AuthManager = AuthManagers.expiration_based(provider) provider.assert_not_called() assert manager.get_auth() is auth1 provider.assert_called_once() provider.reset_mock() - provider.return_value = TemporalAuth(auth2) + provider.return_value = ExpiringAuth(auth2) if expires_in is None or expires_in < 0: frozen_time.tick(1_000_000) From 6aef317873fddd208a2d74b8a1bb9664d4847625 Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Mon, 3 Apr 2023 16:11:36 +0200 Subject: [PATCH 19/23] Mark re-auth as preview feature and schedule for 5.8 release --- docs/source/api.rst | 33 ++++++++-- docs/source/async_api.rst | 29 +++++++-- src/neo4j/__init__.py | 2 + src/neo4j/_async/auth_management.py | 14 +++- src/neo4j/_async/driver.py | 74 +++++++++++++++++++--- src/neo4j/_async/io/_bolt.py | 2 +- src/neo4j/_async/work/result.py | 2 +- src/neo4j/_async/work/session.py | 15 ++++- src/neo4j/_auth_management.py | 20 +++++- src/neo4j/_meta.py | 46 ++++++++++++++ src/neo4j/_sync/auth_management.py | 14 +++- src/neo4j/_sync/driver.py | 74 +++++++++++++++++++--- src/neo4j/_sync/io/_bolt.py | 2 +- src/neo4j/_sync/work/result.py | 2 +- src/neo4j/_sync/work/session.py | 15 ++++- src/neo4j/_work/eager_result.py | 2 +- testkitbackend/_async/requests.py | 50 +++++++++++---- testkitbackend/_sync/requests.py | 50 +++++++++++---- tests/unit/async_/io/conftest.py | 15 +++++ tests/unit/async_/io/test_class_bolt.py | 23 +++---- tests/unit/async_/io/test_class_bolt3.py | 4 +- tests/unit/async_/io/test_class_bolt4x0.py | 4 +- tests/unit/async_/io/test_class_bolt4x1.py | 4 +- tests/unit/async_/io/test_class_bolt4x2.py | 4 +- tests/unit/async_/io/test_class_bolt4x3.py | 4 +- tests/unit/async_/io/test_class_bolt4x4.py | 4 +- tests/unit/async_/io/test_class_bolt5x0.py | 4 +- tests/unit/async_/io/test_class_bolt5x1.py | 4 +- tests/unit/async_/io/test_class_bolt5x2.py | 4 +- tests/unit/async_/io/test_direct.py | 16 +++-- tests/unit/async_/io/test_neo4j_pool.py | 10 ++- tests/unit/async_/test_auth_manager.py | 59 +++++++++++++---- tests/unit/async_/test_driver.py | 3 +- tests/unit/sync/io/conftest.py | 15 +++++ tests/unit/sync/io/test_class_bolt.py | 23 +++---- tests/unit/sync/io/test_class_bolt3.py | 4 +- tests/unit/sync/io/test_class_bolt4x0.py | 4 +- tests/unit/sync/io/test_class_bolt4x1.py | 4 +- tests/unit/sync/io/test_class_bolt4x2.py | 4 +- tests/unit/sync/io/test_class_bolt4x3.py | 4 +- tests/unit/sync/io/test_class_bolt4x4.py | 4 +- tests/unit/sync/io/test_class_bolt5x0.py | 4 +- tests/unit/sync/io/test_class_bolt5x1.py | 4 +- tests/unit/sync/io/test_class_bolt5x2.py | 4 +- tests/unit/sync/io/test_direct.py | 16 +++-- tests/unit/sync/io/test_neo4j_pool.py | 10 ++- tests/unit/sync/test_auth_manager.py | 59 +++++++++++++---- tests/unit/sync/test_driver.py | 3 +- 48 files changed, 597 insertions(+), 173 deletions(-) diff --git a/docs/source/api.rst b/docs/source/api.rst index cfa13af3..5ae6ff11 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -182,7 +182,7 @@ Closing a driver will immediately shut down all connections in the pool. def execute_query( query_, parameters_, routing_, database_, impersonated_user_, - bookmark_manager_, result_transformer_, **kwargs + bookmark_manager_, auth_, result_transformer_, **kwargs ): def work(tx): result = tx.run(query_, parameters_, **kwargs) @@ -192,6 +192,7 @@ Closing a driver will immediately shut down all connections in the pool. database=database_, impersonated_user=impersonated_user_, bookmark_manager=bookmark_manager_, + auth=auth_, ) as session: if routing_ == RoutingControl.WRITERS: return session.execute_write(work) @@ -271,6 +272,20 @@ Closing a driver will immediately shut down all connections in the pool. See also the Session config :ref:`impersonated-user-ref`. :type impersonated_user_: typing.Optional[str] + :param auth_: + Authentication information to use for this query. + + By default, the driver configuration is used. + + **This is a preview** (see :ref:`filter-warnings-ref`). + It might be changed without following the deprecation policy. + See also + https://github.com/neo4j/neo4j-python-driver/wiki/preview-features + + See also the Session config :ref:`session-auth-ref`. + :type auth_: typing.Union[ + typing.Tuple[typing.Any, typing.Any], neo4j.Auth, None + ] :param result_transformer_: A function that gets passed the :class:`neo4j.Result` object resulting from the query and converts it to a different type. The @@ -281,7 +296,6 @@ Closing a driver will immediately shut down all connections in the pool. The transformer function must **not** return the :class:`neo4j.Result` itself. - .. warning:: N.B. the driver might retry the underlying transaction so the @@ -356,7 +370,7 @@ Closing a driver will immediately shut down all connections in the pool. :returns: the result of the ``result_transformer`` :rtype: T - **This is experimental.** (See :ref:`filter-warnings-ref`) + **This is experimental** (see :ref:`filter-warnings-ref`). It might be changed or removed any time even without prior notice. We are looking for feedback on this feature. Please let us know what @@ -365,6 +379,9 @@ Closing a driver will immediately shut down all connections in the pool. .. versionadded:: 5.5 + .. versionchanged:: 5.8 + Added the ``auth_`` parameter. + .. _driver-configuration-ref: @@ -441,7 +458,7 @@ Specify whether TCP keep-alive should be enabled. :Type: ``bool`` :Default: ``True`` -**This is experimental.** (See :ref:`filter-warnings-ref`) +**This is experimental** (see :ref:`filter-warnings-ref`). It might be changed or removed any time even without prior notice. @@ -1005,7 +1022,7 @@ See :class:`.BookmarkManager` for more information. .. versionadded:: 5.0 -**This is experimental.** (See :ref:`filter-warnings-ref`) +**This is experimental** (see :ref:`filter-warnings-ref`). It might be changed or removed any time even without prior notice. @@ -1022,6 +1039,10 @@ It is not possible to overwrite the authentication information for the session w i.e., downgrade the authentication at session level. Instead, you should create a driver with no authentication and upgrade the authentication at session level as needed. +**This is a preview** (see :ref:`filter-warnings-ref`). +It might be changed without following the deprecation policy. +See also https://github.com/neo4j/neo4j-python-driver/wiki/preview-features + :Type: :data:`None`, :class:`.Auth` or ``(user, password)``-tuple :Default: :data:`None` - use the authentication information provided during driver creation. @@ -1322,7 +1343,7 @@ Graph .. automethod:: relationship_type -**This is experimental.** (See :ref:`filter-warnings-ref`) +**This is experimental** (see :ref:`filter-warnings-ref`). It might be changed or removed any time even without prior notice. diff --git a/docs/source/async_api.rst b/docs/source/async_api.rst index 87a1edc8..603a9c97 100644 --- a/docs/source/async_api.rst +++ b/docs/source/async_api.rst @@ -172,7 +172,7 @@ Closing a driver will immediately shut down all connections in the pool. async def execute_query( query_, parameters_, routing_, database_, impersonated_user_, - bookmark_manager_, result_transformer_, **kwargs + bookmark_manager_, auth_, result_transformer_, **kwargs ): async def work(tx): result = await tx.run(query_, parameters_, **kwargs) @@ -182,6 +182,7 @@ Closing a driver will immediately shut down all connections in the pool. database=database_, impersonated_user=impersonated_user_, bookmark_manager=bookmark_manager_, + auth=auth_, ) as session: if routing_ == RoutingControl.WRITERS: return await session.execute_write(work) @@ -217,13 +218,14 @@ Closing a driver will immediately shut down all connections in the pool. async def example(driver: neo4j.AsyncDriver) -> int: """Call all young people "My dear" and get their count.""" record = await driver.execute_query( - "MATCH (p:Person) WHERE p.age <= 15 " + "MATCH (p:Person) WHERE p.age <= $age " "SET p.nickname = 'My dear' " "RETURN count(*)", # optional routing parameter, as write is default # routing_=neo4j.RoutingControl.WRITERS, # or just "w", database_="neo4j", result_transformer_=neo4j.AsyncResult.single, + age=15, ) assert record is not None # for typechecking and illustration count = record[0] @@ -260,6 +262,20 @@ Closing a driver will immediately shut down all connections in the pool. See also the Session config :ref:`impersonated-user-ref`. :type impersonated_user_: typing.Optional[str] + :param auth_: + Authentication information to use for this query. + + By default, the driver configuration is used. + + **This is a preview** (see :ref:`filter-warnings-ref`). + It might be changed without following the deprecation policy. + See also + https://github.com/neo4j/neo4j-python-driver/wiki/preview-features + + See also the Session config :ref:`session-auth-ref`. + :type auth_: typing.Union[ + typing.Tuple[typing.Any, typing.Any], neo4j.Auth, None + ] :param result_transformer_: A function that gets passed the :class:`neo4j.AsyncResult` object resulting from the query and converts it to a different type. The @@ -330,7 +346,7 @@ Closing a driver will immediately shut down all connections in the pool. Defaults to the driver's :attr:`.query_bookmark_manager`. - Pass :const:`None` to disable causal consistency. + Pass :data:`None` to disable causal consistency. :type bookmark_manager_: typing.Union[neo4j.AsyncBookmarkManager, neo4j.BookmarkManager, None] @@ -344,7 +360,7 @@ Closing a driver will immediately shut down all connections in the pool. :returns: the result of the ``result_transformer`` :rtype: T - **This is experimental.** (See :ref:`filter-warnings-ref`) + **This is experimental** (see :ref:`filter-warnings-ref`). It might be changed or removed any time even without prior notice. We are looking for feedback on this feature. Please let us know what @@ -353,6 +369,9 @@ Closing a driver will immediately shut down all connections in the pool. .. versionadded:: 5.5 + .. versionchanged:: 5.8 + Added the ``auth_`` parameter. + .. _async-driver-configuration-ref: @@ -640,7 +659,7 @@ See :class:`BookmarkManager` for more information. :Type: :data:`None`, :class:`BookmarkManager`, or :class:`AsyncBookmarkManager` :Default: :data:`None` -**This is experimental.** (See :ref:`filter-warnings-ref`) +**This is experimental** (see :ref:`filter-warnings-ref`). It might be changed or removed any time even without prior notice. diff --git a/src/neo4j/__init__.py b/src/neo4j/__init__.py index 28cbec5a..b18803d5 100644 --- a/src/neo4j/__init__.py +++ b/src/neo4j/__init__.py @@ -52,6 +52,7 @@ deprecation_warn as _deprecation_warn, ExperimentalWarning, get_user_agent, + PreviewWarning, version as __version__, ) from ._sync.driver import ( @@ -140,6 +141,7 @@ "NotificationFilter", "NotificationSeverity", "PoolConfig", + "PreviewWarning", "Query", "READ_ACCESS", "Record", diff --git a/src/neo4j/_async/auth_management.py b/src/neo4j/_async/auth_management.py index 37b7037a..de6d0055 100644 --- a/src/neo4j/_async/auth_management.py +++ b/src/neo4j/_async/auth_management.py @@ -29,6 +29,7 @@ AsyncAuthManager, ExpiringAuth, ) +from .._meta import preview # work around for https://github.com/sphinx-doc/sphinx/pull/10880 # make sure TAuth is resolved in the docs, else they're pretty useless @@ -104,9 +105,17 @@ async def on_auth_expired(self, auth: _TAuth) -> None: class AsyncAuthManagers: - """A collection of :class:`.AsyncAuthManager` factories.""" + """A collection of :class:`.AsyncAuthManager` factories. + + **This is a preview** (see :ref:`filter-warnings-ref`). + It might be changed without following the deprecation policy. + See also https://github.com/neo4j/neo4j-python-driver/wiki/preview-features + + .. versionadded:: 5.8 + """ @staticmethod + @preview("Auth managers are a preview feature.") def static(auth: _TAuth) -> AsyncAuthManager: """Create a static auth manager. @@ -138,6 +147,7 @@ def static(auth: _TAuth) -> AsyncAuthManager: return AsyncStaticAuthManager(auth) @staticmethod + @preview("Auth managers are a preview feature.") def expiration_based( provider: t.Callable[[], t.Awaitable[ExpiringAuth]] ) -> AsyncAuthManager: @@ -190,5 +200,7 @@ async def auth_provider(): the provider again, when the auth info expires (either because it's reached its expiry time or because the server flagged it as expired). + + """ return AsyncExpirationBasedAuthManager(provider) diff --git a/src/neo4j/_async/driver.py b/src/neo4j/_async/driver.py index 0b101829..88de88a6 100644 --- a/src/neo4j/_async/driver.py +++ b/src/neo4j/_async/driver.py @@ -47,6 +47,9 @@ experimental, experimental_warn, ExperimentalWarning, + preview, + preview_warn, + PreviewWarning, unclosed_resource_warn, ) from .._work import EagerResult @@ -198,7 +201,15 @@ def driver( driver_type, security_type, parsed = parse_neo4j_uri(uri) if not isinstance(auth, (AsyncAuthManager, AuthManager)): - auth = AsyncAuthManagers.static(auth) + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", message=r".*\bAuth managers\b.*", + category=PreviewWarning + ) + auth = AsyncAuthManagers.static(auth) + else: + preview_warn("Auth managers are a preview feature.", + stack_level=2) config["auth"] = auth # TODO: 6.0 - remove "trust" config option @@ -349,7 +360,7 @@ def bookmark_manager( :returns: A default implementation of :class:`AsyncBookmarkManager`. - **This is experimental.** (See :ref:`filter-warnings-ref`) + **This is experimental** (see :ref:`filter-warnings-ref`). It might be changed or removed any time even without prior notice. .. versionadded:: 5.0 @@ -506,6 +517,9 @@ def encrypted(self) -> bool: return bool(self._pool.pool_config.encrypted) def _prepare_session_config(self, **config): + if "auth" in config: + preview_warn("User switching is a preview features.", + stack_level=3) _normalize_notifications_config(config) return config @@ -574,6 +588,7 @@ async def execute_query( bookmark_manager_: t.Union[ AsyncBookmarkManager, BookmarkManager, None ] = ..., + auth_: _TAuth = None, result_transformer_: t.Callable[ [AsyncResult], t.Awaitable[EagerResult] ] = ..., @@ -592,6 +607,7 @@ async def execute_query( bookmark_manager_: t.Union[ AsyncBookmarkManager, BookmarkManager, None ] = ..., + auth_: _TAuth = None, result_transformer_: t.Callable[ [AsyncResult], t.Awaitable[_T] ] = ..., @@ -614,6 +630,7 @@ async def execute_query( AsyncBookmarkManager, BookmarkManager, None, te.Literal[_DefaultEnum.default] ] = _default, + auth_: _TAuth = None, result_transformer_: t.Callable[ [AsyncResult], t.Awaitable[t.Any] ] = AsyncResult.to_eager_result, @@ -635,7 +652,7 @@ async def execute_query( async def execute_query( query_, parameters_, routing_, database_, impersonated_user_, - bookmark_manager_, result_transformer_, **kwargs + bookmark_manager_, auth_, result_transformer_, **kwargs ): async def work(tx): result = await tx.run(query_, parameters_, **kwargs) @@ -645,6 +662,7 @@ async def work(tx): database=database_, impersonated_user=impersonated_user_, bookmark_manager=bookmark_manager_, + auth=auth_, ) as session: if routing_ == RoutingControl.WRITERS: return await session.execute_write(work) @@ -680,13 +698,14 @@ async def example(driver: neo4j.AsyncDriver) -> List[str]: async def example(driver: neo4j.AsyncDriver) -> int: \"""Call all young people "My dear" and get their count.\""" record = await driver.execute_query( - "MATCH (p:Person) WHERE p.age <= 15 " + "MATCH (p:Person) WHERE p.age <= $age " "SET p.nickname = 'My dear' " "RETURN count(*)", # optional routing parameter, as write is default # routing_=neo4j.RoutingControl.WRITERS, # or just "w", database_="neo4j", result_transformer_=neo4j.AsyncResult.single, + age=15, ) assert record is not None # for typechecking and illustration count = record[0] @@ -723,6 +742,20 @@ async def example(driver: neo4j.AsyncDriver) -> int: See also the Session config :ref:`impersonated-user-ref`. :type impersonated_user_: typing.Optional[str] + :param auth_: + Authentication information to use for this query. + + By default, the driver configuration is used. + + **This is a preview** (see :ref:`filter-warnings-ref`). + It might be changed without following the deprecation policy. + See also + https://github.com/neo4j/neo4j-python-driver/wiki/preview-features + + See also the Session config :ref:`session-auth-ref`. + :type auth_: typing.Union[ + typing.Tuple[typing.Any, typing.Any], neo4j.Auth, None + ] :param result_transformer_: A function that gets passed the :class:`neo4j.AsyncResult` object resulting from the query and converts it to a different type. The @@ -807,10 +840,17 @@ async def example(driver: neo4j.AsyncDriver) -> neo4j.Record:: :returns: the result of the ``result_transformer`` :rtype: T - **This is experimental.** (See :ref:`filter-warnings-ref`) + **This is experimental** (see :ref:`filter-warnings-ref`). It might be changed or removed any time even without prior notice. + We are looking for feedback on this feature. Please let us know what + you think about it here: + https://github.com/neo4j/neo4j-python-driver/discussions/896 + .. versionadded:: 5.5 + + .. versionchanged:: 5.8 + Added the ``auth_`` parameter. """ invalid_kwargs = [k for k in kwargs if k[-2:-1] != "_" and k[-1:] == "_"] @@ -832,9 +872,13 @@ async def example(driver: neo4j.AsyncDriver) -> neo4j.Record:: warnings.filterwarnings("ignore", message=r".*\bbookmark_manager\b.*", category=ExperimentalWarning) + warnings.filterwarnings("ignore", + message=r"^User switching\b.*", + category=PreviewWarning) session = self.session(database=database_, impersonated_user=impersonated_user_, - bookmark_manager=bookmark_manager_) + bookmark_manager=bookmark_manager_, + auth=auth_) async with session: if routing_ == RoutingControl.WRITERS: executor = session.execute_write @@ -871,7 +915,7 @@ async def example(driver: neo4j.AsyncDriver) -> None: # (i.e., can read what was written by ) await driver.execute_query("") - **This is experimental.** (See :ref:`filter-warnings-ref`) + **This is experimental** (see :ref:`filter-warnings-ref`). It might be changed or removed any time even without prior notice. .. versionadded:: 5.5 @@ -1065,6 +1109,7 @@ async def verify_authentication( else: + @preview("User switching is a preview feature.") async def verify_authentication( self, auth: t.Union[Auth, t.Tuple[t.Any, t.Any], None] = None, @@ -1098,7 +1143,12 @@ async def verify_authentication( Use the exception to further understand the cause of the connectivity problem. - .. versionadded:: 5.x + **This is a preview** (see :ref:`filter-warnings-ref`). + It might be changed without following the deprecation policy. + See also + https://github.com/neo4j/neo4j-python-driver/wiki/preview-features + + .. versionadded:: 5.8 """ if config: experimental_warn( @@ -1111,7 +1161,13 @@ async def verify_authentication( if "database" not in config: config["database"] = "system" try: - async with self.session(**config) as session: + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", message=r"^User switching\b.*", + category=PreviewWarning + ) + session = self.session(**config) + async with session as session: await session._verify_authentication() except Neo4jError as exc: if exc.code in ( diff --git a/src/neo4j/_async/io/_bolt.py b/src/neo4j/_async/io/_bolt.py index 5be70afe..28472ecd 100644 --- a/src/neo4j/_async/io/_bolt.py +++ b/src/neo4j/_async/io/_bolt.py @@ -216,7 +216,7 @@ def supports_re_auth(self): def assert_re_auth_support(self): if not self.supports_re_auth: raise ConfigurationError( - "Session level authentication is not supported for Bolt " + "User switching is not supported for Bolt " f"Protocol {self.PROTOCOL_VERSION!r}. Server Agent " f"{self.server_info.agent!r}" ) diff --git a/src/neo4j/_async/work/result.py b/src/neo4j/_async/work/result.py index 767391f1..5a8dea97 100644 --- a/src/neo4j/_async/work/result.py +++ b/src/neo4j/_async/work/result.py @@ -624,7 +624,7 @@ async def to_eager_result(self) -> EagerResult: was obtained has been closed or the Result has been explicitly consumed. - **This is experimental.** (See :ref:`filter-warnings-ref`) + **This is experimental** (see :ref:`filter-warnings-ref`). It might be changed or removed any time even without prior notice. .. versionadded:: 5.5 diff --git a/src/neo4j/_async/work/session.py b/src/neo4j/_async/work/session.py index 57df3c99..41227ea0 100644 --- a/src/neo4j/_async/work/session.py +++ b/src/neo4j/_async/work/session.py @@ -20,6 +20,7 @@ import asyncio import typing as t +import warnings from logging import getLogger from random import random from time import perf_counter @@ -27,7 +28,10 @@ from ..._async_compat import async_sleep from ..._async_compat.util import AsyncUtil from ..._conf import SessionConfig -from ..._meta import deprecated +from ..._meta import ( + deprecated, + PreviewWarning, +) from ..._work import Query from ...api import ( Bookmarks, @@ -99,7 +103,14 @@ class AsyncSession(AsyncWorkspace): def __init__(self, pool, session_config): assert isinstance(session_config, SessionConfig) if session_config.auth is not None: - session_config.auth = AsyncAuthManagers.static(session_config.auth) + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", message=r".*\bAuth managers\b.*", + category=PreviewWarning + ) + session_config.auth = AsyncAuthManagers.static( + session_config.auth + ) super().__init__(pool, session_config) self._config = session_config self._initialize_bookmarks(session_config.bookmarks) diff --git a/src/neo4j/_auth_management.py b/src/neo4j/_auth_management.py index 5becce4d..6d3c7a92 100644 --- a/src/neo4j/_auth_management.py +++ b/src/neo4j/_auth_management.py @@ -24,9 +24,11 @@ import typing as t from dataclasses import dataclass +from ._meta import preview from .api import _TAuth +@preview("Auth managers are a preview feature.") @dataclass class ExpiringAuth: """Represents potentially expiring authentication information. @@ -39,10 +41,14 @@ class ExpiringAuth: information expires. If :data:`None`, the authentication information is considered to not expire until the server explicitly indicates so. + **This is a preview** (see :ref:`filter-warnings-ref`). + It might be changed without following the deprecation policy. + See also https://github.com/neo4j/neo4j-python-driver/wiki/preview-features + .. seealso:: :meth:`.AuthManagers.temporal`, :meth:`.AsyncAuthManagers.temporal` - .. versionadded:: 5.x + .. versionadded:: 5.8 """ auth: _TAuth expires_in: t.Optional[float] = None @@ -67,9 +73,13 @@ class AuthManager(metaclass=abc.ABCMeta): The token returned must always belong to the same identity. Switching identities using the `AuthManager` is undefined behavior. + **This is a preview** (see :ref:`filter-warnings-ref`). + It might be changed without following the deprecation policy. + See also https://github.com/neo4j/neo4j-python-driver/wiki/preview-features + .. seealso:: :class:`.AuthManagers` - .. versionadded:: 5.x + .. versionadded:: 5.8 """ @abc.abstractmethod @@ -104,9 +114,13 @@ def on_auth_expired(self, auth: _TAuth) -> None: class AsyncAuthManager(metaclass=abc.ABCMeta): """Async version of :class:`.AuthManager`. + **This is a preview** (see :ref:`filter-warnings-ref`). + It might be changed without following the deprecation policy. + See also https://github.com/neo4j/neo4j-python-driver/wiki/preview-features + .. seealso:: :class:`.AuthManager` - .. versionadded:: 5.x + .. versionadded:: 5.8 """ @abc.abstractmethod diff --git a/src/neo4j/_meta.py b/src/neo4j/_meta.py index f8f8cd9c..5489fc44 100644 --- a/src/neo4j/_meta.py +++ b/src/neo4j/_meta.py @@ -94,6 +94,9 @@ def decorator(f): class ExperimentalWarning(Warning): """ Base class for warnings about experimental features. + + .. deprecated:: 5.8 + we now use "preview" instead of "experimental". """ @@ -101,6 +104,20 @@ def experimental_warn(message, stack_level=1): warn(message, category=ExperimentalWarning, stacklevel=stack_level + 1) +class PreviewWarning(Warning): + """ Base class for warnings about experimental features. + """ + + +def preview_warn(message, stack_level=1): + message += ( + " It might be changed without following the deprecation policy. " + "See also " + "https://github.com/neo4j/neo4j-python-driver/wiki/preview-features." + ) + warn(message, category=PreviewWarning, stacklevel=stack_level + 1) + + def experimental(message) -> t.Callable[[_FuncT], _FuncT]: """ Decorator for tagging experimental functions and methods. @@ -111,6 +128,8 @@ def experimental(message) -> t.Callable[[_FuncT], _FuncT]: def foo(x): pass + .. deprecated:: 5.8 + we now use "preview" instead of "experimental". """ def decorator(f): if asyncio.iscoroutinefunction(f): @@ -131,6 +150,33 @@ def inner(*args, **kwargs): return decorator +def preview(message) -> t.Callable[[_FuncT], _FuncT]: + """ + Decorator for tagging preview functions and methods. + + @preview("foo is a preview.") + def foo(x): + pass + """ + def decorator(f): + if asyncio.iscoroutinefunction(f): + @wraps(f) + async def inner(*args, **kwargs): + preview_warn(message, stack_level=2) + return await f(*args, **kwargs) + + return inner + else: + @wraps(f) + def inner(*args, **kwargs): + preview_warn(message, stack_level=2) + return f(*args, **kwargs) + + return inner + + return decorator + + def unclosed_resource_warn(obj): msg = f"Unclosed {obj!r}." trace = tracemalloc.get_object_traceback(obj) diff --git a/src/neo4j/_sync/auth_management.py b/src/neo4j/_sync/auth_management.py index 96d89973..fc016f4b 100644 --- a/src/neo4j/_sync/auth_management.py +++ b/src/neo4j/_sync/auth_management.py @@ -29,6 +29,7 @@ AuthManager, ExpiringAuth, ) +from .._meta import preview # work around for https://github.com/sphinx-doc/sphinx/pull/10880 # make sure TAuth is resolved in the docs, else they're pretty useless @@ -104,9 +105,17 @@ def on_auth_expired(self, auth: _TAuth) -> None: class AuthManagers: - """A collection of :class:`.AuthManager` factories.""" + """A collection of :class:`.AuthManager` factories. + + **This is a preview** (see :ref:`filter-warnings-ref`). + It might be changed without following the deprecation policy. + See also https://github.com/neo4j/neo4j-python-driver/wiki/preview-features + + .. versionadded:: 5.8 + """ @staticmethod + @preview("Auth managers are a preview feature.") def static(auth: _TAuth) -> AuthManager: """Create a static auth manager. @@ -138,6 +147,7 @@ def static(auth: _TAuth) -> AuthManager: return StaticAuthManager(auth) @staticmethod + @preview("Auth managers are a preview feature.") def expiration_based( provider: t.Callable[[], t.Union[ExpiringAuth]] ) -> AuthManager: @@ -190,5 +200,7 @@ def auth_provider(): the provider again, when the auth info expires (either because it's reached its expiry time or because the server flagged it as expired). + + """ return ExpirationBasedAuthManager(provider) diff --git a/src/neo4j/_sync/driver.py b/src/neo4j/_sync/driver.py index 3065f894..e9d9f046 100644 --- a/src/neo4j/_sync/driver.py +++ b/src/neo4j/_sync/driver.py @@ -47,6 +47,9 @@ experimental, experimental_warn, ExperimentalWarning, + preview, + preview_warn, + PreviewWarning, unclosed_resource_warn, ) from .._work import EagerResult @@ -196,7 +199,15 @@ def driver( driver_type, security_type, parsed = parse_neo4j_uri(uri) if not isinstance(auth, (AuthManager, AuthManager)): - auth = AuthManagers.static(auth) + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", message=r".*\bAuth managers\b.*", + category=PreviewWarning + ) + auth = AuthManagers.static(auth) + else: + preview_warn("Auth managers are a preview feature.", + stack_level=2) config["auth"] = auth # TODO: 6.0 - remove "trust" config option @@ -347,7 +358,7 @@ def bookmark_manager( :returns: A default implementation of :class:`BookmarkManager`. - **This is experimental.** (See :ref:`filter-warnings-ref`) + **This is experimental** (see :ref:`filter-warnings-ref`). It might be changed or removed any time even without prior notice. .. versionadded:: 5.0 @@ -504,6 +515,9 @@ def encrypted(self) -> bool: return bool(self._pool.pool_config.encrypted) def _prepare_session_config(self, **config): + if "auth" in config: + preview_warn("User switching is a preview features.", + stack_level=3) _normalize_notifications_config(config) return config @@ -572,6 +586,7 @@ def execute_query( bookmark_manager_: t.Union[ BookmarkManager, BookmarkManager, None ] = ..., + auth_: _TAuth = None, result_transformer_: t.Callable[ [Result], t.Union[EagerResult] ] = ..., @@ -590,6 +605,7 @@ def execute_query( bookmark_manager_: t.Union[ BookmarkManager, BookmarkManager, None ] = ..., + auth_: _TAuth = None, result_transformer_: t.Callable[ [Result], t.Union[_T] ] = ..., @@ -612,6 +628,7 @@ def execute_query( BookmarkManager, BookmarkManager, None, te.Literal[_DefaultEnum.default] ] = _default, + auth_: _TAuth = None, result_transformer_: t.Callable[ [Result], t.Union[t.Any] ] = Result.to_eager_result, @@ -633,7 +650,7 @@ def execute_query( def execute_query( query_, parameters_, routing_, database_, impersonated_user_, - bookmark_manager_, result_transformer_, **kwargs + bookmark_manager_, auth_, result_transformer_, **kwargs ): def work(tx): result = tx.run(query_, parameters_, **kwargs) @@ -643,6 +660,7 @@ def work(tx): database=database_, impersonated_user=impersonated_user_, bookmark_manager=bookmark_manager_, + auth=auth_, ) as session: if routing_ == RoutingControl.WRITERS: return session.execute_write(work) @@ -678,13 +696,14 @@ def example(driver: neo4j.Driver) -> List[str]: def example(driver: neo4j.Driver) -> int: \"""Call all young people "My dear" and get their count.\""" record = driver.execute_query( - "MATCH (p:Person) WHERE p.age <= 15 " + "MATCH (p:Person) WHERE p.age <= $age " "SET p.nickname = 'My dear' " "RETURN count(*)", # optional routing parameter, as write is default # routing_=neo4j.RoutingControl.WRITERS, # or just "w", database_="neo4j", result_transformer_=neo4j.Result.single, + age=15, ) assert record is not None # for typechecking and illustration count = record[0] @@ -721,6 +740,20 @@ def example(driver: neo4j.Driver) -> int: See also the Session config :ref:`impersonated-user-ref`. :type impersonated_user_: typing.Optional[str] + :param auth_: + Authentication information to use for this query. + + By default, the driver configuration is used. + + **This is a preview** (see :ref:`filter-warnings-ref`). + It might be changed without following the deprecation policy. + See also + https://github.com/neo4j/neo4j-python-driver/wiki/preview-features + + See also the Session config :ref:`session-auth-ref`. + :type auth_: typing.Union[ + typing.Tuple[typing.Any, typing.Any], neo4j.Auth, None + ] :param result_transformer_: A function that gets passed the :class:`neo4j.Result` object resulting from the query and converts it to a different type. The @@ -805,10 +838,17 @@ def example(driver: neo4j.Driver) -> neo4j.Record:: :returns: the result of the ``result_transformer`` :rtype: T - **This is experimental.** (See :ref:`filter-warnings-ref`) + **This is experimental** (see :ref:`filter-warnings-ref`). It might be changed or removed any time even without prior notice. + We are looking for feedback on this feature. Please let us know what + you think about it here: + https://github.com/neo4j/neo4j-python-driver/discussions/896 + .. versionadded:: 5.5 + + .. versionchanged:: 5.8 + Added the ``auth_`` parameter. """ invalid_kwargs = [k for k in kwargs if k[-2:-1] != "_" and k[-1:] == "_"] @@ -830,9 +870,13 @@ def example(driver: neo4j.Driver) -> neo4j.Record:: warnings.filterwarnings("ignore", message=r".*\bbookmark_manager\b.*", category=ExperimentalWarning) + warnings.filterwarnings("ignore", + message=r"^User switching\b.*", + category=PreviewWarning) session = self.session(database=database_, impersonated_user=impersonated_user_, - bookmark_manager=bookmark_manager_) + bookmark_manager=bookmark_manager_, + auth=auth_) with session: if routing_ == RoutingControl.WRITERS: executor = session.execute_write @@ -869,7 +913,7 @@ def example(driver: neo4j.Driver) -> None: # (i.e., can read what was written by ) driver.execute_query("") - **This is experimental.** (See :ref:`filter-warnings-ref`) + **This is experimental** (see :ref:`filter-warnings-ref`). It might be changed or removed any time even without prior notice. .. versionadded:: 5.5 @@ -1063,6 +1107,7 @@ def verify_authentication( else: + @preview("User switching is a preview feature.") def verify_authentication( self, auth: t.Union[Auth, t.Tuple[t.Any, t.Any], None] = None, @@ -1096,7 +1141,12 @@ def verify_authentication( Use the exception to further understand the cause of the connectivity problem. - .. versionadded:: 5.x + **This is a preview** (see :ref:`filter-warnings-ref`). + It might be changed without following the deprecation policy. + See also + https://github.com/neo4j/neo4j-python-driver/wiki/preview-features + + .. versionadded:: 5.8 """ if config: experimental_warn( @@ -1109,7 +1159,13 @@ def verify_authentication( if "database" not in config: config["database"] = "system" try: - with self.session(**config) as session: + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", message=r"^User switching\b.*", + category=PreviewWarning + ) + session = self.session(**config) + with session as session: session._verify_authentication() except Neo4jError as exc: if exc.code in ( diff --git a/src/neo4j/_sync/io/_bolt.py b/src/neo4j/_sync/io/_bolt.py index 00f88b9d..f2b97e58 100644 --- a/src/neo4j/_sync/io/_bolt.py +++ b/src/neo4j/_sync/io/_bolt.py @@ -216,7 +216,7 @@ def supports_re_auth(self): def assert_re_auth_support(self): if not self.supports_re_auth: raise ConfigurationError( - "Session level authentication is not supported for Bolt " + "User switching is not supported for Bolt " f"Protocol {self.PROTOCOL_VERSION!r}. Server Agent " f"{self.server_info.agent!r}" ) diff --git a/src/neo4j/_sync/work/result.py b/src/neo4j/_sync/work/result.py index 1cc55fec..cfc6df05 100644 --- a/src/neo4j/_sync/work/result.py +++ b/src/neo4j/_sync/work/result.py @@ -624,7 +624,7 @@ def to_eager_result(self) -> EagerResult: was obtained has been closed or the Result has been explicitly consumed. - **This is experimental.** (See :ref:`filter-warnings-ref`) + **This is experimental** (see :ref:`filter-warnings-ref`). It might be changed or removed any time even without prior notice. .. versionadded:: 5.5 diff --git a/src/neo4j/_sync/work/session.py b/src/neo4j/_sync/work/session.py index 2bc54bcc..8f38739b 100644 --- a/src/neo4j/_sync/work/session.py +++ b/src/neo4j/_sync/work/session.py @@ -20,6 +20,7 @@ import asyncio import typing as t +import warnings from logging import getLogger from random import random from time import perf_counter @@ -27,7 +28,10 @@ from ..._async_compat import sleep from ..._async_compat.util import Util from ..._conf import SessionConfig -from ..._meta import deprecated +from ..._meta import ( + deprecated, + PreviewWarning, +) from ..._work import Query from ...api import ( Bookmarks, @@ -99,7 +103,14 @@ class Session(Workspace): def __init__(self, pool, session_config): assert isinstance(session_config, SessionConfig) if session_config.auth is not None: - session_config.auth = AuthManagers.static(session_config.auth) + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", message=r".*\bAuth managers\b.*", + category=PreviewWarning + ) + session_config.auth = AuthManagers.static( + session_config.auth + ) super().__init__(pool, session_config) self._config = session_config self._initialize_bookmarks(session_config.bookmarks) diff --git a/src/neo4j/_work/eager_result.py b/src/neo4j/_work/eager_result.py index 87abb410..d88e23f5 100644 --- a/src/neo4j/_work/eager_result.py +++ b/src/neo4j/_work/eager_result.py @@ -33,7 +33,7 @@ class EagerResult(t.NamedTuple): * keys - the list of keys returned by the query (see :attr:`AsyncResult.keys` and :attr:`.Result.keys`) - **This is experimental.** (See :ref:`filter-warnings-ref`) + **This is experimental** (see :ref:`filter-warnings-ref`). It might be changed or removed any time even without prior notice. .. seealso:: diff --git a/testkitbackend/_async/requests.py b/testkitbackend/_async/requests.py index ec867870..c2d2aeea 100644 --- a/testkitbackend/_async/requests.py +++ b/testkitbackend/_async/requests.py @@ -39,7 +39,10 @@ test_subtest_skips, totestkit, ) -from .._warning_check import warning_check +from .._warning_check import ( + warning_check, + warnings_check, +) from ..exceptions import MarkdAsDriverException @@ -104,8 +107,12 @@ async def GetFeatures(backend, data): async def NewDriver(backend, data): auth = fromtestkit.to_auth_token(data, "authorizationToken") + expected_warnings = [] if auth is None and data.get("authTokenManagerId") is not None: auth = backend.auth_token_managers[data["authTokenManagerId"]] + expected_warnings.append( + (neo4j.PreviewWarning, "Auth managers are a preview feature.") + ) else: data.mark_item_as_read_if_equals("authTokenManagerId", None) kwargs = {} @@ -147,9 +154,15 @@ async def NewDriver(backend, data): fromtestkit.set_notifications_config(kwargs, data) data.mark_item_as_read_if_equals("livenessCheckTimeoutMs", None) - driver = neo4j.AsyncGraphDatabase.driver( - data["uri"], auth=auth, user_agent=data["userAgent"], **kwargs, - ) + if expected_warnings: + with warnings_check(expected_warnings): + driver = neo4j.AsyncGraphDatabase.driver( + data["uri"], auth=auth, user_agent=data["userAgent"], **kwargs, + ) + else: + driver = neo4j.AsyncGraphDatabase.driver( + data["uri"], auth=auth, user_agent=data["userAgent"], **kwargs, + ) key = backend.next_key() backend.drivers[key] = driver await backend.send_response("Driver", {"id": key}) @@ -243,7 +256,9 @@ async def auth_token_provider(): ) return backend.expiring_auth_token_supplies.pop(key) - auth_manager = AsyncAuthManagers.expiration_based(auth_token_provider) + with warning_check(neo4j.PreviewWarning, + "Auth managers are a preview feature."): + auth_manager = AsyncAuthManagers.expiration_based(auth_token_provider) backend.auth_token_managers[auth_token_manager_id] = auth_manager await backend.send_response( "ExpirationBasedAuthTokenManager", {"id": auth_token_manager_id} @@ -260,7 +275,9 @@ async def ExpirationBasedAuthTokenProviderCompleted(backend, data): expires_in = temp_auth_data["expiresInMs"] / 1000 else: expires_in = None - expiring_auth = ExpiringAuth(auth_token, expires_in) + with warning_check(neo4j.PreviewWarning, + "Auth managers are a preview feature."): + expiring_auth = ExpiringAuth(auth_token, expires_in) backend.expiring_auth_token_supplies[data["requestId"]] = expiring_auth @@ -296,7 +313,9 @@ async def VerifyAuthentication(backend, data): driver_id = data["driverId"] driver = backend.drivers[driver_id] auth = fromtestkit.to_auth_token(data, "auth_token") - authenticated = await driver.verify_authentication(auth=auth) + with warning_check(neo4j.PreviewWarning, + "User switching is a preview feature."): + authenticated = await driver.verify_authentication(auth=auth) await backend.send_response("DriverIsAuthenticated", { "id": backend.next_key(), "authenticated": authenticated }) @@ -518,6 +537,7 @@ def __init__(self, session): async def NewSession(backend, data): driver = backend.drivers[data["driverId"]] access_mode = data["accessMode"] + expected_warnings = [] if access_mode == "r": access_mode = neo4j.READ_ACCESS elif access_mode == "w": @@ -536,6 +556,11 @@ async def NewSession(backend, data): config["bookmark_manager"] = backend.bookmark_managers[ data["bookmarkManagerId"] ] + expected_warnings.append(( + neo4j.ExperimentalWarning, + "The 'bookmark_manager' config key is experimental. It might be " + "changed or removed any time even without prior notice." + )) for (conf_name, data_name) in ( ("fetch_size", "fetchSize"), ("impersonated_user", "impersonatedUser"), @@ -544,13 +569,12 @@ async def NewSession(backend, data): config[conf_name] = data[data_name] if data.get("authorizationToken"): config["auth"] = fromtestkit.to_auth_token(data, "authorizationToken") + expected_warnings.append( + (neo4j.PreviewWarning, "User switching is a preview features.") + ) fromtestkit.set_notifications_config(config, data) - if "bookmark_manager" in config: - with warning_check( - neo4j.ExperimentalWarning, - "The 'bookmark_manager' config key is experimental. It might be " - "changed or removed any time even without prior notice." - ): + if expected_warnings: + with warnings_check(expected_warnings): session = driver.session(**config) else: session = driver.session(**config) diff --git a/testkitbackend/_sync/requests.py b/testkitbackend/_sync/requests.py index a6fdf45c..d3e13377 100644 --- a/testkitbackend/_sync/requests.py +++ b/testkitbackend/_sync/requests.py @@ -39,7 +39,10 @@ test_subtest_skips, totestkit, ) -from .._warning_check import warning_check +from .._warning_check import ( + warning_check, + warnings_check, +) from ..exceptions import MarkdAsDriverException @@ -104,8 +107,12 @@ def GetFeatures(backend, data): def NewDriver(backend, data): auth = fromtestkit.to_auth_token(data, "authorizationToken") + expected_warnings = [] if auth is None and data.get("authTokenManagerId") is not None: auth = backend.auth_token_managers[data["authTokenManagerId"]] + expected_warnings.append( + (neo4j.PreviewWarning, "Auth managers are a preview feature.") + ) else: data.mark_item_as_read_if_equals("authTokenManagerId", None) kwargs = {} @@ -147,9 +154,15 @@ def NewDriver(backend, data): fromtestkit.set_notifications_config(kwargs, data) data.mark_item_as_read_if_equals("livenessCheckTimeoutMs", None) - driver = neo4j.GraphDatabase.driver( - data["uri"], auth=auth, user_agent=data["userAgent"], **kwargs, - ) + if expected_warnings: + with warnings_check(expected_warnings): + driver = neo4j.GraphDatabase.driver( + data["uri"], auth=auth, user_agent=data["userAgent"], **kwargs, + ) + else: + driver = neo4j.GraphDatabase.driver( + data["uri"], auth=auth, user_agent=data["userAgent"], **kwargs, + ) key = backend.next_key() backend.drivers[key] = driver backend.send_response("Driver", {"id": key}) @@ -243,7 +256,9 @@ def auth_token_provider(): ) return backend.expiring_auth_token_supplies.pop(key) - auth_manager = AuthManagers.expiration_based(auth_token_provider) + with warning_check(neo4j.PreviewWarning, + "Auth managers are a preview feature."): + auth_manager = AuthManagers.expiration_based(auth_token_provider) backend.auth_token_managers[auth_token_manager_id] = auth_manager backend.send_response( "ExpirationBasedAuthTokenManager", {"id": auth_token_manager_id} @@ -260,7 +275,9 @@ def ExpirationBasedAuthTokenProviderCompleted(backend, data): expires_in = temp_auth_data["expiresInMs"] / 1000 else: expires_in = None - expiring_auth = ExpiringAuth(auth_token, expires_in) + with warning_check(neo4j.PreviewWarning, + "Auth managers are a preview feature."): + expiring_auth = ExpiringAuth(auth_token, expires_in) backend.expiring_auth_token_supplies[data["requestId"]] = expiring_auth @@ -296,7 +313,9 @@ def VerifyAuthentication(backend, data): driver_id = data["driverId"] driver = backend.drivers[driver_id] auth = fromtestkit.to_auth_token(data, "auth_token") - authenticated = driver.verify_authentication(auth=auth) + with warning_check(neo4j.PreviewWarning, + "User switching is a preview feature."): + authenticated = driver.verify_authentication(auth=auth) backend.send_response("DriverIsAuthenticated", { "id": backend.next_key(), "authenticated": authenticated }) @@ -518,6 +537,7 @@ def __init__(self, session): def NewSession(backend, data): driver = backend.drivers[data["driverId"]] access_mode = data["accessMode"] + expected_warnings = [] if access_mode == "r": access_mode = neo4j.READ_ACCESS elif access_mode == "w": @@ -536,6 +556,11 @@ def NewSession(backend, data): config["bookmark_manager"] = backend.bookmark_managers[ data["bookmarkManagerId"] ] + expected_warnings.append(( + neo4j.ExperimentalWarning, + "The 'bookmark_manager' config key is experimental. It might be " + "changed or removed any time even without prior notice." + )) for (conf_name, data_name) in ( ("fetch_size", "fetchSize"), ("impersonated_user", "impersonatedUser"), @@ -544,13 +569,12 @@ def NewSession(backend, data): config[conf_name] = data[data_name] if data.get("authorizationToken"): config["auth"] = fromtestkit.to_auth_token(data, "authorizationToken") + expected_warnings.append( + (neo4j.PreviewWarning, "User switching is a preview features.") + ) fromtestkit.set_notifications_config(config, data) - if "bookmark_manager" in config: - with warning_check( - neo4j.ExperimentalWarning, - "The 'bookmark_manager' config key is experimental. It might be " - "changed or removed any time even without prior notice." - ): + if expected_warnings: + with warnings_check(expected_warnings): session = driver.session(**config) else: session = driver.session(**config) diff --git a/tests/unit/async_/io/conftest.py b/tests/unit/async_/io/conftest.py index 0d3afa85..d697d182 100644 --- a/tests/unit/async_/io/conftest.py +++ b/tests/unit/async_/io/conftest.py @@ -24,10 +24,12 @@ import pytest +from neo4j import PreviewWarning from neo4j._async.io._common import ( AsyncInbox, AsyncOutbox, ) +from neo4j.auth_management import AsyncAuthManagers class AsyncFakeSocket: @@ -145,3 +147,16 @@ def fake_socket_2(): @pytest.fixture def fake_socket_pair(): return AsyncFakeSocketPair + +@pytest.fixture +def static_auth(): + def inner(auth): + with pytest.warns(PreviewWarning, match="Auth managers"): + return AsyncAuthManagers.static(auth) + + return inner + + +@pytest.fixture +def none_auth(static_auth): + return static_auth(None) diff --git a/tests/unit/async_/io/test_class_bolt.py b/tests/unit/async_/io/test_class_bolt.py index 9b8cc0ff..735e7fd5 100644 --- a/tests/unit/async_/io/test_class_bolt.py +++ b/tests/unit/async_/io/test_class_bolt.py @@ -95,7 +95,7 @@ def test_magic_preamble(): @AsyncTestDecorators.mark_async_only_test -async def test_cancel_hello_in_open(mocker): +async def test_cancel_hello_in_open(mocker, none_auth): address = ("localhost", 7687) socket_mock = mocker.AsyncMock(spec=AsyncBoltSocket) @@ -113,10 +113,7 @@ async def test_cancel_hello_in_open(mocker): bolt_mock.local_port = 1234 with pytest.raises(asyncio.CancelledError): - await AsyncBolt.open( - address, - auth_manager=neo4j.auth_management.AsyncAuthManagers.static(None) - ) + await AsyncBolt.open(address, auth_manager=none_auth) bolt_mock.kill.assert_called_once_with() @@ -136,7 +133,9 @@ async def test_cancel_hello_in_open(mocker): ), ) @mark_async_test -async def test_version_negotiation(mocker, bolt_version, bolt_cls_path): +async def test_version_negotiation( + mocker, bolt_version, bolt_cls_path, none_auth +): address = ("localhost", 7687) socket_mock = mocker.AsyncMock(spec=AsyncBoltSocket) @@ -151,10 +150,7 @@ async def test_version_negotiation(mocker, bolt_version, bolt_cls_path): bolt_mock = bolt_cls_mock.return_value bolt_mock.socket = socket_mock - connection = await AsyncBolt.open( - address, - auth_manager=neo4j.auth_management.AsyncAuthManagers.static(None) - ) + connection = await AsyncBolt.open(address, auth_manager=none_auth) bolt_cls_mock.assert_called_once() assert connection is bolt_mock @@ -170,7 +166,7 @@ async def test_version_negotiation(mocker, bolt_version, bolt_cls_path): (6, 0), )) @mark_async_test -async def test_failing_version_negotiation(mocker, bolt_version): +async def test_failing_version_negotiation(mocker, bolt_version, none_auth): supported_protocols = \ "('3.0', '4.1', '4.2', '4.3', '4.4', '5.0', '5.1', '5.2')" @@ -185,10 +181,7 @@ async def test_failing_version_negotiation(mocker, bolt_version): socket_mock.getpeername.return_value = address with pytest.raises(BoltHandshakeError) as exc: - await AsyncBolt.open( - address, - auth_manager=neo4j.auth_management.AsyncAuthManagers.static(None) - ) + await AsyncBolt.open(address, auth_manager=none_auth) assert exc.match(supported_protocols) diff --git a/tests/unit/async_/io/test_class_bolt3.py b/tests/unit/async_/io/test_class_bolt3.py index 8f612655..e0954f47 100644 --- a/tests/unit/async_/io/test_class_bolt3.py +++ b/tests/unit/async_/io/test_class_bolt3.py @@ -160,7 +160,7 @@ def test_auth_message_raises_configuration_error(message, fake_socket): connection = AsyncBolt3(address, fake_socket(address), PoolConfig.max_connection_lifetime) with pytest.raises(ConfigurationError, - match="Session level authentication is not supported"): + match="User switching is not supported"): getattr(connection, message)() @@ -200,7 +200,7 @@ async def test_re_auth(auth1, auth2, fake_socket): connection = AsyncBolt3(address, fake_socket(address), PoolConfig.max_connection_lifetime, auth=auth1) with pytest.raises(ConfigurationError, - match="Session level authentication is not supported"): + match="User switching is not supported"): connection.re_auth(auth2, None) diff --git a/tests/unit/async_/io/test_class_bolt4x0.py b/tests/unit/async_/io/test_class_bolt4x0.py index 5ae9432a..ff1edcbd 100644 --- a/tests/unit/async_/io/test_class_bolt4x0.py +++ b/tests/unit/async_/io/test_class_bolt4x0.py @@ -256,7 +256,7 @@ def test_auth_message_raises_configuration_error(message, fake_socket): connection = AsyncBolt4x0(address, fake_socket(address), PoolConfig.max_connection_lifetime) with pytest.raises(ConfigurationError, - match="Session level authentication is not supported"): + match="User switching is not supported"): getattr(connection, message)() @@ -296,7 +296,7 @@ async def test_re_auth(auth1, auth2, fake_socket): connection = AsyncBolt4x0(address, fake_socket(address), PoolConfig.max_connection_lifetime, auth=auth1) with pytest.raises(ConfigurationError, - match="Session level authentication is not supported"): + match="User switching is not supported"): connection.re_auth(auth2, None) diff --git a/tests/unit/async_/io/test_class_bolt4x1.py b/tests/unit/async_/io/test_class_bolt4x1.py index 0e5f05c9..7766c559 100644 --- a/tests/unit/async_/io/test_class_bolt4x1.py +++ b/tests/unit/async_/io/test_class_bolt4x1.py @@ -273,7 +273,7 @@ def test_auth_message_raises_configuration_error(message, fake_socket): connection = AsyncBolt4x1(address, fake_socket(address), PoolConfig.max_connection_lifetime) with pytest.raises(ConfigurationError, - match="Session level authentication is not supported"): + match="User switching is not supported"): getattr(connection, message)() @@ -313,7 +313,7 @@ async def test_re_auth(auth1, auth2, fake_socket): connection = AsyncBolt4x1(address, fake_socket(address), PoolConfig.max_connection_lifetime, auth=auth1) with pytest.raises(ConfigurationError, - match="Session level authentication is not supported"): + match="User switching is not supported"): connection.re_auth(auth2, None) diff --git a/tests/unit/async_/io/test_class_bolt4x2.py b/tests/unit/async_/io/test_class_bolt4x2.py index 86cfc32c..0547606a 100644 --- a/tests/unit/async_/io/test_class_bolt4x2.py +++ b/tests/unit/async_/io/test_class_bolt4x2.py @@ -274,7 +274,7 @@ def test_auth_message_raises_configuration_error(message, fake_socket): connection = AsyncBolt4x2(address, fake_socket(address), PoolConfig.max_connection_lifetime) with pytest.raises(ConfigurationError, - match="Session level authentication is not supported"): + match="User switching is not supported"): getattr(connection, message)() @@ -314,7 +314,7 @@ async def test_re_auth(auth1, auth2, fake_socket): connection = AsyncBolt4x2(address, fake_socket(address), PoolConfig.max_connection_lifetime, auth=auth1) with pytest.raises(ConfigurationError, - match="Session level authentication is not supported"): + match="User switching is not supported"): connection.re_auth(auth2, None) diff --git a/tests/unit/async_/io/test_class_bolt4x3.py b/tests/unit/async_/io/test_class_bolt4x3.py index dddf994f..dca81144 100644 --- a/tests/unit/async_/io/test_class_bolt4x3.py +++ b/tests/unit/async_/io/test_class_bolt4x3.py @@ -301,7 +301,7 @@ def test_auth_message_raises_configuration_error(message, fake_socket): connection = AsyncBolt4x3(address, fake_socket(address), PoolConfig.max_connection_lifetime) with pytest.raises(ConfigurationError, - match="Session level authentication is not supported"): + match="User switching is not supported"): getattr(connection, message)() @@ -341,7 +341,7 @@ async def test_re_auth(auth1, auth2, fake_socket): connection = AsyncBolt4x3(address, fake_socket(address), PoolConfig.max_connection_lifetime, auth=auth1) with pytest.raises(ConfigurationError, - match="Session level authentication is not supported"): + match="User switching is not supported"): connection.re_auth(auth2, None) diff --git a/tests/unit/async_/io/test_class_bolt4x4.py b/tests/unit/async_/io/test_class_bolt4x4.py index 40fb7518..19cf8b5c 100644 --- a/tests/unit/async_/io/test_class_bolt4x4.py +++ b/tests/unit/async_/io/test_class_bolt4x4.py @@ -314,7 +314,7 @@ def test_auth_message_raises_configuration_error(message, fake_socket): connection = AsyncBolt4x4(address, fake_socket(address), PoolConfig.max_connection_lifetime) with pytest.raises(ConfigurationError, - match="Session level authentication is not supported"): + match="User switching is not supported"): getattr(connection, message)() @@ -354,7 +354,7 @@ async def test_re_auth(auth1, auth2, fake_socket): connection = AsyncBolt4x4(address, fake_socket(address), PoolConfig.max_connection_lifetime, auth=auth1) with pytest.raises(ConfigurationError, - match="Session level authentication is not supported"): + match="User switching is not supported"): connection.re_auth(auth2, None) diff --git a/tests/unit/async_/io/test_class_bolt5x0.py b/tests/unit/async_/io/test_class_bolt5x0.py index a5c3f72b..2937d442 100644 --- a/tests/unit/async_/io/test_class_bolt5x0.py +++ b/tests/unit/async_/io/test_class_bolt5x0.py @@ -314,7 +314,7 @@ def test_auth_message_raises_configuration_error(message, fake_socket): connection = AsyncBolt5x0(address, fake_socket(address), PoolConfig.max_connection_lifetime) with pytest.raises(ConfigurationError, - match="Session level authentication is not supported"): + match="User switching is not supported"): getattr(connection, message)() @@ -354,7 +354,7 @@ async def test_re_auth(auth1, auth2, fake_socket): connection = AsyncBolt5x0(address, fake_socket(address), PoolConfig.max_connection_lifetime, auth=auth1) with pytest.raises(ConfigurationError, - match="Session level authentication is not supported"): + match="User switching is not supported"): connection.re_auth(auth2, None) diff --git a/tests/unit/async_/io/test_class_bolt5x1.py b/tests/unit/async_/io/test_class_bolt5x1.py index 7eeb04a6..83dab776 100644 --- a/tests/unit/async_/io/test_class_bolt5x1.py +++ b/tests/unit/async_/io/test_class_bolt5x1.py @@ -287,9 +287,9 @@ async def test_logon(fake_socket_pair): @mark_async_test -async def test_re_auth(fake_socket_pair, mocker): +async def test_re_auth(fake_socket_pair, mocker, static_auth): auth = neo4j.Auth("basic", "alice123", "supersecret123") - auth_manager = AsyncAuthManagers.static(auth) + auth_manager = static_auth(auth) address = neo4j.Address(("127.0.0.1", 7687)) sockets = fake_socket_pair(address, packer_cls=AsyncBolt5x1.PACKER_CLS, diff --git a/tests/unit/async_/io/test_class_bolt5x2.py b/tests/unit/async_/io/test_class_bolt5x2.py index c9bc20fa..676eb5b9 100644 --- a/tests/unit/async_/io/test_class_bolt5x2.py +++ b/tests/unit/async_/io/test_class_bolt5x2.py @@ -275,9 +275,9 @@ async def test_logon(fake_socket_pair): @mark_async_test -async def test_re_auth(fake_socket_pair, mocker): +async def test_re_auth(fake_socket_pair, mocker, static_auth): auth = neo4j.Auth("basic", "alice123", "supersecret123") - auth_manager = AsyncAuthManagers.static(auth) + auth_manager = static_auth(auth) address = neo4j.Address(("127.0.0.1", 7687)) sockets = fake_socket_pair(address, packer_cls=AsyncBolt5x2.PACKER_CLS, diff --git a/tests/unit/async_/io/test_direct.py b/tests/unit/async_/io/test_direct.py index e8d17f5e..953287c6 100644 --- a/tests/unit/async_/io/test_direct.py +++ b/tests/unit/async_/io/test_direct.py @@ -18,6 +18,7 @@ import pytest +from neo4j import PreviewWarning from neo4j._async.io import AsyncBolt from neo4j._async.io._pool import AsyncIOPool from neo4j._conf import ( @@ -84,7 +85,7 @@ def timedout(self): class AsyncFakeBoltPool(AsyncIOPool): def __init__(self, address, *, auth=None, **config): - config["auth"] = AsyncAuthManagers.static(None) + config["auth"] = static_auth(None) self.pool_config, self.workspace_config = Config.consume_chain(config, PoolConfig, WorkspaceConfig) if config: raise ValueError("Unexpected config keys: %s" % ", ".join(config.keys())) @@ -103,17 +104,22 @@ async def acquire( self.address, auth, timeout, liveness_check_timeout ) +def static_auth(auth): + with pytest.warns(PreviewWarning, match="Auth managers"): + return AsyncAuthManagers.static(auth) + +@pytest.fixture +def auth_manager(): + static_auth(("test", "test")) @mark_async_test -async def test_bolt_connection_open(): - auth_manager = AsyncAuthManagers.static(("test", "test")) +async def test_bolt_connection_open(auth_manager): with pytest.raises(ServiceUnavailable): await AsyncBolt.open(("localhost", 9999), auth_manager=auth_manager) @mark_async_test -async def test_bolt_connection_open_timeout(): - auth_manager = AsyncAuthManagers.static(("test", "test")) +async def test_bolt_connection_open_timeout(auth_manager): with pytest.raises(ServiceUnavailable): await AsyncBolt.open( ("localhost", 9999), auth_manager=auth_manager, diff --git a/tests/unit/async_/io/test_neo4j_pool.py b/tests/unit/async_/io/test_neo4j_pool.py index 9cdeb9df..ca27d92f 100644 --- a/tests/unit/async_/io/test_neo4j_pool.py +++ b/tests/unit/async_/io/test_neo4j_pool.py @@ -21,6 +21,7 @@ import pytest from neo4j import ( + PreviewWarning, READ_ACCESS, WRITE_ACCESS, ) @@ -101,10 +102,15 @@ def opener(routing_failure_opener): def _pool_config(): pool_config = PoolConfig() - pool_config.auth = AsyncAuthManagers.static(("user", "pass")) + pool_config.auth = _auth_manager(("user", "pass")) return pool_config +def _auth_manager(auth): + with pytest.warns(PreviewWarning, match="Auth managers"): + return AsyncAuthManagers.static(auth) + + def _simple_pool(opener) -> AsyncNeo4jPool: return AsyncNeo4jPool( opener, _pool_config(), WorkspaceConfig(), ROUTER1_ADDRESS @@ -606,7 +612,7 @@ async def test_connection_error_callback( opener, error, marks_unauthenticated, fetches_new, mocker ): config = _pool_config() - auth_manager = AsyncAuthManagers.static(("user", "auth")) + auth_manager = _auth_manager(("user", "auth")) on_auth_expired_mock = mocker.patch.object(auth_manager, "on_auth_expired", autospec=True) config.auth = auth_manager diff --git a/tests/unit/async_/test_auth_manager.py b/tests/unit/async_/test_auth_manager.py index 9312824a..e29abacf 100644 --- a/tests/unit/async_/test_auth_manager.py +++ b/tests/unit/async_/test_auth_manager.py @@ -1,3 +1,21 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + import itertools import typing as t @@ -8,7 +26,9 @@ from neo4j import ( Auth, basic_auth, + PreviewWarning, ) +from neo4j._meta import copy_signature from neo4j.auth_management import ( AsyncAuthManager, AsyncAuthManagers, @@ -27,12 +47,29 @@ ) +@copy_signature(AsyncAuthManagers.static) +def static_auth_manager(*args, **kwargs): + with pytest.warns(PreviewWarning, match="Auth managers"): + return AsyncAuthManagers.static(*args, **kwargs) + +@copy_signature(AsyncAuthManagers.expiration_based) +def expiration_based_auth_manager(*args, **kwargs): + with pytest.warns(PreviewWarning, match="Auth managers"): + return AsyncAuthManagers.expiration_based(*args, **kwargs) + + +@copy_signature(ExpiringAuth) +def expiring_auth(*args, **kwargs): + with pytest.warns(PreviewWarning, match="Auth managers"): + return ExpiringAuth(*args, **kwargs) + + @mark_async_test @pytest.mark.parametrize("auth", SAMPLE_AUTHS) async def test_static_manager( auth ) -> None: - manager: AsyncAuthManager = AsyncAuthManagers.static(auth) + manager: AsyncAuthManager = static_auth_manager(auth) assert await manager.get_auth() is auth await manager.on_auth_expired(("something", "else")) @@ -46,25 +83,25 @@ async def test_static_manager( @pytest.mark.parametrize(("auth1", "auth2"), itertools.product(SAMPLE_AUTHS, repeat=2)) @pytest.mark.parametrize("expires_in", (None, -1, 1., 1, 1000.)) -async def test_temporal_manager_manual_expiry( +async def test_expiration_based_manager_manual_expiry( auth1: t.Union[t.Tuple[str, str], Auth, None], auth2: t.Union[t.Tuple[str, str], Auth, None], expires_in: t.Union[float, int], mocker ) -> None: if expires_in is None or expires_in >= 0: - temporal_auth = ExpiringAuth(auth1, expires_in) + temporal_auth = expiring_auth(auth1, expires_in) else: - temporal_auth = ExpiringAuth(auth1) + temporal_auth = expiring_auth(auth1) provider = mocker.AsyncMock(return_value=temporal_auth) - manager: AsyncAuthManager = AsyncAuthManagers.expiration_based(provider) + manager: AsyncAuthManager = expiration_based_auth_manager(provider) provider.assert_not_called() assert await manager.get_auth() is auth1 provider.assert_awaited_once() provider.reset_mock() - provider.return_value = ExpiringAuth(auth2) + provider.return_value = expiring_auth(auth2) await manager.on_auth_expired(("something", "else")) assert await manager.get_auth() is auth1 @@ -81,7 +118,7 @@ async def test_temporal_manager_manual_expiry( @pytest.mark.parametrize(("auth1", "auth2"), itertools.product(SAMPLE_AUTHS, repeat=2)) @pytest.mark.parametrize("expires_in", (None, -1, 1., 1, 1000.)) -async def test_temporal_manager_time_expiry( +async def test_expiration_based_manager_time_expiry( auth1: t.Union[t.Tuple[str, str], Auth, None], auth2: t.Union[t.Tuple[str, str], Auth, None], expires_in: t.Union[float, int, None], @@ -90,18 +127,18 @@ async def test_temporal_manager_time_expiry( with freeze_time() as frozen_time: assert isinstance(frozen_time, FrozenDateTimeFactory) if expires_in is None or expires_in >= 0: - temporal_auth = ExpiringAuth(auth1, expires_in) + temporal_auth = expiring_auth(auth1, expires_in) else: - temporal_auth = ExpiringAuth(auth1) + temporal_auth = expiring_auth(auth1) provider = mocker.AsyncMock(return_value=temporal_auth) - manager: AsyncAuthManager = AsyncAuthManagers.expiration_based(provider) + manager: AsyncAuthManager = expiration_based_auth_manager(provider) provider.assert_not_called() assert await manager.get_auth() is auth1 provider.assert_awaited_once() provider.reset_mock() - provider.return_value = ExpiringAuth(auth2) + provider.return_value = expiring_auth(auth2) if expires_in is None or expires_in < 0: frozen_time.tick(1_000_000) diff --git a/tests/unit/async_/test_driver.py b/tests/unit/async_/test_driver.py index 27639b95..62848ebb 100644 --- a/tests/unit/async_/test_driver.py +++ b/tests/unit/async_/test_driver.py @@ -971,7 +971,8 @@ async def test_execute_query_result_transformer( with assert_warns_execute_query_bmm_experimental(): bmm = driver.query_bookmark_manager res_custom = await driver.execute_query( - "", None, "w", None, None, bmm, result_transformer + "", None, "w", None, None, bmm, None, + result_transformer ) else: res_custom = await driver.execute_query( diff --git a/tests/unit/sync/io/conftest.py b/tests/unit/sync/io/conftest.py index beba1e27..dc286dbd 100644 --- a/tests/unit/sync/io/conftest.py +++ b/tests/unit/sync/io/conftest.py @@ -24,10 +24,12 @@ import pytest +from neo4j import PreviewWarning from neo4j._sync.io._common import ( Inbox, Outbox, ) +from neo4j.auth_management import AuthManagers class FakeSocket: @@ -145,3 +147,16 @@ def fake_socket_2(): @pytest.fixture def fake_socket_pair(): return FakeSocketPair + +@pytest.fixture +def static_auth(): + def inner(auth): + with pytest.warns(PreviewWarning, match="Auth managers"): + return AuthManagers.static(auth) + + return inner + + +@pytest.fixture +def none_auth(static_auth): + return static_auth(None) diff --git a/tests/unit/sync/io/test_class_bolt.py b/tests/unit/sync/io/test_class_bolt.py index 5bd6241d..f0306f96 100644 --- a/tests/unit/sync/io/test_class_bolt.py +++ b/tests/unit/sync/io/test_class_bolt.py @@ -95,7 +95,7 @@ def test_magic_preamble(): @TestDecorators.mark_async_only_test -def test_cancel_hello_in_open(mocker): +def test_cancel_hello_in_open(mocker, none_auth): address = ("localhost", 7687) socket_mock = mocker.MagicMock(spec=BoltSocket) @@ -113,10 +113,7 @@ def test_cancel_hello_in_open(mocker): bolt_mock.local_port = 1234 with pytest.raises(asyncio.CancelledError): - Bolt.open( - address, - auth_manager=neo4j.auth_management.AuthManagers.static(None) - ) + Bolt.open(address, auth_manager=none_auth) bolt_mock.kill.assert_called_once_with() @@ -136,7 +133,9 @@ def test_cancel_hello_in_open(mocker): ), ) @mark_sync_test -def test_version_negotiation(mocker, bolt_version, bolt_cls_path): +def test_version_negotiation( + mocker, bolt_version, bolt_cls_path, none_auth +): address = ("localhost", 7687) socket_mock = mocker.MagicMock(spec=BoltSocket) @@ -151,10 +150,7 @@ def test_version_negotiation(mocker, bolt_version, bolt_cls_path): bolt_mock = bolt_cls_mock.return_value bolt_mock.socket = socket_mock - connection = Bolt.open( - address, - auth_manager=neo4j.auth_management.AuthManagers.static(None) - ) + connection = Bolt.open(address, auth_manager=none_auth) bolt_cls_mock.assert_called_once() assert connection is bolt_mock @@ -170,7 +166,7 @@ def test_version_negotiation(mocker, bolt_version, bolt_cls_path): (6, 0), )) @mark_sync_test -def test_failing_version_negotiation(mocker, bolt_version): +def test_failing_version_negotiation(mocker, bolt_version, none_auth): supported_protocols = \ "('3.0', '4.1', '4.2', '4.3', '4.4', '5.0', '5.1', '5.2')" @@ -185,10 +181,7 @@ def test_failing_version_negotiation(mocker, bolt_version): socket_mock.getpeername.return_value = address with pytest.raises(BoltHandshakeError) as exc: - Bolt.open( - address, - auth_manager=neo4j.auth_management.AuthManagers.static(None) - ) + Bolt.open(address, auth_manager=none_auth) assert exc.match(supported_protocols) diff --git a/tests/unit/sync/io/test_class_bolt3.py b/tests/unit/sync/io/test_class_bolt3.py index 2057f2b0..73999db0 100644 --- a/tests/unit/sync/io/test_class_bolt3.py +++ b/tests/unit/sync/io/test_class_bolt3.py @@ -160,7 +160,7 @@ def test_auth_message_raises_configuration_error(message, fake_socket): connection = Bolt3(address, fake_socket(address), PoolConfig.max_connection_lifetime) with pytest.raises(ConfigurationError, - match="Session level authentication is not supported"): + match="User switching is not supported"): getattr(connection, message)() @@ -200,7 +200,7 @@ def test_re_auth(auth1, auth2, fake_socket): connection = Bolt3(address, fake_socket(address), PoolConfig.max_connection_lifetime, auth=auth1) with pytest.raises(ConfigurationError, - match="Session level authentication is not supported"): + match="User switching is not supported"): connection.re_auth(auth2, None) diff --git a/tests/unit/sync/io/test_class_bolt4x0.py b/tests/unit/sync/io/test_class_bolt4x0.py index 79e381fc..6fb416e1 100644 --- a/tests/unit/sync/io/test_class_bolt4x0.py +++ b/tests/unit/sync/io/test_class_bolt4x0.py @@ -256,7 +256,7 @@ def test_auth_message_raises_configuration_error(message, fake_socket): connection = Bolt4x0(address, fake_socket(address), PoolConfig.max_connection_lifetime) with pytest.raises(ConfigurationError, - match="Session level authentication is not supported"): + match="User switching is not supported"): getattr(connection, message)() @@ -296,7 +296,7 @@ def test_re_auth(auth1, auth2, fake_socket): connection = Bolt4x0(address, fake_socket(address), PoolConfig.max_connection_lifetime, auth=auth1) with pytest.raises(ConfigurationError, - match="Session level authentication is not supported"): + match="User switching is not supported"): connection.re_auth(auth2, None) diff --git a/tests/unit/sync/io/test_class_bolt4x1.py b/tests/unit/sync/io/test_class_bolt4x1.py index 3b61d1d3..3ed433d9 100644 --- a/tests/unit/sync/io/test_class_bolt4x1.py +++ b/tests/unit/sync/io/test_class_bolt4x1.py @@ -273,7 +273,7 @@ def test_auth_message_raises_configuration_error(message, fake_socket): connection = Bolt4x1(address, fake_socket(address), PoolConfig.max_connection_lifetime) with pytest.raises(ConfigurationError, - match="Session level authentication is not supported"): + match="User switching is not supported"): getattr(connection, message)() @@ -313,7 +313,7 @@ def test_re_auth(auth1, auth2, fake_socket): connection = Bolt4x1(address, fake_socket(address), PoolConfig.max_connection_lifetime, auth=auth1) with pytest.raises(ConfigurationError, - match="Session level authentication is not supported"): + match="User switching is not supported"): connection.re_auth(auth2, None) diff --git a/tests/unit/sync/io/test_class_bolt4x2.py b/tests/unit/sync/io/test_class_bolt4x2.py index 98a3a975..d9172bad 100644 --- a/tests/unit/sync/io/test_class_bolt4x2.py +++ b/tests/unit/sync/io/test_class_bolt4x2.py @@ -274,7 +274,7 @@ def test_auth_message_raises_configuration_error(message, fake_socket): connection = Bolt4x2(address, fake_socket(address), PoolConfig.max_connection_lifetime) with pytest.raises(ConfigurationError, - match="Session level authentication is not supported"): + match="User switching is not supported"): getattr(connection, message)() @@ -314,7 +314,7 @@ def test_re_auth(auth1, auth2, fake_socket): connection = Bolt4x2(address, fake_socket(address), PoolConfig.max_connection_lifetime, auth=auth1) with pytest.raises(ConfigurationError, - match="Session level authentication is not supported"): + match="User switching is not supported"): connection.re_auth(auth2, None) diff --git a/tests/unit/sync/io/test_class_bolt4x3.py b/tests/unit/sync/io/test_class_bolt4x3.py index 8f47b100..c89d31d9 100644 --- a/tests/unit/sync/io/test_class_bolt4x3.py +++ b/tests/unit/sync/io/test_class_bolt4x3.py @@ -301,7 +301,7 @@ def test_auth_message_raises_configuration_error(message, fake_socket): connection = Bolt4x3(address, fake_socket(address), PoolConfig.max_connection_lifetime) with pytest.raises(ConfigurationError, - match="Session level authentication is not supported"): + match="User switching is not supported"): getattr(connection, message)() @@ -341,7 +341,7 @@ def test_re_auth(auth1, auth2, fake_socket): connection = Bolt4x3(address, fake_socket(address), PoolConfig.max_connection_lifetime, auth=auth1) with pytest.raises(ConfigurationError, - match="Session level authentication is not supported"): + match="User switching is not supported"): connection.re_auth(auth2, None) diff --git a/tests/unit/sync/io/test_class_bolt4x4.py b/tests/unit/sync/io/test_class_bolt4x4.py index 4f0c8faa..be7cc754 100644 --- a/tests/unit/sync/io/test_class_bolt4x4.py +++ b/tests/unit/sync/io/test_class_bolt4x4.py @@ -314,7 +314,7 @@ def test_auth_message_raises_configuration_error(message, fake_socket): connection = Bolt4x4(address, fake_socket(address), PoolConfig.max_connection_lifetime) with pytest.raises(ConfigurationError, - match="Session level authentication is not supported"): + match="User switching is not supported"): getattr(connection, message)() @@ -354,7 +354,7 @@ def test_re_auth(auth1, auth2, fake_socket): connection = Bolt4x4(address, fake_socket(address), PoolConfig.max_connection_lifetime, auth=auth1) with pytest.raises(ConfigurationError, - match="Session level authentication is not supported"): + match="User switching is not supported"): connection.re_auth(auth2, None) diff --git a/tests/unit/sync/io/test_class_bolt5x0.py b/tests/unit/sync/io/test_class_bolt5x0.py index f20cbf83..cf9dab99 100644 --- a/tests/unit/sync/io/test_class_bolt5x0.py +++ b/tests/unit/sync/io/test_class_bolt5x0.py @@ -314,7 +314,7 @@ def test_auth_message_raises_configuration_error(message, fake_socket): connection = Bolt5x0(address, fake_socket(address), PoolConfig.max_connection_lifetime) with pytest.raises(ConfigurationError, - match="Session level authentication is not supported"): + match="User switching is not supported"): getattr(connection, message)() @@ -354,7 +354,7 @@ def test_re_auth(auth1, auth2, fake_socket): connection = Bolt5x0(address, fake_socket(address), PoolConfig.max_connection_lifetime, auth=auth1) with pytest.raises(ConfigurationError, - match="Session level authentication is not supported"): + match="User switching is not supported"): connection.re_auth(auth2, None) diff --git a/tests/unit/sync/io/test_class_bolt5x1.py b/tests/unit/sync/io/test_class_bolt5x1.py index 5d10bb6b..9ef06704 100644 --- a/tests/unit/sync/io/test_class_bolt5x1.py +++ b/tests/unit/sync/io/test_class_bolt5x1.py @@ -287,9 +287,9 @@ def test_logon(fake_socket_pair): @mark_sync_test -def test_re_auth(fake_socket_pair, mocker): +def test_re_auth(fake_socket_pair, mocker, static_auth): auth = neo4j.Auth("basic", "alice123", "supersecret123") - auth_manager = AuthManagers.static(auth) + auth_manager = static_auth(auth) address = neo4j.Address(("127.0.0.1", 7687)) sockets = fake_socket_pair(address, packer_cls=Bolt5x1.PACKER_CLS, diff --git a/tests/unit/sync/io/test_class_bolt5x2.py b/tests/unit/sync/io/test_class_bolt5x2.py index 748bb75a..5eb2092c 100644 --- a/tests/unit/sync/io/test_class_bolt5x2.py +++ b/tests/unit/sync/io/test_class_bolt5x2.py @@ -275,9 +275,9 @@ def test_logon(fake_socket_pair): @mark_sync_test -def test_re_auth(fake_socket_pair, mocker): +def test_re_auth(fake_socket_pair, mocker, static_auth): auth = neo4j.Auth("basic", "alice123", "supersecret123") - auth_manager = AuthManagers.static(auth) + auth_manager = static_auth(auth) address = neo4j.Address(("127.0.0.1", 7687)) sockets = fake_socket_pair(address, packer_cls=Bolt5x2.PACKER_CLS, diff --git a/tests/unit/sync/io/test_direct.py b/tests/unit/sync/io/test_direct.py index a266e9dd..5349f5a2 100644 --- a/tests/unit/sync/io/test_direct.py +++ b/tests/unit/sync/io/test_direct.py @@ -18,6 +18,7 @@ import pytest +from neo4j import PreviewWarning from neo4j._conf import ( Config, PoolConfig, @@ -84,7 +85,7 @@ def timedout(self): class FakeBoltPool(IOPool): def __init__(self, address, *, auth=None, **config): - config["auth"] = AuthManagers.static(None) + config["auth"] = static_auth(None) self.pool_config, self.workspace_config = Config.consume_chain(config, PoolConfig, WorkspaceConfig) if config: raise ValueError("Unexpected config keys: %s" % ", ".join(config.keys())) @@ -103,17 +104,22 @@ def acquire( self.address, auth, timeout, liveness_check_timeout ) +def static_auth(auth): + with pytest.warns(PreviewWarning, match="Auth managers"): + return AuthManagers.static(auth) + +@pytest.fixture +def auth_manager(): + static_auth(("test", "test")) @mark_sync_test -def test_bolt_connection_open(): - auth_manager = AuthManagers.static(("test", "test")) +def test_bolt_connection_open(auth_manager): with pytest.raises(ServiceUnavailable): Bolt.open(("localhost", 9999), auth_manager=auth_manager) @mark_sync_test -def test_bolt_connection_open_timeout(): - auth_manager = AuthManagers.static(("test", "test")) +def test_bolt_connection_open_timeout(auth_manager): with pytest.raises(ServiceUnavailable): Bolt.open( ("localhost", 9999), auth_manager=auth_manager, diff --git a/tests/unit/sync/io/test_neo4j_pool.py b/tests/unit/sync/io/test_neo4j_pool.py index ff9635a8..cfaaf1f3 100644 --- a/tests/unit/sync/io/test_neo4j_pool.py +++ b/tests/unit/sync/io/test_neo4j_pool.py @@ -21,6 +21,7 @@ import pytest from neo4j import ( + PreviewWarning, READ_ACCESS, WRITE_ACCESS, ) @@ -101,10 +102,15 @@ def opener(routing_failure_opener): def _pool_config(): pool_config = PoolConfig() - pool_config.auth = AuthManagers.static(("user", "pass")) + pool_config.auth = _auth_manager(("user", "pass")) return pool_config +def _auth_manager(auth): + with pytest.warns(PreviewWarning, match="Auth managers"): + return AuthManagers.static(auth) + + def _simple_pool(opener) -> Neo4jPool: return Neo4jPool( opener, _pool_config(), WorkspaceConfig(), ROUTER1_ADDRESS @@ -606,7 +612,7 @@ def test_connection_error_callback( opener, error, marks_unauthenticated, fetches_new, mocker ): config = _pool_config() - auth_manager = AuthManagers.static(("user", "auth")) + auth_manager = _auth_manager(("user", "auth")) on_auth_expired_mock = mocker.patch.object(auth_manager, "on_auth_expired", autospec=True) config.auth = auth_manager diff --git a/tests/unit/sync/test_auth_manager.py b/tests/unit/sync/test_auth_manager.py index 0b91714f..15d8b598 100644 --- a/tests/unit/sync/test_auth_manager.py +++ b/tests/unit/sync/test_auth_manager.py @@ -1,3 +1,21 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + import itertools import typing as t @@ -8,7 +26,9 @@ from neo4j import ( Auth, basic_auth, + PreviewWarning, ) +from neo4j._meta import copy_signature from neo4j.auth_management import ( AuthManager, AuthManagers, @@ -27,12 +47,29 @@ ) +@copy_signature(AuthManagers.static) +def static_auth_manager(*args, **kwargs): + with pytest.warns(PreviewWarning, match="Auth managers"): + return AuthManagers.static(*args, **kwargs) + +@copy_signature(AuthManagers.expiration_based) +def expiration_based_auth_manager(*args, **kwargs): + with pytest.warns(PreviewWarning, match="Auth managers"): + return AuthManagers.expiration_based(*args, **kwargs) + + +@copy_signature(ExpiringAuth) +def expiring_auth(*args, **kwargs): + with pytest.warns(PreviewWarning, match="Auth managers"): + return ExpiringAuth(*args, **kwargs) + + @mark_sync_test @pytest.mark.parametrize("auth", SAMPLE_AUTHS) def test_static_manager( auth ) -> None: - manager: AuthManager = AuthManagers.static(auth) + manager: AuthManager = static_auth_manager(auth) assert manager.get_auth() is auth manager.on_auth_expired(("something", "else")) @@ -46,25 +83,25 @@ def test_static_manager( @pytest.mark.parametrize(("auth1", "auth2"), itertools.product(SAMPLE_AUTHS, repeat=2)) @pytest.mark.parametrize("expires_in", (None, -1, 1., 1, 1000.)) -def test_temporal_manager_manual_expiry( +def test_expiration_based_manager_manual_expiry( auth1: t.Union[t.Tuple[str, str], Auth, None], auth2: t.Union[t.Tuple[str, str], Auth, None], expires_in: t.Union[float, int], mocker ) -> None: if expires_in is None or expires_in >= 0: - temporal_auth = ExpiringAuth(auth1, expires_in) + temporal_auth = expiring_auth(auth1, expires_in) else: - temporal_auth = ExpiringAuth(auth1) + temporal_auth = expiring_auth(auth1) provider = mocker.MagicMock(return_value=temporal_auth) - manager: AuthManager = AuthManagers.expiration_based(provider) + manager: AuthManager = expiration_based_auth_manager(provider) provider.assert_not_called() assert manager.get_auth() is auth1 provider.assert_called_once() provider.reset_mock() - provider.return_value = ExpiringAuth(auth2) + provider.return_value = expiring_auth(auth2) manager.on_auth_expired(("something", "else")) assert manager.get_auth() is auth1 @@ -81,7 +118,7 @@ def test_temporal_manager_manual_expiry( @pytest.mark.parametrize(("auth1", "auth2"), itertools.product(SAMPLE_AUTHS, repeat=2)) @pytest.mark.parametrize("expires_in", (None, -1, 1., 1, 1000.)) -def test_temporal_manager_time_expiry( +def test_expiration_based_manager_time_expiry( auth1: t.Union[t.Tuple[str, str], Auth, None], auth2: t.Union[t.Tuple[str, str], Auth, None], expires_in: t.Union[float, int, None], @@ -90,18 +127,18 @@ def test_temporal_manager_time_expiry( with freeze_time() as frozen_time: assert isinstance(frozen_time, FrozenDateTimeFactory) if expires_in is None or expires_in >= 0: - temporal_auth = ExpiringAuth(auth1, expires_in) + temporal_auth = expiring_auth(auth1, expires_in) else: - temporal_auth = ExpiringAuth(auth1) + temporal_auth = expiring_auth(auth1) provider = mocker.MagicMock(return_value=temporal_auth) - manager: AuthManager = AuthManagers.expiration_based(provider) + manager: AuthManager = expiration_based_auth_manager(provider) provider.assert_not_called() assert manager.get_auth() is auth1 provider.assert_called_once() provider.reset_mock() - provider.return_value = ExpiringAuth(auth2) + provider.return_value = expiring_auth(auth2) if expires_in is None or expires_in < 0: frozen_time.tick(1_000_000) diff --git a/tests/unit/sync/test_driver.py b/tests/unit/sync/test_driver.py index 8809c0bf..10b7c6e1 100644 --- a/tests/unit/sync/test_driver.py +++ b/tests/unit/sync/test_driver.py @@ -970,7 +970,8 @@ def test_execute_query_result_transformer( with assert_warns_execute_query_bmm_experimental(): bmm = driver.query_bookmark_manager res_custom = driver.execute_query( - "", None, "w", None, None, bmm, result_transformer + "", None, "w", None, None, bmm, None, + result_transformer ) else: res_custom = driver.execute_query( From 06f0e4f8e8ace9b596a841b92c2b38616e7f83fa Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Wed, 5 Apr 2023 12:38:26 +0200 Subject: [PATCH 20/23] Update TestKit protocol --- testkitbackend/_async/requests.py | 2 +- testkitbackend/_sync/requests.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/testkitbackend/_async/requests.py b/testkitbackend/_async/requests.py index c2d2aeea..7493c7e5 100644 --- a/testkitbackend/_async/requests.py +++ b/testkitbackend/_async/requests.py @@ -312,7 +312,7 @@ async def CheckMultiDBSupport(backend, data): async def VerifyAuthentication(backend, data): driver_id = data["driverId"] driver = backend.drivers[driver_id] - auth = fromtestkit.to_auth_token(data, "auth_token") + auth = fromtestkit.to_auth_token(data, "authorizationToken") with warning_check(neo4j.PreviewWarning, "User switching is a preview feature."): authenticated = await driver.verify_authentication(auth=auth) diff --git a/testkitbackend/_sync/requests.py b/testkitbackend/_sync/requests.py index d3e13377..258b116a 100644 --- a/testkitbackend/_sync/requests.py +++ b/testkitbackend/_sync/requests.py @@ -312,7 +312,7 @@ def CheckMultiDBSupport(backend, data): def VerifyAuthentication(backend, data): driver_id = data["driverId"] driver = backend.drivers[driver_id] - auth = fromtestkit.to_auth_token(data, "auth_token") + auth = fromtestkit.to_auth_token(data, "authorizationToken") with warning_check(neo4j.PreviewWarning, "User switching is a preview feature."): authenticated = driver.verify_authentication(auth=auth) From 65581647a2b30b1b8ac5606e5676ab3f0ab2500d Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Thu, 6 Apr 2023 16:00:16 +0200 Subject: [PATCH 21/23] Fix example integration test --- tests/integration/examples/test_bearer_auth_example.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/integration/examples/test_bearer_auth_example.py b/tests/integration/examples/test_bearer_auth_example.py index bb6988df..0c29e0e9 100644 --- a/tests/integration/examples/test_bearer_auth_example.py +++ b/tests/integration/examples/test_bearer_auth_example.py @@ -17,6 +17,7 @@ import neo4j +from neo4j._sync.auth_management import StaticAuthManager from . import DriverSetupExample @@ -54,7 +55,8 @@ def test_example(uri, mocker): assert len(calls) == 1 args_, kwargs = calls[0] auth = kwargs.get("auth") - assert isinstance(auth, neo4j.Auth) + assert isinstance(auth, StaticAuthManager) + auth = auth._auth assert auth.scheme == "bearer" assert not hasattr(auth, "principal") assert auth.credentials == token From ac6e46a250d22976d2913f92034db54a54f99368 Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Tue, 11 Apr 2023 11:56:54 +0200 Subject: [PATCH 22/23] Do not allow sync auth managers in async driver --- src/neo4j/_async/driver.py | 5 +---- src/neo4j/_sync/driver.py | 4 +--- 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/src/neo4j/_async/driver.py b/src/neo4j/_async/driver.py index 88de88a6..81180bb7 100644 --- a/src/neo4j/_async/driver.py +++ b/src/neo4j/_async/driver.py @@ -79,7 +79,6 @@ from ..auth_management import ( AsyncAuthManager, AsyncAuthManagers, - AuthManager, ) from ..exceptions import Neo4jError from .bookmark_manager import ( @@ -132,7 +131,6 @@ def driver( # TAuth, t.Union[t.Tuple[t.Any, t.Any], Auth, None], AsyncAuthManager, - AuthManager ] = ..., max_connection_lifetime: float = ..., max_connection_pool_size: int = ..., @@ -183,7 +181,6 @@ def driver( # TAuth, t.Union[t.Tuple[t.Any, t.Any], Auth, None], AsyncAuthManager, - AuthManager ] = None, **config ) -> AsyncDriver: @@ -200,7 +197,7 @@ def driver( driver_type, security_type, parsed = parse_neo4j_uri(uri) - if not isinstance(auth, (AsyncAuthManager, AuthManager)): + if not isinstance(auth, AsyncAuthManager): with warnings.catch_warnings(): warnings.filterwarnings( "ignore", message=r".*\bAuth managers\b.*", diff --git a/src/neo4j/_sync/driver.py b/src/neo4j/_sync/driver.py index e9d9f046..69ccc904 100644 --- a/src/neo4j/_sync/driver.py +++ b/src/neo4j/_sync/driver.py @@ -130,7 +130,6 @@ def driver( # TAuth, t.Union[t.Tuple[t.Any, t.Any], Auth, None], AuthManager, - AuthManager ] = ..., max_connection_lifetime: float = ..., max_connection_pool_size: int = ..., @@ -181,7 +180,6 @@ def driver( # TAuth, t.Union[t.Tuple[t.Any, t.Any], Auth, None], AuthManager, - AuthManager ] = None, **config ) -> Driver: @@ -198,7 +196,7 @@ def driver( driver_type, security_type, parsed = parse_neo4j_uri(uri) - if not isinstance(auth, (AuthManager, AuthManager)): + if not isinstance(auth, AuthManager): with warnings.catch_warnings(): warnings.filterwarnings( "ignore", message=r".*\bAuth managers\b.*", From 7065b8456b0516d1cc871aba32f9d7faa33b1f78 Mon Sep 17 00:00:00 2001 From: Rouven Bauer Date: Tue, 11 Apr 2023 12:56:01 +0200 Subject: [PATCH 23/23] Clearer async auth API doc Signed-off-by: Andy Heap --- docs/source/async_api.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/async_api.rst b/docs/source/async_api.rst index 603a9c97..35e08a13 100644 --- a/docs/source/async_api.rst +++ b/docs/source/async_api.rst @@ -123,8 +123,8 @@ Each supported scheme maps to a particular :class:`neo4j.AsyncDriver` subclass t Async Auth ========== -Authentication mostly works the same as in the synchronous driver. -However, there are async equivalents of the synchronous constructs. +Authentication works the same as in the synchronous driver. +With the exception that when using AuthManagers, their asynchronous equivalents have to be used. .. autoclass:: neo4j.auth_management.AsyncAuthManager :members: