Skip to content

Improve address resolution #532

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Apr 21, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 72 additions & 4 deletions neo4j/addressing.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,46 @@
AF_INET,
AF_INET6,
)
import logging


class Address(tuple):
log = logging.getLogger("neo4j")


class _AddressMeta(type(tuple)):

def __init__(self, *args, **kwargs):
self._ipv4_cls = None
self._ipv6_cls = None

def _subclass_by_family(self, family):
subclasses = [
sc for sc in self.__subclasses__()
if (sc.__module__ == self.__module__
and getattr(sc, "family", None) == family)
]
if len(subclasses) != 1:
raise ValueError(
"Class {} needs exactly one direct subclass with attribute "
"`family == {}` within this module. "
"Found: {}".format(self, family, subclasses)
)
return subclasses[0]

@property
def ipv4_cls(self):
if self._ipv4_cls is None:
self._ipv4_cls = self._subclass_by_family(AF_INET)
return self._ipv4_cls

@property
def ipv6_cls(self):
if self._ipv6_cls is None:
self._ipv6_cls = self._subclass_by_family(AF_INET6)
return self._ipv6_cls


class Address(tuple, metaclass=_AddressMeta):

@classmethod
def from_socket(cls, socket):
Expand Down Expand Up @@ -75,9 +112,9 @@ def __new__(cls, iterable):
n_parts = len(iterable)
inst = tuple.__new__(cls, iterable)
if n_parts == 2:
inst.__class__ = IPv4Address
inst.__class__ = cls.ipv4_cls
elif n_parts == 4:
inst.__class__ = IPv6Address
inst.__class__ = cls.ipv6_cls
else:
raise ValueError("Addresses must consist of either "
"two parts (IPv4) or four parts (IPv6)")
Expand All @@ -89,6 +126,10 @@ def __new__(cls, iterable):
def __repr__(self):
return "{}({!r})".format(self.__class__.__name__, tuple(self))

@property
def host_name(self):
return self[0]

@property
def host(self):
return self[0]
Expand Down Expand Up @@ -118,7 +159,9 @@ def _dns_resolve(cls, address, family=0):
# as these appear to cause problems on some platforms
continue
if addr not in resolved:
resolved.append(Address(addr))
resolved.append(ResolvedAddress(
addr, host_name=address.host_name)
)
return resolved

def resolve(self, family=0, resolver=None):
Expand All @@ -138,6 +181,8 @@ def resolve(self, family=0, resolver=None):
:param resolver: optional customer resolver function to be
called before regular DNS resolution
"""

log.debug("[#0000] C: <RESOLVE> %s", self)
resolved = []
if resolver:
for address in map(Address, resolver(self)):
Expand Down Expand Up @@ -173,3 +218,26 @@ class IPv6Address(Address):

def __str__(self):
return "[{}]:{}".format(*self)


class ResolvedAddress(Address):

@property
def host_name(self):
return self._host_name

def resolve(self, family=0, resolver=None):
return [self]

def __new__(cls, iterable, host_name=None):
new = super().__new__(cls, iterable)
new._host_name = host_name
return new


class ResolvedIPv4Address(IPv4Address, ResolvedAddress):
pass


class ResolvedIPv6Address(IPv6Address, ResolvedAddress):
pass
122 changes: 74 additions & 48 deletions neo4j/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
timeout as SocketTimeout,
)
from ssl import (
CertificateError,
HAS_SNI,
SSLError,
)
Expand All @@ -59,6 +60,7 @@
from time import perf_counter

from neo4j._exceptions import (
BoltError,
BoltHandshakeError,
BoltProtocolError,
BoltRoutingError,
Expand All @@ -77,6 +79,7 @@
from neo4j.exceptions import (
ClientError,
ConfigurationError,
DriverError,
ReadServiceUnavailable,
ServiceUnavailable,
SessionExpired,
Expand Down Expand Up @@ -708,16 +711,15 @@ def fetch_routing_table(self, *, address, timeout, database, bookmarks):

:return: a new RoutingTable instance or None if the given router is
currently unable to provide routing information

:raise neo4j.exceptions.ServiceUnavailable: if no writers are available
:raise neo4j._exceptions.BoltProtocolError: if the routing information received is unusable
"""
new_routing_info = self.fetch_routing_info(address, database, bookmarks,
timeout)
if new_routing_info is None:
try:
new_routing_info = self.fetch_routing_info(address, database,
bookmarks, timeout)
except ServiceUnavailable:
new_routing_info = None
if not new_routing_info:
log.debug("Failed to fetch routing info %s", address)
return None
elif not new_routing_info:
raise BoltRoutingError("Invalid routing table", address)
else:
servers = new_routing_info[0]["servers"]
ttl = new_routing_info[0]["ttl"]
Expand All @@ -733,11 +735,13 @@ def fetch_routing_table(self, *, address, timeout, database, bookmarks):

# No routers
if num_routers == 0:
raise BoltRoutingError("No routing servers returned from server", address)
log.debug("No routing servers returned from server %s", address)
return None

# No readers
if num_readers == 0:
raise BoltRoutingError("No read servers returned from server", address)
log.debug("No read servers returned from server %s", address)
return None

