Skip to content

Commit 8d1a8d2

Browse files
committed
More tests
1 parent ebb633b commit 8d1a8d2

File tree

6 files changed

+292
-262
lines changed

6 files changed

+292
-262
lines changed

neo4j/v1/session.py

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -146,9 +146,6 @@ def __init__(self, requires_encryption, ssl_context, routing_compatible):
146146
self.ssl_context = ssl_context
147147
self.routing_compatible = routing_compatible
148148

149-
def __repr__(self):
150-
return repr(vars(self))
151-
152149

153150
class Driver(object):
154151
""" A :class:`.Driver` is an accessor for a specific graph database
@@ -180,7 +177,7 @@ def session(self, access_mode=None):
180177
not thread safe, therefore a session should generally be short-lived
181178
within a single thread.
182179
"""
183-
raise NotImplementedError()
180+
pass
184181

185182
def close(self):
186183
if self.pool:
@@ -202,7 +199,7 @@ def session(self, access_mode=None):
202199
return Session(self, self.pool.acquire(self.address))
203200

204201

205-
class Router(object):
202+
class ConnectionRouter(object):
206203
""" The `Router` class contains logic for discovering servers within a
207204
cluster that supports routing.
208205
"""
@@ -241,18 +238,20 @@ def discover(self):
241238
"routing" % (router,))
242239
raise
243240
except ResultError:
244-
raise RuntimeError("TODO")
245-
new_expiry_time = self.timer() + record["ttl"]
246-
servers = record["servers"]
247-
new_routers = [s["addresses"] for s in servers if s["role"] == "ROUTE"][0]
248-
new_readers = [s["addresses"] for s in servers if s["role"] == "READ"][0]
249-
new_writers = [s["addresses"] for s in servers if s["role"] == "WRITE"][0]
250-
if new_routers and new_readers and new_writers:
251-
self.expiry_time = new_expiry_time
252-
self.routers.replace(map(parse_address, new_routers))
253-
self.readers.replace(map(parse_address, new_readers))
254-
self.writers.replace(map(parse_address, new_writers))
255-
return router
241+
raise ServiceUnavailable("Server %r returned no record from "
242+
"discovery procedure" % (router,))
243+
else:
244+
new_expiry_time = self.timer() + record["ttl"]
245+
servers = record["servers"]
246+
new_routers = [s["addresses"] for s in servers if s["role"] == "ROUTE"][0]
247+
new_readers = [s["addresses"] for s in servers if s["role"] == "READ"][0]
248+
new_writers = [s["addresses"] for s in servers if s["role"] == "WRITE"][0]
249+
if new_routers and new_readers and new_writers:
250+
self.expiry_time = new_expiry_time
251+
self.routers.replace(map(parse_address, new_routers))
252+
self.readers.replace(map(parse_address, new_readers))
253+
self.writers.replace(map(parse_address, new_writers))
254+
return router
256255
raise ServiceUnavailable("Unable to establish routing information")
257256

258257
def acquire_read_connection(self):
@@ -280,9 +279,9 @@ def __init__(self, address, **config):
280279
if not security_plan.routing_compatible:
281280
# this error message is case-specific as there is only one incompatible
282281
# scenario right now
283-
raise RuntimeError("TRUST_ON_FIRST_USE is not compatible with routing")
282+
raise ValueError("TRUST_ON_FIRST_USE is not compatible with routing")
284283
Driver.__init__(self, lambda a: connect(a, security_plan.ssl_context, **config))
285-
self.router = Router(self.pool, address)
284+
self.router = ConnectionRouter(self.pool, address)
286285
self.router.discover()
287286

288287
def session(self, access_mode=None):

test/resources/silent_router.script

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
!: AUTO INIT
2+
!: AUTO RESET
3+
4+
C: RUN "CALL dbms.cluster.routing.getServers" {}
5+
PULL_ALL
6+
S: SUCCESS {"fields": ["ttl", "servers"]}
7+
SUCCESS {}

test/test_connection.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
from socket import create_connection
2222

23-
from neo4j.v1.bolt import ConnectionPool
23+
from neo4j.v1 import basic_auth, ConnectionRouter, ConnectionPool, connect, ServiceUnavailable
2424

2525
from test.util import ServerTestCase
2626

