From a6275f7288d84718d533be91552327c1f8ed581e Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Thu, 11 Apr 2024 12:26:33 +0200 Subject: [PATCH 1/2] Invalidate writers per database This should improve the performance of the driver in multi database use-cases. The driver now only removes a server as a writer for a single database (before for all databases) if that server returns an error that notifies the driver that the server is no longer a writer (`Neo.ClientError.Cluster.NotALeader` or `Neo.ClientError.General.ForbiddenOnReadOnlyDatabase`). --- neo4j/io/__init__.py | 34 +++++++++-- neo4j/io/_bolt.py | 29 ++++++++++ neo4j/io/_bolt3.py | 89 ++++++++++++++++++++++------- neo4j/io/_bolt4.py | 34 +++++++++-- tests/unit/io/test_class_bolt3.py | 75 ++++++++++++++++++++++++ tests/unit/io/test_class_bolt4x0.py | 61 ++++++++++++++++++++ tests/unit/io/test_class_bolt4x1.py | 61 ++++++++++++++++++++ tests/unit/io/test_class_bolt4x2.py | 61 ++++++++++++++++++++ tests/unit/io/test_class_bolt4x3.py | 61 ++++++++++++++++++++ tests/unit/io/test_class_bolt4x4.py | 61 ++++++++++++++++++++ 10 files changed, 534 insertions(+), 32 deletions(-) create mode 100644 neo4j/io/_bolt.py diff --git a/neo4j/io/__init__.py b/neo4j/io/__init__.py index 7b8c63ee..4785fd66 100644 --- a/neo4j/io/__init__.py +++ b/neo4j/io/__init__.py @@ -105,6 +105,16 @@ log = getLogger("neo4j") +class ClientStateManagerBase(abc.ABC): + @abc.abstractmethod + def __init__(self, init_state, on_change=None): + ... + + @abc.abstractmethod + def transition(self, message): + ... + + class Bolt(abc.ABC): """ Server connection for Bolt protocol. @@ -125,6 +135,10 @@ class Bolt(abc.ABC): # The socket in_use = False + # The database name the connection was last used with + # (BEGIN for explicit transactions, RUN for auto-commit transactions) + last_database = None + # The socket _closing = False _closed = False @@ -399,6 +413,10 @@ def __del__(self): except OSError: pass + @abc.abstractmethod + def _get_client_state_manager(self): + ... + @abc.abstractmethod def route(self, database=None, imp_user=None, bookmarks=None): """ Fetch a routing table from the server for the given @@ -504,6 +522,8 @@ def _append(self, signature, fields=(), response=None): self.packer.pack_struct(signature, fields) self.outbox.wrap_message() self.responses.append(response) + if response: + self._get_client_state_manager().transition(response.message) def _send_all(self): with self.outbox.view() as data: @@ -867,8 +887,10 @@ def deactivate(self, address): if not self.connections[address]: del self.connections[address] - def on_write_failure(self, address): - raise WriteServiceUnavailable("No write service available for pool {}".format(self)) + def on_write_failure(self, address, database): + raise WriteServiceUnavailable( + "No write service available for pool {}".format(self) + ) def close(self): """ Close all connections and empty the pool. @@ -1342,12 +1364,14 @@ def deactivate(self, address): log.debug("[#0000] C: table=%r", self.routing_tables) super(Neo4jPool, self).deactivate(address) - def on_write_failure(self, address): + def on_write_failure(self, address, database): """ Remove a writer address from the routing table, if present. """ - log.debug("[#0000] C: Removing writer %r", address) + log.debug("[#0000] C: Removing writer %r for database %r", + address, database) with self.refresh_lock: - for database in self.routing_tables.keys(): + table = self.routing_tables.get(database) + if table is not None: self.routing_tables[database].writers.discard(address) log.debug("[#0000] C: table=%r", self.routing_tables) diff --git a/neo4j/io/_bolt.py b/neo4j/io/_bolt.py new file mode 100644 index 00000000..9916e7ef --- /dev/null +++ b/neo4j/io/_bolt.py @@ -0,0 +1,29 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import abc + + +class ClientStateManagerBase(abc.ABC): + @abc.abstractmethod + def __init__(self, init_state, on_change=None): + ... + + @abc.abstractmethod + def transition(self, message): + ... diff --git a/neo4j/io/_bolt3.py b/neo4j/io/_bolt3.py index 21f71b18..34f89c97 100644 --- a/neo4j/io/_bolt3.py +++ b/neo4j/io/_bolt3.py @@ -43,6 +43,7 @@ Bolt, check_supported_server_product, ) +from neo4j.io._bolt import ClientStateManagerBase from neo4j.io._common import ( CommitResponse, InitResponse, @@ -53,7 +54,7 @@ log = getLogger("neo4j") -class ServerStates(Enum): +class BoltStates(Enum): CONNECTED = "CONNECTED" READY = "READY" STREAMING = "STREAMING" @@ -63,25 +64,25 @@ class ServerStates(Enum): class ServerStateManager: _STATE_TRANSITIONS = { - ServerStates.CONNECTED: { - "hello": ServerStates.READY, + BoltStates.CONNECTED: { + "hello": BoltStates.READY, }, - ServerStates.READY: { - "run": ServerStates.STREAMING, - "begin": ServerStates.TX_READY_OR_TX_STREAMING, + BoltStates.READY: { + "run": BoltStates.STREAMING, + "begin": BoltStates.TX_READY_OR_TX_STREAMING, }, - ServerStates.STREAMING: { - "pull": ServerStates.READY, - "discard": ServerStates.READY, - "reset": ServerStates.READY, + BoltStates.STREAMING: { + "pull": BoltStates.READY, + "discard": BoltStates.READY, + "reset": BoltStates.READY, }, - ServerStates.TX_READY_OR_TX_STREAMING: { - "commit": ServerStates.READY, - "rollback": ServerStates.READY, - "reset": ServerStates.READY, + BoltStates.TX_READY_OR_TX_STREAMING: { + "commit": BoltStates.READY, + "rollback": BoltStates.READY, + "reset": BoltStates.READY, }, - ServerStates.FAILED: { - "reset": ServerStates.READY, + BoltStates.FAILED: { + "reset": BoltStates.READY, } } @@ -100,6 +101,39 @@ def transition(self, message, metadata): self._on_change(state_before, self.state) +class ClientStateManager(ClientStateManagerBase): + _STATE_TRANSITIONS = { + BoltStates.CONNECTED: { + "hello": BoltStates.READY, + }, + BoltStates.READY: { + "run": BoltStates.STREAMING, + "begin": BoltStates.TX_READY_OR_TX_STREAMING, + }, + BoltStates.STREAMING: { + "begin": BoltStates.TX_READY_OR_TX_STREAMING, + "reset": BoltStates.READY, + }, + BoltStates.TX_READY_OR_TX_STREAMING: { + "commit": BoltStates.READY, + "rollback": BoltStates.READY, + "reset": BoltStates.READY, + }, + } + + def __init__(self, init_state, on_change=None): + self.state = init_state + self._on_change = on_change + + def transition(self, message): + state_before = self.state + self.state = self._STATE_TRANSITIONS \ + .get(self.state, {}) \ + .get(message, self.state) + if state_before != self.state and callable(self._on_change): + self._on_change(state_before, self.state) + + class Bolt3(Bolt): """ Protocol handler for Bolt 3. @@ -115,13 +149,23 @@ class Bolt3(Bolt): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._server_state_manager = ServerStateManager( - ServerStates.CONNECTED, on_change=self._on_server_state_change + BoltStates.CONNECTED, on_change=self._on_server_state_change + ) + self._client_state_manager = ClientStateManager( + BoltStates.CONNECTED, on_change=self._on_client_state_change ) def _on_server_state_change(self, old_state, new_state): - log.debug("[#%04X] State: %s > %s", self.local_port, + log.debug("[#%04X] Server State: %s > %s", self.local_port, old_state.name, new_state.name) + def _on_client_state_change(self, old_state, new_state): + log.debug("[#%04X] Client state: %s > %s", + self.local_port, old_state.name, new_state.name) + + def _get_client_state_manager(self): + return self._client_state_manager + @property def is_reset(self): # We can't be sure of the server's state if there are still pending @@ -130,7 +174,7 @@ def is_reset(self): if (self.responses and self.responses[-1] and self.responses[-1].message == "reset"): return True - return self._server_state_manager.state == ServerStates.READY + return self._server_state_manager.state == BoltStates.READY @property def encrypted(self): @@ -349,7 +393,7 @@ def fetch_message(self): response.on_ignored(summary_metadata or {}) elif summary_signature == b"\x7F": log.debug("[#%04X] S: FAILURE %r", self.local_port, summary_metadata) - self._server_state_manager.state = ServerStates.FAILED + self._server_state_manager.state = BoltStates.FAILED try: response.on_failure(summary_metadata or {}) except (ServiceUnavailable, DatabaseUnavailable): @@ -358,7 +402,10 @@ def fetch_message(self): raise except (NotALeader, ForbiddenOnReadOnlyDatabase): if self.pool: - self.pool.on_write_failure(address=self.unresolved_address), + self.pool.on_write_failure( + address=self.unresolved_address, + database=self.last_database, + ), raise except Neo4jError as e: if self.pool and e.invalidates_all_connections(): diff --git a/neo4j/io/_bolt4.py b/neo4j/io/_bolt4.py index 2feeea14..eea8d7f9 100644 --- a/neo4j/io/_bolt4.py +++ b/neo4j/io/_bolt4.py @@ -50,8 +50,9 @@ Response, ) from neo4j.io._bolt3 import ( + BoltStates, + ClientStateManager, ServerStateManager, - ServerStates, ) @@ -73,13 +74,23 @@ class Bolt4x0(Bolt): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._server_state_manager = ServerStateManager( - ServerStates.CONNECTED, on_change=self._on_server_state_change + BoltStates.CONNECTED, on_change=self._on_server_state_change + ) + self._client_state_manager = ClientStateManager( + BoltStates.CONNECTED, on_change=self._on_client_state_change ) def _on_server_state_change(self, old_state, new_state): - log.debug("[#%04X] State: %s > %s", self.local_port, + log.debug("[#%04X] Server state: %s > %s", self.local_port, old_state.name, new_state.name) + def _on_client_state_change(self, old_state, new_state): + log.debug("[#%04X] Client state: %s > %s", + self.local_port, old_state.name, new_state.name) + + def _get_client_state_manager(self): + return self._client_state_manager + @property def is_reset(self): # We can't be sure of the server's state if there are still pending @@ -88,7 +99,7 @@ def is_reset(self): if (self.responses and self.responses[-1] and self.responses[-1].message == "reset"): return True - return self._server_state_manager.state == ServerStates.READY + return self._server_state_manager.state == BoltStates.READY @property def encrypted(self): @@ -168,6 +179,9 @@ def run(self, query, parameters=None, mode=None, bookmarks=None, extra["mode"] = "r" # It will default to mode "w" if nothing is specified if db: extra["db"] = db + client_state = self._client_state_manager.state + if client_state != BoltStates.TX_READY_OR_TX_STREAMING: + self.last_database = db if bookmarks: try: extra["bookmarks"] = list(bookmarks) @@ -219,6 +233,7 @@ def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, extra["mode"] = "r" # It will default to mode "w" if nothing is specified if db: extra["db"] = db + self.last_database = db if bookmarks: try: extra["bookmarks"] = list(bookmarks) @@ -301,7 +316,7 @@ def fetch_message(self): response.on_ignored(summary_metadata or {}) elif summary_signature == b"\x7F": log.debug("[#%04X] S: FAILURE %r", self.local_port, summary_metadata) - self._server_state_manager.state = ServerStates.FAILED + self._server_state_manager.state = BoltStates.FAILED try: response.on_failure(summary_metadata or {}) except (ServiceUnavailable, DatabaseUnavailable): @@ -310,7 +325,10 @@ def fetch_message(self): raise except (NotALeader, ForbiddenOnReadOnlyDatabase): if self.pool: - self.pool.on_write_failure(address=self.unresolved_address), + self.pool.on_write_failure( + address=self.unresolved_address, + database=self.last_database, + ), raise except Neo4jError as e: if self.pool and e.invalidates_all_connections(): @@ -478,6 +496,9 @@ def run(self, query, parameters=None, mode=None, bookmarks=None, extra["mode"] = "r" if db: extra["db"] = db + client_state = self._client_state_manager.state + if client_state != BoltStates.TX_READY_OR_TX_STREAMING: + self.last_database = db if imp_user: extra["imp_user"] = imp_user if bookmarks: @@ -513,6 +534,7 @@ def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, extra["mode"] = "r" if db: extra["db"] = db + self.last_database = db if imp_user: extra["imp_user"] = imp_user if bookmarks: diff --git a/tests/unit/io/test_class_bolt3.py b/tests/unit/io/test_class_bolt3.py index f7d63e85..5efe5696 100644 --- a/tests/unit/io/test_class_bolt3.py +++ b/tests/unit/io/test_class_bolt3.py @@ -18,10 +18,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import contextlib +import itertools from unittest.mock import MagicMock import pytest +from neo4j import Address from neo4j.io._bolt3 import Bolt3 from neo4j.conf import PoolConfig from neo4j.exceptions import ( @@ -110,3 +113,75 @@ def test_hint_recv_timeout_seconds_gets_ignored(fake_socket_pair, recv_timeout): PoolConfig.max_connection_lifetime) connection.hello() sockets.client.settimeout.assert_not_called() + + +@pytest.mark.parametrize( + "actions", + itertools.combinations_with_replacement( + itertools.product( + ("run", "begin", "begin_run"), + ("reset", "commit", "rollback"), + (None, "some_db", "another_db"), + ), + 2 + ) +) +def test_tracks_last_database(fake_socket_pair, actions): + address = Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address) + connection = Bolt3(address, sockets.client, 0) + sockets.server.send_message(0x70, {"server": "Neo4j/1.2.3"}) + connection.hello() + assert connection.last_database is None + for action, finish, db in actions: + sockets.server.send_message(0x70, {}) + if action == "run": + with raises_if_db(db): + connection.run("RETURN 1", db=db) + elif action == "begin": + with raises_if_db(db): + connection.begin(db=db) + elif action == "begin_run": + with raises_if_db(db): + connection.begin(db=db) + assert connection.last_database is None + sockets.server.send_message(0x70, {}) + connection.run("RETURN 1") + else: + raise ValueError(action) + + assert connection.last_database is None + connection.send_all() + connection.fetch_all() + assert connection.last_database is None + + sockets.server.send_message(0x70, {}) + if finish == "reset": + connection.reset() + elif finish == "commit": + if action == "run": + connection.pull() + else: + connection.commit() + elif finish == "rollback": + if action == "run": + connection.pull() + else: + connection.rollback() + else: + raise ValueError(finish) + + connection.send_all() + connection.fetch_all() + + assert connection.last_database is None + + +@contextlib.contextmanager +def raises_if_db(db): + if db is None: + yield + else: + with pytest.raises(ConfigurationError, + match="selecting database is not supported"): + yield diff --git a/tests/unit/io/test_class_bolt4x0.py b/tests/unit/io/test_class_bolt4x0.py index 3879acb0..39a4340c 100644 --- a/tests/unit/io/test_class_bolt4x0.py +++ b/tests/unit/io/test_class_bolt4x0.py @@ -18,10 +18,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +import itertools from unittest.mock import MagicMock import pytest +from neo4j import Address from neo4j.io._bolt4 import Bolt4x0 from neo4j.conf import PoolConfig @@ -197,3 +199,62 @@ def test_hint_recv_timeout_seconds_gets_ignored(fake_socket_pair, recv_timeout): PoolConfig.max_connection_lifetime) connection.hello() sockets.client.settimeout.assert_not_called() + + +@pytest.mark.parametrize( + "actions", + itertools.combinations_with_replacement( + itertools.product( + ("run", "begin", "begin_run"), + ("reset", "commit", "rollback"), + (None, "some_db", "another_db"), + ), + 2 + ) +) +def test_tracks_last_database(fake_socket_pair, actions): + address = Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address) + connection = Bolt4x0(address, sockets.client, 0) + sockets.server.send_message(0x70, {"server": "Neo4j/1.2.3"}) + connection.hello() + assert connection.last_database is None + for action, finish, db in actions: + sockets.server.send_message(0x70, {}) + if action == "run": + connection.run("RETURN 1", db=db) + elif action == "begin": + connection.begin(db=db) + elif action == "begin_run": + connection.begin(db=db) + assert connection.last_database == db + sockets.server.send_message(0x70, {}) + connection.run("RETURN 1") + else: + raise ValueError(action) + + assert connection.last_database == db + connection.send_all() + connection.fetch_all() + assert connection.last_database == db + + sockets.server.send_message(0x70, {}) + if finish == "reset": + connection.reset() + elif finish == "commit": + if action == "run": + connection.pull() + else: + connection.commit() + elif finish == "rollback": + if action == "run": + connection.pull() + else: + connection.rollback() + else: + raise ValueError(finish) + + connection.send_all() + connection.fetch_all() + + assert connection.last_database == db diff --git a/tests/unit/io/test_class_bolt4x1.py b/tests/unit/io/test_class_bolt4x1.py index 663d3cbe..b4288999 100644 --- a/tests/unit/io/test_class_bolt4x1.py +++ b/tests/unit/io/test_class_bolt4x1.py @@ -18,10 +18,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +import itertools from unittest.mock import MagicMock import pytest +from neo4j import Address from neo4j.io._bolt4 import Bolt4x1 from neo4j.conf import PoolConfig @@ -210,3 +212,62 @@ def test_hint_recv_timeout_seconds_gets_ignored(fake_socket_pair, recv_timeout): PoolConfig.max_connection_lifetime) connection.hello() sockets.client.settimeout.assert_not_called() + + +@pytest.mark.parametrize( + "actions", + itertools.combinations_with_replacement( + itertools.product( + ("run", "begin", "begin_run"), + ("reset", "commit", "rollback"), + (None, "some_db", "another_db"), + ), + 2 + ) +) +def test_tracks_last_database(fake_socket_pair, actions): + address = Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address) + connection = Bolt4x1(address, sockets.client, 0) + sockets.server.send_message(0x70, {"server": "Neo4j/1.2.3"}) + connection.hello() + assert connection.last_database is None + for action, finish, db in actions: + sockets.server.send_message(0x70, {}) + if action == "run": + connection.run("RETURN 1", db=db) + elif action == "begin": + connection.begin(db=db) + elif action == "begin_run": + connection.begin(db=db) + assert connection.last_database == db + sockets.server.send_message(0x70, {}) + connection.run("RETURN 1") + else: + raise ValueError(action) + + assert connection.last_database == db + connection.send_all() + connection.fetch_all() + assert connection.last_database == db + + sockets.server.send_message(0x70, {}) + if finish == "reset": + connection.reset() + elif finish == "commit": + if action == "run": + connection.pull() + else: + connection.commit() + elif finish == "rollback": + if action == "run": + connection.pull() + else: + connection.rollback() + else: + raise ValueError(finish) + + connection.send_all() + connection.fetch_all() + + assert connection.last_database == db diff --git a/tests/unit/io/test_class_bolt4x2.py b/tests/unit/io/test_class_bolt4x2.py index 470adf5c..d2e05944 100644 --- a/tests/unit/io/test_class_bolt4x2.py +++ b/tests/unit/io/test_class_bolt4x2.py @@ -19,10 +19,12 @@ # limitations under the License. +import itertools from unittest.mock import MagicMock import pytest +from neo4j import Address from neo4j.io._bolt4 import Bolt4x2 from neo4j.conf import PoolConfig @@ -211,3 +213,62 @@ def test_hint_recv_timeout_seconds_gets_ignored(fake_socket_pair, recv_timeout): PoolConfig.max_connection_lifetime) connection.hello() sockets.client.settimeout.assert_not_called() + + +@pytest.mark.parametrize( + "actions", + itertools.combinations_with_replacement( + itertools.product( + ("run", "begin", "begin_run"), + ("reset", "commit", "rollback"), + (None, "some_db", "another_db"), + ), + 2 + ) +) +def test_tracks_last_database(fake_socket_pair, actions): + address = Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address) + connection = Bolt4x2(address, sockets.client, 0) + sockets.server.send_message(0x70, {"server": "Neo4j/1.2.3"}) + connection.hello() + assert connection.last_database is None + for action, finish, db in actions: + sockets.server.send_message(0x70, {}) + if action == "run": + connection.run("RETURN 1", db=db) + elif action == "begin": + connection.begin(db=db) + elif action == "begin_run": + connection.begin(db=db) + assert connection.last_database == db + sockets.server.send_message(0x70, {}) + connection.run("RETURN 1") + else: + raise ValueError(action) + + assert connection.last_database == db + connection.send_all() + connection.fetch_all() + assert connection.last_database == db + + sockets.server.send_message(0x70, {}) + if finish == "reset": + connection.reset() + elif finish == "commit": + if action == "run": + connection.pull() + else: + connection.commit() + elif finish == "rollback": + if action == "run": + connection.pull() + else: + connection.rollback() + else: + raise ValueError(finish) + + connection.send_all() + connection.fetch_all() + + assert connection.last_database == db diff --git a/tests/unit/io/test_class_bolt4x3.py b/tests/unit/io/test_class_bolt4x3.py index fc08f5b9..81bb585a 100644 --- a/tests/unit/io/test_class_bolt4x3.py +++ b/tests/unit/io/test_class_bolt4x3.py @@ -18,11 +18,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import itertools import logging from unittest.mock import MagicMock import pytest +from neo4j import Address from neo4j.io._bolt4 import Bolt4x3 from neo4j.conf import PoolConfig @@ -237,3 +239,62 @@ def test_hint_recv_timeout_seconds(fake_socket_pair, hints, valid, and "recv_timeout_seconds" in msg and "invalid" in msg for msg in caplog.messages) + + +@pytest.mark.parametrize( + "actions", + itertools.combinations_with_replacement( + itertools.product( + ("run", "begin", "begin_run"), + ("reset", "commit", "rollback"), + (None, "some_db", "another_db"), + ), + 2 + ) +) +def test_tracks_last_database(fake_socket_pair, actions): + address = Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address) + connection = Bolt4x3(address, sockets.client, 0) + sockets.server.send_message(0x70, {"server": "Neo4j/1.2.3"}) + connection.hello() + assert connection.last_database is None + for action, finish, db in actions: + sockets.server.send_message(0x70, {}) + if action == "run": + connection.run("RETURN 1", db=db) + elif action == "begin": + connection.begin(db=db) + elif action == "begin_run": + connection.begin(db=db) + assert connection.last_database == db + sockets.server.send_message(0x70, {}) + connection.run("RETURN 1") + else: + raise ValueError(action) + + assert connection.last_database == db + connection.send_all() + connection.fetch_all() + assert connection.last_database == db + + sockets.server.send_message(0x70, {}) + if finish == "reset": + connection.reset() + elif finish == "commit": + if action == "run": + connection.pull() + else: + connection.commit() + elif finish == "rollback": + if action == "run": + connection.pull() + else: + connection.rollback() + else: + raise ValueError(finish) + + connection.send_all() + connection.fetch_all() + + assert connection.last_database == db diff --git a/tests/unit/io/test_class_bolt4x4.py b/tests/unit/io/test_class_bolt4x4.py index 19378a1c..ef3349b6 100644 --- a/tests/unit/io/test_class_bolt4x4.py +++ b/tests/unit/io/test_class_bolt4x4.py @@ -18,11 +18,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import itertools import logging from unittest.mock import MagicMock import pytest +from neo4j import Address from neo4j.io._bolt4 import Bolt4x4 from neo4j.conf import PoolConfig @@ -249,3 +251,62 @@ def test_hint_recv_timeout_seconds(fake_socket_pair, hints, valid, and "recv_timeout_seconds" in msg and "invalid" in msg for msg in caplog.messages) + + +@pytest.mark.parametrize( + "actions", + itertools.combinations_with_replacement( + itertools.product( + ("run", "begin", "begin_run"), + ("reset", "commit", "rollback"), + (None, "some_db", "another_db"), + ), + 2 + ) +) +def test_tracks_last_database(fake_socket_pair, actions): + address = Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair(address) + connection = Bolt4x4(address, sockets.client, 0) + sockets.server.send_message(0x70, {"server": "Neo4j/1.2.3"}) + connection.hello() + assert connection.last_database is None + for action, finish, db in actions: + sockets.server.send_message(0x70, {}) + if action == "run": + connection.run("RETURN 1", db=db) + elif action == "begin": + connection.begin(db=db) + elif action == "begin_run": + connection.begin(db=db) + assert connection.last_database == db + sockets.server.send_message(0x70, {}) + connection.run("RETURN 1") + else: + raise ValueError(action) + + assert connection.last_database == db + connection.send_all() + connection.fetch_all() + assert connection.last_database == db + + sockets.server.send_message(0x70, {}) + if finish == "reset": + connection.reset() + elif finish == "commit": + if action == "run": + connection.pull() + else: + connection.commit() + elif finish == "rollback": + if action == "run": + connection.pull() + else: + connection.rollback() + else: + raise ValueError(finish) + + connection.send_all() + connection.fetch_all() + + assert connection.last_database == db From 4dcf9ef605447500b1fd49e0a78abdc7448fc977 Mon Sep 17 00:00:00 2001 From: Robsdedude Date: Fri, 12 Apr 2024 12:18:49 +0200 Subject: [PATCH 2/2] Minor code clean-up MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Antonio Barcélos --- neo4j/io/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/neo4j/io/__init__.py b/neo4j/io/__init__.py index 4785fd66..4378c7c2 100644 --- a/neo4j/io/__init__.py +++ b/neo4j/io/__init__.py @@ -1372,7 +1372,7 @@ def on_write_failure(self, address, database): with self.refresh_lock: table = self.routing_tables.get(database) if table is not None: - self.routing_tables[database].writers.discard(address) + table.writers.discard(address) log.debug("[#0000] C: table=%r", self.routing_tables)