Skip to content

Commit 4fb29ef

Browse files
authored
Improve address resolution (#532)
* Try all routers before giving up (including results of custom- and dns-resolver) * Add DNS resolution request/response to testkit back-end * Rewrite errors on connect into ServiceUnavailable
1 parent 29b1869 commit 4fb29ef

File tree

7 files changed

+210
-81
lines changed

7 files changed

+210
-81
lines changed

neo4j/addressing.py

Lines changed: 72 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,46 @@
2626
AF_INET,
2727
AF_INET6,
2828
)
29+
import logging
2930

3031

31-
class Address(tuple):
32+
log = logging.getLogger("neo4j")
33+
34+
35+
class _AddressMeta(type(tuple)):
36+
37+
def __init__(self, *args, **kwargs):
38+
self._ipv4_cls = None
39+
self._ipv6_cls = None
40+
41+
def _subclass_by_family(self, family):
42+
subclasses = [
43+
sc for sc in self.__subclasses__()
44+
if (sc.__module__ == self.__module__
45+
and getattr(sc, "family", None) == family)
46+
]
47+
if len(subclasses) != 1:
48+
raise ValueError(
49+
"Class {} needs exactly one direct subclass with attribute "
50+
"`family == {}` within this module. "
51+
"Found: {}".format(self, family, subclasses)
52+
)
53+
return subclasses[0]
54+
55+
@property
56+
def ipv4_cls(self):
57+
if self._ipv4_cls is None:
58+
self._ipv4_cls = self._subclass_by_family(AF_INET)
59+
return self._ipv4_cls
60+
61+
@property
62+
def ipv6_cls(self):
63+
if self._ipv6_cls is None:
64+
self._ipv6_cls = self._subclass_by_family(AF_INET6)
65+
return self._ipv6_cls
66+
67+
68+
class Address(tuple, metaclass=_AddressMeta):
3269

3370
@classmethod
3471
def from_socket(cls, socket):
@@ -75,9 +112,9 @@ def __new__(cls, iterable):
75112
n_parts = len(iterable)
76113
inst = tuple.__new__(cls, iterable)
77114
if n_parts == 2:
78-
inst.__class__ = IPv4Address
115+
inst.__class__ = cls.ipv4_cls
79116
elif n_parts == 4:
80-
inst.__class__ = IPv6Address
117+
inst.__class__ = cls.ipv6_cls
81118
else:
82119
raise ValueError("Addresses must consist of either "
83120
"two parts (IPv4) or four parts (IPv6)")
@@ -89,6 +126,10 @@ def __new__(cls, iterable):
89126
def __repr__(self):
90127
return "{}({!r})".format(self.__class__.__name__, tuple(self))
91128

129+
@property
130+
def host_name(self):
131+
return self[0]
132+
92133
@property
93134
def host(self):
94135
return self[0]
@@ -118,7 +159,9 @@ def _dns_resolve(cls, address, family=0):
118159
# as these appear to cause problems on some platforms
119160
continue
120161
if addr not in resolved:
121-
resolved.append(Address(addr))
162+
resolved.append(ResolvedAddress(
163+
addr, host_name=address.host_name)
164+
)
122165
return resolved
123166