@@ -98,3 +98,48 @@ def test_releasing_twice(self):
9898
self.assert_pool_size(address, 0, 1)
9999
self.pool.release(connection)
100100
self.assert_pool_size(address, 0, 1)
101+
102+
103+
class RouterTestCase(ServerTestCase):
104+
105+
def setUp(self):
106+
self.pool = ConnectionPool(lambda a: connect(a, auth=basic_auth("neo4j", "password")))
107+
108+
def test_router_is_initially_stale(self):
109+
router = ConnectionRouter(self.pool, ("127.0.0.1", 7687))
110+
assert router.stale()
111+
112+
def test_discovery(self):
113+
self.start_stub_server(9001, "router.script")
114+
router = ConnectionRouter(self.pool, ("127.0.0.1", 9001))
115+
router.timer = lambda: 0
116+
router.discover()
117+
assert router.expiry_time == 300
118+
assert router.routers == {('127.0.0.1', 9001), ('127.0.0.1', 9002), ('127.0.0.1', 9003)}
119+
assert router.readers == {('127.0.0.1', 9004), ('127.0.0.1', 9005)}
120+
assert router.writers == {('127.0.0.1', 9006)}
121+
122+
def test_discovery_after_bad_discovery(self):
123+
self.start_stub_server(9001, "bad_router.script")
124+
self.start_stub_server(9002, "router.script")
125+
router = ConnectionRouter(self.pool, ("127.0.0.1", 9001), ("127.0.0.1", 9002))
126+
router.timer = lambda: 0
127+
router.discover()
128+
assert router.expiry_time == 300
129+
assert router.routers == {('127.0.0.1', 9001), ('127.0.0.1', 9002), ('127.0.0.1', 9003)}
130+
assert router.readers == {('127.0.0.1', 9004), ('127.0.0.1', 9005)}
131+
assert router.writers == {('127.0.0.1', 9006)}
132+
133+
def test_discovery_against_non_router(self):
134+
self.start_stub_server(9001, "non_router.script")
135+
router = ConnectionRouter(self.pool, ("127.0.0.1", 9001))
136+
with self.assertRaises(ServiceUnavailable):
137+
router.discover()
138+
139+
def test_running_out_of_good_routers_on_discovery(self):
140+
self.start_stub_server(9001, "bad_router.script")
141+
self.start_stub_server(9002, "bad_router.script")
142+
self.start_stub_server(9003, "bad_router.script")
143+
router = ConnectionRouter(self.pool, ("127.0.0.1", 9001), ("127.0.0.1", 9002), ("127.0.0.1", 9003))
144+
with self.assertRaises(ServiceUnavailable):
145+
router.discover()

