From b6668840f00de5933fe281016326d020ee97dc25 Mon Sep 17 00:00:00 2001 From: lutovich Date: Fri, 21 Jul 2017 00:32:23 +0200 Subject: [PATCH 1/3] Routing driver forgets addresses on some errors This commit makes routing driver forget addresses and purge their connections on network and database errors. It also makes driver forget write server addresses when they fail as writers, connections remain in the pool because such machines can remain readers/routers. --- neo4j/bolt/connection.py | 19 +++++ neo4j/v1/api.py | 2 +- neo4j/v1/routing.py | 17 ++++ test/stub/scripts/database_unavailable.script | 12 +++ .../forbidden_on_read_only_database.script | 12 +++ test/stub/scripts/not_a_leader.script | 12 +++ test/stub/scripts/rude_reader.script | 7 ++ test/stub/test_routingdriver.py | 81 +++++++++++++++++-- 8 files changed, 156 insertions(+), 6 deletions(-) create mode 100644 test/stub/scripts/database_unavailable.script create mode 100644 test/stub/scripts/forbidden_on_read_only_database.script create mode 100644 test/stub/scripts/not_a_leader.script create mode 100644 test/stub/scripts/rude_reader.script diff --git a/neo4j/bolt/connection.py b/neo4j/bolt/connection.py index d35a603e..e60655fc 100644 --- a/neo4j/bolt/connection.py +++ b/neo4j/bolt/connection.py @@ -144,6 +144,9 @@ class Connection(object): #: Error class used for raising connection errors Error = ServiceUnavailable + #: The function to handle send and receive errors + error_handler = None + _supports_statement_reuse = False _last_run_statement = None @@ -237,6 +240,14 @@ def reset(self): self.sync() def send(self): + try: + self._send() + except Exception as error: + if self.error_handler is not None: + self.error_handler(error) + raise error + + def _send(self): """ Send all queued messages to the server. """ data = self.output_buffer.view() @@ -250,6 +261,14 @@ def send(self): self.output_buffer.clear() def fetch(self): + try: + return self._fetch() + except Exception as error: + if self.error_handler is not None: + self.error_handler(error) + raise error + + def _fetch(self): """ Receive at least one message from the server, if available. :return: 2-tuple of number of detail messages and number of summary messages fetched diff --git a/neo4j/v1/api.py b/neo4j/v1/api.py index 65b28a80..4e063da3 100644 --- a/neo4j/v1/api.py +++ b/neo4j/v1/api.py @@ -277,7 +277,7 @@ def _disconnect(self, sync): if sync: try: self._connection.sync() - except ServiceUnavailable: + except (SessionError, ServiceUnavailable): pass if self._connection: self._connection.in_use = False diff --git a/neo4j/v1/routing.py b/neo4j/v1/routing.py index 7d3ffbda..4d7b35d9 100644 --- a/neo4j/v1/routing.py +++ b/neo4j/v1/routing.py @@ -251,6 +251,9 @@ class RoutingConnectionPool(ConnectionPool): """ Connection pool with routing table. """ + FAILURE_CODES = ("Neo.TransientError.General.DatabaseUnavailable") + WRITE_FAILURE_CODES = ("Neo.ClientError.Cluster.NotALeader", "Neo.ClientError.General.ForbiddenOnReadOnlyDatabase") + def __init__(self, connector, initial_address, routing_context, *routers, **config): super(RoutingConnectionPool, self).__init__(connector) self.initial_address = initial_address @@ -399,12 +402,21 @@ def acquire(self, access_mode=None): try: connection = self.acquire_direct(address) # should always be a resolved address connection.Error = SessionExpired + connection.error_handler = lambda error: self._handle_connection_error(address, error) except ServiceUnavailable: self.remove(address) else: return connection raise SessionExpired("Failed to obtain connection towards '%s' server." % access_mode) + def _handle_connection_error(self, address, error): + """ Handle routing connection send or receive error. + """ + if isinstance(error, (SessionExpired, ServiceUnavailable)) or error.code in self.FAILURE_CODES: + self.remove(address) + elif error.code in self.WRITE_FAILURE_CODES: + self._remove_writer(address) + def remove(self, address): """ Remove an address from the connection pool, if present, closing all connections to that address. Also remove from the routing table. @@ -416,6 +428,11 @@ def remove(self, address): self.routing_table.writers.discard(address) super(RoutingConnectionPool, self).remove(address) + def _remove_writer(self, address): + """ Remove a writer address from the routing table, if present. + """ + self.routing_table.writers.discard(address) + class RoutingDriver(Driver): """ A :class:`.RoutingDriver` is created from a ``bolt+routing`` URI. The diff --git a/test/stub/scripts/database_unavailable.script b/test/stub/scripts/database_unavailable.script new file mode 100644 index 00000000..c482f17f --- /dev/null +++ b/test/stub/scripts/database_unavailable.script @@ -0,0 +1,12 @@ +!: AUTO INIT +!: AUTO RESET +!: AUTO PULL_ALL +!: AUTO ACK_FAILURE +!: AUTO RUN "ROLLBACK" {} +!: AUTO RUN "BEGIN" {} +!: AUTO RUN "COMMIT" {} + +C: RUN "RETURN 1" {} +C: PULL_ALL +S: FAILURE {"code": "Neo.TransientError.General.DatabaseUnavailable", "message": "Database is busy doing store copy"} +S: IGNORED diff --git a/test/stub/scripts/forbidden_on_read_only_database.script b/test/stub/scripts/forbidden_on_read_only_database.script new file mode 100644 index 00000000..3385b097 --- /dev/null +++ b/test/stub/scripts/forbidden_on_read_only_database.script @@ -0,0 +1,12 @@ +!: AUTO INIT +!: AUTO RESET +!: AUTO PULL_ALL +!: AUTO ACK_FAILURE +!: AUTO RUN "ROLLBACK" {} +!: AUTO RUN "BEGIN" {} +!: AUTO RUN "COMMIT" {} + +C: RUN "CREATE (n {name:'Bob'})" {} +C: PULL_ALL +S: FAILURE {"code": "Neo.ClientError.General.ForbiddenOnReadOnlyDatabase", "message": "Unable to write"} +S: IGNORED diff --git a/test/stub/scripts/not_a_leader.script b/test/stub/scripts/not_a_leader.script new file mode 100644 index 00000000..8716466c --- /dev/null +++ b/test/stub/scripts/not_a_leader.script @@ -0,0 +1,12 @@ +!: AUTO INIT +!: AUTO RESET +!: AUTO PULL_ALL +!: AUTO ACK_FAILURE +!: AUTO RUN "ROLLBACK" {} +!: AUTO RUN "BEGIN" {} +!: AUTO RUN "COMMIT" {} + +C: RUN "CREATE (n {name:'Bob'})" {} +C: PULL_ALL +S: FAILURE {"code": "Neo.ClientError.Cluster.NotALeader", "message": "Leader switched has happened"} +S: IGNORED diff --git a/test/stub/scripts/rude_reader.script b/test/stub/scripts/rude_reader.script new file mode 100644 index 00000000..1b1f7d48 --- /dev/null +++ b/test/stub/scripts/rude_reader.script @@ -0,0 +1,7 @@ +!: AUTO INIT +!: AUTO RESET + +C: RUN "RETURN 1" {} + PULL_ALL +S: + diff --git a/test/stub/test_routingdriver.py b/test/stub/test_routingdriver.py index 0edefe66..ae12fa8a 100644 --- a/test/stub/test_routingdriver.py +++ b/test/stub/test_routingdriver.py @@ -21,7 +21,7 @@ from neo4j.v1 import GraphDatabase, READ_ACCESS, WRITE_ACCESS, SessionExpired, \ RoutingDriver, RoutingConnectionPool, LeastConnectedLoadBalancingStrategy, LOAD_BALANCING_STRATEGY_ROUND_ROBIN, \ - RoundRobinLoadBalancingStrategy + RoundRobinLoadBalancingStrategy, TransientError, ClientError from neo4j.bolt import ProtocolError, ServiceUnavailable from test.stub.tools import StubTestCase, StubCluster @@ -236,8 +236,79 @@ def test_can_select_round_robin_load_balancing_strategy(self): self.assertIsInstance(driver._pool.load_balancing_strategy, RoundRobinLoadBalancingStrategy) def test_no_other_load_balancing_strategies_are_available(self): - with StubCluster({9001: "router.script"}): + uri = "bolt+routing://127.0.0.1:9001" + with self.assertRaises(ValueError): + with GraphDatabase.driver(uri, auth=self.auth_token, encrypted=False, load_balancing_strategy=-1): + pass + + def test_forgets_address_on_not_a_leader_error(self): + with StubCluster({9001: "router.script", 9006: "not_a_leader.script"}): uri = "bolt+routing://127.0.0.1:9001" - with self.assertRaises(ValueError): - with GraphDatabase.driver(uri, auth=self.auth_token, encrypted=False, load_balancing_strategy=-1): - pass + with GraphDatabase.driver(uri, auth=self.auth_token, encrypted=False) as driver: + with driver.session(WRITE_ACCESS) as session: + with self.assertRaises(ClientError): + _ = session.run("CREATE (n {name:'Bob'})") + + pool = driver._pool + table = pool.routing_table + + # address might still have connections in the pool, failed instance just can't serve writes + assert ('127.0.0.1', 9006) in pool.connections + assert table.routers == {('127.0.0.1', 9001), ('127.0.0.1', 9002), ('127.0.0.1', 9003)} + assert table.readers == {('127.0.0.1', 9004), ('127.0.0.1', 9005)} + # writer 127.0.0.1:9006 should've been forgotten because of an error + assert len(table.writers) == 0 + + def test_forgets_address_on_forbidden_on_read_only_database_error(self): + with StubCluster({9001: "router.script", 9006: "forbidden_on_read_only_database.script"}): + uri = "bolt+routing://127.0.0.1:9001" + with GraphDatabase.driver(uri, auth=self.auth_token, encrypted=False) as driver: + with driver.session(WRITE_ACCESS) as session: + with self.assertRaises(ClientError): + _ = session.run("CREATE (n {name:'Bob'})") + + pool = driver._pool + table = pool.routing_table + + # address might still have connections in the pool, failed instance just can't serve writes + assert ('127.0.0.1', 9006) in pool.connections + assert table.routers == {('127.0.0.1', 9001), ('127.0.0.1', 9002), ('127.0.0.1', 9003)} + assert table.readers == {('127.0.0.1', 9004), ('127.0.0.1', 9005)} + # writer 127.0.0.1:9006 should've been forgotten because of an error + assert len(table.writers) == 0 + + def test_forgets_address_on_service_unavailable_error(self): + with StubCluster({9001: "router.script", 9004: "rude_reader.script"}): + uri = "bolt+routing://127.0.0.1:9001" + with GraphDatabase.driver(uri, auth=self.auth_token, encrypted=False) as driver: + with driver.session(READ_ACCESS) as session: + with self.assertRaises(SessionExpired): + _ = session.run("RETURN 1") + + pool = driver._pool + table = pool.routing_table + + # address should not have connections in the pool, it has failed + assert ('127.0.0.1', 9004) not in pool.connections + assert table.routers == {('127.0.0.1', 9001), ('127.0.0.1', 9002), ('127.0.0.1', 9003)} + # reader 127.0.0.1:9004 should've been forgotten because of an error + assert table.readers == {('127.0.0.1', 9005)} + assert table.writers == {('127.0.0.1', 9006)} + + def test_forgets_address_on_database_unavailable_error(self): + with StubCluster({9001: "router.script", 9004: "database_unavailable.script"}): + uri = "bolt+routing://127.0.0.1:9001" + with GraphDatabase.driver(uri, auth=self.auth_token, encrypted=False) as driver: + with driver.session(READ_ACCESS) as session: + with self.assertRaises(TransientError): + _ = session.run("RETURN 1") + + pool = driver._pool + table = pool.routing_table + + # address should not have connections in the pool, it has failed + assert ('127.0.0.1', 9004) not in pool.connections + assert table.routers == {('127.0.0.1', 9001), ('127.0.0.1', 9002), ('127.0.0.1', 9003)} + # reader 127.0.0.1:9004 should've been forgotten because of an error + assert table.readers == {('127.0.0.1', 9005)} + assert table.writers == {('127.0.0.1', 9006)} From e41b65b262233684ad848284bd0293da5dd9daf8 Mon Sep 17 00:00:00 2001 From: lutovich Date: Fri, 4 Aug 2017 12:36:44 +0200 Subject: [PATCH 2/3] Add new more specific error types To avoid comparing strings when handling routing errors. --- neo4j/exceptions.py | 59 ++++++++++++++++++++++++++++++++++++--------- neo4j/v1/routing.py | 10 ++++---- 2 files changed, 52 insertions(+), 17 deletions(-) diff --git a/neo4j/exceptions.py b/neo4j/exceptions.py index 1afae696..069b0176 100644 --- a/neo4j/exceptions.py +++ b/neo4j/exceptions.py @@ -65,17 +65,9 @@ def hydrate(cls, message=None, code=None, **metadata): classification = "DatabaseError" category = "General" title = "UnknownError" - if classification == "ClientError": - try: - error_class = client_errors[code] - except KeyError: - error_class = ClientError - elif classification == "DatabaseError": - error_class = DatabaseError - elif classification == "TransientError": - error_class = TransientError - else: - error_class = cls + + error_class = cls._extract_error_class(classification, code) + inst = error_class(message) inst.message = message inst.code = code @@ -85,6 +77,26 @@ def hydrate(cls, message=None, code=None, **metadata): inst.metadata = metadata return inst + @classmethod + def _extract_error_class(cls, classification, code): + if classification == "ClientError": + try: + return client_errors[code] + except KeyError: + return ClientError + + elif classification == "TransientError": + try: + return transient_errors[code] + except KeyError: + return TransientError + + elif classification == "DatabaseError": + return DatabaseError + + else: + return cls + class ClientError(CypherError): """ The Client sent a bad request - changing the request might yield a successful outcome. @@ -101,6 +113,11 @@ class TransientError(CypherError): """ +class DatabaseUnavailableError(TransientError): + """ + """ + + class ConstraintError(ClientError): """ """ @@ -116,11 +133,21 @@ class CypherTypeError(ClientError): """ +class NotALeaderError(ClientError): + """ + """ + + class Forbidden(ClientError, SecurityError): """ """ +class ForbiddenOnReadOnlyDatabaseError(Forbidden): + """ + """ + + class AuthError(ClientError, SecurityError): """ Raised when authentication failure occurs. """ @@ -144,7 +171,7 @@ class AuthError(ClientError, SecurityError): "Neo.ClientError.Statement.TypeError": CypherTypeError, # Forbidden - "Neo.ClientError.General.ForbiddenOnReadOnlyDatabase": Forbidden, + "Neo.ClientError.General.ForbiddenOnReadOnlyDatabase": ForbiddenOnReadOnlyDatabaseError, "Neo.ClientError.General.ReadOnly": Forbidden, "Neo.ClientError.Schema.ForbiddenOnConstraintIndex": Forbidden, "Neo.ClientError.Schema.IndexBelongsToConstraint": Forbidden, @@ -155,4 +182,12 @@ class AuthError(ClientError, SecurityError): "Neo.ClientError.Security.AuthorizationFailed": AuthError, "Neo.ClientError.Security.Unauthorized": AuthError, + # NotALeaderError + "Neo.ClientError.Cluster.NotALeader": NotALeaderError +} + +transient_errors = { + + # DatabaseUnavailableError + "Neo.TransientError.General.DatabaseUnavailable": DatabaseUnavailableError } diff --git a/neo4j/v1/routing.py b/neo4j/v1/routing.py index 4d7b35d9..48b548b1 100644 --- a/neo4j/v1/routing.py +++ b/neo4j/v1/routing.py @@ -25,7 +25,7 @@ from neo4j.addressing import SocketAddress, resolve from neo4j.bolt import ConnectionPool, ServiceUnavailable, ProtocolError, DEFAULT_PORT, connect from neo4j.compat.collections import MutableSet, OrderedDict -from neo4j.exceptions import CypherError +from neo4j.exceptions import CypherError, DatabaseUnavailableError, NotALeaderError, ForbiddenOnReadOnlyDatabaseError from neo4j.util import ServerVersion from neo4j.v1.api import Driver, READ_ACCESS, WRITE_ACCESS, fix_statement, fix_parameters from neo4j.v1.exceptions import SessionExpired @@ -251,8 +251,8 @@ class RoutingConnectionPool(ConnectionPool): """ Connection pool with routing table. """ - FAILURE_CODES = ("Neo.TransientError.General.DatabaseUnavailable") - WRITE_FAILURE_CODES = ("Neo.ClientError.Cluster.NotALeader", "Neo.ClientError.General.ForbiddenOnReadOnlyDatabase") + CLUSTER_MEMBER_FAILURE_ERRORS = (ServiceUnavailable, SessionExpired, DatabaseUnavailableError) + WRITE_FAILURE_ERRORS = (NotALeaderError, ForbiddenOnReadOnlyDatabaseError) def __init__(self, connector, initial_address, routing_context, *routers, **config): super(RoutingConnectionPool, self).__init__(connector) @@ -412,9 +412,9 @@ def acquire(self, access_mode=None): def _handle_connection_error(self, address, error): """ Handle routing connection send or receive error. """ - if isinstance(error, (SessionExpired, ServiceUnavailable)) or error.code in self.FAILURE_CODES: + if isinstance(error, self.CLUSTER_MEMBER_FAILURE_ERRORS): self.remove(address) - elif error.code in self.WRITE_FAILURE_CODES: + elif isinstance(error, self.WRITE_FAILURE_ERRORS): self._remove_writer(address) def remove(self, address): From 9edcf42750cdf7a58d87b735056d3840a533ed60 Mon Sep 17 00:00:00 2001 From: lutovich Date: Fri, 4 Aug 2017 18:57:40 +0200 Subject: [PATCH 3/3] Improve connection error handling infra Connection now takes a dedicated error handler object in constructor instead of having one dynamically attached after creation. --- neo4j/bolt/connection.py | 47 ++++++++++++++++++++--------- neo4j/v1/direct.py | 18 +++++++++-- neo4j/v1/routing.py | 37 ++++++++++++----------- test/integration/test_connection.py | 10 ++++-- test/stub/test_routing.py | 4 +-- test/unit/test_routing.py | 4 +-- 6 files changed, 78 insertions(+), 42 deletions(-) diff --git a/neo4j/bolt/connection.py b/neo4j/bolt/connection.py index e60655fc..28e4adca 100644 --- a/neo4j/bolt/connection.py +++ b/neo4j/bolt/connection.py @@ -119,6 +119,26 @@ def supports_bytes(self): return self.version_info() >= (3, 2) +class ConnectionErrorHandler(object): + """ A handler for send and receive errors. + """ + + def __init__(self, handlers_by_error_class=None): + if handlers_by_error_class is None: + handlers_by_error_class = {} + + self.handlers_by_error_class = handlers_by_error_class + self.known_errors = tuple(handlers_by_error_class.keys()) + + def handle(self, error, address): + try: + error_class = error.__class__ + handler = self.handlers_by_error_class[error_class] + handler(address) + except KeyError: + pass + + class Connection(object): """ Server connection for Bolt protocol v1. @@ -144,15 +164,14 @@ class Connection(object): #: Error class used for raising connection errors Error = ServiceUnavailable - #: The function to handle send and receive errors - error_handler = None - _supports_statement_reuse = False _last_run_statement = None - def __init__(self, sock, **config): + def __init__(self, address, sock, error_handler, **config): + self.address = address self.socket = sock + self.error_handler = error_handler self.server = ServerInfo(SocketAddress.from_socket(sock)) self.input_buffer = ChunkedInputBuffer() self.output_buffer = ChunkedOutputBuffer() @@ -242,9 +261,8 @@ def reset(self): def send(self): try: self._send() - except Exception as error: - if self.error_handler is not None: - self.error_handler(error) + except self.error_handler.known_errors as error: + self.error_handler.handle(error, self.address) raise error def _send(self): @@ -263,9 +281,8 @@ def _send(self): def fetch(self): try: return self._fetch() - except Exception as error: - if self.error_handler is not None: - self.error_handler(error) + except self.error_handler.known_errors as error: + self.error_handler.handle(error, self.address) raise error def _fetch(self): @@ -379,8 +396,9 @@ class ConnectionPool(object): _closed = False - def __init__(self, connector): + def __init__(self, connector, connection_error_handler): self.connector = connector + self.connection_error_handler = connection_error_handler self.connections = {} self.lock = RLock() @@ -414,7 +432,7 @@ def acquire_direct(self, address): connection.in_use = True return connection try: - connection = self.connector(address) + connection = self.connector(address, self.connection_error_handler) except ServiceUnavailable: self.remove(address) raise @@ -476,7 +494,7 @@ def closed(self): return self._closed -def connect(address, ssl_context=None, **config): +def connect(address, ssl_context=None, error_handler=None, **config): """ Connect and perform a handshake and return a valid Connection object, assuming a protocol version can be agreed. """ @@ -582,7 +600,8 @@ def connect(address, ssl_context=None, **config): s.shutdown(SHUT_RDWR) s.close() elif agreed_version == 1: - return Connection(s, der_encoded_server_certificate=der_encoded_server_certificate, **config) + return Connection(address, s, der_encoded_server_certificate=der_encoded_server_certificate, + error_handler=error_handler, **config) elif agreed_version == 0x48545450: log_error("S: [CLOSE]") s.close() diff --git a/neo4j/v1/direct.py b/neo4j/v1/direct.py index a5910db7..c63419ee 100644 --- a/neo4j/v1/direct.py +++ b/neo4j/v1/direct.py @@ -20,17 +20,25 @@ from neo4j.addressing import SocketAddress, resolve -from neo4j.bolt import DEFAULT_PORT, ConnectionPool, connect +from neo4j.bolt import DEFAULT_PORT, ConnectionPool, connect, ConnectionErrorHandler from neo4j.exceptions import ServiceUnavailable from neo4j.v1.api import Driver from neo4j.v1.security import SecurityPlan from neo4j.v1.session import BoltSession +class DirectConnectionErrorHandler(ConnectionErrorHandler): + """ Handler for errors in direct driver connections. + """ + + def __init__(self): + super(DirectConnectionErrorHandler, self).__init__({}) # does not need to handle errors + + class DirectConnectionPool(ConnectionPool): def __init__(self, connector, address): - super(DirectConnectionPool, self).__init__(connector) + super(DirectConnectionPool, self).__init__(connector, DirectConnectionErrorHandler()) self.address = address def acquire(self, access_mode=None): @@ -61,7 +69,11 @@ def __init__(self, uri, **config): self.address = SocketAddress.from_uri(uri, DEFAULT_PORT) self.security_plan = security_plan = SecurityPlan.build(**config) self.encrypted = security_plan.encrypted - pool = DirectConnectionPool(lambda a: connect(a, security_plan.ssl_context, **config), self.address) + + def connector(address, error_handler): + return connect(address, security_plan.ssl_context, error_handler, **config) + + pool = DirectConnectionPool(connector, self.address) pool.release(pool.acquire()) Driver.__init__(self, pool, **config) diff --git a/neo4j/v1/routing.py b/neo4j/v1/routing.py index 48b548b1..4bfe688c 100644 --- a/neo4j/v1/routing.py +++ b/neo4j/v1/routing.py @@ -23,7 +23,7 @@ from time import clock from neo4j.addressing import SocketAddress, resolve -from neo4j.bolt import ConnectionPool, ServiceUnavailable, ProtocolError, DEFAULT_PORT, connect +from neo4j.bolt import ConnectionPool, ServiceUnavailable, ProtocolError, DEFAULT_PORT, connect, ConnectionErrorHandler from neo4j.compat.collections import MutableSet, OrderedDict from neo4j.exceptions import CypherError, DatabaseUnavailableError, NotALeaderError, ForbiddenOnReadOnlyDatabaseError from neo4j.util import ServerVersion @@ -32,7 +32,6 @@ from neo4j.v1.security import SecurityPlan from neo4j.v1.session import BoltSession - LOAD_BALANCING_STRATEGY_LEAST_CONNECTED = 0 LOAD_BALANCING_STRATEGY_ROUND_ROBIN = 1 LOAD_BALANCING_STRATEGY_DEFAULT = LOAD_BALANCING_STRATEGY_LEAST_CONNECTED @@ -247,15 +246,26 @@ def _select(self, offset, addresses): return least_connected_address +class RoutingConnectionErrorHandler(ConnectionErrorHandler): + """ Handler for errors in routing driver connections. + """ + + def __init__(self, pool): + super(RoutingConnectionErrorHandler, self).__init__({ + SessionExpired: lambda address: pool.remove(address), + ServiceUnavailable: lambda address: pool.remove(address), + DatabaseUnavailableError: lambda address: pool.remove(address), + NotALeaderError: lambda address: pool.remove_writer(address), + ForbiddenOnReadOnlyDatabaseError: lambda address: pool.remove_writer(address) + }) + + class RoutingConnectionPool(ConnectionPool): """ Connection pool with routing table. """ - CLUSTER_MEMBER_FAILURE_ERRORS = (ServiceUnavailable, SessionExpired, DatabaseUnavailableError) - WRITE_FAILURE_ERRORS = (NotALeaderError, ForbiddenOnReadOnlyDatabaseError) - def __init__(self, connector, initial_address, routing_context, *routers, **config): - super(RoutingConnectionPool, self).__init__(connector) + super(RoutingConnectionPool, self).__init__(connector, RoutingConnectionErrorHandler(self)) self.initial_address = initial_address self.routing_context = routing_context self.routing_table = RoutingTable(routers) @@ -402,21 +412,12 @@ def acquire(self, access_mode=None): try: connection = self.acquire_direct(address) # should always be a resolved address connection.Error = SessionExpired - connection.error_handler = lambda error: self._handle_connection_error(address, error) except ServiceUnavailable: self.remove(address) else: return connection raise SessionExpired("Failed to obtain connection towards '%s' server." % access_mode) - def _handle_connection_error(self, address, error): - """ Handle routing connection send or receive error. - """ - if isinstance(error, self.CLUSTER_MEMBER_FAILURE_ERRORS): - self.remove(address) - elif isinstance(error, self.WRITE_FAILURE_ERRORS): - self._remove_writer(address) - def remove(self, address): """ Remove an address from the connection pool, if present, closing all connections to that address. Also remove from the routing table. @@ -428,7 +429,7 @@ def remove(self, address): self.routing_table.writers.discard(address) super(RoutingConnectionPool, self).remove(address) - def _remove_writer(self, address): + def remove_writer(self, address): """ Remove a writer address from the routing table, if present. """ self.routing_table.writers.discard(address) @@ -450,8 +451,8 @@ def __init__(self, uri, **config): # scenario right now raise ValueError("TRUST_ON_FIRST_USE is not compatible with routing") - def connector(a): - return connect(a, security_plan.ssl_context, **config) + def connector(address, error_handler): + return connect(address, security_plan.ssl_context, error_handler, **config) pool = RoutingConnectionPool(connector, initial_address, routing_context, *resolve(initial_address), **config) try: diff --git a/test/integration/test_connection.py b/test/integration/test_connection.py index 1a28c2ff..703f97df 100644 --- a/test/integration/test_connection.py +++ b/test/integration/test_connection.py @@ -21,7 +21,7 @@ from socket import create_connection -from neo4j.v1 import ConnectionPool, ServiceUnavailable +from neo4j.v1 import ConnectionPool, ServiceUnavailable, DirectConnectionErrorHandler from test.integration.tools import IntegrationTestCase @@ -45,10 +45,14 @@ def defunct(self): return False +def connector(address, _): + return QuickConnection(create_connection(address)) + + class ConnectionPoolTestCase(IntegrationTestCase): def setUp(self): - self.pool = ConnectionPool(lambda a: QuickConnection(create_connection(a))) + self.pool = ConnectionPool(connector, DirectConnectionErrorHandler()) def tearDown(self): self.pool.close() @@ -104,7 +108,7 @@ def test_releasing_twice(self): self.assert_pool_size(address, 0, 1) def test_cannot_acquire_after_close(self): - with ConnectionPool(lambda a: QuickConnection(create_connection(a))) as pool: + with ConnectionPool(lambda a: QuickConnection(create_connection(a)), DirectConnectionErrorHandler()) as pool: pool.close() with self.assertRaises(ServiceUnavailable): _ = pool.acquire_direct("X") diff --git a/test/stub/test_routing.py b/test/stub/test_routing.py index 2c9a23ce..05f2bfca 100644 --- a/test/stub/test_routing.py +++ b/test/stub/test_routing.py @@ -50,8 +50,8 @@ UNREACHABLE_ADDRESS = ("127.0.0.1", 8080) -def connector(address): - return connect(address, auth=basic_auth("neotest", "neotest")) +def connector(address, error_handler): + return connect(address, error_handler=error_handler, auth=basic_auth("neotest", "neotest")) def RoutingPool(*routers): diff --git a/test/unit/test_routing.py b/test/unit/test_routing.py index 92ca5507..a7d12c4d 100644 --- a/test/unit/test_routing.py +++ b/test/unit/test_routing.py @@ -52,8 +52,8 @@ } -def connector(address): - return connect(address, auth=basic_auth("neotest", "neotest")) +def connector(address, error_handler): + return connect(address, error_handler=error_handler, auth=basic_auth("neotest", "neotest")) class RoundRobinSetTestCase(TestCase):