Skip to content

Commit 196e50c

Browse files
committed
Router and test_routing with boltkit
1 parent 2ab7384 commit 196e50c

File tree

6 files changed

+135
-55
lines changed

6 files changed

+135
-55
lines changed

neo4j/v1/exceptions.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,3 +51,8 @@ class TransactionError(Exception):
5151
class ResultError(Exception):
5252
""" Raised when an error occurs while consuming a result.
5353
"""
54+
55+
56+
class ServiceUnavailable(Exception):
57+
""" Raised when no database service is available.
58+
"""

neo4j/v1/session.py

Lines changed: 66 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@
4040
from .constants import DEFAULT_PORT, ENCRYPTION_DEFAULT, TRUST_DEFAULT, TRUST_SIGNED_CERTIFICATES, \
4141
TRUST_ON_FIRST_USE, READ_ACCESS, TRUST_SYSTEM_CA_SIGNED_CERTIFICATES, \
4242
TRUST_ALL_CERTIFICATES, TRUST_CUSTOM_CA_SIGNED_CERTIFICATES
43-
from .exceptions import CypherError, ProtocolError, ResultError, TransactionError
43+
from .exceptions import CypherError, ProtocolError, ResultError, TransactionError, \
44+
ServiceUnavailable
4445
from .ssl_compat import SSL_AVAILABLE, SSLContext, PROTOCOL_SSLv23, OP_NO_SSLv2, CERT_REQUIRED
4546
from .summary import ResultSummary
4647
from .types import hydrated
@@ -196,53 +197,88 @@ def session(self, access_mode=None):
196197
return Session(self.pool.acquire(self.address))
197198

198199

