Skip to content

[1.7.0] Custom resolver option #241

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 19, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions docs/source/driver.rst
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,25 @@ The maximum time to allow for retries to be attempted when using transaction fun
After this time, no more retries will be attempted.
This setting does not terminate running queries.

``resolver``
------------

A custom resolver function to resolve host and port values ahead of DNS resolution.
This function is called with a 2-tuple of (host, port) and should return an iterable of tuples as would be returned from ``getaddrinfo``.
If no custom resolver function is supplied, the internal resolver moves straight to regular DNS resolution.

For example::

def my_resolver(socket_address):
if socket_address == ("foo", 9999):
yield "::1", 7687
yield "127.0.0.1", 7687
else:
from socket import gaierror
raise gaierror("Unexpected socket address %r" % socket_address)

driver = GraphDatabase.driver("bolt+routing://foo:9999", auth=("neo4j", "password"), resolver=my_resolver)



Object Lifetime
Expand Down
64 changes: 44 additions & 20 deletions neo4j/addressing.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,9 @@
from collections import namedtuple
from socket import getaddrinfo, gaierror, SOCK_STREAM, IPPROTO_TCP

from neo4j.compat import urlparse
from neo4j.compat import urlparse, parse_qs
from neo4j.exceptions import AddressError

try:
from urllib.parse import parse_qs
except ImportError:
from urlparse import parse_qs


VALID_IPv4_SEGMENTS = [str(i).encode("latin1") for i in range(0x100)]
VALID_IPv6_SEGMENT_CHARS = b"0123456789abcdef"
Expand Down Expand Up @@ -102,17 +97,46 @@ def parse_routing_context(cls, uri):
return context


def resolve(socket_address):
try:
info = getaddrinfo(socket_address[0], socket_address[1], 0, SOCK_STREAM, IPPROTO_TCP)
except gaierror:
raise AddressError("Cannot resolve address {!r}".format(socket_address[0]))
else:
addresses = []
for _, _, _, _, address in info:
if len(address) == 4 and address[3] != 0:
# skip any IPv6 addresses with a non-zero scope id
# as these appear to cause problems on some platforms
continue
addresses.append(address)
return addresses
class Resolver(object):
""" A Resolver instance stores a list of addresses, each in a tuple, and
provides methods to perform resolution on these, thereby replacing them
with the resolved values.
"""

def __init__(self, custom_resolver=None):
self.addresses = []
self.custom_resolver = custom_resolver

def custom_resolve(self):
""" If a custom resolver is defined, perform custom resolution on
the contained addresses.

:return:
"""
if not callable(self.custom_resolver):
return
new_addresses = []
for address in self.addresses:
for new_address in self.custom_resolver(address):
new_addresses.append(new_address)
self.addresses = new_addresses

def dns_resolve(self):
""" Perform DNS resolution on the contained addresses.

:return:
"""
new_addresses = []
for address in self.addresses:
try:
info = getaddrinfo(address[0], address[1], 0, SOCK_STREAM, IPPROTO_TCP)
except gaierror:
raise AddressError("Cannot resolve address {!r}".format(address))
else:
for _, _, _, _, address in info:
if len(address) == 4 and address[3] != 0:
# skip any IPv6 addresses with a non-zero scope id
# as these appear to cause problems on some platforms
continue
new_addresses.append(address)
self.addresses = new_addresses
10 changes: 7 additions & 3 deletions neo4j/bolt/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from struct import pack as struct_pack, unpack as struct_unpack
from threading import RLock, Condition

from neo4j.addressing import SocketAddress, resolve
from neo4j.addressing import SocketAddress, Resolver
from neo4j.bolt.cert import KNOWN_HOSTS
from neo4j.bolt.response import InitResponse, AckFailureResponse, ResetResponse
from neo4j.compat.ssl import SSL_AVAILABLE, HAS_SNI, SSLError
Expand Down Expand Up @@ -685,12 +685,16 @@ def connect(address, ssl_context=None, error_handler=None, **config):
a protocol version can be agreed.
"""

last_error = None
# Establish a connection to the host and port specified
# Catches refused connections see:
# https://docs.python.org/2/library/errno.html
log_debug("~~ [RESOLVE] %s", address)
last_error = None
for resolved_address in resolve(address):
resolver = Resolver(custom_resolver=config.get("resolver"))
resolver.addresses.append(address)
resolver.custom_resolve()
resolver.dns_resolve()
for resolved_address in resolver.addresses:
log_debug("~~ [RESOLVED] %s -> %s", address, resolved_address)
try:
s = _connect(resolved_address, **config)
Expand Down
4 changes: 2 additions & 2 deletions neo4j/compat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,6 @@ def perf_counter():

# The location of urlparse varies between Python 2 and 3
try:
from urllib.parse import urlparse
from urllib.parse import urlparse, parse_qs
except ImportError:
from urlparse import urlparse
from urlparse import urlparse, parse_qs
15 changes: 14 additions & 1 deletion test/integration/test_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
# limitations under the License.


from neo4j.v1 import GraphDatabase, ServiceUnavailable
from neo4j.bolt import DEFAULT_PORT
from neo4j.v1 import GraphDatabase, Driver, ServiceUnavailable
from test.integration.tools import IntegrationTestCase


Expand All @@ -43,3 +44,15 @@ def test_fail_nicely_when_using_http_port(self):
with self.assertRaises(ServiceUnavailable):
with GraphDatabase.driver(uri, auth=self.auth_token, encrypted=False):
pass

def test_custom_resolver(self):

def my_resolver(socket_address):
self.assertEqual(socket_address, ("*", DEFAULT_PORT))
yield "99.99.99.99", self.bolt_port # this should be rejected as unable to connect
yield "127.0.0.1", self.bolt_port # this should succeed

with Driver("bolt://*", auth=self.auth_token, resolver=my_resolver) as driver:
with driver.session() as session:
summary = session.run("RETURN 1").summary()
self.assertEqual(summary.server.address, ("127.0.0.1", 7687))