124167
def resolve(self, family=0, resolver=None):
@@ -138,6 +181,8 @@ def resolve(self, family=0, resolver=None):
138181
:param resolver: optional customer resolver function to be
139182
called before regular DNS resolution
140183
"""
184+
185+
log.debug("[#0000] C: <RESOLVE> %s", self)
141186
resolved = []
142187
if resolver:
143188
for address in map(Address, resolver(self)):
@@ -173,3 +218,26 @@ class IPv6Address(Address):
173218

174219
def __str__(self):
175220
return "[{}]:{}".format(*self)
221+
222+
223+
class ResolvedAddress(Address):
224+
225+
@property
226+
def host_name(self):
227+
return self._host_name
228+
229+
def resolve(self, family=0, resolver=None):
230+
return [self]
231+
232+
def __new__(cls, iterable, host_name=None):
233+
new = super().__new__(cls, iterable)
234+
new._host_name = host_name
235+
return new
236+
237+
238+
class ResolvedIPv4Address(IPv4Address, ResolvedAddress):
239+
pass
240+
241+
242+
class ResolvedIPv6Address(IPv6Address, ResolvedAddress):
243+
pass

neo4j/io/__init__.py

Lines changed: 74 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
timeout as SocketTimeout,
4949
)
5050
from ssl import (
51+
CertificateError,
5152
HAS_SNI,
5253
SSLError,
5354
)
@@ -59,6 +60,7 @@
5960
from time import perf_counter
6061

6162
from neo4j._exceptions import (
63+
BoltError,
6264
BoltHandshakeError,
6365
BoltProtocolError,
6466
BoltRoutingError,
@@ -77,6 +79,7 @@
7779
from neo4j.exceptions import (
7880
ClientError,
7981
ConfigurationError,
82+
DriverError,
8083
ReadServiceUnavailable,
8184
ServiceUnavailable,
8285
SessionExpired,
@@ -708,16 +711,15 @@ def fetch_routing_table(self, *, address, timeout, database, bookmarks):
708711
709712
:return: a new RoutingTable instance or None if the given router is
710713
currently unable to provide routing information
711-
712-
:raise neo4j.exceptions.ServiceUnavailable: if no writers are available
713-
:raise neo4j._exceptions.BoltProtocolError: if the routing information received is unusable
714714
"""
715-
new_routing_info = self.fetch_routing_info(address, database, bookmarks,
716-
timeout)
717-
if new_routing_info is None:
715+
try:
716+
new_routing_info = self.fetch_routing_info(address, database,
717+
bookmarks, timeout)
718+
except ServiceUnavailable:
719+
new_routing_info = None
720+
if not new_routing_info:
721+
log.debug("Failed to fetch routing info %s", address)
718722
return None
719-
elif not new_routing_info:
720-
raise BoltRoutingError("Invalid routing table", address)
721723
else:
722724
servers = new_routing_info[0]["servers"]
723725
ttl = new_routing_info[0]["ttl"]
@@ -733,11 +735,13 @@ def fetch_routing_table(self, *, address, timeout, database, bookmarks):
733735

734736
# No routers
735737
if num_routers == 0:
736-
raise BoltRoutingError("No routing servers returned from server", address)
738+
log.debug("No routing servers returned from server %s", address)
739+
return None
737740

738741
# No readers
739742
if num_readers == 0:
740-
raise BoltRoutingError("No read servers returned from server", address)
743+
log.debug("No read servers returned from server %s", address)
744+
return None
741745