# At least one of each is fine, so return this table
return new_routing_table
Expand All @@ -751,14 +755,20 @@ def update_routing_table_from(self, *routers, database=None,
"""
log.debug("Attempting to update routing table from {}".format(", ".join(map(repr, routers))))
for router in routers:
new_routing_table = self.fetch_routing_table(
address=router, timeout=self.pool_config.connection_timeout,
database=database, bookmarks=bookmarks
)
if new_routing_table is not None:
self.routing_tables[database].update(new_routing_table)
log.debug("[#0000] C: <UPDATE ROUTING TABLE> address={!r} ({!r})".format(router, self.routing_tables[database]))
return True
for address in router.resolve(resolver=self.pool_config.resolver):
new_routing_table = self.fetch_routing_table(
address=address,
timeout=self.pool_config.connection_timeout,
database=database, bookmarks=bookmarks
)
if new_routing_table is not None:
self.routing_tables[database].update(new_routing_table)
log.debug(
"[#0000] C: <UPDATE ROUTING TABLE> address=%r (%r)",
address, self.routing_tables[database]
)
return True
self.deactivate(router)
return False

def update_routing_table(self, *, database, bookmarks):
Expand All @@ -771,24 +781,26 @@ def update_routing_table(self, *, database, bookmarks):
:raise neo4j.exceptions.ServiceUnavailable:
"""
# copied because it can be modified
existing_routers = list(self.routing_tables[database].routers)
existing_routers = set(self.routing_tables[database].routers)

prefer_initial_routing_address = \
self.routing_tables[database].missing_fresh_writer()

has_tried_initial_routers = False
if self.routing_tables[database].missing_fresh_writer():
if prefer_initial_routing_address:
# TODO: Test this state
has_tried_initial_routers = True
if self.update_routing_table_from(
self.first_initial_routing_address, database=database,
bookmarks=bookmarks
):
# Why is only the first initial routing address used?
return
if self.update_routing_table_from(*existing_routers, database=database,
bookmarks=bookmarks):
if self.update_routing_table_from(
*(existing_routers - {self.first_initial_routing_address}),
database=database, bookmarks=bookmarks
):
return

if (not has_tried_initial_routers
and self.first_initial_routing_address not in existing_routers):
if not prefer_initial_routing_address:
if self.update_routing_table_from(
self.first_initial_routing_address, database=database,
bookmarks=bookmarks
Expand Down Expand Up @@ -956,21 +968,24 @@ def _secure(s, host, ssl_context):
local_port = s.getsockname()[1]
# Secure the connection if an SSL context has been provided
if ssl_context:
last_error = None
log.debug("[#%04X] C: <SECURE> %s", local_port, host)
try:
sni_host = host if HAS_SNI and host else None
s = ssl_context.wrap_socket(s, server_hostname=sni_host)
except (SSLError, OSError) as cause:
_close_socket(s)
error = BoltSecurityError(message="Failed to establish encrypted connection.", address=(host, local_port))
error.__cause__ = cause
raise error
else:
# Check that the server provides a certificate
der_encoded_server_certificate = s.getpeercert(binary_form=True)
if der_encoded_server_certificate is None:
s.close()
raise BoltProtocolError("When using an encrypted socket, the server should always provide a certificate", address=(host, local_port))
except (OSError, SSLError, CertificateError) as cause:
raise BoltSecurityError(
message="Failed to establish encrypted connection.",
address=(host, local_port)
) from cause
# Check that the server provides a certificate
der_encoded_server_certificate = s.getpeercert(binary_form=True)
if der_encoded_server_certificate is None:
raise BoltProtocolError(
"When using an encrypted socket, the server should always "
"provide a certificate", address=(host, local_port)
)
return s
return s


Expand Down Expand Up @@ -1041,27 +1056,38 @@ def connect(address, *, timeout, custom_resolver, ssl_context, keep_alive):
""" Connect and perform a handshake and return a valid Connection object,
assuming a protocol version can be agreed.
"""
last_error = None
errors = []
# Establish a connection to the host and port specified
# Catches refused connections see:
# https://docs.python.org/2/library/errno.html
log.debug("[#0000] C: <RESOLVE> %s", address)

for resolved_address in Address(address).resolve(resolver=custom_resolver):
resolved_addresses = Address(address).resolve(resolver=custom_resolver)
for resolved_address in resolved_addresses:
s = None
try:
host = address[0]
s = _connect(resolved_address, timeout, keep_alive)
s = _secure(s, host, ssl_context)
return _handshake(s, address)
except Exception as error:
s = _secure(s, resolved_address.host_name, ssl_context)
return _handshake(s, resolved_address)
except (BoltError, DriverError, OSError) as error:
if s:
_close_socket(s)
last_error = error
if last_error is None:
raise ServiceUnavailable("Failed to resolve addresses for %s" % address)
errors.append(error)
except Exception:
if s:
_close_socket(s)
raise
if not errors:
raise ServiceUnavailable(
"Couldn't connect to %s (resolved to %s)" % (
str(address), tuple(map(str, resolved_addresses)))
)
else:
raise last_error
raise ServiceUnavailable(
"Couldn't connect to %s (resolved to %s):\n%s" % (
str(address), tuple(map(str, resolved_addresses)),
"\n".join(map(str, errors))
)
) from errors[0]


def check_supported_server_product(agent):
Expand Down
3 changes: 2 additions & 1 deletion testkitbackend/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ def __init__(self, rd, wr):
self._rd = rd
self._wr = wr
self.drivers = {}
self.address_resolutions = {}
self.custom_resolutions = {}
self.dns_resolutions = {}
self.sessions = {}
self.results = {}
self.errors = {}
Expand Down
Loading