Skip to content

Commit 89ca053

Browse files
committed
TLS: verify host name from pre and post custom resolver
1 parent f0bda50 commit 89ca053

File tree

2 files changed

+66
-26
lines changed

2 files changed

+66
-26
lines changed

neo4j/addressing.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,10 @@ def __new__(cls, iterable):
126126
def __repr__(self):
127127
return "{}({!r})".format(self.__class__.__name__, tuple(self))
128128

129+
@property
130+
def host_names(self):
131+
return self[0],
132+
129133
@property
130134
def host(self):
131135
return self[0]
@@ -155,7 +159,9 @@ def _dns_resolve(cls, address, family=0):
155159
# as these appear to cause problems on some platforms
156160
continue
157161
if addr not in resolved:
158-
resolved.append(ResolvedAddress(addr))
162+
resolved.append(ResolvedAddress(
163+
addr, host_names=address.host_names)
164+
)
159165
return resolved
160166

161167
def resolve(self, family=0, resolver=None):
@@ -179,7 +185,10 @@ def resolve(self, family=0, resolver=None):
179185
log.debug("[#0000] C: <RESOLVE> %s", self)
180186
resolved = []
181187
if resolver:
182-
for address in map(Address, resolver(self)):
188+
for address in resolver(self):
189+
address = ResolvedAddress(
190+
address, host_names={self.host_names}
191+
)
183192
resolved.extend(self._dns_resolve(address, family))
184193
else:
185194
resolved.extend(self._dns_resolve(self, family))
@@ -215,9 +224,19 @@ def __str__(self):
215224

216225

217226
class ResolvedAddress(Address):
227+
228+
@property
229+
def host_names(self):
230+
return tuple(self._host_names)
231+
218232
def resolve(self, family=0, resolver=None):
219233
return [self]
220234

235+
def __new__(cls, iterable, host_names=None):
236+
new = super().__new__(cls, iterable)
237+
new._host_names = set() if host_names is None else set(host_names)
238+
return new
239+
221240

222241
class ResolvedIPv4Address(IPv4Address, ResolvedAddress):
223242
pass

neo4j/io/__init__.py

Lines changed: 45 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
timeout as SocketTimeout,
4949
)
5050
from ssl import (
51+
CertificateError,
5152
HAS_SNI,
5253
SSLError,
5354
)
@@ -59,6 +60,7 @@
5960
from time import perf_counter
6061

6162
from neo4j._exceptions import (
63+
BoltError,
6264
BoltHandshakeError,
6365
BoltProtocolError,
6466
BoltRoutingError,
@@ -77,6 +79,7 @@
7779
from neo4j.exceptions import (
7880
ClientError,
7981
ConfigurationError,
82+
DriverError,
8083
ReadServiceUnavailable,
8184
ServiceUnavailable,
8285
SessionExpired,
@@ -768,6 +771,7 @@ def update_routing_table_from(self, *routers, database=None,
768771
address, self.routing_tables[database]
769772
)
770773
return True
774+
self.deactivate(router)
771775
return False
772776

773777
def update_routing_table(self, *, database, bookmarks):
@@ -963,25 +967,34 @@ def _connect(resolved_address, timeout, keep_alive):
963967
return s
964968

965969

966-
def _secure(s, host, ssl_context):
970+
def _secure(s, hosts, ssl_context):
967971
local_port = s.getsockname()[1]
968972
# Secure the connection if an SSL context has been provided
973+
if not hosts:
974+
hosts = [None]
969975
if ssl_context:
970-
log.debug("[#%04X] C: <SECURE> %s", local_port, host)
971-
try:
972-
sni_host = host if HAS_SNI and host else None
973-
s = ssl_context.wrap_socket(s, server_hostname=sni_host)
974-
except (SSLError, OSError) as cause:
975-
_close_socket(s)
976-
error = BoltSecurityError(message="Failed to establish encrypted connection.", address=(host, local_port))
977-
error.__cause__ = cause
978-
raise error
979-
else:
980-
# Check that the server provides a certificate
981-
der_encoded_server_certificate = s.getpeercert(binary_form=True)
982-
if der_encoded_server_certificate is None:
983-
s.close()
984-
raise BoltProtocolError("When using an encrypted socket, the server should always provide a certificate", address=(host, local_port))
976+
last_error = None
977+
for host in hosts:
978+
log.debug("[#%04X] C: <SECURE> %s", local_port, host)
979+
try:
980+
sni_host = host if HAS_SNI and host else None
981+
s = ssl_context.wrap_socket(s, server_hostname=sni_host)
982+
except (SSLError, CertificateError) as cause:
983+
last_error = cause
984+
continue
985+
except OSError as cause:
986+
# No sense in trying another host name with a broken socket
987+
last_error = cause
988+
break
989+
else:
990+
# Check that the server provides a certificate
991+
der_encoded_server_certificate = s.getpeercert(binary_form=True)
992+
if der_encoded_server_certificate is None:
993+
raise BoltProtocolError("When using an encrypted socket, the server should always provide a certificate", address=(host, local_port))
994+
return s
995+
raise BoltSecurityError(
996+
message="Failed to establish encrypted connection.",
997+
address=(hosts[0], local_port)) from last_error
985998
return s
986999

9871000

@@ -1057,23 +1070,31 @@ def connect(address, *, timeout, custom_resolver, ssl_context, keep_alive):
10571070
# Catches refused connections see:
10581071
# https://docs.python.org/2/library/errno.html
10591072

1060-
for resolved_address in Address(address).resolve(resolver=custom_resolver):
1073+
resolved_addresses = Address(address).resolve(resolver=custom_resolver)
1074+
for resolved_address in resolved_addresses:
10611075
s = None
10621076
try:
1063-
host = resolved_address[0]
10641077
s = _connect(resolved_address, timeout, keep_alive)
1065-
s = _secure(s, host, ssl_context)
1078+
s = _secure(s, resolved_address.host_names, ssl_context)
10661079
return _handshake(s, resolved_address)
1067-
except Exception as error:
1080+
except (BoltError, DriverError, OSError) as error:
10681081
if s:
10691082
_close_socket(s)
10701083
last_error = error
1084+
except Exception:
1085+
if s:
1086+
_close_socket(s)
1087+
raise
10711088
if last_error is None:
1072-
raise ServiceUnavailable("Failed to resolve addresses for %s" %
1073-
str(address))
1089+
raise ServiceUnavailable(
1090+
"Couldn't connect to %s (resolved to %s)" % (
1091+
str(address), tuple(map(str, resolved_addresses)))
1092+
)
10741093
else:
1075-
raise ServiceUnavailable("Failed to resolve addresses for %s" %
1076-
str(address)) from last_error
1094+
raise ServiceUnavailable(
1095+
"Couldn't connect to %s (resolved to %s)" % (
1096+
str(address), tuple(map(str, resolved_addresses)))
1097+
) from last_error
10771098

10781099

10791100
def check_supported_server_product(agent):

0 commit comments

Comments
 (0)