diff --git a/src/neo4j/_async/io/_bolt.py b/src/neo4j/_async/io/_bolt.py index 6acfafd0..f39fa1c8 100644 --- a/src/neo4j/_async/io/_bolt.py +++ b/src/neo4j/_async/io/_bolt.py @@ -20,6 +20,7 @@ import abc import asyncio +import typing as t from collections import deque from logging import getLogger from time import perf_counter @@ -74,6 +75,16 @@ def failed(self): ... +class ClientStateManagerBase(abc.ABC): + @abc.abstractmethod + def __init__(self, init_state, on_change=None): + ... + + @abc.abstractmethod + def transition(self, message): + ... + + class AsyncBolt: """ Server connection for Bolt protocol. @@ -103,12 +114,13 @@ class AsyncBolt: # When the connection was last put back into the pool idle_since = float("-inf") + # The database name the connection was last used with + # (BEGIN for explicit transactions, RUN for auto-commit transactions) + last_database: t.Optional[str] = None # The socket _closing = False _closed = False - - # The socket _defunct = False #: The pool of which this connection is a member @@ -173,6 +185,10 @@ def __del__(self): def _get_server_state_manager(self) -> ServerStateManagerBase: ... + @abc.abstractmethod + def _get_client_state_manager(self) -> ClientStateManagerBase: + ... + @classmethod def _to_auth_dict(cls, auth): # Determine auth details @@ -753,6 +769,8 @@ def _append(self, signature, fields=(), response=None, """ self.outbox.append_message(signature, fields, dehydration_hooks) self.responses.append(response) + if response: + self._get_client_state_manager().transition(response.message) async def _send_all(self): if await self.outbox.flush(): diff --git a/src/neo4j/_async/io/_bolt3.py b/src/neo4j/_async/io/_bolt3.py index 3e948b02..d272c30d 100644 --- a/src/neo4j/_async/io/_bolt3.py +++ b/src/neo4j/_async/io/_bolt3.py @@ -38,6 +38,7 @@ ) from ._bolt import ( AsyncBolt, + ClientStateManagerBase, ServerStateManagerBase, tx_timeout_as_ms, ) @@ -52,7 +53,7 @@ log = getLogger("neo4j") -class ServerStates(Enum): +class BoltStates(Enum): CONNECTED = "CONNECTED" READY = "READY" STREAMING = "STREAMING" @@ -62,25 +63,25 @@ class ServerStates(Enum): class ServerStateManager(ServerStateManagerBase): _STATE_TRANSITIONS: t.Dict[Enum, t.Dict[str, Enum]] = { - ServerStates.CONNECTED: { - "hello": ServerStates.READY, + BoltStates.CONNECTED: { + "hello": BoltStates.READY, }, - ServerStates.READY: { - "run": ServerStates.STREAMING, - "begin": ServerStates.TX_READY_OR_TX_STREAMING, + BoltStates.READY: { + "run": BoltStates.STREAMING, + "begin": BoltStates.TX_READY_OR_TX_STREAMING, }, - ServerStates.STREAMING: { - "pull": ServerStates.READY, - "discard": ServerStates.READY, - "reset": ServerStates.READY, + BoltStates.STREAMING: { + "pull": BoltStates.READY, + "discard": BoltStates.READY, + "reset": BoltStates.READY, }, - ServerStates.TX_READY_OR_TX_STREAMING: { - "commit": ServerStates.READY, - "rollback": ServerStates.READY, - "reset": ServerStates.READY, + BoltStates.TX_READY_OR_TX_STREAMING: { + "commit": BoltStates.READY, + "rollback": BoltStates.READY, + "reset": BoltStates.READY, }, - ServerStates.FAILED: { - "reset": ServerStates.READY, + BoltStates.FAILED: { + "reset": BoltStates.READY, } } @@ -99,7 +100,40 @@ def transition(self, message, metadata): self._on_change(state_before, self.state) def failed(self): - return self.state == ServerStates.FAILED + return self.state == BoltStates.FAILED + + +class ClientStateManager(ClientStateManagerBase): + _STATE_TRANSITIONS: t.Dict[Enum, t.Dict[str, Enum]] = { + BoltStates.CONNECTED: { + "hello": BoltStates.READY, + }, + BoltStates.READY: { + "run": BoltStates.STREAMING, + "begin": BoltStates.TX_READY_OR_TX_STREAMING, + }, + BoltStates.STREAMING: { + "begin": BoltStates.TX_READY_OR_TX_STREAMING, + "reset": BoltStates.READY, + }, + BoltStates.TX_READY_OR_TX_STREAMING: { + "commit": BoltStates.READY, + "rollback": BoltStates.READY, + "reset": BoltStates.READY, + }, + } + + def __init__(self, init_state, on_change=None): + self.state = init_state + self._on_change = on_change + + def transition(self, message): + state_before = self.state + self.state = self._STATE_TRANSITIONS \ + .get(self.state, {}) \ + .get(message, self.state) + if state_before != self.state and callable(self._on_change): + self._on_change(state_before, self.state) class AsyncBolt3(AsyncBolt): @@ -121,25 +155,34 @@ class AsyncBolt3(AsyncBolt): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._server_state_manager = ServerStateManager( - ServerStates.CONNECTED, on_change=self._on_server_state_change + BoltStates.CONNECTED, on_change=self._on_server_state_change + ) + self._client_state_manager = ClientStateManager( + BoltStates.CONNECTED, on_change=self._on_client_state_change ) def _on_server_state_change(self, old_state, new_state): - log.debug("[#%04X] _: state: %s > %s", self.local_port, - old_state.name, new_state.name) + log.debug("[#%04X] _: server state: %s > %s", + self.local_port, old_state.name, new_state.name) def _get_server_state_manager(self) -> ServerStateManagerBase: return self._server_state_manager + def _on_client_state_change(self, old_state, new_state): + log.debug("[#%04X] _: client state: %s > %s", + self.local_port, old_state.name, new_state.name) + + def _get_client_state_manager(self) -> ClientStateManagerBase: + return self._client_state_manager + @property def is_reset(self): # We can't be sure of the server's state if there are still pending # responses. Unless the last message we sent was RESET. In that case # the server state will always be READY when we're done. - if (self.responses and self.responses[-1] - and self.responses[-1].message == "reset"): - return True - return self._server_state_manager.state == ServerStates.READY + if self.responses: + return self.responses[-1] and self.responses[-1].message == "reset" + return self._server_state_manager.state == BoltStates.READY @property def encrypted(self): @@ -216,7 +259,7 @@ async def route( hydration_hooks=hydration_hooks, on_success=metadata.update ) - self.pull(dehydration_hooks = None, hydration_hooks = None, + self.pull(dehydration_hooks=None, hydration_hooks=None, on_success=metadata.update, on_records=records.extend) await self.send_all() await self.fetch_all() @@ -398,7 +441,7 @@ async def _process_message(self, tag, fields): await response.on_ignored(summary_metadata or {}) elif summary_signature == b"\x7F": log.debug("[#%04X] S: FAILURE %r", self.local_port, summary_metadata) - self._server_state_manager.state = ServerStates.FAILED + self._server_state_manager.state = BoltStates.FAILED try: await response.on_failure(summary_metadata or {}) except (ServiceUnavailable, DatabaseUnavailable): @@ -408,7 +451,8 @@ async def _process_message(self, tag, fields): except (NotALeader, ForbiddenOnReadOnlyDatabase): if self.pool: await self.pool.on_write_failure( - address=self.unresolved_address + address=self.unresolved_address, + database=self.last_database, ) raise except Neo4jError as e: diff --git a/src/neo4j/_async/io/_bolt4.py b/src/neo4j/_async/io/_bolt4.py index f4f43378..149fd786 100644 --- a/src/neo4j/_async/io/_bolt4.py +++ b/src/neo4j/_async/io/_bolt4.py @@ -35,12 +35,14 @@ ) from ._bolt import ( AsyncBolt, + ClientStateManagerBase, ServerStateManagerBase, tx_timeout_as_ms, ) from ._bolt3 import ( + BoltStates, + ClientStateManager, ServerStateManager, - ServerStates, ) from ._common import ( check_supported_server_product, @@ -72,25 +74,34 @@ class AsyncBolt4x0(AsyncBolt): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._server_state_manager = ServerStateManager( - ServerStates.CONNECTED, on_change=self._on_server_state_change + BoltStates.CONNECTED, on_change=self._on_server_state_change + ) + self._client_state_manager = ClientStateManager( + BoltStates.CONNECTED, on_change=self._on_client_state_change ) def _on_server_state_change(self, old_state, new_state): - log.debug("[#%04X] _: state: %s > %s", self.local_port, - old_state.name, new_state.name) + log.debug("[#%04X] _: server state: %s > %s", + self.local_port, old_state.name, new_state.name) def _get_server_state_manager(self) -> ServerStateManagerBase: return self._server_state_manager + def _on_client_state_change(self, old_state, new_state): + log.debug("[#%04X] _: client state: %s > %s", + self.local_port, old_state.name, new_state.name) + + def _get_client_state_manager(self) -> ClientStateManagerBase: + return self._client_state_manager + @property def is_reset(self): # We can't be sure of the server's state if there are still pending # responses. Unless the last message we sent was RESET. In that case # the server state will always be READY when we're done. - if (self.responses and self.responses[-1] - and self.responses[-1].message == "reset"): - return True - return self._server_state_manager.state == ServerStates.READY + if self.responses: + return self.responses[-1] and self.responses[-1].message == "reset" + return self._server_state_manager.state == BoltStates.READY @property def encrypted(self): @@ -202,6 +213,8 @@ def run(self, query, parameters=None, mode=None, bookmarks=None, extra["mode"] = "r" # It will default to mode "w" if nothing is specified if db: extra["db"] = db + if self._client_state_manager.state != BoltStates.TX_READY_OR_TX_STREAMING: + self.last_database = db if bookmarks: try: extra["bookmarks"] = list(bookmarks) @@ -261,6 +274,7 @@ def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, extra["mode"] = "r" # It will default to mode "w" if nothing is specified if db: extra["db"] = db + self.last_database = db if bookmarks: try: extra["bookmarks"] = list(bookmarks) @@ -347,7 +361,7 @@ async def _process_message(self, tag, fields): await response.on_ignored(summary_metadata or {}) elif summary_signature == b"\x7F": log.debug("[#%04X] S: FAILURE %r", self.local_port, summary_metadata) - self._server_state_manager.state = ServerStates.FAILED + self._server_state_manager.state = BoltStates.FAILED try: await response.on_failure(summary_metadata or {}) except (ServiceUnavailable, DatabaseUnavailable): @@ -357,7 +371,8 @@ async def _process_message(self, tag, fields): except (NotALeader, ForbiddenOnReadOnlyDatabase): if self.pool: await self.pool.on_write_failure( - address=self.unresolved_address + address=self.unresolved_address, + database=self.last_database ) raise except Neo4jError as e: @@ -535,6 +550,11 @@ def run(self, query, parameters=None, mode=None, bookmarks=None, extra["mode"] = "r" if db: extra["db"] = db + if ( + self._client_state_manager.state + != BoltStates.TX_READY_OR_TX_STREAMING + ): + self.last_database = db if imp_user: extra["imp_user"] = imp_user if bookmarks: @@ -571,6 +591,7 @@ def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, extra["mode"] = "r" if db: extra["db"] = db + self.last_database = db if imp_user: extra["imp_user"] = imp_user if bookmarks: diff --git a/src/neo4j/_async/io/_bolt5.py b/src/neo4j/_async/io/_bolt5.py index 89a06f99..d9b6071e 100644 --- a/src/neo4j/_async/io/_bolt5.py +++ b/src/neo4j/_async/io/_bolt5.py @@ -37,12 +37,14 @@ ) from ._bolt import ( AsyncBolt, + ClientStateManagerBase, ServerStateManagerBase, tx_timeout_as_ms, ) from ._bolt3 import ( + BoltStates, + ClientStateManager, ServerStateManager, - ServerStates, ) from ._common import ( check_supported_server_product, @@ -71,31 +73,41 @@ class AsyncBolt5x0(AsyncBolt): supports_notification_filtering = False - server_states: t.Any = ServerStates + bolt_states: t.Any = BoltStates def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._server_state_manager = ServerStateManager( - self.server_states.CONNECTED, + self.bolt_states.CONNECTED, on_change=self._on_server_state_change ) + self._client_state_manager = ClientStateManager( + self.bolt_states.CONNECTED, + on_change=self._on_client_state_change + ) def _on_server_state_change(self, old_state, new_state): - log.debug("[#%04X] _: state: %s > %s", self.local_port, - old_state.name, new_state.name) + log.debug("[#%04X] _: server state: %s > %s", + self.local_port, old_state.name, new_state.name) def _get_server_state_manager(self) -> ServerStateManagerBase: return self._server_state_manager + def _on_client_state_change(self, old_state, new_state): + log.debug("[#%04X] _: client state: %s > %s", + self.local_port, old_state.name, new_state.name) + + def _get_client_state_manager(self) -> ClientStateManagerBase: + return self._client_state_manager + @property def is_reset(self): # We can't be sure of the server's state if there are still pending # responses. Unless the last message we sent was RESET. In that case # the server state will always be READY when we're done. - if (self.responses and self.responses[-1] - and self.responses[-1].message == "reset"): - return True - return self._server_state_manager.state == self.server_states.READY + if self.responses: + return self.responses[-1] and self.responses[-1].message == "reset" + return self._server_state_manager.state == self.bolt_states.READY @property def encrypted(self): @@ -197,6 +209,11 @@ def run(self, query, parameters=None, mode=None, bookmarks=None, extra["mode"] = "r" if db: extra["db"] = db + if ( + self._client_state_manager.state + != self.bolt_states.TX_READY_OR_TX_STREAMING + ): + self.last_database = db if imp_user: extra["imp_user"] = imp_user if bookmarks: @@ -253,6 +270,7 @@ def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, extra["mode"] = "r" if db: extra["db"] = db + self.last_database = db if imp_user: extra["imp_user"] = imp_user if bookmarks: @@ -347,7 +365,7 @@ async def _process_message(self, tag, fields): elif summary_signature == b"\x7F": log.debug("[#%04X] S: FAILURE %r", self.local_port, summary_metadata) - self._server_state_manager.state = self.server_states.FAILED + self._server_state_manager.state = self.bolt_states.FAILED try: await response.on_failure(summary_metadata or {}) except (ServiceUnavailable, DatabaseUnavailable): @@ -357,7 +375,8 @@ async def _process_message(self, tag, fields): except (NotALeader, ForbiddenOnReadOnlyDatabase): if self.pool: await self.pool.on_write_failure( - address=self.unresolved_address + address=self.unresolved_address, + database=self.last_database ) raise except Neo4jError as e: @@ -374,7 +393,7 @@ async def _process_message(self, tag, fields): return len(details), 1 -class ServerStates5x1(Enum): +class BoltStates5x1(Enum): CONNECTED = "CONNECTED" READY = "READY" STREAMING = "STREAMING" @@ -385,35 +404,60 @@ class ServerStates5x1(Enum): class ServerStateManager5x1(ServerStateManager): _STATE_TRANSITIONS = { # type: ignore - ServerStates5x1.CONNECTED: { - "hello": ServerStates5x1.AUTHENTICATION, + BoltStates5x1.CONNECTED: { + "hello": BoltStates5x1.AUTHENTICATION, }, - ServerStates5x1.AUTHENTICATION: { - "logon": ServerStates5x1.READY, + BoltStates5x1.AUTHENTICATION: { + "logon": BoltStates5x1.READY, }, - ServerStates5x1.READY: { - "run": ServerStates5x1.STREAMING, - "begin": ServerStates5x1.TX_READY_OR_TX_STREAMING, - "logoff": ServerStates5x1.AUTHENTICATION, + BoltStates5x1.READY: { + "run": BoltStates5x1.STREAMING, + "begin": BoltStates5x1.TX_READY_OR_TX_STREAMING, + "logoff": BoltStates5x1.AUTHENTICATION, }, - ServerStates5x1.STREAMING: { - "pull": ServerStates5x1.READY, - "discard": ServerStates5x1.READY, - "reset": ServerStates5x1.READY, + BoltStates5x1.STREAMING: { + "pull": BoltStates5x1.READY, + "discard": BoltStates5x1.READY, + "reset": BoltStates5x1.READY, }, - ServerStates5x1.TX_READY_OR_TX_STREAMING: { - "commit": ServerStates5x1.READY, - "rollback": ServerStates5x1.READY, - "reset": ServerStates5x1.READY, + BoltStates5x1.TX_READY_OR_TX_STREAMING: { + "commit": BoltStates5x1.READY, + "rollback": BoltStates5x1.READY, + "reset": BoltStates5x1.READY, }, - ServerStates5x1.FAILED: { - "reset": ServerStates5x1.READY, + BoltStates5x1.FAILED: { + "reset": BoltStates5x1.READY, } } - def failed(self): - return self.state == ServerStates5x1.FAILED + return self.state == BoltStates5x1.FAILED + + +class ClientStateManager5x1(ClientStateManager): + _STATE_TRANSITIONS = { # type: ignore + BoltStates5x1.CONNECTED: { + "hello": BoltStates5x1.AUTHENTICATION, + }, + BoltStates5x1.AUTHENTICATION: { + "logon": BoltStates5x1.READY, + }, + BoltStates5x1.READY: { + "run": BoltStates5x1.STREAMING, + "begin": BoltStates5x1.TX_READY_OR_TX_STREAMING, + "logoff": BoltStates5x1.AUTHENTICATION, + }, + BoltStates5x1.STREAMING: { + "begin": BoltStates5x1.TX_READY_OR_TX_STREAMING, + "logoff": BoltStates5x1.AUTHENTICATION, + "reset": BoltStates5x1.READY, + }, + BoltStates5x1.TX_READY_OR_TX_STREAMING: { + "commit": BoltStates5x1.READY, + "rollback": BoltStates5x1.READY, + "reset": BoltStates5x1.READY, + }, + } class AsyncBolt5x1(AsyncBolt5x0): @@ -423,12 +467,15 @@ class AsyncBolt5x1(AsyncBolt5x0): supports_re_auth = True - server_states = ServerStates5x1 + bolt_states = BoltStates5x1 def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._server_state_manager = ServerStateManager5x1( - ServerStates5x1.CONNECTED, on_change=self._on_server_state_change + BoltStates5x1.CONNECTED, on_change=self._on_server_state_change + ) + self._client_state_manager = ClientStateManager5x1( + BoltStates5x1.CONNECTED, on_change=self._on_client_state_change ) async def hello(self, dehydration_hooks=None, hydration_hooks=None): @@ -541,6 +588,11 @@ def run(self, query, parameters=None, mode=None, bookmarks=None, extra["mode"] = "r" if db: extra["db"] = db + if ( + self._client_state_manager.state + != self.bolt_states.TX_READY_OR_TX_STREAMING + ): + self.last_database = db if imp_user: extra["imp_user"] = imp_user if notifications_min_severity is not None: @@ -578,6 +630,7 @@ def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, extra["mode"] = "r" if db: extra["db"] = db + self.last_database = db if imp_user: extra["imp_user"] = imp_user if bookmarks: diff --git a/src/neo4j/_async/io/_pool.py b/src/neo4j/_async/io/_pool.py index 193329f4..218b49a4 100644 --- a/src/neo4j/_async/io/_pool.py +++ b/src/neo4j/_async/io/_pool.py @@ -447,7 +447,7 @@ async def deactivate(self, address): await self._close_connections(closable_connections) - async def on_write_failure(self, address): + async def on_write_failure(self, address, database): raise WriteServiceUnavailable( "No write service available for pool {}".format(self) ) @@ -956,12 +956,13 @@ async def deactivate(self, address): log.debug("[#0000] _: table=%r", self.routing_tables) await super(AsyncNeo4jPool, self).deactivate(address) - async def on_write_failure(self, address): + async def on_write_failure(self, address, database): """ Remove a writer address from the routing table, if present. """ - # FIXME: only need to remove the writer for a specific database - log.debug("[#0000] _: removing writer %r", address) + log.debug("[#0000] _: removing writer %r for database %r", + address, database) async with self.refresh_lock: - for database in self.routing_tables.keys(): + table = self.routing_tables.get(database) + if table is not None: self.routing_tables[database].writers.discard(address) log.debug("[#0000] _: table=%r", self.routing_tables) diff --git a/src/neo4j/_sync/io/_bolt.py b/src/neo4j/_sync/io/_bolt.py index e8e620e6..24336fa8 100644 --- a/src/neo4j/_sync/io/_bolt.py +++ b/src/neo4j/_sync/io/_bolt.py @@ -20,6 +20,7 @@ import abc import asyncio +import typing as t from collections import deque from logging import getLogger from time import perf_counter @@ -74,6 +75,16 @@ def failed(self): ... +class ClientStateManagerBase(abc.ABC): + @abc.abstractmethod + def __init__(self, init_state, on_change=None): + ... + + @abc.abstractmethod + def transition(self, message): + ... + + class Bolt: """ Server connection for Bolt protocol. @@ -103,12 +114,13 @@ class Bolt: # When the connection was last put back into the pool idle_since = float("-inf") + # The database name the connection was last used with + # (BEGIN for explicit transactions, RUN for auto-commit transactions) + last_database: t.Optional[str] = None # The socket _closing = False _closed = False - - # The socket _defunct = False #: The pool of which this connection is a member @@ -173,6 +185,10 @@ def __del__(self): def _get_server_state_manager(self) -> ServerStateManagerBase: ... + @abc.abstractmethod + def _get_client_state_manager(self) -> ClientStateManagerBase: + ... + @classmethod def _to_auth_dict(cls, auth): # Determine auth details @@ -753,6 +769,8 @@ def _append(self, signature, fields=(), response=None, """ self.outbox.append_message(signature, fields, dehydration_hooks) self.responses.append(response) + if response: + self._get_client_state_manager().transition(response.message) def _send_all(self): if self.outbox.flush(): diff --git a/src/neo4j/_sync/io/_bolt3.py b/src/neo4j/_sync/io/_bolt3.py index 06fccadb..eff57949 100644 --- a/src/neo4j/_sync/io/_bolt3.py +++ b/src/neo4j/_sync/io/_bolt3.py @@ -38,6 +38,7 @@ ) from ._bolt import ( Bolt, + ClientStateManagerBase, ServerStateManagerBase, tx_timeout_as_ms, ) @@ -52,7 +53,7 @@ log = getLogger("neo4j") -class ServerStates(Enum): +class BoltStates(Enum): CONNECTED = "CONNECTED" READY = "READY" STREAMING = "STREAMING" @@ -62,25 +63,25 @@ class ServerStates(Enum): class ServerStateManager(ServerStateManagerBase): _STATE_TRANSITIONS: t.Dict[Enum, t.Dict[str, Enum]] = { - ServerStates.CONNECTED: { - "hello": ServerStates.READY, + BoltStates.CONNECTED: { + "hello": BoltStates.READY, }, - ServerStates.READY: { - "run": ServerStates.STREAMING, - "begin": ServerStates.TX_READY_OR_TX_STREAMING, + BoltStates.READY: { + "run": BoltStates.STREAMING, + "begin": BoltStates.TX_READY_OR_TX_STREAMING, }, - ServerStates.STREAMING: { - "pull": ServerStates.READY, - "discard": ServerStates.READY, - "reset": ServerStates.READY, + BoltStates.STREAMING: { + "pull": BoltStates.READY, + "discard": BoltStates.READY, + "reset": BoltStates.READY, }, - ServerStates.TX_READY_OR_TX_STREAMING: { - "commit": ServerStates.READY, - "rollback": ServerStates.READY, - "reset": ServerStates.READY, + BoltStates.TX_READY_OR_TX_STREAMING: { + "commit": BoltStates.READY, + "rollback": BoltStates.READY, + "reset": BoltStates.READY, }, - ServerStates.FAILED: { - "reset": ServerStates.READY, + BoltStates.FAILED: { + "reset": BoltStates.READY, } } @@ -99,7 +100,40 @@ def transition(self, message, metadata): self._on_change(state_before, self.state) def failed(self): - return self.state == ServerStates.FAILED + return self.state == BoltStates.FAILED + + +class ClientStateManager(ClientStateManagerBase): + _STATE_TRANSITIONS: t.Dict[Enum, t.Dict[str, Enum]] = { + BoltStates.CONNECTED: { + "hello": BoltStates.READY, + }, + BoltStates.READY: { + "run": BoltStates.STREAMING, + "begin": BoltStates.TX_READY_OR_TX_STREAMING, + }, + BoltStates.STREAMING: { + "begin": BoltStates.TX_READY_OR_TX_STREAMING, + "reset": BoltStates.READY, + }, + BoltStates.TX_READY_OR_TX_STREAMING: { + "commit": BoltStates.READY, + "rollback": BoltStates.READY, + "reset": BoltStates.READY, + }, + } + + def __init__(self, init_state, on_change=None): + self.state = init_state + self._on_change = on_change + + def transition(self, message): + state_before = self.state + self.state = self._STATE_TRANSITIONS \ + .get(self.state, {}) \ + .get(message, self.state) + if state_before != self.state and callable(self._on_change): + self._on_change(state_before, self.state) class Bolt3(Bolt): @@ -121,25 +155,34 @@ class Bolt3(Bolt): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._server_state_manager = ServerStateManager( - ServerStates.CONNECTED, on_change=self._on_server_state_change + BoltStates.CONNECTED, on_change=self._on_server_state_change + ) + self._client_state_manager = ClientStateManager( + BoltStates.CONNECTED, on_change=self._on_client_state_change ) def _on_server_state_change(self, old_state, new_state): - log.debug("[#%04X] _: state: %s > %s", self.local_port, - old_state.name, new_state.name) + log.debug("[#%04X] _: server state: %s > %s", + self.local_port, old_state.name, new_state.name) def _get_server_state_manager(self) -> ServerStateManagerBase: return self._server_state_manager + def _on_client_state_change(self, old_state, new_state): + log.debug("[#%04X] _: client state: %s > %s", + self.local_port, old_state.name, new_state.name) + + def _get_client_state_manager(self) -> ClientStateManagerBase: + return self._client_state_manager + @property def is_reset(self): # We can't be sure of the server's state if there are still pending # responses. Unless the last message we sent was RESET. In that case # the server state will always be READY when we're done. - if (self.responses and self.responses[-1] - and self.responses[-1].message == "reset"): - return True - return self._server_state_manager.state == ServerStates.READY + if self.responses: + return self.responses[-1] and self.responses[-1].message == "reset" + return self._server_state_manager.state == BoltStates.READY @property def encrypted(self): @@ -216,7 +259,7 @@ def route( hydration_hooks=hydration_hooks, on_success=metadata.update ) - self.pull(dehydration_hooks = None, hydration_hooks = None, + self.pull(dehydration_hooks=None, hydration_hooks=None, on_success=metadata.update, on_records=records.extend) self.send_all() self.fetch_all() @@ -398,7 +441,7 @@ def _process_message(self, tag, fields): response.on_ignored(summary_metadata or {}) elif summary_signature == b"\x7F": log.debug("[#%04X] S: FAILURE %r", self.local_port, summary_metadata) - self._server_state_manager.state = ServerStates.FAILED + self._server_state_manager.state = BoltStates.FAILED try: response.on_failure(summary_metadata or {}) except (ServiceUnavailable, DatabaseUnavailable): @@ -408,7 +451,8 @@ def _process_message(self, tag, fields): except (NotALeader, ForbiddenOnReadOnlyDatabase): if self.pool: self.pool.on_write_failure( - address=self.unresolved_address + address=self.unresolved_address, + database=self.last_database, ) raise except Neo4jError as e: diff --git a/src/neo4j/_sync/io/_bolt4.py b/src/neo4j/_sync/io/_bolt4.py index 4eea7300..44f012f1 100644 --- a/src/neo4j/_sync/io/_bolt4.py +++ b/src/neo4j/_sync/io/_bolt4.py @@ -35,12 +35,14 @@ ) from ._bolt import ( Bolt, + ClientStateManagerBase, ServerStateManagerBase, tx_timeout_as_ms, ) from ._bolt3 import ( + BoltStates, + ClientStateManager, ServerStateManager, - ServerStates, ) from ._common import ( check_supported_server_product, @@ -72,25 +74,34 @@ class Bolt4x0(Bolt): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._server_state_manager = ServerStateManager( - ServerStates.CONNECTED, on_change=self._on_server_state_change + BoltStates.CONNECTED, on_change=self._on_server_state_change + ) + self._client_state_manager = ClientStateManager( + BoltStates.CONNECTED, on_change=self._on_client_state_change ) def _on_server_state_change(self, old_state, new_state): - log.debug("[#%04X] _: state: %s > %s", self.local_port, - old_state.name, new_state.name) + log.debug("[#%04X] _: server state: %s > %s", + self.local_port, old_state.name, new_state.name) def _get_server_state_manager(self) -> ServerStateManagerBase: return self._server_state_manager + def _on_client_state_change(self, old_state, new_state): + log.debug("[#%04X] _: client state: %s > %s", + self.local_port, old_state.name, new_state.name) + + def _get_client_state_manager(self) -> ClientStateManagerBase: + return self._client_state_manager + @property def is_reset(self): # We can't be sure of the server's state if there are still pending # responses. Unless the last message we sent was RESET. In that case # the server state will always be READY when we're done. - if (self.responses and self.responses[-1] - and self.responses[-1].message == "reset"): - return True - return self._server_state_manager.state == ServerStates.READY + if self.responses: + return self.responses[-1] and self.responses[-1].message == "reset" + return self._server_state_manager.state == BoltStates.READY @property def encrypted(self): @@ -202,6 +213,8 @@ def run(self, query, parameters=None, mode=None, bookmarks=None, extra["mode"] = "r" # It will default to mode "w" if nothing is specified if db: extra["db"] = db + if self._client_state_manager.state != BoltStates.TX_READY_OR_TX_STREAMING: + self.last_database = db if bookmarks: try: extra["bookmarks"] = list(bookmarks) @@ -261,6 +274,7 @@ def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, extra["mode"] = "r" # It will default to mode "w" if nothing is specified if db: extra["db"] = db + self.last_database = db if bookmarks: try: extra["bookmarks"] = list(bookmarks) @@ -347,7 +361,7 @@ def _process_message(self, tag, fields): response.on_ignored(summary_metadata or {}) elif summary_signature == b"\x7F": log.debug("[#%04X] S: FAILURE %r", self.local_port, summary_metadata) - self._server_state_manager.state = ServerStates.FAILED + self._server_state_manager.state = BoltStates.FAILED try: response.on_failure(summary_metadata or {}) except (ServiceUnavailable, DatabaseUnavailable): @@ -357,7 +371,8 @@ def _process_message(self, tag, fields): except (NotALeader, ForbiddenOnReadOnlyDatabase): if self.pool: self.pool.on_write_failure( - address=self.unresolved_address + address=self.unresolved_address, + database=self.last_database ) raise except Neo4jError as e: @@ -535,6 +550,11 @@ def run(self, query, parameters=None, mode=None, bookmarks=None, extra["mode"] = "r" if db: extra["db"] = db + if ( + self._client_state_manager.state + != BoltStates.TX_READY_OR_TX_STREAMING + ): + self.last_database = db if imp_user: extra["imp_user"] = imp_user if bookmarks: @@ -571,6 +591,7 @@ def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, extra["mode"] = "r" if db: extra["db"] = db + self.last_database = db if imp_user: extra["imp_user"] = imp_user if bookmarks: diff --git a/src/neo4j/_sync/io/_bolt5.py b/src/neo4j/_sync/io/_bolt5.py index bea6eabc..73882c17 100644 --- a/src/neo4j/_sync/io/_bolt5.py +++ b/src/neo4j/_sync/io/_bolt5.py @@ -37,12 +37,14 @@ ) from ._bolt import ( Bolt, + ClientStateManagerBase, ServerStateManagerBase, tx_timeout_as_ms, ) from ._bolt3 import ( + BoltStates, + ClientStateManager, ServerStateManager, - ServerStates, ) from ._common import ( check_supported_server_product, @@ -71,31 +73,41 @@ class Bolt5x0(Bolt): supports_notification_filtering = False - server_states: t.Any = ServerStates + bolt_states: t.Any = BoltStates def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._server_state_manager = ServerStateManager( - self.server_states.CONNECTED, + self.bolt_states.CONNECTED, on_change=self._on_server_state_change ) + self._client_state_manager = ClientStateManager( + self.bolt_states.CONNECTED, + on_change=self._on_client_state_change + ) def _on_server_state_change(self, old_state, new_state): - log.debug("[#%04X] _: state: %s > %s", self.local_port, - old_state.name, new_state.name) + log.debug("[#%04X] _: server state: %s > %s", + self.local_port, old_state.name, new_state.name) def _get_server_state_manager(self) -> ServerStateManagerBase: return self._server_state_manager + def _on_client_state_change(self, old_state, new_state): + log.debug("[#%04X] _: client state: %s > %s", + self.local_port, old_state.name, new_state.name) + + def _get_client_state_manager(self) -> ClientStateManagerBase: + return self._client_state_manager + @property def is_reset(self): # We can't be sure of the server's state if there are still pending # responses. Unless the last message we sent was RESET. In that case # the server state will always be READY when we're done. - if (self.responses and self.responses[-1] - and self.responses[-1].message == "reset"): - return True - return self._server_state_manager.state == self.server_states.READY + if self.responses: + return self.responses[-1] and self.responses[-1].message == "reset" + return self._server_state_manager.state == self.bolt_states.READY @property def encrypted(self): @@ -197,6 +209,11 @@ def run(self, query, parameters=None, mode=None, bookmarks=None, extra["mode"] = "r" if db: extra["db"] = db + if ( + self._client_state_manager.state + != self.bolt_states.TX_READY_OR_TX_STREAMING + ): + self.last_database = db if imp_user: extra["imp_user"] = imp_user if bookmarks: @@ -253,6 +270,7 @@ def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, extra["mode"] = "r" if db: extra["db"] = db + self.last_database = db if imp_user: extra["imp_user"] = imp_user if bookmarks: @@ -347,7 +365,7 @@ def _process_message(self, tag, fields): elif summary_signature == b"\x7F": log.debug("[#%04X] S: FAILURE %r", self.local_port, summary_metadata) - self._server_state_manager.state = self.server_states.FAILED + self._server_state_manager.state = self.bolt_states.FAILED try: response.on_failure(summary_metadata or {}) except (ServiceUnavailable, DatabaseUnavailable): @@ -357,7 +375,8 @@ def _process_message(self, tag, fields): except (NotALeader, ForbiddenOnReadOnlyDatabase): if self.pool: self.pool.on_write_failure( - address=self.unresolved_address + address=self.unresolved_address, + database=self.last_database ) raise except Neo4jError as e: @@ -374,7 +393,7 @@ def _process_message(self, tag, fields): return len(details), 1 -class ServerStates5x1(Enum): +class BoltStates5x1(Enum): CONNECTED = "CONNECTED" READY = "READY" STREAMING = "STREAMING" @@ -385,35 +404,60 @@ class ServerStates5x1(Enum): class ServerStateManager5x1(ServerStateManager): _STATE_TRANSITIONS = { # type: ignore - ServerStates5x1.CONNECTED: { - "hello": ServerStates5x1.AUTHENTICATION, + BoltStates5x1.CONNECTED: { + "hello": BoltStates5x1.AUTHENTICATION, }, - ServerStates5x1.AUTHENTICATION: { - "logon": ServerStates5x1.READY, + BoltStates5x1.AUTHENTICATION: { + "logon": BoltStates5x1.READY, }, - ServerStates5x1.READY: { - "run": ServerStates5x1.STREAMING, - "begin": ServerStates5x1.TX_READY_OR_TX_STREAMING, - "logoff": ServerStates5x1.AUTHENTICATION, + BoltStates5x1.READY: { + "run": BoltStates5x1.STREAMING, + "begin": BoltStates5x1.TX_READY_OR_TX_STREAMING, + "logoff": BoltStates5x1.AUTHENTICATION, }, - ServerStates5x1.STREAMING: { - "pull": ServerStates5x1.READY, - "discard": ServerStates5x1.READY, - "reset": ServerStates5x1.READY, + BoltStates5x1.STREAMING: { + "pull": BoltStates5x1.READY, + "discard": BoltStates5x1.READY, + "reset": BoltStates5x1.READY, }, - ServerStates5x1.TX_READY_OR_TX_STREAMING: { - "commit": ServerStates5x1.READY, - "rollback": ServerStates5x1.READY, - "reset": ServerStates5x1.READY, + BoltStates5x1.TX_READY_OR_TX_STREAMING: { + "commit": BoltStates5x1.READY, + "rollback": BoltStates5x1.READY, + "reset": BoltStates5x1.READY, }, - ServerStates5x1.FAILED: { - "reset": ServerStates5x1.READY, + BoltStates5x1.FAILED: { + "reset": BoltStates5x1.READY, } } - def failed(self): - return self.state == ServerStates5x1.FAILED + return self.state == BoltStates5x1.FAILED + + +class ClientStateManager5x1(ClientStateManager): + _STATE_TRANSITIONS = { # type: ignore + BoltStates5x1.CONNECTED: { + "hello": BoltStates5x1.AUTHENTICATION, + }, + BoltStates5x1.AUTHENTICATION: { + "logon": BoltStates5x1.READY, + }, + BoltStates5x1.READY: { + "run": BoltStates5x1.STREAMING, + "begin": BoltStates5x1.TX_READY_OR_TX_STREAMING, + "logoff": BoltStates5x1.AUTHENTICATION, + }, + BoltStates5x1.STREAMING: { + "begin": BoltStates5x1.TX_READY_OR_TX_STREAMING, + "logoff": BoltStates5x1.AUTHENTICATION, + "reset": BoltStates5x1.READY, + }, + BoltStates5x1.TX_READY_OR_TX_STREAMING: { + "commit": BoltStates5x1.READY, + "rollback": BoltStates5x1.READY, + "reset": BoltStates5x1.READY, + }, + } class Bolt5x1(Bolt5x0): @@ -423,12 +467,15 @@ class Bolt5x1(Bolt5x0): supports_re_auth = True - server_states = ServerStates5x1 + bolt_states = BoltStates5x1 def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._server_state_manager = ServerStateManager5x1( - ServerStates5x1.CONNECTED, on_change=self._on_server_state_change + BoltStates5x1.CONNECTED, on_change=self._on_server_state_change + ) + self._client_state_manager = ClientStateManager5x1( + BoltStates5x1.CONNECTED, on_change=self._on_client_state_change ) def hello(self, dehydration_hooks=None, hydration_hooks=None): @@ -541,6 +588,11 @@ def run(self, query, parameters=None, mode=None, bookmarks=None, extra["mode"] = "r" if db: extra["db"] = db + if ( + self._client_state_manager.state + != self.bolt_states.TX_READY_OR_TX_STREAMING + ): + self.last_database = db if imp_user: extra["imp_user"] = imp_user if notifications_min_severity is not None: @@ -578,6 +630,7 @@ def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, extra["mode"] = "r" if db: extra["db"] = db + self.last_database = db if imp_user: extra["imp_user"] = imp_user if bookmarks: diff --git a/src/neo4j/_sync/io/_pool.py b/src/neo4j/_sync/io/_pool.py index 695edd94..7f24a8d6 100644 --- a/src/neo4j/_sync/io/_pool.py +++ b/src/neo4j/_sync/io/_pool.py @@ -444,7 +444,7 @@ def deactivate(self, address): self._close_connections(closable_connections) - def on_write_failure(self, address): + def on_write_failure(self, address, database): raise WriteServiceUnavailable( "No write service available for pool {}".format(self) ) @@ -953,12 +953,13 @@ def deactivate(self, address): log.debug("[#0000] _: table=%r", self.routing_tables) super(Neo4jPool, self).deactivate(address) - def on_write_failure(self, address): + def on_write_failure(self, address, database): """ Remove a writer address from the routing table, if present. """ - # FIXME: only need to remove the writer for a specific database - log.debug("[#0000] _: removing writer %r", address) + log.debug("[#0000] _: removing writer %r for database %r", + address, database) with self.refresh_lock: - for database in self.routing_tables.keys(): + table = self.routing_tables.get(database) + if table is not None: self.routing_tables[database].writers.discard(address) log.debug("[#0000] _: table=%r", self.routing_tables) diff --git a/tests/unit/async_/io/conftest.py b/tests/unit/async_/io/conftest.py index d697d182..314e69d3 100644 --- a/tests/unit/async_/io/conftest.py +++ b/tests/unit/async_/io/conftest.py @@ -102,7 +102,7 @@ async def sendall(self, data): if callable(self.on_send): self.on_send(data) - def close(self): + async def close(self): return def kill(self): diff --git a/tests/unit/async_/io/test_class_bolt3.py b/tests/unit/async_/io/test_class_bolt3.py index 2ae6e840..635d9e1e 100644 --- a/tests/unit/async_/io/test_class_bolt3.py +++ b/tests/unit/async_/io/test_class_bolt3.py @@ -14,10 +14,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - +import contextlib +import itertools import logging -from itertools import permutations import pytest @@ -186,7 +185,7 @@ async def test_re_auth_noop(auth, fake_socket, mocker): @pytest.mark.parametrize( ("auth1", "auth2"), - permutations( + itertools.permutations( ( None, neo4j.Auth("scheme", "principal", "credentials", "realm"), @@ -359,3 +358,78 @@ async def test_tx_timeout( assert "tx_timeout" not in extra else: assert extra["tx_timeout"] == res + + +@pytest.mark.parametrize( + "actions", + itertools.combinations_with_replacement( + itertools.product( + ("run", "begin", "begin_run"), + ("reset", "commit", "rollback"), + (None, "some_db", "another_db"), + ), + 2 + ) +) +@mark_async_test +async def test_tracks_last_database(fake_socket_pair, actions): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=AsyncBolt3.PACKER_CLS, + unpacker_cls=AsyncBolt3.UNPACKER_CLS) + connection = AsyncBolt3(address, sockets.client, 0) + await sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + await connection.hello() + assert connection.last_database is None + for action, finish, db in actions: + await sockets.server.send_message(b"\x70", {}) + if action == "run": + with raises_if_db(db): + connection.run("RETURN 1", db=db) + elif action == "begin": + with raises_if_db(db): + connection.begin(db=db) + elif action == "begin_run": + with raises_if_db(db): + connection.begin(db=db) + assert connection.last_database is None + await sockets.server.send_message(b"\x70", {}) + connection.run("RETURN 1") + else: + raise ValueError(action) + + assert connection.last_database is None + await connection.send_all() + await connection.fetch_all() + assert connection.last_database is None + + await sockets.server.send_message(b"\x70", {}) + if finish == "reset": + await connection.reset() + elif finish == "commit": + if action == "run": + connection.pull() + else: + connection.commit() + elif finish == "rollback": + if action == "run": + connection.pull() + else: + connection.rollback() + else: + raise ValueError(finish) + + await connection.send_all() + await connection.fetch_all() + + assert connection.last_database is None + + +@contextlib.contextmanager +def raises_if_db(db): + if db is None: + yield + else: + with pytest.raises(ConfigurationError, + match="selecting database is not supported"): + yield diff --git a/tests/unit/async_/io/test_class_bolt4x0.py b/tests/unit/async_/io/test_class_bolt4x0.py index 979f4fb2..b667ba81 100644 --- a/tests/unit/async_/io/test_class_bolt4x0.py +++ b/tests/unit/async_/io/test_class_bolt4x0.py @@ -16,8 +16,9 @@ # limitations under the License. +import contextlib +import itertools import logging -from itertools import permutations import pytest @@ -282,7 +283,7 @@ async def test_re_auth_noop(auth, fake_socket, mocker): @pytest.mark.parametrize( ("auth1", "auth2"), - permutations( + itertools.permutations( ( None, neo4j.Auth("scheme", "principal", "credentials", "realm"), @@ -455,3 +456,65 @@ async def test_tx_timeout( assert "tx_timeout" not in extra else: assert extra["tx_timeout"] == res + + +@pytest.mark.parametrize( + "actions", + itertools.combinations_with_replacement( + itertools.product( + ("run", "begin", "begin_run"), + ("reset", "commit", "rollback"), + (None, "some_db", "another_db"), + ), + 2 + ) +) +@mark_async_test +async def test_tracks_last_database(fake_socket_pair, actions): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=AsyncBolt4x0.PACKER_CLS, + unpacker_cls=AsyncBolt4x0.UNPACKER_CLS) + connection = AsyncBolt4x0(address, sockets.client, 0) + await sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + await connection.hello() + assert connection.last_database is None + for action, finish, db in actions: + await sockets.server.send_message(b"\x70", {}) + if action == "run": + connection.run("RETURN 1", db=db) + elif action == "begin": + connection.begin(db=db) + elif action == "begin_run": + connection.begin(db=db) + assert connection.last_database == db + await sockets.server.send_message(b"\x70", {}) + connection.run("RETURN 1") + else: + raise ValueError(action) + + assert connection.last_database == db + await connection.send_all() + await connection.fetch_all() + assert connection.last_database == db + + await sockets.server.send_message(b"\x70", {}) + if finish == "reset": + await connection.reset() + elif finish == "commit": + if action == "run": + connection.pull() + else: + connection.commit() + elif finish == "rollback": + if action == "run": + connection.pull() + else: + connection.rollback() + else: + raise ValueError(finish) + + await connection.send_all() + await connection.fetch_all() + + assert connection.last_database == db diff --git a/tests/unit/async_/io/test_class_bolt4x1.py b/tests/unit/async_/io/test_class_bolt4x1.py index 4af0bddb..a2858046 100644 --- a/tests/unit/async_/io/test_class_bolt4x1.py +++ b/tests/unit/async_/io/test_class_bolt4x1.py @@ -16,8 +16,8 @@ # limitations under the License. +import itertools import logging -from itertools import permutations import pytest @@ -299,7 +299,7 @@ async def test_re_auth_noop(auth, fake_socket, mocker): @pytest.mark.parametrize( ("auth1", "auth2"), - permutations( + itertools.permutations( ( None, neo4j.Auth("scheme", "principal", "credentials", "realm"), @@ -472,3 +472,65 @@ async def test_tx_timeout( assert "tx_timeout" not in extra else: assert extra["tx_timeout"] == res + + +@pytest.mark.parametrize( + "actions", + itertools.combinations_with_replacement( + itertools.product( + ("run", "begin", "begin_run"), + ("reset", "commit", "rollback"), + (None, "some_db", "another_db"), + ), + 2 + ) +) +@mark_async_test +async def test_tracks_last_database(fake_socket_pair, actions): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=AsyncBolt4x1.PACKER_CLS, + unpacker_cls=AsyncBolt4x1.UNPACKER_CLS) + connection = AsyncBolt4x1(address, sockets.client, 0) + await sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + await connection.hello() + assert connection.last_database is None + for action, finish, db in actions: + await sockets.server.send_message(b"\x70", {}) + if action == "run": + connection.run("RETURN 1", db=db) + elif action == "begin": + connection.begin(db=db) + elif action == "begin_run": + connection.begin(db=db) + assert connection.last_database == db + await sockets.server.send_message(b"\x70", {}) + connection.run("RETURN 1") + else: + raise ValueError(action) + + assert connection.last_database == db + await connection.send_all() + await connection.fetch_all() + assert connection.last_database == db + + await sockets.server.send_message(b"\x70", {}) + if finish == "reset": + await connection.reset() + elif finish == "commit": + if action == "run": + connection.pull() + else: + connection.commit() + elif finish == "rollback": + if action == "run": + connection.pull() + else: + connection.rollback() + else: + raise ValueError(finish) + + await connection.send_all() + await connection.fetch_all() + + assert connection.last_database == db diff --git a/tests/unit/async_/io/test_class_bolt4x2.py b/tests/unit/async_/io/test_class_bolt4x2.py index 6256edd5..41d732f2 100644 --- a/tests/unit/async_/io/test_class_bolt4x2.py +++ b/tests/unit/async_/io/test_class_bolt4x2.py @@ -16,8 +16,8 @@ # limitations under the License. +import itertools import logging -from itertools import permutations import pytest @@ -300,7 +300,7 @@ async def test_re_auth_noop(auth, fake_socket, mocker): @pytest.mark.parametrize( ("auth1", "auth2"), - permutations( + itertools.permutations( ( None, neo4j.Auth("scheme", "principal", "credentials", "realm"), @@ -473,3 +473,65 @@ async def test_tx_timeout( assert "tx_timeout" not in extra else: assert extra["tx_timeout"] == res + + +@pytest.mark.parametrize( + "actions", + itertools.combinations_with_replacement( + itertools.product( + ("run", "begin", "begin_run"), + ("reset", "commit", "rollback"), + (None, "some_db", "another_db"), + ), + 2 + ) +) +@mark_async_test +async def test_tracks_last_database(fake_socket_pair, actions): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=AsyncBolt4x2.PACKER_CLS, + unpacker_cls=AsyncBolt4x2.UNPACKER_CLS) + connection = AsyncBolt4x2(address, sockets.client, 0) + await sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + await connection.hello() + assert connection.last_database is None + for action, finish, db in actions: + await sockets.server.send_message(b"\x70", {}) + if action == "run": + connection.run("RETURN 1", db=db) + elif action == "begin": + connection.begin(db=db) + elif action == "begin_run": + connection.begin(db=db) + assert connection.last_database == db + await sockets.server.send_message(b"\x70", {}) + connection.run("RETURN 1") + else: + raise ValueError(action) + + assert connection.last_database == db + await connection.send_all() + await connection.fetch_all() + assert connection.last_database == db + + await sockets.server.send_message(b"\x70", {}) + if finish == "reset": + await connection.reset() + elif finish == "commit": + if action == "run": + connection.pull() + else: + connection.commit() + elif finish == "rollback": + if action == "run": + connection.pull() + else: + connection.rollback() + else: + raise ValueError(finish) + + await connection.send_all() + await connection.fetch_all() + + assert connection.last_database == db diff --git a/tests/unit/async_/io/test_class_bolt4x3.py b/tests/unit/async_/io/test_class_bolt4x3.py index 70a2e0d5..a05656d5 100644 --- a/tests/unit/async_/io/test_class_bolt4x3.py +++ b/tests/unit/async_/io/test_class_bolt4x3.py @@ -16,8 +16,8 @@ # limitations under the License. +import itertools import logging -from itertools import permutations import pytest @@ -327,7 +327,7 @@ async def test_re_auth_noop(auth, fake_socket, mocker): @pytest.mark.parametrize( ("auth1", "auth2"), - permutations( + itertools.permutations( ( None, neo4j.Auth("scheme", "principal", "credentials", "realm"), @@ -500,3 +500,65 @@ async def test_tx_timeout( assert "tx_timeout" not in extra else: assert extra["tx_timeout"] == res + + +@pytest.mark.parametrize( + "actions", + itertools.combinations_with_replacement( + itertools.product( + ("run", "begin", "begin_run"), + ("reset", "commit", "rollback"), + (None, "some_db", "another_db"), + ), + 2 + ) +) +@mark_async_test +async def test_tracks_last_database(fake_socket_pair, actions): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=AsyncBolt4x3.PACKER_CLS, + unpacker_cls=AsyncBolt4x3.UNPACKER_CLS) + connection = AsyncBolt4x3(address, sockets.client, 0) + await sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + await connection.hello() + assert connection.last_database is None + for action, finish, db in actions: + await sockets.server.send_message(b"\x70", {}) + if action == "run": + connection.run("RETURN 1", db=db) + elif action == "begin": + connection.begin(db=db) + elif action == "begin_run": + connection.begin(db=db) + assert connection.last_database == db + await sockets.server.send_message(b"\x70", {}) + connection.run("RETURN 1") + else: + raise ValueError(action) + + assert connection.last_database == db + await connection.send_all() + await connection.fetch_all() + assert connection.last_database == db + + await sockets.server.send_message(b"\x70", {}) + if finish == "reset": + await connection.reset() + elif finish == "commit": + if action == "run": + connection.pull() + else: + connection.commit() + elif finish == "rollback": + if action == "run": + connection.pull() + else: + connection.rollback() + else: + raise ValueError(finish) + + await connection.send_all() + await connection.fetch_all() + + assert connection.last_database == db diff --git a/tests/unit/async_/io/test_class_bolt4x4.py b/tests/unit/async_/io/test_class_bolt4x4.py index baa85260..9923c504 100644 --- a/tests/unit/async_/io/test_class_bolt4x4.py +++ b/tests/unit/async_/io/test_class_bolt4x4.py @@ -16,8 +16,8 @@ # limitations under the License. +import itertools import logging -from itertools import permutations import pytest @@ -340,7 +340,7 @@ async def test_re_auth_noop(auth, fake_socket, mocker): @pytest.mark.parametrize( ("auth1", "auth2"), - permutations( + itertools.permutations( ( None, neo4j.Auth("scheme", "principal", "credentials", "realm"), diff --git a/tests/unit/async_/io/test_class_bolt5x0.py b/tests/unit/async_/io/test_class_bolt5x0.py index 8b2b9f0a..c6f4d329 100644 --- a/tests/unit/async_/io/test_class_bolt5x0.py +++ b/tests/unit/async_/io/test_class_bolt5x0.py @@ -16,8 +16,8 @@ # limitations under the License. +import itertools import logging -from itertools import permutations import pytest @@ -340,7 +340,7 @@ async def test_re_auth_noop(auth, fake_socket, mocker): @pytest.mark.parametrize( ("auth1", "auth2"), - permutations( + itertools.permutations( ( None, neo4j.Auth("scheme", "principal", "credentials", "realm"), @@ -513,3 +513,65 @@ async def test_tx_timeout( assert "tx_timeout" not in extra else: assert extra["tx_timeout"] == res + + +@pytest.mark.parametrize( + "actions", + itertools.combinations_with_replacement( + itertools.product( + ("run", "begin", "begin_run"), + ("reset", "commit", "rollback"), + (None, "some_db", "another_db"), + ), + 2 + ) +) +@mark_async_test +async def test_tracks_last_database(fake_socket_pair, actions): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=AsyncBolt5x0.PACKER_CLS, + unpacker_cls=AsyncBolt5x0.UNPACKER_CLS) + connection = AsyncBolt5x0(address, sockets.client, 0) + await sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + await connection.hello() + assert connection.last_database is None + for action, finish, db in actions: + await sockets.server.send_message(b"\x70", {}) + if action == "run": + connection.run("RETURN 1", db=db) + elif action == "begin": + connection.begin(db=db) + elif action == "begin_run": + connection.begin(db=db) + assert connection.last_database == db + await sockets.server.send_message(b"\x70", {}) + connection.run("RETURN 1") + else: + raise ValueError(action) + + assert connection.last_database == db + await connection.send_all() + await connection.fetch_all() + assert connection.last_database == db + + await sockets.server.send_message(b"\x70", {}) + if finish == "reset": + await connection.reset() + elif finish == "commit": + if action == "run": + connection.pull() + else: + connection.commit() + elif finish == "rollback": + if action == "run": + connection.pull() + else: + connection.rollback() + else: + raise ValueError(finish) + + await connection.send_all() + await connection.fetch_all() + + assert connection.last_database == db diff --git a/tests/unit/async_/io/test_class_bolt5x1.py b/tests/unit/async_/io/test_class_bolt5x1.py index 8897ce70..2ee26b13 100644 --- a/tests/unit/async_/io/test_class_bolt5x1.py +++ b/tests/unit/async_/io/test_class_bolt5x1.py @@ -16,6 +16,7 @@ # limitations under the License. +import itertools import logging import pytest @@ -567,3 +568,66 @@ async def test_tx_timeout( assert "tx_timeout" not in extra else: assert extra["tx_timeout"] == res + + +@pytest.mark.parametrize( + "actions", + itertools.combinations_with_replacement( + itertools.product( + ("run", "begin", "begin_run"), + ("reset", "commit", "rollback"), + (None, "some_db", "another_db"), + ), + 2 + ) +) +@mark_async_test +async def test_tracks_last_database(fake_socket_pair, actions): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=AsyncBolt5x1.PACKER_CLS, + unpacker_cls=AsyncBolt5x1.UNPACKER_CLS) + connection = AsyncBolt5x1(address, sockets.client, 0) + await sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + await sockets.server.send_message(b"\x70", {}) + await connection.hello() + assert connection.last_database is None + for action, finish, db in actions: + await sockets.server.send_message(b"\x70", {}) + if action == "run": + connection.run("RETURN 1", db=db) + elif action == "begin": + connection.begin(db=db) + elif action == "begin_run": + connection.begin(db=db) + assert connection.last_database == db + await sockets.server.send_message(b"\x70", {}) + connection.run("RETURN 1") + else: + raise ValueError(action) + + assert connection.last_database == db + await connection.send_all() + await connection.fetch_all() + assert connection.last_database == db + + await sockets.server.send_message(b"\x70", {}) + if finish == "reset": + await connection.reset() + elif finish == "commit": + if action == "run": + connection.pull() + else: + connection.commit() + elif finish == "rollback": + if action == "run": + connection.pull() + else: + connection.rollback() + else: + raise ValueError(finish) + + await connection.send_all() + await connection.fetch_all() + + assert connection.last_database == db diff --git a/tests/unit/async_/io/test_class_bolt5x2.py b/tests/unit/async_/io/test_class_bolt5x2.py index 0de7e387..429b0268 100644 --- a/tests/unit/async_/io/test_class_bolt5x2.py +++ b/tests/unit/async_/io/test_class_bolt5x2.py @@ -585,3 +585,66 @@ async def test_tx_timeout( assert "tx_timeout" not in extra else: assert extra["tx_timeout"] == res + + +@pytest.mark.parametrize( + "actions", + itertools.combinations_with_replacement( + itertools.product( + ("run", "begin", "begin_run"), + ("reset", "commit", "rollback"), + (None, "some_db", "another_db"), + ), + 2 + ) +) +@mark_async_test +async def test_tracks_last_database(fake_socket_pair, actions): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=AsyncBolt5x2.PACKER_CLS, + unpacker_cls=AsyncBolt5x2.UNPACKER_CLS) + connection = AsyncBolt5x2(address, sockets.client, 0) + await sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + await sockets.server.send_message(b"\x70", {}) + await connection.hello() + assert connection.last_database is None + for action, finish, db in actions: + await sockets.server.send_message(b"\x70", {}) + if action == "run": + connection.run("RETURN 1", db=db) + elif action == "begin": + connection.begin(db=db) + elif action == "begin_run": + connection.begin(db=db) + assert connection.last_database == db + await sockets.server.send_message(b"\x70", {}) + connection.run("RETURN 1") + else: + raise ValueError(action) + + assert connection.last_database == db + await connection.send_all() + await connection.fetch_all() + assert connection.last_database == db + + await sockets.server.send_message(b"\x70", {}) + if finish == "reset": + await connection.reset() + elif finish == "commit": + if action == "run": + connection.pull() + else: + connection.commit() + elif finish == "rollback": + if action == "run": + connection.pull() + else: + connection.rollback() + else: + raise ValueError(finish) + + await connection.send_all() + await connection.fetch_all() + + assert connection.last_database == db diff --git a/tests/unit/async_/io/test_class_bolt5x3.py b/tests/unit/async_/io/test_class_bolt5x3.py index 0247bad1..a11c2ccc 100644 --- a/tests/unit/async_/io/test_class_bolt5x3.py +++ b/tests/unit/async_/io/test_class_bolt5x3.py @@ -496,3 +496,66 @@ async def test_tx_timeout( assert "tx_timeout" not in extra else: assert extra["tx_timeout"] == res + + +@pytest.mark.parametrize( + "actions", + itertools.combinations_with_replacement( + itertools.product( + ("run", "begin", "begin_run"), + ("reset", "commit", "rollback"), + (None, "some_db", "another_db"), + ), + 2 + ) +) +@mark_async_test +async def test_tracks_last_database(fake_socket_pair, actions): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=AsyncBolt5x3.PACKER_CLS, + unpacker_cls=AsyncBolt5x3.UNPACKER_CLS) + connection = AsyncBolt5x3(address, sockets.client, 0) + await sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + await sockets.server.send_message(b"\x70", {}) + await connection.hello() + assert connection.last_database is None + for action, finish, db in actions: + await sockets.server.send_message(b"\x70", {}) + if action == "run": + connection.run("RETURN 1", db=db) + elif action == "begin": + connection.begin(db=db) + elif action == "begin_run": + connection.begin(db=db) + assert connection.last_database == db + await sockets.server.send_message(b"\x70", {}) + connection.run("RETURN 1") + else: + raise ValueError(action) + + assert connection.last_database == db + await connection.send_all() + await connection.fetch_all() + assert connection.last_database == db + + await sockets.server.send_message(b"\x70", {}) + if finish == "reset": + await connection.reset() + elif finish == "commit": + if action == "run": + connection.pull() + else: + connection.commit() + elif finish == "rollback": + if action == "run": + connection.pull() + else: + connection.rollback() + else: + raise ValueError(finish) + + await connection.send_all() + await connection.fetch_all() + + assert connection.last_database == db diff --git a/tests/unit/sync/io/test_class_bolt3.py b/tests/unit/sync/io/test_class_bolt3.py index 1527bd86..0ebdfb14 100644 --- a/tests/unit/sync/io/test_class_bolt3.py +++ b/tests/unit/sync/io/test_class_bolt3.py @@ -14,10 +14,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - +import contextlib +import itertools import logging -from itertools import permutations import pytest @@ -186,7 +185,7 @@ def test_re_auth_noop(auth, fake_socket, mocker): @pytest.mark.parametrize( ("auth1", "auth2"), - permutations( + itertools.permutations( ( None, neo4j.Auth("scheme", "principal", "credentials", "realm"), @@ -359,3 +358,78 @@ def test_tx_timeout( assert "tx_timeout" not in extra else: assert extra["tx_timeout"] == res + + +@pytest.mark.parametrize( + "actions", + itertools.combinations_with_replacement( + itertools.product( + ("run", "begin", "begin_run"), + ("reset", "commit", "rollback"), + (None, "some_db", "another_db"), + ), + 2 + ) +) +@mark_sync_test +def test_tracks_last_database(fake_socket_pair, actions): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=Bolt3.PACKER_CLS, + unpacker_cls=Bolt3.UNPACKER_CLS) + connection = Bolt3(address, sockets.client, 0) + sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + connection.hello() + assert connection.last_database is None + for action, finish, db in actions: + sockets.server.send_message(b"\x70", {}) + if action == "run": + with raises_if_db(db): + connection.run("RETURN 1", db=db) + elif action == "begin": + with raises_if_db(db): + connection.begin(db=db) + elif action == "begin_run": + with raises_if_db(db): + connection.begin(db=db) + assert connection.last_database is None + sockets.server.send_message(b"\x70", {}) + connection.run("RETURN 1") + else: + raise ValueError(action) + + assert connection.last_database is None + connection.send_all() + connection.fetch_all() + assert connection.last_database is None + + sockets.server.send_message(b"\x70", {}) + if finish == "reset": + connection.reset() + elif finish == "commit": + if action == "run": + connection.pull() + else: + connection.commit() + elif finish == "rollback": + if action == "run": + connection.pull() + else: + connection.rollback() + else: + raise ValueError(finish) + + connection.send_all() + connection.fetch_all() + + assert connection.last_database is None + + +@contextlib.contextmanager +def raises_if_db(db): + if db is None: + yield + else: + with pytest.raises(ConfigurationError, + match="selecting database is not supported"): + yield diff --git a/tests/unit/sync/io/test_class_bolt4x0.py b/tests/unit/sync/io/test_class_bolt4x0.py index 90af5f07..497723cd 100644 --- a/tests/unit/sync/io/test_class_bolt4x0.py +++ b/tests/unit/sync/io/test_class_bolt4x0.py @@ -16,8 +16,9 @@ # limitations under the License. +import contextlib +import itertools import logging -from itertools import permutations import pytest @@ -282,7 +283,7 @@ def test_re_auth_noop(auth, fake_socket, mocker): @pytest.mark.parametrize( ("auth1", "auth2"), - permutations( + itertools.permutations( ( None, neo4j.Auth("scheme", "principal", "credentials", "realm"), @@ -455,3 +456,65 @@ def test_tx_timeout( assert "tx_timeout" not in extra else: assert extra["tx_timeout"] == res + + +@pytest.mark.parametrize( + "actions", + itertools.combinations_with_replacement( + itertools.product( + ("run", "begin", "begin_run"), + ("reset", "commit", "rollback"), + (None, "some_db", "another_db"), + ), + 2 + ) +) +@mark_sync_test +def test_tracks_last_database(fake_socket_pair, actions): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=Bolt4x0.PACKER_CLS, + unpacker_cls=Bolt4x0.UNPACKER_CLS) + connection = Bolt4x0(address, sockets.client, 0) + sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + connection.hello() + assert connection.last_database is None + for action, finish, db in actions: + sockets.server.send_message(b"\x70", {}) + if action == "run": + connection.run("RETURN 1", db=db) + elif action == "begin": + connection.begin(db=db) + elif action == "begin_run": + connection.begin(db=db) + assert connection.last_database == db + sockets.server.send_message(b"\x70", {}) + connection.run("RETURN 1") + else: + raise ValueError(action) + + assert connection.last_database == db + connection.send_all() + connection.fetch_all() + assert connection.last_database == db + + sockets.server.send_message(b"\x70", {}) + if finish == "reset": + connection.reset() + elif finish == "commit": + if action == "run": + connection.pull() + else: + connection.commit() + elif finish == "rollback": + if action == "run": + connection.pull() + else: + connection.rollback() + else: + raise ValueError(finish) + + connection.send_all() + connection.fetch_all() + + assert connection.last_database == db diff --git a/tests/unit/sync/io/test_class_bolt4x1.py b/tests/unit/sync/io/test_class_bolt4x1.py index 8df441f9..a6f4adfb 100644 --- a/tests/unit/sync/io/test_class_bolt4x1.py +++ b/tests/unit/sync/io/test_class_bolt4x1.py @@ -16,8 +16,8 @@ # limitations under the License. +import itertools import logging -from itertools import permutations import pytest @@ -299,7 +299,7 @@ def test_re_auth_noop(auth, fake_socket, mocker): @pytest.mark.parametrize( ("auth1", "auth2"), - permutations( + itertools.permutations( ( None, neo4j.Auth("scheme", "principal", "credentials", "realm"), @@ -472,3 +472,65 @@ def test_tx_timeout( assert "tx_timeout" not in extra else: assert extra["tx_timeout"] == res + + +@pytest.mark.parametrize( + "actions", + itertools.combinations_with_replacement( + itertools.product( + ("run", "begin", "begin_run"), + ("reset", "commit", "rollback"), + (None, "some_db", "another_db"), + ), + 2 + ) +) +@mark_sync_test +def test_tracks_last_database(fake_socket_pair, actions): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=Bolt4x1.PACKER_CLS, + unpacker_cls=Bolt4x1.UNPACKER_CLS) + connection = Bolt4x1(address, sockets.client, 0) + sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + connection.hello() + assert connection.last_database is None + for action, finish, db in actions: + sockets.server.send_message(b"\x70", {}) + if action == "run": + connection.run("RETURN 1", db=db) + elif action == "begin": + connection.begin(db=db) + elif action == "begin_run": + connection.begin(db=db) + assert connection.last_database == db + sockets.server.send_message(b"\x70", {}) + connection.run("RETURN 1") + else: + raise ValueError(action) + + assert connection.last_database == db + connection.send_all() + connection.fetch_all() + assert connection.last_database == db + + sockets.server.send_message(b"\x70", {}) + if finish == "reset": + connection.reset() + elif finish == "commit": + if action == "run": + connection.pull() + else: + connection.commit() + elif finish == "rollback": + if action == "run": + connection.pull() + else: + connection.rollback() + else: + raise ValueError(finish) + + connection.send_all() + connection.fetch_all() + + assert connection.last_database == db diff --git a/tests/unit/sync/io/test_class_bolt4x2.py b/tests/unit/sync/io/test_class_bolt4x2.py index 2a6ceee9..820fb733 100644 --- a/tests/unit/sync/io/test_class_bolt4x2.py +++ b/tests/unit/sync/io/test_class_bolt4x2.py @@ -16,8 +16,8 @@ # limitations under the License. +import itertools import logging -from itertools import permutations import pytest @@ -300,7 +300,7 @@ def test_re_auth_noop(auth, fake_socket, mocker): @pytest.mark.parametrize( ("auth1", "auth2"), - permutations( + itertools.permutations( ( None, neo4j.Auth("scheme", "principal", "credentials", "realm"), @@ -473,3 +473,65 @@ def test_tx_timeout( assert "tx_timeout" not in extra else: assert extra["tx_timeout"] == res + + +@pytest.mark.parametrize( + "actions", + itertools.combinations_with_replacement( + itertools.product( + ("run", "begin", "begin_run"), + ("reset", "commit", "rollback"), + (None, "some_db", "another_db"), + ), + 2 + ) +) +@mark_sync_test +def test_tracks_last_database(fake_socket_pair, actions): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=Bolt4x2.PACKER_CLS, + unpacker_cls=Bolt4x2.UNPACKER_CLS) + connection = Bolt4x2(address, sockets.client, 0) + sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + connection.hello() + assert connection.last_database is None + for action, finish, db in actions: + sockets.server.send_message(b"\x70", {}) + if action == "run": + connection.run("RETURN 1", db=db) + elif action == "begin": + connection.begin(db=db) + elif action == "begin_run": + connection.begin(db=db) + assert connection.last_database == db + sockets.server.send_message(b"\x70", {}) + connection.run("RETURN 1") + else: + raise ValueError(action) + + assert connection.last_database == db + connection.send_all() + connection.fetch_all() + assert connection.last_database == db + + sockets.server.send_message(b"\x70", {}) + if finish == "reset": + connection.reset() + elif finish == "commit": + if action == "run": + connection.pull() + else: + connection.commit() + elif finish == "rollback": + if action == "run": + connection.pull() + else: + connection.rollback() + else: + raise ValueError(finish) + + connection.send_all() + connection.fetch_all() + + assert connection.last_database == db diff --git a/tests/unit/sync/io/test_class_bolt4x3.py b/tests/unit/sync/io/test_class_bolt4x3.py index 58d9b01a..02cf9139 100644 --- a/tests/unit/sync/io/test_class_bolt4x3.py +++ b/tests/unit/sync/io/test_class_bolt4x3.py @@ -16,8 +16,8 @@ # limitations under the License. +import itertools import logging -from itertools import permutations import pytest @@ -327,7 +327,7 @@ def test_re_auth_noop(auth, fake_socket, mocker): @pytest.mark.parametrize( ("auth1", "auth2"), - permutations( + itertools.permutations( ( None, neo4j.Auth("scheme", "principal", "credentials", "realm"), @@ -500,3 +500,65 @@ def test_tx_timeout( assert "tx_timeout" not in extra else: assert extra["tx_timeout"] == res + + +@pytest.mark.parametrize( + "actions", + itertools.combinations_with_replacement( + itertools.product( + ("run", "begin", "begin_run"), + ("reset", "commit", "rollback"), + (None, "some_db", "another_db"), + ), + 2 + ) +) +@mark_sync_test +def test_tracks_last_database(fake_socket_pair, actions): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=Bolt4x3.PACKER_CLS, + unpacker_cls=Bolt4x3.UNPACKER_CLS) + connection = Bolt4x3(address, sockets.client, 0) + sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + connection.hello() + assert connection.last_database is None + for action, finish, db in actions: + sockets.server.send_message(b"\x70", {}) + if action == "run": + connection.run("RETURN 1", db=db) + elif action == "begin": + connection.begin(db=db) + elif action == "begin_run": + connection.begin(db=db) + assert connection.last_database == db + sockets.server.send_message(b"\x70", {}) + connection.run("RETURN 1") + else: + raise ValueError(action) + + assert connection.last_database == db + connection.send_all() + connection.fetch_all() + assert connection.last_database == db + + sockets.server.send_message(b"\x70", {}) + if finish == "reset": + connection.reset() + elif finish == "commit": + if action == "run": + connection.pull() + else: + connection.commit() + elif finish == "rollback": + if action == "run": + connection.pull() + else: + connection.rollback() + else: + raise ValueError(finish) + + connection.send_all() + connection.fetch_all() + + assert connection.last_database == db diff --git a/tests/unit/sync/io/test_class_bolt4x4.py b/tests/unit/sync/io/test_class_bolt4x4.py index 6ac7d792..b417a4c1 100644 --- a/tests/unit/sync/io/test_class_bolt4x4.py +++ b/tests/unit/sync/io/test_class_bolt4x4.py @@ -16,8 +16,8 @@ # limitations under the License. +import itertools import logging -from itertools import permutations import pytest @@ -340,7 +340,7 @@ def test_re_auth_noop(auth, fake_socket, mocker): @pytest.mark.parametrize( ("auth1", "auth2"), - permutations( + itertools.permutations( ( None, neo4j.Auth("scheme", "principal", "credentials", "realm"), diff --git a/tests/unit/sync/io/test_class_bolt5x0.py b/tests/unit/sync/io/test_class_bolt5x0.py index a4416a97..3beff035 100644 --- a/tests/unit/sync/io/test_class_bolt5x0.py +++ b/tests/unit/sync/io/test_class_bolt5x0.py @@ -16,8 +16,8 @@ # limitations under the License. +import itertools import logging -from itertools import permutations import pytest @@ -340,7 +340,7 @@ def test_re_auth_noop(auth, fake_socket, mocker): @pytest.mark.parametrize( ("auth1", "auth2"), - permutations( + itertools.permutations( ( None, neo4j.Auth("scheme", "principal", "credentials", "realm"), @@ -513,3 +513,65 @@ def test_tx_timeout( assert "tx_timeout" not in extra else: assert extra["tx_timeout"] == res + + +@pytest.mark.parametrize( + "actions", + itertools.combinations_with_replacement( + itertools.product( + ("run", "begin", "begin_run"), + ("reset", "commit", "rollback"), + (None, "some_db", "another_db"), + ), + 2 + ) +) +@mark_sync_test +def test_tracks_last_database(fake_socket_pair, actions): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=Bolt5x0.PACKER_CLS, + unpacker_cls=Bolt5x0.UNPACKER_CLS) + connection = Bolt5x0(address, sockets.client, 0) + sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + connection.hello() + assert connection.last_database is None + for action, finish, db in actions: + sockets.server.send_message(b"\x70", {}) + if action == "run": + connection.run("RETURN 1", db=db) + elif action == "begin": + connection.begin(db=db) + elif action == "begin_run": + connection.begin(db=db) + assert connection.last_database == db + sockets.server.send_message(b"\x70", {}) + connection.run("RETURN 1") + else: + raise ValueError(action) + + assert connection.last_database == db + connection.send_all() + connection.fetch_all() + assert connection.last_database == db + + sockets.server.send_message(b"\x70", {}) + if finish == "reset": + connection.reset() + elif finish == "commit": + if action == "run": + connection.pull() + else: + connection.commit() + elif finish == "rollback": + if action == "run": + connection.pull() + else: + connection.rollback() + else: + raise ValueError(finish) + + connection.send_all() + connection.fetch_all() + + assert connection.last_database == db diff --git a/tests/unit/sync/io/test_class_bolt5x1.py b/tests/unit/sync/io/test_class_bolt5x1.py index e6315df6..02178cfe 100644 --- a/tests/unit/sync/io/test_class_bolt5x1.py +++ b/tests/unit/sync/io/test_class_bolt5x1.py @@ -16,6 +16,7 @@ # limitations under the License. +import itertools import logging import pytest @@ -567,3 +568,66 @@ def test_tx_timeout( assert "tx_timeout" not in extra else: assert extra["tx_timeout"] == res + + +@pytest.mark.parametrize( + "actions", + itertools.combinations_with_replacement( + itertools.product( + ("run", "begin", "begin_run"), + ("reset", "commit", "rollback"), + (None, "some_db", "another_db"), + ), + 2 + ) +) +@mark_sync_test +def test_tracks_last_database(fake_socket_pair, actions): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=Bolt5x1.PACKER_CLS, + unpacker_cls=Bolt5x1.UNPACKER_CLS) + connection = Bolt5x1(address, sockets.client, 0) + sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + sockets.server.send_message(b"\x70", {}) + connection.hello() + assert connection.last_database is None + for action, finish, db in actions: + sockets.server.send_message(b"\x70", {}) + if action == "run": + connection.run("RETURN 1", db=db) + elif action == "begin": + connection.begin(db=db) + elif action == "begin_run": + connection.begin(db=db) + assert connection.last_database == db + sockets.server.send_message(b"\x70", {}) + connection.run("RETURN 1") + else: + raise ValueError(action) + + assert connection.last_database == db + connection.send_all() + connection.fetch_all() + assert connection.last_database == db + + sockets.server.send_message(b"\x70", {}) + if finish == "reset": + connection.reset() + elif finish == "commit": + if action == "run": + connection.pull() + else: + connection.commit() + elif finish == "rollback": + if action == "run": + connection.pull() + else: + connection.rollback() + else: + raise ValueError(finish) + + connection.send_all() + connection.fetch_all() + + assert connection.last_database == db diff --git a/tests/unit/sync/io/test_class_bolt5x2.py b/tests/unit/sync/io/test_class_bolt5x2.py index 9dec5bac..99f8daee 100644 --- a/tests/unit/sync/io/test_class_bolt5x2.py +++ b/tests/unit/sync/io/test_class_bolt5x2.py @@ -585,3 +585,66 @@ def test_tx_timeout( assert "tx_timeout" not in extra else: assert extra["tx_timeout"] == res + + +@pytest.mark.parametrize( + "actions", + itertools.combinations_with_replacement( + itertools.product( + ("run", "begin", "begin_run"), + ("reset", "commit", "rollback"), + (None, "some_db", "another_db"), + ), + 2 + ) +) +@mark_sync_test +def test_tracks_last_database(fake_socket_pair, actions): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=Bolt5x2.PACKER_CLS, + unpacker_cls=Bolt5x2.UNPACKER_CLS) + connection = Bolt5x2(address, sockets.client, 0) + sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + sockets.server.send_message(b"\x70", {}) + connection.hello() + assert connection.last_database is None + for action, finish, db in actions: + sockets.server.send_message(b"\x70", {}) + if action == "run": + connection.run("RETURN 1", db=db) + elif action == "begin": + connection.begin(db=db) + elif action == "begin_run": + connection.begin(db=db) + assert connection.last_database == db + sockets.server.send_message(b"\x70", {}) + connection.run("RETURN 1") + else: + raise ValueError(action) + + assert connection.last_database == db + connection.send_all() + connection.fetch_all() + assert connection.last_database == db + + sockets.server.send_message(b"\x70", {}) + if finish == "reset": + connection.reset() + elif finish == "commit": + if action == "run": + connection.pull() + else: + connection.commit() + elif finish == "rollback": + if action == "run": + connection.pull() + else: + connection.rollback() + else: + raise ValueError(finish) + + connection.send_all() + connection.fetch_all() + + assert connection.last_database == db diff --git a/tests/unit/sync/io/test_class_bolt5x3.py b/tests/unit/sync/io/test_class_bolt5x3.py index 9437a348..7a3dd884 100644 --- a/tests/unit/sync/io/test_class_bolt5x3.py +++ b/tests/unit/sync/io/test_class_bolt5x3.py @@ -496,3 +496,66 @@ def test_tx_timeout( assert "tx_timeout" not in extra else: assert extra["tx_timeout"] == res + + +@pytest.mark.parametrize( + "actions", + itertools.combinations_with_replacement( + itertools.product( + ("run", "begin", "begin_run"), + ("reset", "commit", "rollback"), + (None, "some_db", "another_db"), + ), + 2 + ) +) +@mark_sync_test +def test_tracks_last_database(fake_socket_pair, actions): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address, + packer_cls=Bolt5x3.PACKER_CLS, + unpacker_cls=Bolt5x3.UNPACKER_CLS) + connection = Bolt5x3(address, sockets.client, 0) + sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + sockets.server.send_message(b"\x70", {}) + connection.hello() + assert connection.last_database is None + for action, finish, db in actions: + sockets.server.send_message(b"\x70", {}) + if action == "run": + connection.run("RETURN 1", db=db) + elif action == "begin": + connection.begin(db=db) + elif action == "begin_run": + connection.begin(db=db) + assert connection.last_database == db + sockets.server.send_message(b"\x70", {}) + connection.run("RETURN 1") + else: + raise ValueError(action) + + assert connection.last_database == db + connection.send_all() + connection.fetch_all() + assert connection.last_database == db + + sockets.server.send_message(b"\x70", {}) + if finish == "reset": + connection.reset() + elif finish == "commit": + if action == "run": + connection.pull() + else: + connection.commit() + elif finish == "rollback": + if action == "run": + connection.pull() + else: + connection.rollback() + else: + raise ValueError(finish) + + connection.send_all() + connection.fetch_all() + + assert connection.last_database == db