diff --git a/neo4j/addressing.py b/neo4j/addressing.py index e294291b..908b84cb 100644 --- a/neo4j/addressing.py +++ b/neo4j/addressing.py @@ -26,9 +26,46 @@ AF_INET, AF_INET6, ) +import logging -class Address(tuple): +log = logging.getLogger("neo4j") + + +class _AddressMeta(type(tuple)): + + def __init__(self, *args, **kwargs): + self._ipv4_cls = None + self._ipv6_cls = None + + def _subclass_by_family(self, family): + subclasses = [ + sc for sc in self.__subclasses__() + if (sc.__module__ == self.__module__ + and getattr(sc, "family", None) == family) + ] + if len(subclasses) != 1: + raise ValueError( + "Class {} needs exactly one direct subclass with attribute " + "`family == {}` within this module. " + "Found: {}".format(self, family, subclasses) + ) + return subclasses[0] + + @property + def ipv4_cls(self): + if self._ipv4_cls is None: + self._ipv4_cls = self._subclass_by_family(AF_INET) + return self._ipv4_cls + + @property + def ipv6_cls(self): + if self._ipv6_cls is None: + self._ipv6_cls = self._subclass_by_family(AF_INET6) + return self._ipv6_cls + + +class Address(tuple, metaclass=_AddressMeta): @classmethod def from_socket(cls, socket): @@ -75,9 +112,9 @@ def __new__(cls, iterable): n_parts = len(iterable) inst = tuple.__new__(cls, iterable) if n_parts == 2: - inst.__class__ = IPv4Address + inst.__class__ = cls.ipv4_cls elif n_parts == 4: - inst.__class__ = IPv6Address + inst.__class__ = cls.ipv6_cls else: raise ValueError("Addresses must consist of either " "two parts (IPv4) or four parts (IPv6)") @@ -89,6 +126,10 @@ def __new__(cls, iterable): def __repr__(self): return "{}({!r})".format(self.__class__.__name__, tuple(self)) + @property + def host_name(self): + return self[0] + @property def host(self): return self[0] @@ -118,7 +159,9 @@ def _dns_resolve(cls, address, family=0): # as these appear to cause problems on some platforms continue if addr not in resolved: - resolved.append(Address(addr)) + resolved.append(ResolvedAddress( + addr, host_name=address.host_name) + ) return resolved def resolve(self, family=0, resolver=None): @@ -138,6 +181,8 @@ def resolve(self, family=0, resolver=None): :param resolver: optional customer resolver function to be called before regular DNS resolution """ + + log.debug("[#0000] C: %s", self) resolved = [] if resolver: for address in map(Address, resolver(self)): @@ -173,3 +218,26 @@ class IPv6Address(Address): def __str__(self): return "[{}]:{}".format(*self) + + +class ResolvedAddress(Address): + + @property + def host_name(self): + return self._host_name + + def resolve(self, family=0, resolver=None): + return [self] + + def __new__(cls, iterable, host_name=None): + new = super().__new__(cls, iterable) + new._host_name = host_name + return new + + +class ResolvedIPv4Address(IPv4Address, ResolvedAddress): + pass + + +class ResolvedIPv6Address(IPv6Address, ResolvedAddress): + pass diff --git a/neo4j/io/__init__.py b/neo4j/io/__init__.py index 41e5a0c2..23a9fc8b 100644 --- a/neo4j/io/__init__.py +++ b/neo4j/io/__init__.py @@ -48,6 +48,7 @@ timeout as SocketTimeout, ) from ssl import ( + CertificateError, HAS_SNI, SSLError, ) @@ -59,6 +60,7 @@ from time import perf_counter from neo4j._exceptions import ( + BoltError, BoltHandshakeError, BoltProtocolError, BoltRoutingError, @@ -77,6 +79,7 @@ from neo4j.exceptions import ( ClientError, ConfigurationError, + DriverError, ReadServiceUnavailable, ServiceUnavailable, SessionExpired, @@ -708,16 +711,15 @@ def fetch_routing_table(self, *, address, timeout, database, bookmarks): :return: a new RoutingTable instance or None if the given router is currently unable to provide routing information - - :raise neo4j.exceptions.ServiceUnavailable: if no writers are available - :raise neo4j._exceptions.BoltProtocolError: if the routing information received is unusable """ - new_routing_info = self.fetch_routing_info(address, database, bookmarks, - timeout) - if new_routing_info is None: + try: + new_routing_info = self.fetch_routing_info(address, database, + bookmarks, timeout) + except ServiceUnavailable: + new_routing_info = None + if not new_routing_info: + log.debug("Failed to fetch routing info %s", address) return None - elif not new_routing_info: - raise BoltRoutingError("Invalid routing table", address) else: servers = new_routing_info[0]["servers"] ttl = new_routing_info[0]["ttl"] @@ -733,11 +735,13 @@ def fetch_routing_table(self, *, address, timeout, database, bookmarks): # No routers if num_routers == 0: - raise BoltRoutingError("No routing servers returned from server", address) + log.debug("No routing servers returned from server %s", address) + return None # No readers if num_readers == 0: - raise BoltRoutingError("No read servers returned from server", address) + log.debug("No read servers returned from server %s", address) + return None # At least one of each is fine, so return this table return new_routing_table @@ -751,14 +755,20 @@ def update_routing_table_from(self, *routers, database=None, """ log.debug("Attempting to update routing table from {}".format(", ".join(map(repr, routers)))) for router in routers: - new_routing_table = self.fetch_routing_table( - address=router, timeout=self.pool_config.connection_timeout, - database=database, bookmarks=bookmarks - ) - if new_routing_table is not None: - self.routing_tables[database].update(new_routing_table) - log.debug("[#0000] C: address={!r} ({!r})".format(router, self.routing_tables[database])) - return True + for address in router.resolve(resolver=self.pool_config.resolver): + new_routing_table = self.fetch_routing_table( + address=address, + timeout=self.pool_config.connection_timeout, + database=database, bookmarks=bookmarks + ) + if new_routing_table is not None: + self.routing_tables[database].update(new_routing_table) + log.debug( + "[#0000] C: address=%r (%r)", + address, self.routing_tables[database] + ) + return True + self.deactivate(router) return False def update_routing_table(self, *, database, bookmarks): @@ -771,24 +781,26 @@ def update_routing_table(self, *, database, bookmarks): :raise neo4j.exceptions.ServiceUnavailable: """ # copied because it can be modified - existing_routers = list(self.routing_tables[database].routers) + existing_routers = set(self.routing_tables[database].routers) + + prefer_initial_routing_address = \ + self.routing_tables[database].missing_fresh_writer() - has_tried_initial_routers = False - if self.routing_tables[database].missing_fresh_writer(): + if prefer_initial_routing_address: # TODO: Test this state - has_tried_initial_routers = True if self.update_routing_table_from( self.first_initial_routing_address, database=database, bookmarks=bookmarks ): # Why is only the first initial routing address used? return - if self.update_routing_table_from(*existing_routers, database=database, - bookmarks=bookmarks): + if self.update_routing_table_from( + *(existing_routers - {self.first_initial_routing_address}), + database=database, bookmarks=bookmarks + ): return - if (not has_tried_initial_routers - and self.first_initial_routing_address not in existing_routers): + if not prefer_initial_routing_address: if self.update_routing_table_from( self.first_initial_routing_address, database=database, bookmarks=bookmarks @@ -956,21 +968,24 @@ def _secure(s, host, ssl_context): local_port = s.getsockname()[1] # Secure the connection if an SSL context has been provided if ssl_context: + last_error = None log.debug("[#%04X] C: %s", local_port, host) try: sni_host = host if HAS_SNI and host else None s = ssl_context.wrap_socket(s, server_hostname=sni_host) - except (SSLError, OSError) as cause: - _close_socket(s) - error = BoltSecurityError(message="Failed to establish encrypted connection.", address=(host, local_port)) - error.__cause__ = cause - raise error - else: - # Check that the server provides a certificate - der_encoded_server_certificate = s.getpeercert(binary_form=True) - if der_encoded_server_certificate is None: - s.close() - raise BoltProtocolError("When using an encrypted socket, the server should always provide a certificate", address=(host, local_port)) + except (OSError, SSLError, CertificateError) as cause: + raise BoltSecurityError( + message="Failed to establish encrypted connection.", + address=(host, local_port) + ) from cause + # Check that the server provides a certificate + der_encoded_server_certificate = s.getpeercert(binary_form=True) + if der_encoded_server_certificate is None: + raise BoltProtocolError( + "When using an encrypted socket, the server should always " + "provide a certificate", address=(host, local_port) + ) + return s return s @@ -1041,27 +1056,38 @@ def connect(address, *, timeout, custom_resolver, ssl_context, keep_alive): """ Connect and perform a handshake and return a valid Connection object, assuming a protocol version can be agreed. """ - last_error = None + errors = [] # Establish a connection to the host and port specified # Catches refused connections see: # https://docs.python.org/2/library/errno.html - log.debug("[#0000] C: %s", address) - for resolved_address in Address(address).resolve(resolver=custom_resolver): + resolved_addresses = Address(address).resolve(resolver=custom_resolver) + for resolved_address in resolved_addresses: s = None try: - host = address[0] s = _connect(resolved_address, timeout, keep_alive) - s = _secure(s, host, ssl_context) - return _handshake(s, address) - except Exception as error: + s = _secure(s, resolved_address.host_name, ssl_context) + return _handshake(s, resolved_address) + except (BoltError, DriverError, OSError) as error: if s: _close_socket(s) - last_error = error - if last_error is None: - raise ServiceUnavailable("Failed to resolve addresses for %s" % address) + errors.append(error) + except Exception: + if s: + _close_socket(s) + raise + if not errors: + raise ServiceUnavailable( + "Couldn't connect to %s (resolved to %s)" % ( + str(address), tuple(map(str, resolved_addresses))) + ) else: - raise last_error + raise ServiceUnavailable( + "Couldn't connect to %s (resolved to %s):\n%s" % ( + str(address), tuple(map(str, resolved_addresses)), + "\n".join(map(str, errors)) + ) + ) from errors[0] def check_supported_server_product(agent): diff --git a/testkitbackend/backend.py b/testkitbackend/backend.py index fb73af1a..731c776e 100644 --- a/testkitbackend/backend.py +++ b/testkitbackend/backend.py @@ -85,7 +85,8 @@ def __init__(self, rd, wr): self._rd = rd self._wr = wr self.drivers = {} - self.address_resolutions = {} + self.custom_resolutions = {} + self.dns_resolutions = {} self.sessions = {} self.results = {} self.errors = {} diff --git a/testkitbackend/requests.py b/testkitbackend/requests.py index bb1257a8..eb3fe73c 100644 --- a/testkitbackend/requests.py +++ b/testkitbackend/requests.py @@ -44,11 +44,14 @@ def NewDriver(backend, data): auth_token["scheme"], auth_token["principal"], auth_token["credentials"], realm=auth_token["realm"]) auth_token.mark_item_as_read_if_equals("ticket", "") - resolver = resolution_func(backend) if data["resolverRegistered"] else None + resolver = None + if data["resolverRegistered"] or data["domainNameResolverRegistered"]: + resolver = resolution_func(backend, data["resolverRegistered"], + data["domainNameResolverRegistered"]) connection_timeout = data.get("connectionTimeoutMs", None) if connection_timeout is not None: connection_timeout /= 1000 - data.mark_item_as_read_if_equals("domainNameResolverRegistered", False) + data.mark_item_as_read("domainNameResolverRegistered") driver = neo4j.GraphDatabase.driver( data["uri"], auth=auth, user_agent=data["userAgent"], resolver=resolver, connection_timeout=connection_timeout @@ -74,33 +77,66 @@ def CheckMultiDBSupport(backend, data): ) -def resolution_func(backend): +def resolution_func(backend, custom_resolver=False, custom_dns_resolver=False): + # This solution (putting custom resolution together with DNS resolution into + # one function only works because the Python driver calls the custom + # resolver function for every connection, which is not true for all drivers. + # Properly exposing a way to change the DNS lookup behavior is not possible + # without changing the driver's code. + assert custom_resolver or custom_dns_resolver + def resolve(address): - key = backend.next_key() - address = ":".join(map(str, address)) - backend.send_response("ResolverResolutionRequired", { - "id": key, - "address": address - }) - if not backend.process_request(): - # connection was closed before end of next message - return [] - if key not in backend.address_resolutions: - raise RuntimeError( - "Backend did not receive expected ResolverResolutionCompleted " - "message for id %s" % key - ) - resolution = backend.address_resolutions[key] - resolution = list(map(neo4j.Address.parse, resolution)) - # won't be needed anymore -> conserve memory - del backend.address_resolutions[key] - return resolution + addresses = [":".join(map(str, address))] + if custom_resolver: + key = backend.next_key() + backend.send_response("ResolverResolutionRequired", { + "id": key, + "address": addresses[0] + }) + if not backend.process_request(): + # connection was closed before end of next message + return [] + if key not in backend.custom_resolutions: + raise RuntimeError( + "Backend did not receive expected " + "ResolverResolutionCompleted message for id %s" % key + ) + addresses = backend.custom_resolutions.pop(key) + if custom_dns_resolver: + dns_resolved_addresses = [] + for address in addresses: + key = backend.next_key() + address = address.rsplit(":", 1) + backend.send_response("DomainNameResolutionRequired", { + "id": key, + "name": address[0] + }) + if not backend.process_request(): + # connection was closed before end of next message + return [] + if key not in backend.dns_resolutions: + raise RuntimeError( + "Backend did not receive expected " + "DomainNameResolutionCompleted message for id %s" % key + ) + dns_resolved_addresses += list(map( + lambda a: ":".join((a, *address[1:])), + backend.dns_resolutions.pop(key) + )) + + addresses = dns_resolved_addresses + + return list(map(neo4j.Address.parse, addresses)) return resolve def ResolverResolutionCompleted(backend, data): - backend.address_resolutions[data["requestId"]] = data["addresses"] + backend.custom_resolutions[data["requestId"]] = data["addresses"] + + +def DomainNameResolutionCompleted(backend, data): + backend.dns_resolutions[data["requestId"]] = data["addresses"] def DriverClose(backend, data): diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 671fe507..64aa1f3d 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -303,7 +303,7 @@ def neo4j_driver(target, auth): except ServiceUnavailable as error: if isinstance(error.__cause__, BoltHandshakeError): pytest.skip(error.args[0]) - elif error.args[0] == "Server does not support routing": + elif error.args[0] == "Unable to retrieve routing information": pytest.skip(error.args[0]) else: raise diff --git a/tests/performance/__init__.py b/tests/performance/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/stub/test_routingdriver.py b/tests/stub/test_routingdriver.py index acfec4ea..d1319d8c 100644 --- a/tests/stub/test_routingdriver.py +++ b/tests/stub/test_routingdriver.py @@ -34,13 +34,11 @@ ) from neo4j.exceptions import ( ServiceUnavailable, - ClientError, TransientError, SessionExpired, ConfigurationError, ) from neo4j._exceptions import ( - BoltRoutingError, BoltSecurityError, ) from tests.stub.conftest import StubCluster @@ -214,7 +212,7 @@ def test_cannot_discover_servers_on_non_router(driver_info, test_script): def test_cannot_discover_servers_on_silent_router(driver_info, test_script): # python -m pytest tests/stub/test_routingdriver.py -s -v -k test_cannot_discover_servers_on_silent_router with StubCluster(test_script): - with pytest.raises(BoltRoutingError): + with pytest.raises(ServiceUnavailable, match="routing"): with GraphDatabase.driver(driver_info["uri_neo4j"], auth=driver_info["auth_token"]) as driver: assert isinstance(driver, Neo4jDriver) driver._pool.update_routing_table(database=None, bookmarks=None) @@ -532,7 +530,7 @@ def test_should_serve_read_when_missing_writer(driver_info, test_scripts, test_r def test_should_error_when_missing_reader(driver_info, test_script): # python -m pytest tests/stub/test_routingdriver.py -s -v -k test_should_error_when_missing_reader with StubCluster(test_script): - with pytest.raises(BoltRoutingError): + with pytest.raises(ServiceUnavailable, match="routing"): with GraphDatabase.driver(driver_info["uri_neo4j"], auth=driver_info["auth_token"]) as driver: assert isinstance(driver, Neo4jDriver) driver._pool.update_routing_table(database=None, bookmarks=None)