diff --git a/neo4j/io/__init__.py b/neo4j/io/__init__.py index 3bb2167e..5d323d45 100644 --- a/neo4j/io/__init__.py +++ b/neo4j/io/__init__.py @@ -38,60 +38,52 @@ from logging import getLogger from random import choice from select import select -from time import perf_counter - from socket import ( + AF_INET, + AF_INET6, + SHUT_RDWR, + SO_KEEPALIVE, socket, SOL_SOCKET, - SO_KEEPALIVE, - SHUT_RDWR, timeout as SocketTimeout, - AF_INET, - AF_INET6, ) - from ssl import ( HAS_SNI, SSLError, ) - -from struct import ( - pack as struct_pack, -) - from threading import ( + Condition, Lock, RLock, - Condition, ) +from time import perf_counter -from neo4j.addressing import Address -from neo4j.conf import PoolConfig from neo4j._exceptions import ( + BoltHandshakeError, + BoltProtocolError, BoltRoutingError, BoltSecurityError, - BoltProtocolError, - BoltHandshakeError, ) -from neo4j.exceptions import ( - ServiceUnavailable, - ClientError, - SessionExpired, - ReadServiceUnavailable, - WriteServiceUnavailable, - ConfigurationError, - UnsupportedServerProduct, +from neo4j.addressing import Address +from neo4j.api import ( + READ_ACCESS, + Version, + WRITE_ACCESS, ) -from neo4j.routing import RoutingTable from neo4j.conf import ( PoolConfig, WorkspaceConfig, ) -from neo4j.api import ( - READ_ACCESS, - WRITE_ACCESS, - Version, +from neo4j.exceptions import ( + ClientError, + ConfigurationError, + ReadServiceUnavailable, + ServiceUnavailable, + SessionExpired, + UnsupportedServerProduct, + WriteServiceUnavailable, ) +from neo4j.routing import RoutingTable # Set up logger log = getLogger("neo4j") @@ -258,7 +250,7 @@ def open(cls, address, *, auth=None, timeout=None, routing_context=None, **pool_ except Exception as error: log.debug("[#%04X] C: %s", s.getsockname()[1], str(error)) _close_socket(s) - raise error + raise return connection @@ -522,7 +514,7 @@ def deactivate(self, address): connections.remove(conn) try: conn.close() - except IOError: + except OSError: pass if not connections: self.remove(address) @@ -538,7 +530,7 @@ def remove(self, address): for connection in self.connections.pop(address, ()): try: connection.close() - except IOError: + except OSError: pass def close(self): diff --git a/neo4j/io/_bolt3.py b/neo4j/io/_bolt3.py index 764cd699..ad90a652 100644 --- a/neo4j/io/_bolt3.py +++ b/neo4j/io/_bolt3.py @@ -19,44 +19,49 @@ # limitations under the License. from collections import deque +from logging import getLogger from ssl import SSLSocket from time import perf_counter + +from neo4j._exceptions import ( + BoltError, + BoltProtocolError, +) +from neo4j.addressing import Address from neo4j.api import ( - Version, READ_ACCESS, + ServerInfo, + Version, ) -from neo4j.io._common import ( - Inbox, - Outbox, - Response, - InitResponse, - CommitResponse, -) -from neo4j.meta import get_user_agent from neo4j.exceptions import ( AuthError, - DatabaseUnavailable, ConfigurationError, + DatabaseUnavailable, + DriverError, ForbiddenOnReadOnlyDatabase, IncompleteCommit, NotALeader, ServiceUnavailable, SessionExpired, ) -from neo4j._exceptions import BoltProtocolError -from neo4j.packstream import ( - Unpacker, - Packer, -) from neo4j.io import ( + check_supported_server_product, Bolt, BoltPool, - check_supported_server_product, ) -from neo4j.api import ServerInfo -from neo4j.addressing import Address +from neo4j.io._common import ( + CommitResponse, + Inbox, + InitResponse, + Outbox, + Response, +) +from neo4j.meta import get_user_agent +from neo4j.packstream import ( + Packer, + Unpacker, +) -from logging import getLogger log = getLogger("neo4j") @@ -85,7 +90,7 @@ def __init__(self, unresolved_address, sock, max_connection_lifetime, *, auth=No self.socket = sock self.server_info = ServerInfo(Address(sock.getpeername()), self.PROTOCOL_VERSION) self.outbox = Outbox() - self.inbox = Inbox(self.socket, on_error=self._set_defunct) + self.inbox = Inbox(self.socket, on_error=self._set_defunct_read) self.packer = Packer(self.outbox) self.unpacker = Unpacker(self.inbox) self.responses = deque() @@ -135,7 +140,7 @@ def der_encoded_server_certificate(self): def local_port(self): try: return self.socket.getsockname()[1] - except IOError: + except OSError: return 0 def get_base_headers(self): @@ -292,7 +297,10 @@ def fail(metadata): def _send_all(self): data = self.outbox.view() if data: - self.socket.sendall(data) + try: + self.socket.sendall(data) + except OSError as error: + self._set_defunct_write(error) self.outbox.clear() def send_all(self): @@ -306,17 +314,7 @@ def send_all(self): raise ServiceUnavailable("Failed to write to defunct connection {!r} ({!r})".format( self.unresolved_address, self.server_info.address)) - try: - self._send_all() - except (IOError, OSError) as error: - log.error("Failed to write data to connection " - "{!r} ({!r}); ({!r})". - format(self.unresolved_address, - self.server_info.address, - "; ".join(map(repr, error.args)))) - if self.pool: - self.pool.deactivate(address=self.unresolved_address) - raise + self._send_all() def fetch_message(self): """ Receive at least one message from the server, if available. @@ -336,17 +334,7 @@ def fetch_message(self): return 0, 0 # Receive exactly one message - try: - details, summary_signature, summary_metadata = next(self.inbox) - except (IOError, OSError) as error: - log.error("Failed to read data from connection " - "{!r} ({!r}); ({!r})". - format(self.unresolved_address, - self.server_info.address, - "; ".join(map(repr, error.args)))) - if self.pool: - self.pool.deactivate(address=self.unresolved_address) - raise + details, summary_signature, summary_metadata = next(self.inbox) if details: log.debug("[#%04X] S: RECORD * %d", self.local_port, len(details)) # Do not log any data @@ -380,11 +368,20 @@ def fetch_message(self): return len(details), 1 - def _set_defunct(self, error=None): - direct_driver = isinstance(self.pool, BoltPool) + def _set_defunct_read(self, error=None): + message = "Failed to read from defunct connection {!r} ({!r})".format( + self.unresolved_address, self.server_info.address + ) + self._set_defunct(message, error=error) - message = ("Failed to read from defunct connection {!r} ({!r})".format( - self.unresolved_address, self.server_info.address)) + def _set_defunct_write(self, error=None): + message = "Failed to write data to connection {!r} ({!r})".format( + self.unresolved_address, self.server_info.address + ) + self._set_defunct(message, error=error) + + def _set_defunct(self, message, error=None): + direct_driver = isinstance(self.pool, BoltPool) if error: log.error(str(error)) @@ -445,12 +442,12 @@ def close(self): self._append(b"\x02", ()) try: self._send_all() - except: + except (OSError, BoltError, DriverError): pass log.debug("[#%04X] C: ", self.local_port) try: self.socket.close() - except IOError: + except OSError: pass finally: self._closed = True diff --git a/neo4j/io/_bolt4.py b/neo4j/io/_bolt4.py index 9b6ef6e7..ee4c9137 100644 --- a/neo4j/io/_bolt4.py +++ b/neo4j/io/_bolt4.py @@ -19,44 +19,49 @@ # limitations under the License. from collections import deque +from logging import getLogger from ssl import SSLSocket from time import perf_counter + +from neo4j._exceptions import ( + BoltError, + BoltProtocolError, +) +from neo4j.addressing import Address from neo4j.api import ( - Version, READ_ACCESS, SYSTEM_DATABASE, + Version, ) -from neo4j.io._common import ( - Inbox, - Outbox, - Response, - InitResponse, - CommitResponse, -) -from neo4j.meta import get_user_agent +from neo4j.api import ServerInfo from neo4j.exceptions import ( AuthError, DatabaseUnavailable, + DriverError, ForbiddenOnReadOnlyDatabase, IncompleteCommit, NotALeader, ServiceUnavailable, SessionExpired, ) -from neo4j._exceptions import BoltProtocolError -from neo4j.packstream import ( - Unpacker, - Packer, -) from neo4j.io import ( Bolt, BoltPool, check_supported_server_product, ) -from neo4j.api import ServerInfo -from neo4j.addressing import Address +from neo4j.io._common import ( + CommitResponse, + Inbox, + InitResponse, + Outbox, + Response, +) +from neo4j.meta import get_user_agent +from neo4j.packstream import ( + Unpacker, + Packer, +) -from logging import getLogger log = getLogger("neo4j") @@ -85,7 +90,7 @@ def __init__(self, unresolved_address, sock, max_connection_lifetime, *, auth=No self.socket = sock self.server_info = ServerInfo(Address(sock.getpeername()), self.PROTOCOL_VERSION) self.outbox = Outbox() - self.inbox = Inbox(self.socket, on_error=self._set_defunct) + self.inbox = Inbox(self.socket, on_error=self._set_defunct_read) self.packer = Packer(self.outbox) self.unpacker = Unpacker(self.inbox) self.responses = deque() @@ -135,7 +140,7 @@ def der_encoded_server_certificate(self): def local_port(self): try: return self.socket.getsockname()[1] - except IOError: + except OSError: return 0 def get_base_headers(self): @@ -303,7 +308,10 @@ def fail(metadata): def _send_all(self): data = self.outbox.view() if data: - self.socket.sendall(data) + try: + self.socket.sendall(data) + except OSError as error: + self._set_defunct_write(error) self.outbox.clear() def send_all(self): @@ -317,17 +325,7 @@ def send_all(self): raise ServiceUnavailable("Failed to write to defunct connection {!r} ({!r})".format( self.unresolved_address, self.server_info.address)) - try: - self._send_all() - except (IOError, OSError) as error: - log.error("Failed to write data to connection " - "{!r} ({!r}); ({!r})". - format(self.unresolved_address, - self.server_info.address, - "; ".join(map(repr, error.args)))) - if self.pool: - self.pool.deactivate(address=self.unresolved_address) - raise + self._send_all() def fetch_message(self): """ Receive at least one message from the server, if available. @@ -347,17 +345,7 @@ def fetch_message(self): return 0, 0 # Receive exactly one message - try: - details, summary_signature, summary_metadata = next(self.inbox) - except (IOError, OSError) as error: - log.error("Failed to read data from connection " - "{!r} ({!r}); ({!r})". - format(self.unresolved_address, - self.server_info.address, - "; ".join(map(repr, error.args)))) - if self.pool: - self.pool.deactivate(address=self.unresolved_address) - raise + details, summary_signature, summary_metadata = next(self.inbox) if details: log.debug("[#%04X] S: RECORD * %d", self.local_port, len(details)) # Do not log any data @@ -392,11 +380,20 @@ def fetch_message(self): return len(details), 1 - def _set_defunct(self, error=None): - direct_driver = isinstance(self.pool, BoltPool) + def _set_defunct_read(self, error=None): + message = "Failed to read from defunct connection {!r} ({!r})".format( + self.unresolved_address, self.server_info.address + ) + self._set_defunct(message, error=error) - message = ("Failed to read from defunct connection {!r} ({!r})".format( - self.unresolved_address, self.server_info.address)) + def _set_defunct_write(self, error=None): + message = "Failed to write data to connection {!r} ({!r})".format( + self.unresolved_address, self.server_info.address + ) + self._set_defunct(message, error=error) + + def _set_defunct(self, message, error=None): + direct_driver = isinstance(self.pool, BoltPool) if error: log.error(str(error)) @@ -457,12 +454,12 @@ def close(self): self._append(b"\x02", ()) try: self._send_all() - except: + except (OSError, BoltError, DriverError): pass log.debug("[#%04X] C: ", self.local_port) try: self.socket.close() - except IOError: + except OSError: pass finally: self._closed = True diff --git a/neo4j/io/_common.py b/neo4j/io/_common.py index e7dd3653..213b900f 100644 --- a/neo4j/io/_common.py +++ b/neo4j/io/_common.py @@ -22,8 +22,8 @@ from struct import pack as struct_pack from neo4j.exceptions import ( - Neo4jError, AuthError, + Neo4jError, ServiceUnavailable, ) from neo4j.packstream import ( diff --git a/neo4j/work/__init__.py b/neo4j/work/__init__.py index a516004c..e5fd1659 100644 --- a/neo4j/work/__init__.py +++ b/neo4j/work/__init__.py @@ -61,7 +61,7 @@ def _disconnect(self, sync): except (WorkspaceError, ServiceUnavailable): pass if self._connection: - self._connection.in_use = False + self._pool.release(self._connection) self._connection = None self._connection_access_mode = None diff --git a/neo4j/work/result.py b/neo4j/work/result.py index cf68741e..19d768b4 100644 --- a/neo4j/work/result.py +++ b/neo4j/work/result.py @@ -23,17 +23,61 @@ from warnings import warn from neo4j.data import DataDehydrator +from neo4j.exceptions import ( + ServiceUnavailable, + SessionExpired, +) from neo4j.work.summary import ResultSummary +class _ConnectionErrorHandler: + """ + Wrapper class for handling connection errors. + + The class will wrap each method to invoke a callback if the method raises + SessionExpired or ServiceUnavailable. + The error will be re-raised after the callback. + """ + + def __init__(self, connection, on_network_error): + """ + :param connection the connection object to warp + :type connection Bolt + :param on_network_error the function to be called when a method of + connection raises of of the caught errors. The callback takes the + error as argument. + :type on_network_error callable + + """ + self._connection = connection + self._on_network_error = on_network_error + + def __getattr__(self, item): + connection_attr = getattr(self._connection, item) + if not callable(connection_attr): + return connection_attr + + def outer(func): + def inner(*args, **kwargs): + try: + func(*args, **kwargs) + except (SessionExpired, ServiceUnavailable) as error: + self._on_network_error(error) + raise + return inner + + return outer(connection_attr) + + class Result: """A handler for the result of Cypher query execution. Instances of this class are typically constructed and returned by :meth:`.Session.run` and :meth:`.Transaction.run`. """ - def __init__(self, connection, hydrant, fetch_size, on_closed): - self._connection = connection + def __init__(self, connection, hydrant, fetch_size, on_closed, + on_network_error): + self._connection = _ConnectionErrorHandler(connection, on_network_error) self._hydrant = hydrant self._on_closed = on_closed self._metadata = None diff --git a/neo4j/work/simple.py b/neo4j/work/simple.py index ec512e6a..d2f17c0e 100644 --- a/neo4j/work/simple.py +++ b/neo4j/work/simple.py @@ -124,7 +124,7 @@ def _connect(self, access_mode, database): def _disconnect(self): if self._connection: - self._connection.in_use = False + self._pool.release(self._connection) self._connection = None def _collect_bookmark(self, bookmark): @@ -137,6 +137,11 @@ def _result_closed(self): self._autoResult = None self._disconnect() + def _result_network_error(self, error): + if self._autoResult: + self._autoResult = None + self._disconnect() + def close(self): """Close the session. This will release any borrowed resources, such as connections, and will roll back any outstanding transactions. """ @@ -220,8 +225,14 @@ def run(self, query, parameters=None, **kwparameters): hydrant = DataHydrator() - self._autoResult = Result(cx, hydrant, self._config.fetch_size, self._result_closed) - self._autoResult._run(query, parameters, self._config.database, self._config.default_access_mode, self._bookmarks, **kwparameters) + self._autoResult = Result( + cx, hydrant, self._config.fetch_size, self._result_closed, + self._result_network_error + ) + self._autoResult._run( + query, parameters, self._config.database, + self._config.default_access_mode, self._bookmarks, **kwparameters + ) return self._autoResult @@ -250,10 +261,21 @@ def _transaction_closed_handler(self): self._transaction = None self._disconnect() - def _open_transaction(self, *, access_mode, database, metadata=None, timeout=None): + def _transaction_network_error_handler(self, error): + if self._transaction: + self._transaction = None + self._disconnect() + + def _open_transaction(self, *, access_mode, database, metadata=None, + timeout=None): self._connect(access_mode=access_mode, database=database) - self._transaction = Transaction(self._connection, self._config.fetch_size, self._transaction_closed_handler) - self._transaction._begin(database, self._bookmarks, access_mode, metadata, timeout) + self._transaction = Transaction( + self._connection, self._config.fetch_size, + self._transaction_closed_handler, + self._transaction_network_error_handler + ) + self._transaction._begin(database, self._bookmarks, access_mode, + metadata, timeout) def begin_transaction(self, metadata=None, timeout=None): """ Begin a new unmanaged transaction. Creates a new :class:`.Transaction` within this session. diff --git a/neo4j/work/transaction.py b/neo4j/work/transaction.py index 7c11139c..3ff5739f 100644 --- a/neo4j/work/transaction.py +++ b/neo4j/work/transaction.py @@ -38,13 +38,14 @@ class Transaction: """ - def __init__(self, connection, fetch_size, on_closed): + def __init__(self, connection, fetch_size, on_closed, on_network_error): self._connection = connection self._bookmark = None self._results = [] self._closed = False self._fetch_size = fetch_size self._on_closed = on_closed + self._on_network_error = on_network_error def __enter__(self): return self @@ -58,11 +59,16 @@ def __exit__(self, exception_type, exception_value, traceback): self.close() def _begin(self, database, bookmarks, access_mode, metadata, timeout): - self._connection.begin(bookmarks=bookmarks, metadata=metadata, timeout=timeout, mode=access_mode, db=database) + self._connection.begin(bookmarks=bookmarks, metadata=metadata, + timeout=timeout, mode=access_mode, db=database) def _result_on_closed_handler(self): pass + def _result_on_network_error_handler(self, error): + self._closed = True + self._on_network_error(error) + def _consume_results(self): for result in self._results: result.consume() @@ -107,11 +113,18 @@ def run(self, query, parameters=None, **kwparameters): if self._closed: raise TransactionError("Transaction closed") - if self._results and self._connection.supports_multiple_results is False: + if (self._results + and self._connection.supports_multiple_results is False): # Bolt 3 Support - self._results[-1]._buffer_all() # Buffer upp all records for the previous Result because it does not have any qid to fetch in batches. - - result = Result(self._connection, DataHydrator(), self._fetch_size, self._result_on_closed_handler) + # Buffer up all records for the previous Result because it does not + # have any qid to fetch in batches. + self._results[-1]._buffer_all() + + result = Result( + self._connection, DataHydrator(), self._fetch_size, + self._result_on_closed_handler, + self._result_on_network_error_handler + ) self._results.append(result) result._tx_ready_run(query, parameters, **kwparameters) diff --git a/tests/stub/test_routingdriver.py b/tests/stub/test_routingdriver.py index da048347..acfec4ea 100644 --- a/tests/stub/test_routingdriver.py +++ b/tests/stub/test_routingdriver.py @@ -617,7 +617,8 @@ def test_forgets_address_on_service_unavailable_error(driver_info, test_scripts, conns = pool.connections[('127.0.0.1', 9004)] conn = conns[0] assert conn._closed is True - assert conn.in_use is True + assert conn.in_use is False + assert session._connection is None assert table.routers == {('127.0.0.1', 9001), ('127.0.0.1', 9002), ('127.0.0.1', 9003)} # reader 127.0.0.1:9004 should've been forgotten because of an error assert not table.readers