test/test_driver.py

Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
#!/usr/bin/env python
2+
# -*- encoding: utf-8 -*-
3+
4+
# Copyright (c) 2002-2016 "Neo Technology,"
5+
# Network Engine for Objects in Lund AB [http://neotechnology.com]
6+
#
7+
# This file is part of Neo4j.
8+
#
9+
# Licensed under the Apache License, Version 2.0 (the "License");
10+
# you may not use this file except in compliance with the License.
11+
# You may obtain a copy of the License at
12+
#
13+
# http://www.apache.org/licenses/LICENSE-2.0
14+
#
15+
# Unless required by applicable law or agreed to in writing, software
16+
# distributed under the License is distributed on an "AS IS" BASIS,
17+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18+
# See the License for the specific language governing permissions and
19+
# limitations under the License.
20+
21+
22+
from socket import socket
23+
from ssl import SSLSocket
24+
from unittest import skipUnless
25+
26+
from neo4j.v1 import ServiceUnavailable, ProtocolError, READ_ACCESS, WRITE_ACCESS, \
27+
TRUST_ON_FIRST_USE, TRUST_CUSTOM_CA_SIGNED_CERTIFICATES, GraphDatabase, basic_auth, \
28+
SSL_AVAILABLE, SessionExpired, DirectDriver
29+
from test.util import ServerTestCase
30+
31+
BOLT_URI = "bolt://localhost:7687"
32+
BOLT_ROUTING_URI = "bolt+routing://localhost:7687"
33+
AUTH_TOKEN = basic_auth("neotest", "neotest")
34+
35+
36+
class DriverTestCase(ServerTestCase):
37+
38+
def test_driver_with_block(self):
39+
with GraphDatabase.driver(BOLT_URI, auth=AUTH_TOKEN, encrypted=False) as driver:
40+
assert isinstance(driver, DirectDriver)
41+
42+
def test_must_use_valid_url_scheme(self):
43+
with self.assertRaises(ProtocolError):
44+
GraphDatabase.driver("x://xxx", auth=AUTH_TOKEN)
45+
46+
def test_connections_are_reused(self):
47+
driver = GraphDatabase.driver(BOLT_URI, auth=AUTH_TOKEN)
48+
session_1 = driver.session()
49+
connection_1 = session_1.connection
50+
session_1.close()
51+
session_2 = driver.session()
52+
connection_2 = session_2.connection
53+
session_2.close()
54+
assert connection_1 is connection_2
55+
56+
def test_connections_are_not_shared_between_sessions(self):
57+
driver = GraphDatabase.driver(BOLT_URI, auth=AUTH_TOKEN)
58+
session_1 = driver.session()
59+
session_2 = driver.session()
60+
try:
61+
assert session_1.connection is not session_2.connection
62+
finally:
63+
session_1.close()
64+
session_2.close()
65+
66+
def test_fail_nicely_when_connecting_to_http_port(self):
67+
driver = GraphDatabase.driver("bolt://localhost:7474", auth=AUTH_TOKEN, encrypted=False)
68+
with self.assertRaises(ServiceUnavailable) as context:
69+
driver.session()
70+
71+
72+
class DirectDriverTestCase(ServerTestCase):
73+
74+
def test_direct_disconnect_on_run(self):
75+
self.start_stub_server(9001, "disconnect_on_run.script")
76+
uri = "bolt://127.0.0.1:9001"
77+
driver = GraphDatabase.driver(uri, auth=basic_auth("neo4j", "password"), encrypted=False)
78+
try:
79+
with driver.session() as session:
80+
with self.assertRaises(ServiceUnavailable):
81+
session.run("RETURN $x", {"x": 1}).consume()
82+
finally:
83+
driver.close()
84+
85+
def test_direct_disconnect_on_pull_all(self):
86+
self.start_stub_server(9001, "disconnect_on_pull_all.script")
87+
uri = "bolt://127.0.0.1:9001"
88+
driver = GraphDatabase.driver(uri, auth=basic_auth("neo4j", "password"), encrypted=False)
89+
try:
90+
with driver.session() as session:
91+
with self.assertRaises(ServiceUnavailable):
92+
session.run("RETURN $x", {"x": 1}).consume()
93+
finally:
94+
driver.close()
95+
96+
97+
class RoutingDriverTestCase(ServerTestCase):
98+
99+
def test_cannot_discover_servers_on_non_router(self):
100+
self.start_stub_server(9001, "non_router.script")
101+
uri = "bolt+routing://127.0.0.1:9001"
102+
with self.assertRaises(ServiceUnavailable):
103+
GraphDatabase.driver(uri, auth=basic_auth("neo4j", "password"), encrypted=False)
104+
105+
def test_cannot_discover_servers_on_silent_router(self):
106+
self.start_stub_server(9001, "silent_router.script")
107+
uri = "bolt+routing://127.0.0.1:9001"
108+
with self.assertRaises(ServiceUnavailable):
109+
GraphDatabase.driver(uri, auth=basic_auth("neo4j", "password"), encrypted=False)
110+
111+
def test_should_discover_servers_on_driver_construction(self):
112+
self.start_stub_server(9001, "router.script")
113+
uri = "bolt+routing://127.0.0.1:9001"
114+
driver = GraphDatabase.driver(uri, auth=basic_auth("neo4j", "password"), encrypted=False)
115+
router = driver.router
116+
assert router.routers == {('127.0.0.1', 9001), ('127.0.0.1', 9002), ('127.0.0.1', 9003)}
117+
assert router.readers == {('127.0.0.1', 9004), ('127.0.0.1', 9005)}
118+
assert router.writers == {('127.0.0.1', 9006)}
119+
120+
def test_should_be_able_to_read(self):
121+
self.start_stub_server(9001, "router.script")
122+
self.start_stub_server(9004, "return_1.script")
123+
uri = "bolt+routing://127.0.0.1:9001"
124+
driver = GraphDatabase.driver(uri, auth=basic_auth("neo4j", "password"), encrypted=False)
125+
try:
126+
with driver.session(READ_ACCESS) as session:
127+
result = session.run("RETURN $x", {"x": 1})
128+
for record in result:
129+
assert record["x"] == 1
130+
assert session.connection.address == ('127.0.0.1', 9004)
131+
finally:
132+
driver.close()
133+
134+
def test_should_be_able_to_write(self):
135+
self.start_stub_server(9001, "router.script")
136+
self.start_stub_server(9006, "create_a.script")
137+
uri = "bolt+routing://127.0.0.1:9001"
138+
driver = GraphDatabase.driver(uri, auth=basic_auth("neo4j", "password"), encrypted=False)
139+
try:
140+
with driver.session(WRITE_ACCESS) as session:
141+
result = session.run("CREATE (a $x)", {"x": {"name": "Alice"}})
142+
assert not list(result)
143+
assert session.connection.address == ('127.0.0.1', 9006)
144+
finally:
145+
driver.close()
146+
147+
def test_should_be_able_to_write_as_default(self):
148+
self.start_stub_server(9001, "router.script")
149+
self.start_stub_server(9006, "create_a.script")
150+
uri = "bolt+routing://127.0.0.1:9001"
151+
driver = GraphDatabase.driver(uri, auth=basic_auth("neo4j", "password"), encrypted=False)
152+
try:
153+
with driver.session() as session:
154+
result = session.run("CREATE (a $x)", {"x": {"name": "Alice"}})
155+
assert not list(result)
156+
assert session.connection.address == ('127.0.0.1', 9006)
157+
finally:
158+
driver.close()
159+
160+
def test_routing_disconnect_on_run(self):
161+
self.start_stub_server(9001, "router.script")
162+
self.start_stub_server(9004, "disconnect_on_run.script")
163+
uri = "bolt+routing://127.0.0.1:9001"
164+
driver = GraphDatabase.driver(uri, auth=basic_auth("neo4j", "password"), encrypted=False)
165+
try:
166+
with driver.session(READ_ACCESS) as session:
167+
with self.assertRaises(SessionExpired):
168+
session.run("RETURN $x", {"x": 1}).consume()
169+
finally:
170+
driver.close()
171+
172+
def test_routing_disconnect_on_pull_all(self):
173+
self.start_stub_server(9001, "router.script")
174+
self.start_stub_server(9004, "disconnect_on_pull_all.script")
175+
uri = "bolt+routing://127.0.0.1:9001"
176+
driver = GraphDatabase.driver(uri, auth=basic_auth("neo4j", "password"), encrypted=False)
177+
try:
178+
with driver.session(READ_ACCESS) as session:
179+
with self.assertRaises(SessionExpired):
180+
session.run("RETURN $x", {"x": 1}).consume()
181+
finally:
182+
driver.close()
183+
184+
185+
class SecurityTestCase(ServerTestCase):
186+
187+
def test_insecure_session_uses_normal_socket(self):
188+
driver = GraphDatabase.driver(BOLT_URI, auth=AUTH_TOKEN, encrypted=False)
189+
with driver.session() as session:
190+
connection = session.connection
191+
assert isinstance(connection.channel.socket, socket)
192+
assert connection.der_encoded_server_certificate is None
193+
194+
@skipUnless(SSL_AVAILABLE, "Bolt over TLS is not supported by this version of Python")
195+
def test_tofu_session_uses_secure_socket(self):
196+
driver = GraphDatabase.driver(BOLT_URI, auth=AUTH_TOKEN, encrypted=True, trust=TRUST_ON_FIRST_USE)
197+
with driver.session() as session:
198+
connection = session.connection
199+
assert isinstance(connection.channel.socket, SSLSocket)
200+
assert connection.der_encoded_server_certificate is not None
201+
202+
@skipUnless(SSL_AVAILABLE, "Bolt over TLS is not supported by this version of Python")
203+
def test_tofu_session_trusts_certificate_after_first_use(self):
204+
driver = GraphDatabase.driver(BOLT_URI, auth=AUTH_TOKEN, encrypted=True, trust=TRUST_ON_FIRST_USE)
205+
with driver.session() as session:
206+
connection = session.connection
207+
certificate = connection.der_encoded_server_certificate
208+
with driver.session() as session:
209+
connection = session.connection
210+
assert connection.der_encoded_server_certificate == certificate
211+
212+
def test_routing_driver_not_compatible_with_tofu(self):
213+
with self.assertRaises(ValueError):
214+
GraphDatabase.driver(BOLT_ROUTING_URI, auth=AUTH_TOKEN, trust=TRUST_ON_FIRST_USE)
215+
216+
def test_custom_ca_not_implemented(self):
217+
with self.assertRaises(NotImplementedError):
218+
GraphDatabase.driver(BOLT_URI, auth=AUTH_TOKEN,
219+
trust=TRUST_CUSTOM_CA_SIGNED_CERTIFICATES)

0 commit comments

Comments
 (0)