199-
class RoutingDriver(Driver):
200-
""" A :class:`.RoutingDriver` is created from a `bolt+routing` URI.
200+
def parse_address(address):
201+
""" Convert an address string to a tuple.
201202
"""
203+
host, _, port = address.partition(":")
204+
return host, int(port)
202205

203-
def __init__(self, address, **config):
204-
self.address = address
205-
self.security_plan = security_plan = SecurityPlan.build(address, **config)
206-
self.encrypted = security_plan.encrypted
207-
if not security_plan.routing_compatible:
208-
# this error message is case-specific as there is only one incompatible
209-
# scenario right now
210-
raise RuntimeError("TRUST_ON_FIRST_USE is not compatible with routing")
211-
Driver.__init__(self, lambda a: connect(a, security_plan.ssl_context, **config))
212-
self._lock = Lock()
213-
self._expiry_time = None
214-
self._routers = RoundRobinSet([address])
215-
self._readers = RoundRobinSet()
216-
self._writers = RoundRobinSet()
217-
self.discover()
206+
207+
class Router(object):
208+
209+
timer = monotonic
210+
211+
def __init__(self, pool, initial_address):
212+
self.pool = pool
213+
self.lock = Lock()
214+
self.expiry_time = None
215+
self.routers = RoundRobinSet([initial_address])
216+
self.readers = RoundRobinSet()
217+
self.writers = RoundRobinSet()
218+
219+
def stale(self):
220+
expired = self.expiry_time is None or self.expiry_time <= self.timer()
221+
return expired or len(self.routers) <= 1 or not self.readers or not self.writers
218222

219223
def discover(self):
220-
with self._lock:
221-
for router in list(self._routers):
224+
with self.lock:
225+
if not self.routers:
226+
raise ServiceUnavailable("No routers available")
227+
for router in list(self.routers):
222228
session = Session(self.pool.acquire(router))
223229
try:
224230
record = session.run("CALL dbms.cluster.routing.getServers").single()
231+
except CypherError as error:
232+
if error.code == "Neo.ClientError.Procedure.ProcedureNotFound":
233+
raise ServiceUnavailable("Server does not support routing")
234+
raise
225235
except ResultError:
226236
raise RuntimeError("TODO")
227-
new_expiry_time = monotonic() + record["ttl"]
237+
new_expiry_time = self.timer() + record["ttl"]
228238
servers = record["servers"]
229239
new_routers = [s["addresses"] for s in servers if s["role"] == "ROUTE"][0]
230240
new_readers = [s["addresses"] for s in servers if s["role"] == "READ"][0]
231241
new_writers = [s["addresses"] for s in servers if s["role"] == "WRITE"][0]
232242
if new_routers and new_readers and new_writers:
233-
self._expiry_time = new_expiry_time
234-
self._routers.replace(new_routers)
235-
self._readers.replace(new_readers)
236-
self._writers.replace(new_writers)
237-
else:
238-
raise RuntimeError("TODO")
243+
self.expiry_time = new_expiry_time
244+
self.routers.replace(map(parse_address, new_routers))
245+
self.readers.replace(map(parse_address, new_readers))
246+
self.writers.replace(map(parse_address, new_writers))
247+
return
248+
raise ServiceUnavailable("Unable to establish routing information")
249+
250+
def acquire_read_connection(self):
251+
if self.stale():
252+
self.discover()
253+
return self.pool.acquire(next(self.readers))
254+
255+
def acquire_write_connection(self):
256+
if self.stale():
257+
self.discover()
258+
return self.pool.acquire(next(self.writers))
259+
260+
261+
class RoutingDriver(Driver):
262+
""" A :class:`.RoutingDriver` is created from a `bolt+routing` URI.
263+
"""
264+
265+
def __init__(self, address, **config):
266+
self.security_plan = security_plan = SecurityPlan.build(address, **config)
267+
self.encrypted = security_plan.encrypted
268+
if not security_plan.routing_compatible:
269+
# this error message is case-specific as there is only one incompatible
270+
# scenario right now
271+
raise RuntimeError("TRUST_ON_FIRST_USE is not compatible with routing")
272+
Driver.__init__(self, lambda a: connect(a, security_plan.ssl_context, **config))
273+
self.router = Router(self.pool, address)
274+
self.router.discover()
239275

240276
def session(self, access_mode=None):
241277
if access_mode == READ_ACCESS:
242-
address = next(self._readers)
278+
connection = self.router.acquire_read_connection()
243279
else:
244-
address = next(self._writers)
245-
return Session(self.pool.acquire(address))
280+
connection = self.router.acquire_write_connection()
281+
return Session(connection)
246282

247283

248284
class StatementResult(object):

test/resources/discover_servers.script

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,5 @@
44
C: RUN "CALL dbms.cluster.routing.getServers" {}
55
PULL_ALL
66
S: SUCCESS {"fields": ["ttl", "servers"]}
7-
RECORD [9223372036854775807, [{"addresses": ["127.0.0.1:9001"],"role": "WRITE"}, {"addresses": ["127.0.0.1:9002","127.0.0.1:9003"], "role": "READ"},{"addresses": ["127.0.0.1:9001","127.0.0.1:9002","127.0.0.1:9003"], "role": "ROUTE"}]]
7+
RECORD [300, [{"addresses": ["127.0.0.1:9001"],"role": "WRITE"}, {"addresses": ["127.0.0.1:9002","127.0.0.1:9003"], "role": "READ"},{"addresses": ["127.0.0.1:9001","127.0.0.1:9002","127.0.0.1:9003"], "role": "ROUTE"}]]
88
SUCCESS {}

test/resources/return_1.script

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
!: AUTO INIT
2+
!: AUTO RESET
3+
4+
C: RUN "CALL dbms.cluster.routing.getServers" {}
5+
PULL_ALL
6+
S: SUCCESS {"fields": ["ttl", "servers"]}
7+
RECORD [300, [{"addresses": ["127.0.0.1:9001"],"role": "WRITE"}, {"addresses": ["127.0.0.1:9001"], "role": "READ"},{"addresses": ["127.0.0.1:9001", "127.0.0.1:9002"], "role": "ROUTE"}]]
8+
SUCCESS {}
9+
C: RUN "RETURN $x" {"x": 1}
10+
PULL_ALL
11+
S: SUCCESS {"fields": ["x"]}
12+
RECORD [1]
13+
SUCCESS {}

test/test_routing.py

Lines changed: 33 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -19,39 +19,49 @@
1919
# limitations under the License.
2020

2121

22-
# from os.path import join as join_path, normpath
23-
# from subprocess import check_call
24-
from unittest import TestCase
22+
from neo4j.v1 import basic_auth, GraphDatabase, Router, ConnectionPool, connect, READ_ACCESS
23+
from test.util import ServerTestCase
2524

26-
from neo4j.v1 import basic_auth, GraphDatabase, RoutingDriver, READ_ACCESS
2725

26+
class RouterTestCase(ServerTestCase):
2827

29-
# from os.path import dirname
28+
def setUp(self):
29+
self.pool = ConnectionPool(lambda a: connect(a, auth=basic_auth("neo4j", "password")))
3030

31-
# TEST = normpath(dirname(__file__))
32-
# TEST_RESOURCES = join_path(TEST, "resources")
33-
# BOLT_ROUTING_URI = "bolt+routing://127.0.0.1:9001"
34-
# AUTH_TOKEN = basic_auth("neotest", "neotest")
31+
def test_router_is_initially_stale(self):
32+
router = Router(self.pool, ("127.0.0.1", 7687))
33+
assert router.stale()
3534

35+
def test_discovery(self):
36+
self.start_stub_server(9001, "discover_servers.script")
37+
router = Router(self.pool, ("127.0.0.1", 9001))
38+
router.timer = lambda: 0
39+
router.discover()
40+
assert router.expiry_time == 300
41+
assert router.routers == {'127.0.0.1:9001', '127.0.0.1:9002', '127.0.0.1:9003'}
42+
assert router.readers == {'127.0.0.1:9002', '127.0.0.1:9003'}
43+
assert router.writers == {'127.0.0.1:9001'}
3644

37-
class LocalClusterIntegrationTestCase(TestCase):
45+
46+
class LocalClusterIntegrationTestCase(ServerTestCase):
47+
48+
def test_should_discover_servers_on_driver_construction(self):
49+
self.start_stub_server(9001, "discover_servers.script")
50+
uri = "bolt+routing://127.0.0.1:9001"
51+
driver = GraphDatabase.driver(uri, auth=basic_auth("neo4j", "password"), encrypted=False)
52+
assert driver.router.routers == {('127.0.0.1', 9001), ('127.0.0.1', 9002), ('127.0.0.1', 9003)}
53+
assert driver.router.readers == {('127.0.0.1', 9002), ('127.0.0.1', 9003)}
54+
assert driver.router.writers == {('127.0.0.1', 9001)}
3855

3956
def test_should_be_able_to_run_cypher(self):
40-
uri = "bolt+routing://ec2-54-78-203-70.eu-west-1.compute.amazonaws.com:26000"
41-
driver = GraphDatabase.driver(uri, auth=basic_auth("neo4j", "password"))
57+
self.start_stub_server(9001, "return_1.script")
58+
uri = "bolt+routing://127.0.0.1:9001"
59+
driver = GraphDatabase.driver(uri, auth=basic_auth("neo4j", "password"), encrypted=False)
4260
try:
4361
with driver.session(READ_ACCESS) as session:
44-
result = session.run("UNWIND range(1, 3) AS n RETURN n")
62+
result = session.run("RETURN $x", {"x": 1})
4563
for record in result:
46-
print(record)
47-
print(result.summary.metadata)
48-
print(session.connection.address)
64+
assert record["x"] == 1
65+
assert session.connection.address == ('127.0.0.1', 9001)
4966
finally:
5067
driver.close()
51-
52-
def test_should_discover_servers_on_driver_construction(self):
53-
uri = "bolt+routing://ec2-54-78-203-70.eu-west-1.compute.amazonaws.com:26000"
54-
driver = GraphDatabase.driver(uri, auth=basic_auth("neo4j", "password"))
55-
print(driver._routers)
56-
print(driver._readers)
57-
print(driver._writers)

test/util.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,10 @@
2121

2222
import functools
2323
from os import getenv, remove, rename
24-
from os.path import isfile
24+
from os.path import isfile, dirname, join as path_join
2525
from socket import create_connection
2626
from subprocess import check_call, CalledProcessError
27+
from threading import Thread
2728
from time import sleep
2829
from unittest import TestCase
2930

@@ -96,3 +97,18 @@ def tearDown(self):
9697
if isfile(self.known_hosts):
9798
remove(self.known_hosts)
9899
rename(self.known_hosts_backup, self.known_hosts)
100+
101+
def start_stub_server(self, port, script):
102+
StubServer(port, script).start()
103+
sleep(0.5)
104+
105+
106+
class StubServer(Thread):
107+
108+
def __init__(self, port, script):
109+
super(StubServer, self).__init__()
110+
self.port = port
111+
self.script = path_join(dirname(__file__), "resources", script)
112+
113+
def run(self):
114+
check_call(["boltstub", str(self.port), self.script])

0 commit comments

Comments
 (0)