From e489ba3d64638c85ca87d4df6a016a722fadd9df Mon Sep 17 00:00:00 2001 From: Rouven Bauer Date: Wed, 7 Apr 2021 17:23:24 +0200 Subject: [PATCH 01/10] Try all routers before giving up --- neo4j/io/__init__.py | 19 +++++++++++++------ tests/stub/test_routingdriver.py | 6 ++---- 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/neo4j/io/__init__.py b/neo4j/io/__init__.py index 5d323d45..9b5ed842 100644 --- a/neo4j/io/__init__.py +++ b/neo4j/io/__init__.py @@ -754,10 +754,13 @@ def update_routing_table_from(self, *routers, database=None, """ log.debug("Attempting to update routing table from {}".format(", ".join(map(repr, routers)))) for router in routers: - new_routing_table = self.fetch_routing_table( - address=router, timeout=self.pool_config.connection_timeout, - database=database, bookmarks=bookmarks - ) + try: + new_routing_table = self.fetch_routing_table( + address=router, timeout=self.pool_config.connection_timeout, + database=database, bookmarks=bookmarks + ) + except BoltRoutingError: + continue if new_routing_table is not None: self.routing_tables[database].update(new_routing_table) log.debug("[#0000] C: address={!r} ({!r})".format(router, self.routing_tables[database])) @@ -786,8 +789,12 @@ def update_routing_table(self, *, database, bookmarks): ): # Why is only the first initial routing address used? return - if self.update_routing_table_from(*existing_routers, database=database, - bookmarks=bookmarks): + if self.update_routing_table_from( + *[r for r in existing_routers + if (not has_tried_initial_routers + or r != self.first_initial_routing_address)], + database=database, bookmarks=bookmarks + ): return if (not has_tried_initial_routers diff --git a/tests/stub/test_routingdriver.py b/tests/stub/test_routingdriver.py index acfec4ea..d1319d8c 100644 --- a/tests/stub/test_routingdriver.py +++ b/tests/stub/test_routingdriver.py @@ -34,13 +34,11 @@ ) from neo4j.exceptions import ( ServiceUnavailable, - ClientError, TransientError, SessionExpired, ConfigurationError, ) from neo4j._exceptions import ( - BoltRoutingError, BoltSecurityError, ) from tests.stub.conftest import StubCluster @@ -214,7 +212,7 @@ def test_cannot_discover_servers_on_non_router(driver_info, test_script): def test_cannot_discover_servers_on_silent_router(driver_info, test_script): # python -m pytest tests/stub/test_routingdriver.py -s -v -k test_cannot_discover_servers_on_silent_router with StubCluster(test_script): - with pytest.raises(BoltRoutingError): + with pytest.raises(ServiceUnavailable, match="routing"): with GraphDatabase.driver(driver_info["uri_neo4j"], auth=driver_info["auth_token"]) as driver: assert isinstance(driver, Neo4jDriver) driver._pool.update_routing_table(database=None, bookmarks=None) @@ -532,7 +530,7 @@ def test_should_serve_read_when_missing_writer(driver_info, test_scripts, test_r def test_should_error_when_missing_reader(driver_info, test_script): # python -m pytest tests/stub/test_routingdriver.py -s -v -k test_should_error_when_missing_reader with StubCluster(test_script): - with pytest.raises(BoltRoutingError): + with pytest.raises(ServiceUnavailable, match="routing"): with GraphDatabase.driver(driver_info["uri_neo4j"], auth=driver_info["auth_token"]) as driver: assert isinstance(driver, Neo4jDriver) driver._pool.update_routing_table(database=None, bookmarks=None) From 3055ec4910e3dedc18d62d71eee9fffdf6ce784a Mon Sep 17 00:00:00 2001 From: Rouven Bauer Date: Fri, 9 Apr 2021 12:38:54 +0200 Subject: [PATCH 02/10] Add DNS resolution request/response to testkit back-end --- testkitbackend/backend.py | 3 +- testkitbackend/requests.py | 82 +++++++++++++++++++++++++++----------- 2 files changed, 61 insertions(+), 24 deletions(-) diff --git a/testkitbackend/backend.py b/testkitbackend/backend.py index 0a406a4a..df810169 100644 --- a/testkitbackend/backend.py +++ b/testkitbackend/backend.py @@ -85,7 +85,8 @@ def __init__(self, rd, wr): self._rd = rd self._wr = wr self.drivers = {} - self.address_resolutions = {} + self.custom_resolutions = {} + self.dns_resolutions = {} self.sessions = {} self.results = {} self.errors = {} diff --git a/testkitbackend/requests.py b/testkitbackend/requests.py index bb1257a8..eb3fe73c 100644 --- a/testkitbackend/requests.py +++ b/testkitbackend/requests.py @@ -44,11 +44,14 @@ def NewDriver(backend, data): auth_token["scheme"], auth_token["principal"], auth_token["credentials"], realm=auth_token["realm"]) auth_token.mark_item_as_read_if_equals("ticket", "") - resolver = resolution_func(backend) if data["resolverRegistered"] else None + resolver = None + if data["resolverRegistered"] or data["domainNameResolverRegistered"]: + resolver = resolution_func(backend, data["resolverRegistered"], + data["domainNameResolverRegistered"]) connection_timeout = data.get("connectionTimeoutMs", None) if connection_timeout is not None: connection_timeout /= 1000 - data.mark_item_as_read_if_equals("domainNameResolverRegistered", False) + data.mark_item_as_read("domainNameResolverRegistered") driver = neo4j.GraphDatabase.driver( data["uri"], auth=auth, user_agent=data["userAgent"], resolver=resolver, connection_timeout=connection_timeout @@ -74,33 +77,66 @@ def CheckMultiDBSupport(backend, data): ) -def resolution_func(backend): +def resolution_func(backend, custom_resolver=False, custom_dns_resolver=False): + # This solution (putting custom resolution together with DNS resolution into + # one function only works because the Python driver calls the custom + # resolver function for every connection, which is not true for all drivers. + # Properly exposing a way to change the DNS lookup behavior is not possible + # without changing the driver's code. + assert custom_resolver or custom_dns_resolver + def resolve(address): - key = backend.next_key() - address = ":".join(map(str, address)) - backend.send_response("ResolverResolutionRequired", { - "id": key, - "address": address - }) - if not backend.process_request(): - # connection was closed before end of next message - return [] - if key not in backend.address_resolutions: - raise RuntimeError( - "Backend did not receive expected ResolverResolutionCompleted " - "message for id %s" % key - ) - resolution = backend.address_resolutions[key] - resolution = list(map(neo4j.Address.parse, resolution)) - # won't be needed anymore -> conserve memory - del backend.address_resolutions[key] - return resolution + addresses = [":".join(map(str, address))] + if custom_resolver: + key = backend.next_key() + backend.send_response("ResolverResolutionRequired", { + "id": key, + "address": addresses[0] + }) + if not backend.process_request(): + # connection was closed before end of next message + return [] + if key not in backend.custom_resolutions: + raise RuntimeError( + "Backend did not receive expected " + "ResolverResolutionCompleted message for id %s" % key + ) + addresses = backend.custom_resolutions.pop(key) + if custom_dns_resolver: + dns_resolved_addresses = [] + for address in addresses: + key = backend.next_key() + address = address.rsplit(":", 1) + backend.send_response("DomainNameResolutionRequired", { + "id": key, + "name": address[0] + }) + if not backend.process_request(): + # connection was closed before end of next message + return [] + if key not in backend.dns_resolutions: + raise RuntimeError( + "Backend did not receive expected " + "DomainNameResolutionCompleted message for id %s" % key + ) + dns_resolved_addresses += list(map( + lambda a: ":".join((a, *address[1:])), + backend.dns_resolutions.pop(key) + )) + + addresses = dns_resolved_addresses + + return list(map(neo4j.Address.parse, addresses)) return resolve def ResolverResolutionCompleted(backend, data): - backend.address_resolutions[data["requestId"]] = data["addresses"] + backend.custom_resolutions[data["requestId"]] = data["addresses"] + + +def DomainNameResolutionCompleted(backend, data): + backend.dns_resolutions[data["requestId"]] = data["addresses"] def DriverClose(backend, data): From bd1cf90265e5de38b97ae3240ee80dc030a38b61 Mon Sep 17 00:00:00 2001 From: Rouven Bauer Date: Fri, 9 Apr 2021 15:01:03 +0200 Subject: [PATCH 03/10] Try to fetch RT from all resolution results --- neo4j/addressing.py | 60 +++++++++++++++++++++++++++++++++++++++++--- neo4j/io/__init__.py | 50 ++++++++++++++++++++---------------- 2 files changed, 84 insertions(+), 26 deletions(-) diff --git a/neo4j/addressing.py b/neo4j/addressing.py index e294291b..09ce0c37 100644 --- a/neo4j/addressing.py +++ b/neo4j/addressing.py @@ -26,9 +26,46 @@ AF_INET, AF_INET6, ) +import logging -class Address(tuple): +log = logging.getLogger("neo4j") + + +class _AddressMeta(type(tuple)): + + def __init__(self, *args, **kwargs): + self._ipv4_cls = None + self._ipv6_cls = None + + def _subclass_by_family(self, family): + subclasses = [ + sc for sc in self.__subclasses__() + if (sc.__module__ == "neo4j.addressing" + and getattr(sc, "family", None) == family) + ] + if len(subclasses) != 1: + raise ValueError( + "Class {} needs exactly one direct subclass with attribute " + "`family == {}` within this module. " + "Found: {}".format(self, family, subclasses) + ) + return subclasses[0] + + @property + def ipv4_cls(self): + if self._ipv4_cls is None: + self._ipv4_cls = self._subclass_by_family(AF_INET) + return self._ipv4_cls + + @property + def ipv6_cls(self): + if self._ipv6_cls is None: + self._ipv6_cls = self._subclass_by_family(AF_INET6) + return self._ipv6_cls + + +class Address(tuple, metaclass=_AddressMeta): @classmethod def from_socket(cls, socket): @@ -75,9 +112,9 @@ def __new__(cls, iterable): n_parts = len(iterable) inst = tuple.__new__(cls, iterable) if n_parts == 2: - inst.__class__ = IPv4Address + inst.__class__ = cls.ipv4_cls elif n_parts == 4: - inst.__class__ = IPv6Address + inst.__class__ = cls.ipv6_cls else: raise ValueError("Addresses must consist of either " "two parts (IPv4) or four parts (IPv6)") @@ -118,7 +155,7 @@ def _dns_resolve(cls, address, family=0): # as these appear to cause problems on some platforms continue if addr not in resolved: - resolved.append(Address(addr)) + resolved.append(ResolvedAddress(addr)) return resolved def resolve(self, family=0, resolver=None): @@ -138,6 +175,8 @@ def resolve(self, family=0, resolver=None): :param resolver: optional customer resolver function to be called before regular DNS resolution """ + + log.debug("[#0000] C: %s", self) resolved = [] if resolver: for address in map(Address, resolver(self)): @@ -173,3 +212,16 @@ class IPv6Address(Address): def __str__(self): return "[{}]:{}".format(*self) + + +class ResolvedAddress(Address): + def resolve(self, family=0, resolver=None): + return [self] + + +class ResolvedIPv4Address(IPv4Address, ResolvedAddress): + pass + + +class ResolvedIPv6Address(IPv6Address, ResolvedAddress): + pass diff --git a/neo4j/io/__init__.py b/neo4j/io/__init__.py index 5d323d45..0cc268b2 100644 --- a/neo4j/io/__init__.py +++ b/neo4j/io/__init__.py @@ -711,16 +711,15 @@ def fetch_routing_table(self, *, address, timeout, database, bookmarks): :return: a new RoutingTable instance or None if the given router is currently unable to provide routing information - - :raise neo4j.exceptions.ServiceUnavailable: if no writers are available - :raise neo4j._exceptions.BoltProtocolError: if the routing information received is unusable """ - new_routing_info = self.fetch_routing_info(address, database, bookmarks, - timeout) - if new_routing_info is None: + try: + new_routing_info = self.fetch_routing_info(address, database, + bookmarks, timeout) + except ServiceUnavailable: + new_routing_info = None + if not new_routing_info: + log.debug("Failed to fetch routing info %s", address) return None - elif not new_routing_info: - raise BoltRoutingError("Invalid routing table", address) else: servers = new_routing_info[0]["servers"] ttl = new_routing_info[0]["ttl"] @@ -736,11 +735,13 @@ def fetch_routing_table(self, *, address, timeout, database, bookmarks): # No routers if num_routers == 0: - raise BoltRoutingError("No routing servers returned from server", address) + log.debug("No routing servers returned from server %s", address) + return None # No readers if num_readers == 0: - raise BoltRoutingError("No read servers returned from server", address) + log.debug("No read servers returned from server %s", address) + return None # At least one of each is fine, so return this table return new_routing_table @@ -754,14 +755,19 @@ def update_routing_table_from(self, *routers, database=None, """ log.debug("Attempting to update routing table from {}".format(", ".join(map(repr, routers)))) for router in routers: - new_routing_table = self.fetch_routing_table( - address=router, timeout=self.pool_config.connection_timeout, - database=database, bookmarks=bookmarks - ) - if new_routing_table is not None: - self.routing_tables[database].update(new_routing_table) - log.debug("[#0000] C: address={!r} ({!r})".format(router, self.routing_tables[database])) - return True + for address in router.resolve(resolver=self.pool_config.resolver): + new_routing_table = self.fetch_routing_table( + address=address, + timeout=self.pool_config.connection_timeout, + database=database, bookmarks=bookmarks + ) + if new_routing_table is not None: + self.routing_tables[database].update(new_routing_table) + log.debug( + "[#0000] C: address=%r (%r)", + address, self.routing_tables[database] + ) + return True return False def update_routing_table(self, *, database, bookmarks): @@ -1048,21 +1054,21 @@ def connect(address, *, timeout, custom_resolver, ssl_context, keep_alive): # Establish a connection to the host and port specified # Catches refused connections see: # https://docs.python.org/2/library/errno.html - log.debug("[#0000] C: %s", address) for resolved_address in Address(address).resolve(resolver=custom_resolver): s = None try: - host = address[0] + host = resolved_address[0] s = _connect(resolved_address, timeout, keep_alive) s = _secure(s, host, ssl_context) - return _handshake(s, address) + return _handshake(s, resolved_address) except Exception as error: if s: _close_socket(s) last_error = error if last_error is None: - raise ServiceUnavailable("Failed to resolve addresses for %s" % address) + raise ServiceUnavailable("Failed to resolve addresses for %s" % + str(address)) else: raise last_error From 9452ebd87defe0c9d4a3a3acb44514d772b3908d Mon Sep 17 00:00:00 2001 From: Rouven Bauer Date: Fri, 9 Apr 2021 16:02:07 +0200 Subject: [PATCH 04/10] Adjust expected errors in tests --- tests/performance/__init__.py | 0 tests/stub/test_routingdriver.py | 6 ++---- 2 files changed, 2 insertions(+), 4 deletions(-) create mode 100644 tests/performance/__init__.py diff --git a/tests/performance/__init__.py b/tests/performance/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/stub/test_routingdriver.py b/tests/stub/test_routingdriver.py index acfec4ea..d1319d8c 100644 --- a/tests/stub/test_routingdriver.py +++ b/tests/stub/test_routingdriver.py @@ -34,13 +34,11 @@ ) from neo4j.exceptions import ( ServiceUnavailable, - ClientError, TransientError, SessionExpired, ConfigurationError, ) from neo4j._exceptions import ( - BoltRoutingError, BoltSecurityError, ) from tests.stub.conftest import StubCluster @@ -214,7 +212,7 @@ def test_cannot_discover_servers_on_non_router(driver_info, test_script): def test_cannot_discover_servers_on_silent_router(driver_info, test_script): # python -m pytest tests/stub/test_routingdriver.py -s -v -k test_cannot_discover_servers_on_silent_router with StubCluster(test_script): - with pytest.raises(BoltRoutingError): + with pytest.raises(ServiceUnavailable, match="routing"): with GraphDatabase.driver(driver_info["uri_neo4j"], auth=driver_info["auth_token"]) as driver: assert isinstance(driver, Neo4jDriver) driver._pool.update_routing_table(database=None, bookmarks=None) @@ -532,7 +530,7 @@ def test_should_serve_read_when_missing_writer(driver_info, test_scripts, test_r def test_should_error_when_missing_reader(driver_info, test_script): # python -m pytest tests/stub/test_routingdriver.py -s -v -k test_should_error_when_missing_reader with StubCluster(test_script): - with pytest.raises(BoltRoutingError): + with pytest.raises(ServiceUnavailable, match="routing"): with GraphDatabase.driver(driver_info["uri_neo4j"], auth=driver_info["auth_token"]) as driver: assert isinstance(driver, Neo4jDriver) driver._pool.update_routing_table(database=None, bookmarks=None) From 8436c71ad9c023a59222f2b406196892b704996b Mon Sep 17 00:00:00 2001 From: Rouven Bauer Date: Fri, 9 Apr 2021 17:06:12 +0200 Subject: [PATCH 05/10] Rewrite errors on connect into ServiceUnavailable --- neo4j/io/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/neo4j/io/__init__.py b/neo4j/io/__init__.py index c442b966..6e290755 100644 --- a/neo4j/io/__init__.py +++ b/neo4j/io/__init__.py @@ -1072,7 +1072,8 @@ def connect(address, *, timeout, custom_resolver, ssl_context, keep_alive): raise ServiceUnavailable("Failed to resolve addresses for %s" % str(address)) else: - raise last_error + raise ServiceUnavailable("Failed to resolve addresses for %s" % + str(address)) from last_error def check_supported_server_product(agent): From f0bda5008b86a15cfee78a69d3b5840499dad011 Mon Sep 17 00:00:00 2001 From: Rouven Bauer Date: Fri, 9 Apr 2021 17:35:59 +0200 Subject: [PATCH 06/10] Adjust integration tests --- tests/integration/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 671fe507..64aa1f3d 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -303,7 +303,7 @@ def neo4j_driver(target, auth): except ServiceUnavailable as error: if isinstance(error.__cause__, BoltHandshakeError): pytest.skip(error.args[0]) - elif error.args[0] == "Server does not support routing": + elif error.args[0] == "Unable to retrieve routing information": pytest.skip(error.args[0]) else: raise From 89ca053d2bcb17d56f2235390230efddc346e71d Mon Sep 17 00:00:00 2001 From: Rouven Bauer Date: Mon, 12 Apr 2021 12:30:35 +0200 Subject: [PATCH 07/10] TLS: verify host name from pre and post custom resolver --- neo4j/addressing.py | 23 +++++++++++++-- neo4j/io/__init__.py | 69 +++++++++++++++++++++++++++++--------------- 2 files changed, 66 insertions(+), 26 deletions(-) diff --git a/neo4j/addressing.py b/neo4j/addressing.py index 09ce0c37..dd0bfbc7 100644 --- a/neo4j/addressing.py +++ b/neo4j/addressing.py @@ -126,6 +126,10 @@ def __new__(cls, iterable): def __repr__(self): return "{}({!r})".format(self.__class__.__name__, tuple(self)) + @property + def host_names(self): + return self[0], + @property def host(self): return self[0] @@ -155,7 +159,9 @@ def _dns_resolve(cls, address, family=0): # as these appear to cause problems on some platforms continue if addr not in resolved: - resolved.append(ResolvedAddress(addr)) + resolved.append(ResolvedAddress( + addr, host_names=address.host_names) + ) return resolved def resolve(self, family=0, resolver=None): @@ -179,7 +185,10 @@ def resolve(self, family=0, resolver=None): log.debug("[#0000] C: %s", self) resolved = [] if resolver: - for address in map(Address, resolver(self)): + for address in resolver(self): + address = ResolvedAddress( + address, host_names={self.host_names} + ) resolved.extend(self._dns_resolve(address, family)) else: resolved.extend(self._dns_resolve(self, family)) @@ -215,9 +224,19 @@ def __str__(self): class ResolvedAddress(Address): + + @property + def host_names(self): + return tuple(self._host_names) + def resolve(self, family=0, resolver=None): return [self] + def __new__(cls, iterable, host_names=None): + new = super().__new__(cls, iterable) + new._host_names = set() if host_names is None else set(host_names) + return new + class ResolvedIPv4Address(IPv4Address, ResolvedAddress): pass diff --git a/neo4j/io/__init__.py b/neo4j/io/__init__.py index 6e290755..45e90766 100644 --- a/neo4j/io/__init__.py +++ b/neo4j/io/__init__.py @@ -48,6 +48,7 @@ timeout as SocketTimeout, ) from ssl import ( + CertificateError, HAS_SNI, SSLError, ) @@ -59,6 +60,7 @@ from time import perf_counter from neo4j._exceptions import ( + BoltError, BoltHandshakeError, BoltProtocolError, BoltRoutingError, @@ -77,6 +79,7 @@ from neo4j.exceptions import ( ClientError, ConfigurationError, + DriverError, ReadServiceUnavailable, ServiceUnavailable, SessionExpired, @@ -768,6 +771,7 @@ def update_routing_table_from(self, *routers, database=None, address, self.routing_tables[database] ) return True + self.deactivate(router) return False def update_routing_table(self, *, database, bookmarks): @@ -963,25 +967,34 @@ def _connect(resolved_address, timeout, keep_alive): return s -def _secure(s, host, ssl_context): +def _secure(s, hosts, ssl_context): local_port = s.getsockname()[1] # Secure the connection if an SSL context has been provided + if not hosts: + hosts = [None] if ssl_context: - log.debug("[#%04X] C: %s", local_port, host) - try: - sni_host = host if HAS_SNI and host else None - s = ssl_context.wrap_socket(s, server_hostname=sni_host) - except (SSLError, OSError) as cause: - _close_socket(s) - error = BoltSecurityError(message="Failed to establish encrypted connection.", address=(host, local_port)) - error.__cause__ = cause - raise error - else: - # Check that the server provides a certificate - der_encoded_server_certificate = s.getpeercert(binary_form=True) - if der_encoded_server_certificate is None: - s.close() - raise BoltProtocolError("When using an encrypted socket, the server should always provide a certificate", address=(host, local_port)) + last_error = None + for host in hosts: + log.debug("[#%04X] C: %s", local_port, host) + try: + sni_host = host if HAS_SNI and host else None + s = ssl_context.wrap_socket(s, server_hostname=sni_host) + except (SSLError, CertificateError) as cause: + last_error = cause + continue + except OSError as cause: + # No sense in trying another host name with a broken socket + last_error = cause + break + else: + # Check that the server provides a certificate + der_encoded_server_certificate = s.getpeercert(binary_form=True) + if der_encoded_server_certificate is None: + raise BoltProtocolError("When using an encrypted socket, the server should always provide a certificate", address=(host, local_port)) + return s + raise BoltSecurityError( + message="Failed to establish encrypted connection.", + address=(hosts[0], local_port)) from last_error return s @@ -1057,23 +1070,31 @@ def connect(address, *, timeout, custom_resolver, ssl_context, keep_alive): # Catches refused connections see: # https://docs.python.org/2/library/errno.html - for resolved_address in Address(address).resolve(resolver=custom_resolver): + resolved_addresses = Address(address).resolve(resolver=custom_resolver) + for resolved_address in resolved_addresses: s = None try: - host = resolved_address[0] s = _connect(resolved_address, timeout, keep_alive) - s = _secure(s, host, ssl_context) + s = _secure(s, resolved_address.host_names, ssl_context) return _handshake(s, resolved_address) - except Exception as error: + except (BoltError, DriverError, OSError) as error: if s: _close_socket(s) last_error = error + except Exception: + if s: + _close_socket(s) + raise if last_error is None: - raise ServiceUnavailable("Failed to resolve addresses for %s" % - str(address)) + raise ServiceUnavailable( + "Couldn't connect to %s (resolved to %s)" % ( + str(address), tuple(map(str, resolved_addresses))) + ) else: - raise ServiceUnavailable("Failed to resolve addresses for %s" % - str(address)) from last_error + raise ServiceUnavailable( + "Couldn't connect to %s (resolved to %s)" % ( + str(address), tuple(map(str, resolved_addresses))) + ) from last_error def check_supported_server_product(agent): From 51d6430401ee4defbc0ed2ebf631a1cdbf9270d6 Mon Sep 17 00:00:00 2001 From: Rouven Bauer Date: Mon, 12 Apr 2021 13:36:32 +0200 Subject: [PATCH 08/10] Improve error message --- neo4j/io/__init__.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/neo4j/io/__init__.py b/neo4j/io/__init__.py index 45e90766..0e78a32b 100644 --- a/neo4j/io/__init__.py +++ b/neo4j/io/__init__.py @@ -1065,7 +1065,7 @@ def connect(address, *, timeout, custom_resolver, ssl_context, keep_alive): """ Connect and perform a handshake and return a valid Connection object, assuming a protocol version can be agreed. """ - last_error = None + errors = [] # Establish a connection to the host and port specified # Catches refused connections see: # https://docs.python.org/2/library/errno.html @@ -1080,21 +1080,23 @@ def connect(address, *, timeout, custom_resolver, ssl_context, keep_alive): except (BoltError, DriverError, OSError) as error: if s: _close_socket(s) - last_error = error + errors.append(error) except Exception: if s: _close_socket(s) raise - if last_error is None: + if not errors: raise ServiceUnavailable( "Couldn't connect to %s (resolved to %s)" % ( str(address), tuple(map(str, resolved_addresses))) ) else: raise ServiceUnavailable( - "Couldn't connect to %s (resolved to %s)" % ( - str(address), tuple(map(str, resolved_addresses))) - ) from last_error + "Couldn't connect to %s (resolved to %s):\n%s" % ( + str(address), tuple(map(str, resolved_addresses)), + "\n".join(map(str, errors)) + ) + ) from errors[0] def check_supported_server_product(agent): From 463679bcd71996475fdb14081796ed96dcbb6476 Mon Sep 17 00:00:00 2001 From: Rouven Bauer Date: Wed, 21 Apr 2021 11:57:05 +0200 Subject: [PATCH 09/10] Don't hard-code module name in _AddressMeta --- neo4j/addressing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/neo4j/addressing.py b/neo4j/addressing.py index dd0bfbc7..ae37d08f 100644 --- a/neo4j/addressing.py +++ b/neo4j/addressing.py @@ -41,7 +41,7 @@ def __init__(self, *args, **kwargs): def _subclass_by_family(self, family): subclasses = [ sc for sc in self.__subclasses__() - if (sc.__module__ == "neo4j.addressing" + if (sc.__module__ == self.__module__ and getattr(sc, "family", None) == family) ] if len(subclasses) != 1: From 199a7ca5ea3c402b9a22cc50aaa64bc723bf9821 Mon Sep 17 00:00:00 2001 From: Rouven Bauer Date: Wed, 21 Apr 2021 12:27:03 +0200 Subject: [PATCH 10/10] Only use host name after custom resolver for SSL --- neo4j/addressing.py | 19 ++++++++----------- neo4j/io/__init__.py | 44 +++++++++++++++++++------------------------- 2 files changed, 27 insertions(+), 36 deletions(-) diff --git a/neo4j/addressing.py b/neo4j/addressing.py index ae37d08f..908b84cb 100644 --- a/neo4j/addressing.py +++ b/neo4j/addressing.py @@ -127,8 +127,8 @@ def __repr__(self): return "{}({!r})".format(self.__class__.__name__, tuple(self)) @property - def host_names(self): - return self[0], + def host_name(self): + return self[0] @property def host(self): @@ -160,7 +160,7 @@ def _dns_resolve(cls, address, family=0): continue if addr not in resolved: resolved.append(ResolvedAddress( - addr, host_names=address.host_names) + addr, host_name=address.host_name) ) return resolved @@ -185,10 +185,7 @@ def resolve(self, family=0, resolver=None): log.debug("[#0000] C: %s", self) resolved = [] if resolver: - for address in resolver(self): - address = ResolvedAddress( - address, host_names={self.host_names} - ) + for address in map(Address, resolver(self)): resolved.extend(self._dns_resolve(address, family)) else: resolved.extend(self._dns_resolve(self, family)) @@ -226,15 +223,15 @@ def __str__(self): class ResolvedAddress(Address): @property - def host_names(self): - return tuple(self._host_names) + def host_name(self): + return self._host_name def resolve(self, family=0, resolver=None): return [self] - def __new__(cls, iterable, host_names=None): + def __new__(cls, iterable, host_name=None): new = super().__new__(cls, iterable) - new._host_names = set() if host_names is None else set(host_names) + new._host_name = host_name return new diff --git a/neo4j/io/__init__.py b/neo4j/io/__init__.py index 0c0842cf..23a9fc8b 100644 --- a/neo4j/io/__init__.py +++ b/neo4j/io/__init__.py @@ -964,34 +964,28 @@ def _connect(resolved_address, timeout, keep_alive): return s -def _secure(s, hosts, ssl_context): +def _secure(s, host, ssl_context): local_port = s.getsockname()[1] # Secure the connection if an SSL context has been provided - if not hosts: - hosts = [None] if ssl_context: last_error = None - for host in hosts: - log.debug("[#%04X] C: %s", local_port, host) - try: - sni_host = host if HAS_SNI and host else None - s = ssl_context.wrap_socket(s, server_hostname=sni_host) - except (SSLError, CertificateError) as cause: - last_error = cause - continue - except OSError as cause: - # No sense in trying another host name with a broken socket - last_error = cause - break - else: - # Check that the server provides a certificate - der_encoded_server_certificate = s.getpeercert(binary_form=True) - if der_encoded_server_certificate is None: - raise BoltProtocolError("When using an encrypted socket, the server should always provide a certificate", address=(host, local_port)) - return s - raise BoltSecurityError( - message="Failed to establish encrypted connection.", - address=(hosts[0], local_port)) from last_error + log.debug("[#%04X] C: %s", local_port, host) + try: + sni_host = host if HAS_SNI and host else None + s = ssl_context.wrap_socket(s, server_hostname=sni_host) + except (OSError, SSLError, CertificateError) as cause: + raise BoltSecurityError( + message="Failed to establish encrypted connection.", + address=(host, local_port) + ) from cause + # Check that the server provides a certificate + der_encoded_server_certificate = s.getpeercert(binary_form=True) + if der_encoded_server_certificate is None: + raise BoltProtocolError( + "When using an encrypted socket, the server should always " + "provide a certificate", address=(host, local_port) + ) + return s return s @@ -1072,7 +1066,7 @@ def connect(address, *, timeout, custom_resolver, ssl_context, keep_alive): s = None try: s = _connect(resolved_address, timeout, keep_alive) - s = _secure(s, resolved_address.host_names, ssl_context) + s = _secure(s, resolved_address.host_name, ssl_context) return _handshake(s, resolved_address) except (BoltError, DriverError, OSError) as error: if s: