Skip to content

Commit 6fdb4e1

Browse files
committed
Basic routing tests
1 parent 196e50c commit 6fdb4e1

File tree

8 files changed

+122
-48
lines changed

8 files changed

+122
-48
lines changed

neo4j/v1/session.py

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -208,11 +208,11 @@ class Router(object):
208208

209209
timer = monotonic
210210

211-
def __init__(self, pool, initial_address):
211+
def __init__(self, pool, *routers):
212212
self.pool = pool
213213
self.lock = Lock()
214214
self.expiry_time = None
215-
self.routers = RoundRobinSet([initial_address])
215+
self.routers = RoundRobinSet(routers)
216216
self.readers = RoundRobinSet()
217217
self.writers = RoundRobinSet()
218218

@@ -225,26 +225,27 @@ def discover(self):
225225
if not self.routers:
226226
raise ServiceUnavailable("No routers available")
227227
for router in list(self.routers):
228-
session = Session(self.pool.acquire(router))
229-
try:
230-
record = session.run("CALL dbms.cluster.routing.getServers").single()
231-
except CypherError as error:
232-
if error.code == "Neo.ClientError.Procedure.ProcedureNotFound":
233-
raise ServiceUnavailable("Server does not support routing")
234-
raise
235-
except ResultError:
236-
raise RuntimeError("TODO")
237-
new_expiry_time = self.timer() + record["ttl"]
238-
servers = record["servers"]
239-
new_routers = [s["addresses"] for s in servers if s["role"] == "ROUTE"][0]
240-
new_readers = [s["addresses"] for s in servers if s["role"] == "READ"][0]
241-
new_writers = [s["addresses"] for s in servers if s["role"] == "WRITE"][0]
242-
if new_routers and new_readers and new_writers:
243-
self.expiry_time = new_expiry_time
244-
self.routers.replace(map(parse_address, new_routers))
245-
self.readers.replace(map(parse_address, new_readers))
246-
self.writers.replace(map(parse_address, new_writers))
247-
return
228+
connection = self.pool.acquire(router)
229+
with Session(connection) as session:
230+
try:
231+
record = session.run("CALL dbms.cluster.routing.getServers").single()
232+
except CypherError as error:
233+
if error.code == "Neo.ClientError.Procedure.ProcedureNotFound":
234+
raise ServiceUnavailable("Server does not support routing")
235+
raise
236+
except ResultError:
237+
raise RuntimeError("TODO")
238+
new_expiry_time = self.timer() + record["ttl"]
239+
servers = record["servers"]
240+
new_routers = [s["addresses"] for s in servers if s["role"] == "ROUTE"][0]
241+
new_readers = [s["addresses"] for s in servers if s["role"] == "READ"][0]
242+
new_writers = [s["addresses"] for s in servers if s["role"] == "WRITE"][0]
243+
if new_routers and new_readers and new_writers:
244+
self.expiry_time = new_expiry_time
245+
self.routers.replace(map(parse_address, new_routers))
246+
self.readers.replace(map(parse_address, new_readers))
247+
self.writers.replace(map(parse_address, new_writers))
248+
return router
248249
raise ServiceUnavailable("Unable to establish routing information")
249250

250251
def acquire_read_connection(self):

test/resources/bad_router.script

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
!: AUTO INIT
2+
!: AUTO RESET
3+
4+
C: RUN "CALL dbms.cluster.routing.getServers" {}
5+
PULL_ALL
6+
S: SUCCESS {"fields": ["ttl", "servers"]}
7+
RECORD [300, [{"role":"ROUTE","addresses":[]},{"role":"READ","addresses":[]},{"role":"WRITE","addresses":[]}]]
8+
SUCCESS {}

test/resources/create_a.script

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
!: AUTO INIT
2+
!: AUTO RESET
3+
4+
C: RUN "CREATE (a $x)" {"x": {"name": "Alice"}}
5+
PULL_ALL
6+
S: SUCCESS {"fields": []}
7+
SUCCESS {}

test/resources/discover_servers.script

Lines changed: 0 additions & 8 deletions
This file was deleted.

test/resources/non_router.script

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
!: AUTO INIT
2+
!: AUTO RESET
3+
4+
C: RUN "CALL dbms.cluster.routing.getServers" {}
5+
PULL_ALL
6+
S: FAILURE {"code": "Neo.ClientError.Procedure.ProcedureNotFound", "message": "Not a router"}
7+
IGNORED
8+
C: ACK_FAILURE
9+
S: SUCCESS {}

test/resources/return_1.script

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,6 @@
11
!: AUTO INIT
22
!: AUTO RESET
33

4-
C: RUN "CALL dbms.cluster.routing.getServers" {}
5-
PULL_ALL
6-
S: SUCCESS {"fields": ["ttl", "servers"]}
7-
RECORD [300, [{"addresses": ["127.0.0.1:9001"],"role": "WRITE"}, {"addresses": ["127.0.0.1:9001"], "role": "READ"},{"addresses": ["127.0.0.1:9001", "127.0.0.1:9002"], "role": "ROUTE"}]]
8-
SUCCESS {}
94
C: RUN "RETURN $x" {"x": 1}
105
PULL_ALL
116
S: SUCCESS {"fields": ["x"]}

test/resources/router.script

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
!: AUTO INIT
2+
!: AUTO RESET
3+
4+
C: RUN "CALL dbms.cluster.routing.getServers" {}
5+
PULL_ALL
6+
S: SUCCESS {"fields": ["ttl", "servers"]}
7+
RECORD [300, [{"role":"ROUTE","addresses":["127.0.0.1:9001","127.0.0.1:9002","127.0.0.1:9003"]},{"role":"READ","addresses":["127.0.0.1:9004","127.0.0.1:9005"]},{"role":"WRITE","addresses":["127.0.0.1:9006"]}]]
8+
SUCCESS {}

test/test_routing.py

Lines changed: 67 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919
# limitations under the License.
2020

2121

22-
from neo4j.v1 import basic_auth, GraphDatabase, Router, ConnectionPool, connect, READ_ACCESS
22+
from neo4j.v1 import basic_auth, GraphDatabase, Router, ConnectionPool, connect, READ_ACCESS, \
23+
WRITE_ACCESS, ServiceUnavailable
2324
from test.util import ServerTestCase
2425

2526

@@ -33,35 +34,88 @@ def test_router_is_initially_stale(self):
3334
assert router.stale()
3435

3536
def test_discovery(self):
36-
self.start_stub_server(9001, "discover_servers.script")
37+
self.start_stub_server(9001, "router.script")
3738
router = Router(self.pool, ("127.0.0.1", 9001))
3839
router.timer = lambda: 0
3940
router.discover()
4041
assert router.expiry_time == 300
41-
assert router.routers == {'127.0.0.1:9001', '127.0.0.1:9002', '127.0.0.1:9003'}
42-
assert router.readers == {'127.0.0.1:9002', '127.0.0.1:9003'}
43-
assert router.writers == {'127.0.0.1:9001'}
42+
assert router.routers == {('127.0.0.1', 9001), ('127.0.0.1', 9002), ('127.0.0.1', 9003)}
43+
assert router.readers == {('127.0.0.1', 9004), ('127.0.0.1', 9005)}
44+
assert router.writers == {('127.0.0.1', 9006)}
4445

46+
def test_discovery_after_bad_discovery(self):
47+
self.start_stub_server(9001, "bad_router.script")
48+
self.start_stub_server(9002, "router.script")
49+
router = Router(self.pool, ("127.0.0.1", 9001), ("127.0.0.1", 9002))
50+
router.timer = lambda: 0
51+
router.discover()
52+
assert router.expiry_time == 300
53+
assert router.routers == {('127.0.0.1', 9001), ('127.0.0.1', 9002), ('127.0.0.1', 9003)}
54+
assert router.readers == {('127.0.0.1', 9004), ('127.0.0.1', 9005)}
55+
assert router.writers == {('127.0.0.1', 9006)}
56+
57+
def test_discovery_against_non_router(self):
58+
self.start_stub_server(9001, "non_router.script")
59+
router = Router(self.pool, ("127.0.0.1", 9001))
60+
with self.assertRaises(ServiceUnavailable):
61+
router.discover()
4562

46-
class LocalClusterIntegrationTestCase(ServerTestCase):
63+
def test_running_out_of_good_routers_on_discovery(self):
64+
self.start_stub_server(9001, "bad_router.script")
65+
self.start_stub_server(9002, "bad_router.script")
66+
self.start_stub_server(9003, "bad_router.script")
67+
router = Router(self.pool, ("127.0.0.1", 9001), ("127.0.0.1", 9002), ("127.0.0.1", 9003))
68+
with self.assertRaises(ServiceUnavailable):
69+
router.discover()
70+
71+
72+
class RoutingDriverTestCase(ServerTestCase):
4773

4874
def test_should_discover_servers_on_driver_construction(self):
49-
self.start_stub_server(9001, "discover_servers.script")
75+
self.start_stub_server(9001, "router.script")
5076
uri = "bolt+routing://127.0.0.1:9001"
5177
driver = GraphDatabase.driver(uri, auth=basic_auth("neo4j", "password"), encrypted=False)
52-
assert driver.router.routers == {('127.0.0.1', 9001), ('127.0.0.1', 9002), ('127.0.0.1', 9003)}
53-
assert driver.router.readers == {('127.0.0.1', 9002), ('127.0.0.1', 9003)}
54-
assert driver.router.writers == {('127.0.0.1', 9001)}
78+
router = driver.router
79+
assert router.routers == {('127.0.0.1', 9001), ('127.0.0.1', 9002), ('127.0.0.1', 9003)}
80+
assert router.readers == {('127.0.0.1', 9004), ('127.0.0.1', 9005)}
81+
assert router.writers == {('127.0.0.1', 9006)}
5582

56-
def test_should_be_able_to_run_cypher(self):
57-
self.start_stub_server(9001, "return_1.script")
83+
def test_should_be_able_to_read(self):
84+
self.start_stub_server(9001, "router.script")
85+
self.start_stub_server(9004, "return_1.script")
5886
uri = "bolt+routing://127.0.0.1:9001"
5987
driver = GraphDatabase.driver(uri, auth=basic_auth("neo4j", "password"), encrypted=False)
6088
try:
6189
with driver.session(READ_ACCESS) as session:
6290
result = session.run("RETURN $x", {"x": 1})
6391
for record in result:
6492
assert record["x"] == 1
65-
assert session.connection.address == ('127.0.0.1', 9001)
93+
assert session.connection.address == ('127.0.0.1', 9004)
94+
finally:
95+
driver.close()
96+
97+
def test_should_be_able_to_write(self):
98+
self.start_stub_server(9001, "router.script")
99+
self.start_stub_server(9006, "create_a.script")
100+
uri = "bolt+routing://127.0.0.1:9001"
101+
driver = GraphDatabase.driver(uri, auth=basic_auth("neo4j", "password"), encrypted=False)
102+
try:
103+
with driver.session(WRITE_ACCESS) as session:
104+
result = session.run("CREATE (a $x)", {"x": {"name": "Alice"}})
105+
assert not list(result)
106+
assert session.connection.address == ('127.0.0.1', 9006)
107+
finally:
108+
driver.close()
109+
110+
def test_should_be_able_to_write_as_default(self):
111+
self.start_stub_server(9001, "router.script")
112+
self.start_stub_server(9006, "create_a.script")
113+
uri = "bolt+routing://127.0.0.1:9001"
114+
driver = GraphDatabase.driver(uri, auth=basic_auth("neo4j", "password"), encrypted=False)
115+
try:
116+
with driver.session() as session:
117+
result = session.run("CREATE (a $x)", {"x": {"name": "Alice"}})
118+
assert not list(result)
119+
assert session.connection.address == ('127.0.0.1', 9006)
66120
finally:
67121
driver.close()

0 commit comments

Comments
 (0)