742746
# At least one of each is fine, so return this table
743747
return new_routing_table
@@ -751,14 +755,20 @@ def update_routing_table_from(self, *routers, database=None,
751755
"""
752756
log.debug("Attempting to update routing table from {}".format(", ".join(map(repr, routers))))
753757
for router in routers:
754-
new_routing_table = self.fetch_routing_table(
755-
address=router, timeout=self.pool_config.connection_timeout,
756-
database=database, bookmarks=bookmarks
757-
)
758-
if new_routing_table is not None:
759-
self.routing_tables[database].update(new_routing_table)
760-
log.debug("[#0000] C: <UPDATE ROUTING TABLE> address={!r} ({!r})".format(router, self.routing_tables[database]))
761-
return True
758+
for address in router.resolve(resolver=self.pool_config.resolver):
759+
new_routing_table = self.fetch_routing_table(
760+
address=address,
761+
timeout=self.pool_config.connection_timeout,
762+
database=database, bookmarks=bookmarks
763+
)
764+
if new_routing_table is not None:
765+
self.routing_tables[database].update(new_routing_table)
766+
log.debug(
767+
"[#0000] C: <UPDATE ROUTING TABLE> address=%r (%r)",
768+
address, self.routing_tables[database]
769+
)
770+
return True
771+
self.deactivate(router)
762772
return False
763773

764774
def update_routing_table(self, *, database, bookmarks):
@@ -771,24 +781,26 @@ def update_routing_table(self, *, database, bookmarks):
771781
:raise neo4j.exceptions.ServiceUnavailable:
772782
"""
773783
# copied because it can be modified
774-
existing_routers = list(self.routing_tables[database].routers)
784+
existing_routers = set(self.routing_tables[database].routers)
785+
786+
prefer_initial_routing_address = \
787+
self.routing_tables[database].missing_fresh_writer()
775788

776-
has_tried_initial_routers = False
777-
if self.routing_tables[database].missing_fresh_writer():
789+
if prefer_initial_routing_address:
778790
# TODO: Test this state
779-
has_tried_initial_routers = True
780791
if self.update_routing_table_from(
781792
self.first_initial_routing_address, database=database,
782793
bookmarks=bookmarks
783794
):
784795
# Why is only the first initial routing address used?
785796
return
786-
if self.update_routing_table_from(*existing_routers, database=database,
787-
bookmarks=bookmarks):
797+
if self.update_routing_table_from(
798+
*(existing_routers - {self.first_initial_routing_address}),
799+
database=database, bookmarks=bookmarks
800+
):
788801
return
789802

790-
if (not has_tried_initial_routers
791-
and self.first_initial_routing_address not in existing_routers):
803+
if not prefer_initial_routing_address:
792804
if self.update_routing_table_from(
793805
self.first_initial_routing_address, database=database,
794806
bookmarks=bookmarks
@@ -956,21 +968,24 @@ def _secure(s, host, ssl_context):
956968
local_port = s.getsockname()[1]
957969
# Secure the connection if an SSL context has been provided
958970
if ssl_context:
971+
last_error = None
959972
log.debug("[#%04X] C: <SECURE> %s", local_port, host)
960973
try:
961974
sni_host = host if HAS_SNI and host else None
962975
s = ssl_context.wrap_socket(s, server_hostname=sni_host)
963-
except (SSLError, OSError) as cause:
964-
_close_socket(s)
965-
error = BoltSecurityError(message="Failed to establish encrypted connection.", address=(host, local_port))
966-
error.__cause__ = cause
967-
raise error
968-
else:
969-
# Check that the server provides a certificate
970-
der_encoded_server_certificate = s.getpeercert(binary_form=True)
971-
if der_encoded_server_certificate is None:
972-
s.close()
973-
raise BoltProtocolError("When using an encrypted socket, the server should always provide a certificate", address=(host, local_port))
976+
except (OSError, SSLError, CertificateError) as cause:
977+
raise BoltSecurityError(
978+
message="Failed to establish encrypted connection.",
979+
address=(host, local_port)
980+
) from cause
981+
# Check that the server provides a certificate
982+
der_encoded_server_certificate = s.getpeercert(binary_form=True)
983+
if der_encoded_server_certificate is None:
984+
raise BoltProtocolError(
985+
"When using an encrypted socket, the server should always "
986+
"provide a certificate", address=(host, local_port)
987+
)
988+
return s
974989
return s
975990

976991

@@ -1041,27 +1056,38 @@ def connect(address, *, timeout, custom_resolver, ssl_context, keep_alive):
10411056
""" Connect and perform a handshake and return a valid Connection object,
10421057
assuming a protocol version can be agreed.
10431058
"""
1044-
last_error = None
1059+
errors = []
10451060
# Establish a connection to the host and port specified
10461061
# Catches refused connections see:
10471062
# https://docs.python.org/2/library/errno.html
1048-
log.debug("[#0000] C: <RESOLVE> %s", address)
10491063

1050-
for resolved_address in Address(address).resolve(resolver=custom_resolver):
1064+
resolved_addresses = Address(address).resolve(resolver=custom_resolver)
1065+
for resolved_address in resolved_addresses:
10511066
s = None
10521067
try:
1053-
host = address[0]
10541068
s = _connect(resolved_address, timeout, keep_alive)
1055-
s = _secure(s, host, ssl_context)
1056-
return _handshake(s, address)
1057-
except Exception as error:
1069+
s = _secure(s, resolved_address.host_name, ssl_context)
1070+
return _handshake(s, resolved_address)
1071+
except (BoltError, DriverError, OSError) as error:
10581072
if s:
10591073
_close_socket(s)
1060-
last_error = error
1061-
if last_error is None:
1062-
raise ServiceUnavailable("Failed to resolve addresses for %s" % address)
1074+
errors.append(error)
1075+
except Exception:
1076+
if s:
1077+
_close_socket(s)
1078+
raise
1079+
if not errors:
1080+
raise ServiceUnavailable(
1081+
"Couldn't connect to %s (resolved to %s)" % (
1082+
str(address), tuple(map(str, resolved_addresses)))
1083+
)
10631084
else:
1064-
raise last_error
1085+
raise ServiceUnavailable(
1086+
"Couldn't connect to %s (resolved to %s):\n%s" % (
1087+
str(address), tuple(map(str, resolved_addresses)),
1088+
"\n".join(map(str, errors))
1089+
)
1090+
) from errors[0]
10651091

10661092

10671093
def check_supported_server_product(agent):

testkitbackend/backend.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,8 @@ def __init__(self, rd, wr):
8585
self._rd = rd
8686
self._wr = wr
8787
self.drivers = {}
88-
self.address_resolutions = {}
88+
self.custom_resolutions = {}
89+
self.dns_resolutions = {}
8990
self.sessions = {}
9091
self.results = {}
9192
self.errors = {}

0 commit comments

Comments
 (0)