From 99dba90a7c7bb326e0c2d0a8460002aa592f1869 Mon Sep 17 00:00:00 2001 From: Rouven Bauer Date: Fri, 30 Apr 2021 13:51:13 +0200 Subject: [PATCH 1/2] Handle Neo.ClientError.Security.AuthorizationExpired - Close the connection and mark all idle connections in the pool as expired - Raise the error to the user or do a retry if inside transaction function Refactoring shared code of Bolt 3 and Bolt 4 connection into base class. --- neo4j/exceptions.py | 5 + neo4j/io/__init__.py | 227 +++++++++++++++++++++++++--- neo4j/io/_bolt3.py | 174 +-------------------- neo4j/io/_bolt4.py | 161 +------------------- neo4j/io/_common.py | 6 +- neo4j/work/result.py | 6 +- neo4j/work/simple.py | 4 +- neo4j/work/transaction.py | 4 +- testkitbackend/skipped_tests.json | 8 +- tests/unit/io/test_class_bolt3.py | 21 ++- tests/unit/io/test_class_bolt4x0.py | 21 ++- tests/unit/io/test_class_bolt4x1.py | 21 ++- tests/unit/io/test_class_bolt4x2.py | 21 ++- tests/unit/io/test_class_bolt4x3.py | 197 ++++++++++++++++++++++++ tests/unit/io/test_direct.py | 7 + 15 files changed, 512 insertions(+), 371 deletions(-) create mode 100644 tests/unit/io/test_class_bolt4x3.py diff --git a/neo4j/exceptions.py b/neo4j/exceptions.py index 6a98bcb3..d8541a31 100644 --- a/neo4j/exceptions.py +++ b/neo4j/exceptions.py @@ -88,6 +88,8 @@ def hydrate(cls, message=None, code=None, **metadata): code = code or "Neo.DatabaseError.General.UnknownError" try: _, classification, category, title = code.split(".") + if code == "Neo.ClientError.Security.AuthorizationExpired": + classification = CLASSIFICATION_TRANSIENT except ValueError: classification = CLASSIFICATION_DATABASE category = "General" @@ -124,6 +126,9 @@ def _extract_error_class(cls, classification, code): else: return cls + def invalidates_all_connections(self): + return self.code == "Neo.ClientError.Security.AuthorizationExpired" + def __str__(self): return "{{code: {code}}} {{message: {message}}}".format(code=self.code, message=self.message) diff --git a/neo4j/io/__init__.py b/neo4j/io/__init__.py index 23a9fc8b..8404b355 100644 --- a/neo4j/io/__init__.py +++ b/neo4j/io/__init__.py @@ -33,7 +33,7 @@ "check_supported_server_product", ] - +import abc from collections import deque from logging import getLogger from random import choice @@ -69,6 +69,7 @@ from neo4j.addressing import Address from neo4j.api import ( READ_ACCESS, + ServerInfo, Version, WRITE_ACCESS, ) @@ -77,23 +78,36 @@ WorkspaceConfig, ) from neo4j.exceptions import ( + AuthError, ClientError, ConfigurationError, DriverError, + IncompleteCommit, ReadServiceUnavailable, ServiceUnavailable, SessionExpired, UnsupportedServerProduct, WriteServiceUnavailable, ) +from neo4j.io._common import ( + CommitResponse, + Inbox, + InitResponse, + Outbox, + Response, +) +from neo4j.meta import get_user_agent +from neo4j.packstream import ( + Packer, + Unpacker, +) from neo4j.routing import RoutingTable - # Set up logger log = getLogger("neo4j") -class Bolt: +class Bolt(abc.ABC): """ Server connection for Bolt protocol. A :class:`.Bolt` should be constructed following a @@ -107,6 +121,69 @@ class Bolt: PROTOCOL_VERSION = None + # The socket + in_use = False + + # The socket + _closed = False + + # The socket + _defunct = False + + #: The pool of which this connection is a member + pool = None + + def __init__(self, unresolved_address, sock, max_connection_lifetime, *, auth=None, user_agent=None, routing_context=None): + self.unresolved_address = unresolved_address + self.socket = sock + self.server_info = ServerInfo(Address(sock.getpeername()), self.PROTOCOL_VERSION) + self.outbox = Outbox() + self.inbox = Inbox(self.socket, on_error=self._set_defunct_read) + self.packer = Packer(self.outbox) + self.unpacker = Unpacker(self.inbox) + self.responses = deque() + self._max_connection_lifetime = max_connection_lifetime + self._creation_timestamp = perf_counter() + self._is_reset = True + self.routing_context = routing_context + + # Determine the user agent + if user_agent: + self.user_agent = user_agent + else: + self.user_agent = get_user_agent() + + # Determine auth details + if not auth: + self.auth_dict = {} + elif isinstance(auth, tuple) and 2 <= len(auth) <= 3: + from neo4j import Auth + self.auth_dict = vars(Auth("basic", *auth)) + else: + try: + self.auth_dict = vars(auth) + except (KeyError, TypeError): + raise AuthError("Cannot determine auth details from %r" % auth) + + # Check for missing password + try: + credentials = self.auth_dict["credentials"] + except KeyError: + pass + else: + if credentials is None: + raise AuthError("Password cannot be None") + + @property + @abc.abstractmethod + def supports_multiple_results(self): + pass + + @property + @abc.abstractmethod + def supports_multiple_databases(self): + pass + @classmethod def protocol_handlers(cls, protocol_version=None): """ Return a dictionary of available Bolt protocol handlers, @@ -258,19 +335,23 @@ def open(cls, address, *, auth=None, timeout=None, routing_context=None, **pool_ return connection @property + @abc.abstractmethod def encrypted(self): - raise NotImplementedError + pass @property + @abc.abstractmethod def der_encoded_server_certificate(self): - raise NotImplementedError + pass @property + @abc.abstractmethod def local_port(self): - raise NotImplementedError + pass + @abc.abstractmethod def hello(self): - raise NotImplementedError + pass def __del__(self): try: @@ -278,6 +359,7 @@ def __del__(self): except OSError: pass + @abc.abstractmethod def route(self, database=None, bookmarks=None): """ Fetch a routing table from the server for the given `database`. For Bolt 4.3 and above, this appends a ROUTE @@ -290,7 +372,9 @@ def route(self, database=None, bookmarks=None): transaction should begin :return: dictionary of raw routing data """ + pass + @abc.abstractmethod def run(self, query, parameters=None, mode=None, bookmarks=None, metadata=None, timeout=None, db=None, **handlers): """ Appends a RUN message to the output stream. @@ -305,7 +389,9 @@ def run(self, query, parameters=None, mode=None, bookmarks=None, metadata=None, :param handlers: handler functions passed into the returned Response object :return: Response object """ + pass + @abc.abstractmethod def discard(self, n=-1, qid=-1, **handlers): """ Appends a DISCARD message to the output stream. @@ -314,7 +400,9 @@ def discard(self, n=-1, qid=-1, **handlers): :param handlers: handler functions passed into the returned Response object :return: Response object """ + pass + @abc.abstractmethod def pull(self, n=-1, qid=-1, **handlers): """ Appends a PULL message to the output stream. @@ -323,7 +411,9 @@ def pull(self, n=-1, qid=-1, **handlers): :param handlers: handler functions passed into the returned Response object :return: Response object """ + pass + @abc.abstractmethod def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, db=None, **handlers): """ Appends a BEGIN message to the output stream. @@ -335,34 +425,65 @@ def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, db=None, :param handlers: handler functions passed into the returned Response object :return: Response object """ + pass + @abc.abstractmethod def commit(self, **handlers): - raise NotImplementedError + pass + @abc.abstractmethod def rollback(self, **handlers): - raise NotImplementedError + pass + @abc.abstractmethod def reset(self): """ Add a RESET message to the outgoing queue, send it and consume all remaining messages. """ - raise NotImplementedError + pass + + def _append(self, signature, fields=(), response=None): + """ Add a message to the outgoing queue. + + :arg signature: the signature of the message + :arg fields: the fields of the message as a tuple + :arg response: a response object to handle callbacks + """ + self.packer.pack_struct(signature, fields) + self.outbox.chunk() + self.outbox.chunk() + self.responses.append(response) + + def _send_all(self): + data = self.outbox.view() + if data: + try: + self.socket.sendall(data) + except OSError as error: + self._set_defunct_write(error) + self.outbox.clear() def send_all(self): """ Send all queued messages to the server. """ - raise NotImplementedError + if self.closed(): + raise ServiceUnavailable("Failed to write to closed connection {!r} ({!r})".format( + self.unresolved_address, self.server_info.address)) + + if self.defunct(): + raise ServiceUnavailable("Failed to write to defunct connection {!r} ({!r})".format( + self.unresolved_address, self.server_info.address)) + + self._send_all() + @abc.abstractmethod def fetch_message(self): """ Receive at least one message from the server, if available. :return: 2-tuple of number of detail messages and number of summary messages fetched """ - raise NotImplementedError - - def timedout(self): - raise NotImplementedError + pass def fetch_all(self): """ Fetch all outstanding messages. @@ -370,7 +491,73 @@ def fetch_all(self): :return: 2-tuple of number of detail messages and number of summary messages fetched """ - raise NotImplementedError + detail_count = summary_count = 0 + while self.responses: + response = self.responses[0] + while not response.complete: + detail_delta, summary_delta = self.fetch_message() + detail_count += detail_delta + summary_count += summary_delta + return detail_count, summary_count + + def _set_defunct_read(self, error=None, silent=False): + message = "Failed to read from defunct connection {!r} ({!r})".format( + self.unresolved_address, self.server_info.address + ) + self._set_defunct(message, error=error, silent=silent) + + def _set_defunct_write(self, error=None, silent=False): + message = "Failed to write data to connection {!r} ({!r})".format( + self.unresolved_address, self.server_info.address + ) + self._set_defunct(message, error=error, silent=silent) + + def _set_defunct(self, message, error=None, silent=False): + direct_driver = isinstance(self.pool, BoltPool) + + if error: + log.error(str(error)) + log.error(message) + # We were attempting to receive data but the connection + # has unexpectedly terminated. So, we need to close the + # connection from the client side, and remove the address + # from the connection pool. + self._defunct = True + self.close() + if self.pool: + self.pool.deactivate(address=self.unresolved_address) + # Iterate through the outstanding responses, and if any correspond + # to COMMIT requests then raise an error to signal that we are + # unable to confirm that the COMMIT completed successfully. + if silent: + return + for response in self.responses: + if isinstance(response, CommitResponse): + if error: + raise IncompleteCommit(message) from error + else: + raise IncompleteCommit(message) + + if direct_driver: + if error: + raise ServiceUnavailable(message) from error + else: + raise ServiceUnavailable(message) + else: + if error: + raise SessionExpired(message) from error + else: + raise SessionExpired(message) + + def stale(self): + return (self._stale + or (0 <= self._max_connection_lifetime + <= perf_counter()- self._creation_timestamp)) + + _stale = False + + def set_stale(self): + self._stale = True def close(self): """ Close the connection. @@ -430,7 +617,7 @@ def time_remaining(): while True: # try to find a free connection in pool for connection in list(connections): - if connection.closed() or connection.defunct() or connection.timedout(): + if connection.closed() or connection.defunct() or connection.stale(): connections.remove(connection) continue if not connection.in_use: @@ -497,6 +684,12 @@ def in_use_connection_count(self, address): else: return sum(1 if connection.in_use else 0 for connection in connections) + def mark_all_stale(self): + with self.lock: + for address in self.connections: + for connection in self.connections[address]: + connection.set_stale() + def deactivate(self, address): """ Deactivate an address from the connection pool, if present, closing all idle connection to that address diff --git a/neo4j/io/_bolt3.py b/neo4j/io/_bolt3.py index 06ac3328..57edcc24 100644 --- a/neo4j/io/_bolt3.py +++ b/neo4j/io/_bolt3.py @@ -18,49 +18,35 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections import deque from logging import getLogger from ssl import SSLSocket -from time import perf_counter from neo4j._exceptions import ( BoltError, BoltProtocolError, ) -from neo4j.addressing import Address from neo4j.api import ( READ_ACCESS, - ServerInfo, Version, ) from neo4j.exceptions import ( - AuthError, ConfigurationError, DatabaseUnavailable, DriverError, ForbiddenOnReadOnlyDatabase, - IncompleteCommit, + Neo4jError, NotALeader, ServiceUnavailable, - SessionExpired, ) from neo4j.io import ( Bolt, - BoltPool, check_supported_server_product, ) from neo4j.io._common import ( CommitResponse, - Inbox, InitResponse, - Outbox, Response, ) -from neo4j.meta import get_user_agent -from neo4j.packstream import ( - Packer, - Unpacker, -) log = getLogger("neo4j") @@ -74,60 +60,9 @@ class Bolt3(Bolt): PROTOCOL_VERSION = Version(3, 0) - # The socket - in_use = False - - # The socket - _closed = False - - # The socket - _defunct = False - - #: The pool of which this connection is a member - pool = None - - def __init__(self, unresolved_address, sock, max_connection_lifetime, *, auth=None, user_agent=None, routing_context=None): - self.unresolved_address = unresolved_address - self.socket = sock - self.server_info = ServerInfo(Address(sock.getpeername()), self.PROTOCOL_VERSION) - self.outbox = Outbox() - self.inbox = Inbox(self.socket, on_error=self._set_defunct_read) - self.packer = Packer(self.outbox) - self.unpacker = Unpacker(self.inbox) - self.responses = deque() - self._max_connection_lifetime = max_connection_lifetime - self._creation_timestamp = perf_counter() - self.supports_multiple_results = False - self.supports_multiple_databases = False - self._is_reset = True - self.routing_context = routing_context + supports_multiple_results = False - # Determine the user agent - if user_agent: - self.user_agent = user_agent - else: - self.user_agent = get_user_agent() - - # Determine auth details - if not auth: - self.auth_dict = {} - elif isinstance(auth, tuple) and 2 <= len(auth) <= 3: - from neo4j import Auth - self.auth_dict = vars(Auth("basic", *auth)) - else: - try: - self.auth_dict = vars(auth) - except (KeyError, TypeError): - raise AuthError("Cannot determine auth details from %r" % auth) - - # Check for missing password - try: - credentials = self.auth_dict["credentials"] - except KeyError: - pass - else: - if credentials is None: - raise AuthError("Password cannot be None") + supports_multiple_databases = False @property def encrypted(self): @@ -269,18 +204,6 @@ def rollback(self, **handlers): log.debug("[#%04X] C: ROLLBACK", self.local_port) self._append(b"\x13", (), Response(self, **handlers)) - def _append(self, signature, fields=(), response=None): - """ Add a message to the outgoing queue. - - :arg signature: the signature of the message - :arg fields: the fields of the message as a tuple - :arg response: a response object to handle callbacks - """ - self.packer.pack_struct(signature, fields) - self.outbox.chunk() - self.outbox.chunk() - self.responses.append(response) - def reset(self): """ Add a RESET message to the outgoing queue, send it and consume all remaining messages. @@ -295,28 +218,6 @@ def fail(metadata): self.fetch_all() self._is_reset = True - def _send_all(self): - data = self.outbox.view() - if data: - try: - self.socket.sendall(data) - except OSError as error: - self._set_defunct_write(error) - self.outbox.clear() - - def send_all(self): - """ Send all queued messages to the server. - """ - if self.closed(): - raise ServiceUnavailable("Failed to write to closed connection {!r} ({!r})".format( - self.unresolved_address, self.server_info.address)) - - if self.defunct(): - raise ServiceUnavailable("Failed to write to defunct connection {!r} ({!r})".format( - self.unresolved_address, self.server_info.address)) - - self._send_all() - def fetch_message(self): """ Receive at least one message from the server, if available. @@ -364,76 +265,15 @@ def fetch_message(self): if self.pool: self.pool.on_write_failure(address=self.unresolved_address), raise + except Neo4jError as e: + if self.pool and e.invalidates_all_connections(): + self.pool.mark_all_stale() + raise else: raise BoltProtocolError("Unexpected response message with signature %02X" % summary_signature, address=self.unresolved_address) return len(details), 1 - def _set_defunct_read(self, error=None): - message = "Failed to read from defunct connection {!r} ({!r})".format( - self.unresolved_address, self.server_info.address - ) - self._set_defunct(message, error=error) - - def _set_defunct_write(self, error=None): - message = "Failed to write data to connection {!r} ({!r})".format( - self.unresolved_address, self.server_info.address - ) - self._set_defunct(message, error=error) - - def _set_defunct(self, message, error=None): - direct_driver = isinstance(self.pool, BoltPool) - - if error: - log.error(str(error)) - log.error(message) - # We were attempting to receive data but the connection - # has unexpectedly terminated. So, we need to close the - # connection from the client side, and remove the address - # from the connection pool. - self._defunct = True - self.close() - if self.pool: - self.pool.deactivate(address=self.unresolved_address) - # Iterate through the outstanding responses, and if any correspond - # to COMMIT requests then raise an error to signal that we are - # unable to confirm that the COMMIT completed successfully. - for response in self.responses: - if isinstance(response, CommitResponse): - if error: - raise IncompleteCommit(message) from error - else: - raise IncompleteCommit(message) - - if direct_driver: - if error: - raise ServiceUnavailable(message) from error - else: - raise ServiceUnavailable(message) - else: - if error: - raise SessionExpired(message) from error - else: - raise SessionExpired(message) - - def timedout(self): - return 0 <= self._max_connection_lifetime <= perf_counter() - self._creation_timestamp - - def fetch_all(self): - """ Fetch all outstanding messages. - - :return: 2-tuple of number of detail messages and number of summary - messages fetched - """ - detail_count = summary_count = 0 - while self.responses: - response = self.responses[0] - while not response.complete: - detail_delta, summary_delta = self.fetch_message() - detail_count += detail_delta - summary_count += summary_delta - return detail_count, summary_count - def close(self): """ Close the connection. """ diff --git a/neo4j/io/_bolt4.py b/neo4j/io/_bolt4.py index 3051f499..2ae05c47 100644 --- a/neo4j/io/_bolt4.py +++ b/neo4j/io/_bolt4.py @@ -18,49 +18,36 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections import deque from logging import getLogger from ssl import SSLSocket -from time import perf_counter from neo4j._exceptions import ( BoltError, BoltProtocolError, ) -from neo4j.addressing import Address from neo4j.api import ( READ_ACCESS, SYSTEM_DATABASE, Version, ) -from neo4j.api import ServerInfo from neo4j.exceptions import ( - AuthError, DatabaseUnavailable, DriverError, ForbiddenOnReadOnlyDatabase, - IncompleteCommit, + Neo4jError, NotALeader, ServiceUnavailable, SessionExpired, ) from neo4j.io import ( Bolt, - BoltPool, check_supported_server_product, ) from neo4j.io._common import ( CommitResponse, - Inbox, InitResponse, - Outbox, Response, ) -from neo4j.meta import get_user_agent -from neo4j.packstream import ( - Unpacker, - Packer, -) log = getLogger("neo4j") @@ -86,48 +73,9 @@ class Bolt4x0(Bolt): #: The pool of which this connection is a member pool = None - def __init__(self, unresolved_address, sock, max_connection_lifetime, *, auth=None, user_agent=None, routing_context=None): - self.unresolved_address = unresolved_address - self.socket = sock - self.server_info = ServerInfo(Address(sock.getpeername()), self.PROTOCOL_VERSION) - self.outbox = Outbox() - self.inbox = Inbox(self.socket, on_error=self._set_defunct_read) - self.packer = Packer(self.outbox) - self.unpacker = Unpacker(self.inbox) - self.responses = deque() - self._max_connection_lifetime = max_connection_lifetime # self.pool_config.max_connection_lifetime - self._creation_timestamp = perf_counter() - self.supports_multiple_results = True - self.supports_multiple_databases = True - self._is_reset = True - self.routing_context = routing_context - - # Determine the user agent - if user_agent: - self.user_agent = user_agent - else: - self.user_agent = get_user_agent() - - # Determine auth details - if not auth: - self.auth_dict = {} - elif isinstance(auth, tuple) and 2 <= len(auth) <= 3: - from neo4j import Auth - self.auth_dict = vars(Auth("basic", *auth)) - else: - try: - self.auth_dict = vars(auth) - except (KeyError, TypeError): - raise AuthError("Cannot determine auth details from %r" % auth) + supports_multiple_results = True - # Check for missing password - try: - credentials = self.auth_dict["credentials"] - except KeyError: - pass - else: - if credentials is None: - raise AuthError("Password cannot be None") + supports_multiple_databases = True @property def encrypted(self): @@ -280,18 +228,6 @@ def rollback(self, **handlers): log.debug("[#%04X] C: ROLLBACK", self.local_port) self._append(b"\x13", (), Response(self, **handlers)) - def _append(self, signature, fields=(), response=None): - """ Add a message to the outgoing queue. - - :arg signature: the signature of the message - :arg fields: the fields of the message as a tuple - :arg response: a response object to handle callbacks - """ - self.packer.pack_struct(signature, fields) - self.outbox.chunk() - self.outbox.chunk() - self.responses.append(response) - def reset(self): """ Add a RESET message to the outgoing queue, send it and consume all remaining messages. @@ -306,28 +242,6 @@ def fail(metadata): self.fetch_all() self._is_reset = True - def _send_all(self): - data = self.outbox.view() - if data: - try: - self.socket.sendall(data) - except OSError as error: - self._set_defunct_write(error) - self.outbox.clear() - - def send_all(self): - """ Send all queued messages to the server. - """ - if self.closed(): - raise ServiceUnavailable("Failed to write to closed connection {!r} ({!r})".format( - self.unresolved_address, self.server_info.address)) - - if self.defunct(): - raise ServiceUnavailable("Failed to write to defunct connection {!r} ({!r})".format( - self.unresolved_address, self.server_info.address)) - - self._send_all() - def fetch_message(self): """ Receive at least one message from the server, if available. @@ -375,77 +289,16 @@ def fetch_message(self): if self.pool: self.pool.on_write_failure(address=self.unresolved_address), raise + except Neo4jError as e: + if self.pool and e.invalidates_all_connections(): + self.pool.mark_all_stale() + raise else: raise BoltProtocolError("Unexpected response message with signature " "%02X" % ord(summary_signature), self.unresolved_address) return len(details), 1 - def _set_defunct_read(self, error=None): - message = "Failed to read from defunct connection {!r} ({!r})".format( - self.unresolved_address, self.server_info.address - ) - self._set_defunct(message, error=error) - - def _set_defunct_write(self, error=None): - message = "Failed to write data to connection {!r} ({!r})".format( - self.unresolved_address, self.server_info.address - ) - self._set_defunct(message, error=error) - - def _set_defunct(self, message, error=None): - direct_driver = isinstance(self.pool, BoltPool) - - if error: - log.error(str(error)) - log.error(message) - # We were attempting to receive data but the connection - # has unexpectedly terminated. So, we need to close the - # connection from the client side, and remove the address - # from the connection pool. - self._defunct = True - self.close() - if self.pool: - self.pool.deactivate(address=self.unresolved_address) - # Iterate through the outstanding responses, and if any correspond - # to COMMIT requests then raise an error to signal that we are - # unable to confirm that the COMMIT completed successfully. - for response in self.responses: - if isinstance(response, CommitResponse): - if error: - raise IncompleteCommit(message) from error - else: - raise IncompleteCommit(message) - - if direct_driver: - if error: - raise ServiceUnavailable(message) from error - else: - raise ServiceUnavailable(message) - else: - if error: - raise SessionExpired(message) from error - else: - raise SessionExpired(message) - - def timedout(self): - return 0 <= self._max_connection_lifetime <= perf_counter() - self._creation_timestamp - - def fetch_all(self): - """ Fetch all outstanding messages. - - :return: 2-tuple of number of detail messages and number of summary - messages fetched - """ - detail_count = summary_count = 0 - while self.responses: - response = self.responses[0] - while not response.complete: - detail_delta, summary_delta = self.fetch_message() - detail_count += detail_delta - summary_count += summary_delta - return detail_count, summary_count - def close(self): """ Close the connection. """ diff --git a/neo4j/io/_common.py b/neo4j/io/_common.py index 213b900f..38e2dfaa 100644 --- a/neo4j/io/_common.py +++ b/neo4j/io/_common.py @@ -25,6 +25,7 @@ AuthError, Neo4jError, ServiceUnavailable, + SessionExpired, ) from neo4j.packstream import ( UnpackableBuffer, @@ -169,7 +170,10 @@ def on_success(self, metadata): def on_failure(self, metadata): """ Called when a FAILURE message has been received. """ - self.connection.reset() + try: + self.connection.reset() + except (SessionExpired, ServiceUnavailable): + pass handler = self.handlers.get("on_failure") if callable(handler): handler(metadata) diff --git a/neo4j/work/result.py b/neo4j/work/result.py index 6e8a097e..60012bfc 100644 --- a/neo4j/work/result.py +++ b/neo4j/work/result.py @@ -60,9 +60,9 @@ def outer(func): def inner(*args, **kwargs): try: func(*args, **kwargs) - except (SessionExpired, ServiceUnavailable) as error: - self._on_network_error(error) - raise + finally: + if self._connection.defunct(): + self._on_network_error() return inner return outer(connection_attr) diff --git a/neo4j/work/simple.py b/neo4j/work/simple.py index 1a73089d..18715fdb 100644 --- a/neo4j/work/simple.py +++ b/neo4j/work/simple.py @@ -137,7 +137,7 @@ def _result_closed(self): self._autoResult = None self._disconnect() - def _result_network_error(self, error): + def _result_network_error(self): if self._autoResult: self._autoResult = None self._disconnect() @@ -261,7 +261,7 @@ def _transaction_closed_handler(self): self._transaction = None self._disconnect() - def _transaction_network_error_handler(self, error): + def _transaction_network_error_handler(self): if self._transaction: self._transaction = None self._disconnect() diff --git a/neo4j/work/transaction.py b/neo4j/work/transaction.py index 1061f576..c05aa2fb 100644 --- a/neo4j/work/transaction.py +++ b/neo4j/work/transaction.py @@ -65,9 +65,9 @@ def _begin(self, database, bookmarks, access_mode, metadata, timeout): def _result_on_closed_handler(self): pass - def _result_on_network_error_handler(self, error): + def _result_on_network_error_handler(self): self._closed = True - self._on_network_error(error) + self._on_network_error() def _consume_results(self): for result in self._results: diff --git a/testkitbackend/skipped_tests.json b/testkitbackend/skipped_tests.json index 5546e392..e49f78f1 100644 --- a/testkitbackend/skipped_tests.json +++ b/testkitbackend/skipped_tests.json @@ -66,5 +66,11 @@ "stub.retry.TestRetryClustering.test_retry_ForbiddenOnReadOnlyDatabase_ChangingWriter": "Test makes assumptions about how verify_connectivity is implemented", "stub.disconnected.SessionRunDisconnected.test_fail_on_reset": - "It is not reseting connection when putting back to pool" + "It is not reseting connection when putting back to pool", + "stub.authorization.AuthorizationTests.test_should_retry_on_auth_expired_on_begin_using_tx_function": + "Flaky: test requires the driver to contact servers in a specific order", + "stub.authorization.AuthorizationTestsV3.test_should_retry_on_auth_expired_on_begin_using_tx_function": + "Flaky: test requires the driver to contact servers in a specific order", + "stub.authorization.AuthorizationTestsV4.test_should_retry_on_auth_expired_on_begin_using_tx_function": + "Flaky: test requires the driver to contact servers in a specific order" } diff --git a/tests/unit/io/test_class_bolt3.py b/tests/unit/io/test_class_bolt3.py index eba6957a..79bc0cf4 100644 --- a/tests/unit/io/test_class_bolt3.py +++ b/tests/unit/io/test_class_bolt3.py @@ -30,25 +30,34 @@ # python -m pytest tests/unit/io/test_class_bolt3.py -s -v -def test_conn_timed_out(fake_socket): +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_stale(fake_socket, set_stale): address = ("127.0.0.1", 7687) max_connection_lifetime = 0 connection = Bolt3(address, fake_socket(address), max_connection_lifetime) - assert connection.timedout() is True + if set_stale: + connection.set_stale() + assert connection.stale() is True -def test_conn_not_timed_out_if_not_enabled(fake_socket): +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): address = ("127.0.0.1", 7687) max_connection_lifetime = -1 connection = Bolt3(address, fake_socket(address), max_connection_lifetime) - assert connection.timedout() is False + if set_stale: + connection.set_stale() + assert connection.stale() is set_stale -def test_conn_not_timed_out(fake_socket): +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_not_stale(fake_socket, set_stale): address = ("127.0.0.1", 7687) max_connection_lifetime = 999999999 connection = Bolt3(address, fake_socket(address), max_connection_lifetime) - assert connection.timedout() is False + if set_stale: + connection.set_stale() + assert connection.stale() is set_stale def test_db_extra_not_supported_in_begin(fake_socket): diff --git a/tests/unit/io/test_class_bolt4x0.py b/tests/unit/io/test_class_bolt4x0.py index 09787b0e..333fc158 100644 --- a/tests/unit/io/test_class_bolt4x0.py +++ b/tests/unit/io/test_class_bolt4x0.py @@ -25,25 +25,34 @@ from neo4j.conf import PoolConfig -def test_conn_timed_out(fake_socket): +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_stale(fake_socket, set_stale): address = ("127.0.0.1", 7687) max_connection_lifetime = 0 connection = Bolt4x0(address, fake_socket(address), max_connection_lifetime) - assert connection.timedout() is True + if set_stale: + connection.set_stale() + assert connection.stale() is True -def test_conn_not_timed_out_if_not_enabled(fake_socket): +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): address = ("127.0.0.1", 7687) max_connection_lifetime = -1 connection = Bolt4x0(address, fake_socket(address), max_connection_lifetime) - assert connection.timedout() is False + if set_stale: + connection.set_stale() + assert connection.stale() is set_stale -def test_conn_not_timed_out(fake_socket): +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_not_stale(fake_socket, set_stale): address = ("127.0.0.1", 7687) max_connection_lifetime = 999999999 connection = Bolt4x0(address, fake_socket(address), max_connection_lifetime) - assert connection.timedout() is False + if set_stale: + connection.set_stale() + assert connection.stale() is set_stale def test_db_extra_in_begin(fake_socket): diff --git a/tests/unit/io/test_class_bolt4x1.py b/tests/unit/io/test_class_bolt4x1.py index a9a11004..aee69e68 100644 --- a/tests/unit/io/test_class_bolt4x1.py +++ b/tests/unit/io/test_class_bolt4x1.py @@ -25,25 +25,34 @@ from neo4j.conf import PoolConfig -def test_conn_timed_out(fake_socket): +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_stale(fake_socket, set_stale): address = ("127.0.0.1", 7687) max_connection_lifetime = 0 connection = Bolt4x1(address, fake_socket(address), max_connection_lifetime) - assert connection.timedout() is True + if set_stale: + connection.set_stale() + assert connection.stale() is True -def test_conn_not_timed_out_if_not_enabled(fake_socket): +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): address = ("127.0.0.1", 7687) max_connection_lifetime = -1 connection = Bolt4x1(address, fake_socket(address), max_connection_lifetime) - assert connection.timedout() is False + if set_stale: + connection.set_stale() + assert connection.stale() is set_stale -def test_conn_not_timed_out(fake_socket): +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_not_stale(fake_socket, set_stale): address = ("127.0.0.1", 7687) max_connection_lifetime = 999999999 connection = Bolt4x1(address, fake_socket(address), max_connection_lifetime) - assert connection.timedout() is False + if set_stale: + connection.set_stale() + assert connection.stale() is set_stale def test_db_extra_in_begin(fake_socket): diff --git a/tests/unit/io/test_class_bolt4x2.py b/tests/unit/io/test_class_bolt4x2.py index 10a91981..0c0b1a9a 100644 --- a/tests/unit/io/test_class_bolt4x2.py +++ b/tests/unit/io/test_class_bolt4x2.py @@ -25,25 +25,34 @@ from neo4j.conf import PoolConfig -def test_conn_timed_out(fake_socket): +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_stale(fake_socket, set_stale): address = ("127.0.0.1", 7687) max_connection_lifetime = 0 connection = Bolt4x2(address, fake_socket(address), max_connection_lifetime) - assert connection.timedout() is True + if set_stale: + connection.set_stale() + assert connection.stale() is True -def test_conn_not_timed_out_if_not_enabled(fake_socket): +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): address = ("127.0.0.1", 7687) max_connection_lifetime = -1 connection = Bolt4x2(address, fake_socket(address), max_connection_lifetime) - assert connection.timedout() is False + if set_stale: + connection.set_stale() + assert connection.stale() is set_stale -def test_conn_not_timed_out(fake_socket): +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_not_stale(fake_socket, set_stale): address = ("127.0.0.1", 7687) max_connection_lifetime = 999999999 connection = Bolt4x2(address, fake_socket(address), max_connection_lifetime) - assert connection.timedout() is False + if set_stale: + connection.set_stale() + assert connection.stale() is set_stale def test_db_extra_in_begin(fake_socket): diff --git a/tests/unit/io/test_class_bolt4x3.py b/tests/unit/io/test_class_bolt4x3.py new file mode 100644 index 00000000..b82a4f0b --- /dev/null +++ b/tests/unit/io/test_class_bolt4x3.py @@ -0,0 +1,197 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest + +from neo4j.io._bolt4 import Bolt4x3 +from neo4j.conf import PoolConfig + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_stale(fake_socket, set_stale): + address = ("127.0.0.1", 7687) + max_connection_lifetime = 0 + connection = Bolt4x3(address, fake_socket(address), max_connection_lifetime) + if set_stale: + connection.set_stale() + assert connection.stale() is True + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): + address = ("127.0.0.1", 7687) + max_connection_lifetime = -1 + connection = Bolt4x3(address, fake_socket(address), max_connection_lifetime) + if set_stale: + connection.set_stale() + assert connection.stale() is set_stale + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_not_stale(fake_socket, set_stale): + address = ("127.0.0.1", 7687) + max_connection_lifetime = 999999999 + connection = Bolt4x3(address, fake_socket(address), max_connection_lifetime) + if set_stale: + connection.set_stale() + assert connection.stale() is set_stale + + +def test_db_extra_in_begin(fake_socket): + address = ("127.0.0.1", 7687) + socket = fake_socket(address) + connection = Bolt4x3(address, socket, PoolConfig.max_connection_lifetime) + connection.begin(db="something") + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x11" + assert len(fields) == 1 + assert fields[0] == {"db": "something"} + + +def test_db_extra_in_run(fake_socket): + address = ("127.0.0.1", 7687) + socket = fake_socket(address) + connection = Bolt4x3(address, socket, PoolConfig.max_connection_lifetime) + connection.run("", {}, db="something") + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x10" + assert len(fields) == 3 + assert fields[0] == "" + assert fields[1] == {} + assert fields[2] == {"db": "something"} + + +def test_n_extra_in_discard(fake_socket): + address = ("127.0.0.1", 7687) + socket = fake_socket(address) + connection = Bolt4x3(address, socket, PoolConfig.max_connection_lifetime) + connection.discard(n=666) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x2F" + assert len(fields) == 1 + assert fields[0] == {"n": 666} + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (666, {"n": -1, "qid": 666}), + (-1, {"n": -1}), + ] +) +def test_qid_extra_in_discard(fake_socket, test_input, expected): + address = ("127.0.0.1", 7687) + socket = fake_socket(address) + connection = Bolt4x3(address, socket, PoolConfig.max_connection_lifetime) + connection.discard(qid=test_input) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x2F" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (777, {"n": 666, "qid": 777}), + (-1, {"n": 666}), + ] +) +def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): + # python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_n_and_qid_extras_in_discard + address = ("127.0.0.1", 7687) + socket = fake_socket(address) + connection = Bolt4x3(address, socket, PoolConfig.max_connection_lifetime) + connection.discard(n=666, qid=test_input) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x2F" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (666, {"n": 666}), + (-1, {"n": -1}), + ] +) +def test_n_extra_in_pull(fake_socket, test_input, expected): + address = ("127.0.0.1", 7687) + socket = fake_socket(address) + connection = Bolt4x3(address, socket, PoolConfig.max_connection_lifetime) + connection.pull(n=test_input) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x3F" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + "test_input, expected", + [ + (777, {"n": -1, "qid": 777}), + (-1, {"n": -1}), + ] +) +def test_qid_extra_in_pull(fake_socket, test_input, expected): + # python -m pytest tests/unit/io/test_class_bolt4x0.py -s -k test_qid_extra_in_pull + address = ("127.0.0.1", 7687) + socket = fake_socket(address) + connection = Bolt4x3(address, socket, PoolConfig.max_connection_lifetime) + connection.pull(qid=test_input) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x3F" + assert len(fields) == 1 + assert fields[0] == expected + + +def test_n_and_qid_extras_in_pull(fake_socket): + address = ("127.0.0.1", 7687) + socket = fake_socket(address) + connection = Bolt4x3(address, socket, PoolConfig.max_connection_lifetime) + connection.pull(n=666, qid=777) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x3F" + assert len(fields) == 1 + assert fields[0] == {"n": 666, "qid": 777} + + +def test_hello_passes_routing_metadata(fake_socket_pair): + address = ("127.0.0.1", 7687) + sockets = fake_socket_pair(address) + sockets.server.send_message(0x70, {"server": "Neo4j/4.2.0"}) + connection = Bolt4x3(address, sockets.client, + PoolConfig.max_connection_lifetime, + routing_context={"foo": "bar"}) + connection.hello() + tag, fields = sockets.server.pop_message() + assert tag == 0x01 + assert len(fields) == 1 + assert fields[0]["routing"] == {"foo": "bar"} diff --git a/tests/unit/io/test_direct.py b/tests/unit/io/test_direct.py index 1664bfa1..6cfcffa4 100644 --- a/tests/unit/io/test_direct.py +++ b/tests/unit/io/test_direct.py @@ -61,6 +61,13 @@ def __init__(self, socket): self.socket = socket self.address = socket.getpeername() + @property + def is_reset(self): + return True + + def stale(self): + return False + def reset(self): pass From ae8007aa18b70a88a20955a4c1d9bb38f6d1960a Mon Sep 17 00:00:00 2001 From: Rouven Bauer Date: Fri, 30 Apr 2021 15:27:16 +0200 Subject: [PATCH 2/2] Adding/updating docstrings in Bolt base class --- neo4j/io/__init__.py | 32 ++++++++++++++++++++++---------- 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/neo4j/io/__init__.py b/neo4j/io/__init__.py index 8404b355..f660214b 100644 --- a/neo4j/io/__init__.py +++ b/neo4j/io/__init__.py @@ -177,11 +177,18 @@ def __init__(self, unresolved_address, sock, max_connection_lifetime, *, auth=No @property @abc.abstractmethod def supports_multiple_results(self): + """ Boolean flag to indicate if the connection version supports multiple + queries to be buffered on the server side (True) or if all results need + to be eagerly pulled before sending the next RUN (False). + """ pass @property @abc.abstractmethod def supports_multiple_databases(self): + """ Boolean flag to indicate if the connection version supports multiple + databases. + """ pass @classmethod @@ -351,6 +358,9 @@ def local_port(self): @abc.abstractmethod def hello(self): + """ Appends a HELLO message to the outgoing queue, sends it and consumes + all remaining messages. + """ pass def __del__(self): @@ -377,7 +387,7 @@ def route(self, database=None, bookmarks=None): @abc.abstractmethod def run(self, query, parameters=None, mode=None, bookmarks=None, metadata=None, timeout=None, db=None, **handlers): - """ Appends a RUN message to the output stream. + """ Appends a RUN message to the output queue. :param query: Cypher query string :param parameters: dictionary of Cypher parameters @@ -393,7 +403,7 @@ def run(self, query, parameters=None, mode=None, bookmarks=None, metadata=None, @abc.abstractmethod def discard(self, n=-1, qid=-1, **handlers): - """ Appends a DISCARD message to the output stream. + """ Appends a DISCARD message to the output queue. :param n: number of records to discard, default = -1 (ALL) :param qid: query ID to discard for, default = -1 (last query) @@ -404,7 +414,7 @@ def discard(self, n=-1, qid=-1, **handlers): @abc.abstractmethod def pull(self, n=-1, qid=-1, **handlers): - """ Appends a PULL message to the output stream. + """ Appends a PULL message to the output queue. :param n: number of records to pull, default = -1 (ALL) :param qid: query ID to pull for, default = -1 (last query) @@ -415,7 +425,7 @@ def pull(self, n=-1, qid=-1, **handlers): @abc.abstractmethod def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, db=None, **handlers): - """ Appends a BEGIN message to the output stream. + """ Appends a BEGIN message to the output queue. :param mode: access mode for routing - "READ" or "WRITE" (default) :param bookmarks: iterable of bookmark values after which this transaction should begin @@ -429,25 +439,27 @@ def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, db=None, @abc.abstractmethod def commit(self, **handlers): + """ Appends a COMMIT message to the output queue.""" pass @abc.abstractmethod def rollback(self, **handlers): + """ Appends a ROLLBACK message to the output queue.""" pass @abc.abstractmethod def reset(self): - """ Add a RESET message to the outgoing queue, send - it and consume all remaining messages. + """ Appends a RESET message to the outgoing queue, sends it and consumes + all remaining messages. """ pass def _append(self, signature, fields=(), response=None): - """ Add a message to the outgoing queue. + """ Appends a message to the outgoing queue. - :arg signature: the signature of the message - :arg fields: the fields of the message as a tuple - :arg response: a response object to handle callbacks + :param signature: the signature of the message + :param fields: the fields of the message as a tuple + :param response: a response object to handle callbacks """ self.packer.pack_struct(signature, fields) self.outbox.chunk()