Skip to content

Commit 47bd136

Browse files
committed
Tests for routing strategies
1 parent 7f5ab21 commit 47bd136

File tree

2 files changed

+146
-43
lines changed

2 files changed

+146
-43
lines changed

neo4j/v1/routing.py

Lines changed: 31 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,7 @@
1717
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1818
# See the License for the specific language governing permissions and
1919
# limitations under the License.
20-
21-
20+
from abc import abstractmethod
2221
from sys import maxsize
2322
from threading import Lock
2423
from time import clock
@@ -175,11 +174,12 @@ def build(cls, connection_pool, **config):
175174
return RoundRobinLoadBalancingStrategy()
176175
else:
177176
raise ValueError("Unknown load balancing strategy '%s'" % load_balancing_strategy)
178-
pass
179177

178+
@abstractmethod
180179
def select_reader(self, known_readers):
181180
raise NotImplementedError()
182181

182+
@abstractmethod
183183
def select_writer(self, known_writers):
184184
raise NotImplementedError()
185185

@@ -190,22 +190,20 @@ class RoundRobinLoadBalancingStrategy(LoadBalancingStrategy):
190190
_writers_offset = 0
191191

192192
def select_reader(self, known_readers):
193-
address = self.select(self._readers_offset, known_readers)
193+
address = self._select(self._readers_offset, known_readers)
194194
self._readers_offset += 1
195195
return address
196196

197197
def select_writer(self, known_writers):
198-
address = self.select(self._writers_offset, known_writers)
198+
address = self._select(self._writers_offset, known_writers)
199199
self._writers_offset += 1
200200
return address
201201

202-
def select(self, offset, addresses):
203-
length = len(addresses)
204-
if length == 0:
202+
@classmethod
203+
def _select(cls, offset, addresses):
204+
if not addresses:
205205
return None
206-
else:
207-
index = offset % length
208-
return addresses.get(index)
206+
return addresses[offset % len(addresses)]
209207

210208

211209
class LeastConnectedLoadBalancingStrategy(LoadBalancingStrategy):
@@ -216,43 +214,42 @@ def __init__(self, connection_pool):
216214
self._connection_pool = connection_pool
217215

218216
def select_reader(self, known_readers):
219-
address = self.select(self._readers_offset, known_readers)
217+
address = self._select(self._readers_offset, known_readers)
220218
self._readers_offset += 1
221219
return address
222220

223221
def select_writer(self, known_writers):
224-
address = self.select(self._writers_offset, known_writers)
222+
address = self._select(self._writers_offset, known_writers)
225223
self._writers_offset += 1
226224
return address
227225

228-
def select(self, offset, addresses):
229-
length = len(addresses)
230-
if length == 0:
226+
def _select(self, offset, addresses):
227+
if not addresses:
231228
return None
232-
else:
233-
start_index = offset % length
234-
index = start_index
229+
num_addresses = len(addresses)
230+
start_index = offset % num_addresses
231+
index = start_index
235232

236-
least_connected_address = None
237-
least_in_use_connections = maxsize
233+
least_connected_address = None
234+
least_in_use_connections = maxsize
238235

239-
while True:
240-
address = addresses[index]
241-
in_use_connections = self._connection_pool.in_use_connection_count(address)
236+
while True:
237+
address = addresses[index]
238+
in_use_connections = self._connection_pool.in_use_connection_count(address)
242239

243-
if in_use_connections < least_in_use_connections:
244-
least_connected_address = address
245-
least_in_use_connections = in_use_connections
240+
if in_use_connections < least_in_use_connections:
241+
least_connected_address = address
242+
least_in_use_connections = in_use_connections
246243

247-
if index == length - 1:
248-
index = 0
249-
else:
250-
index += 1
244+
if index == num_addresses - 1:
245+
index = 0
246+
else:
247+
index += 1
251248

252-
if index == start_index:
253-
break
249+
if index == start_index:
250+
break
254251

255-
return least_connected_address
252+
return least_connected_address
256253

257254

258255
class RoutingConnectionPool(ConnectionPool):

test/unit/test_routing.py

Lines changed: 115 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,15 @@
1717
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1818
# See the License for the specific language governing permissions and
1919
# limitations under the License.
20-
20+
from collections import OrderedDict
2121
from unittest import TestCase
2222

2323
from neo4j.bolt import ProtocolError
2424
from neo4j.bolt.connection import connect
25-
from neo4j.v1.routing import OrderedSet, RoutingTable, RoutingConnectionPool
25+
from neo4j.v1.routing import OrderedSet, RoutingTable, RoutingConnectionPool, LeastConnectedLoadBalancingStrategy, \
26+
RoundRobinLoadBalancingStrategy
2627
from neo4j.v1.security import basic_auth
27-
from neo4j.v1.api import Driver, READ_ACCESS, WRITE_ACCESS
28+
from neo4j.v1.api import READ_ACCESS, WRITE_ACCESS
2829

2930

3031
VALID_ROUTING_RECORD = {
@@ -56,7 +57,6 @@ def connector(address):
5657

5758

5859
class RoundRobinSetTestCase(TestCase):
59-
6060
def test_should_repr_as_set(self):
6161
s = OrderedSet([1, 2, 3])
6262
assert repr(s) == "{1, 2, 3}"
@@ -135,15 +135,13 @@ def test_should_be_able_to_replace(self):
135135

136136

137137
class RoutingTableConstructionTestCase(TestCase):
138-
139138
def test_should_be_initially_stale(self):
140139
table = RoutingTable()
141140
assert not table.is_fresh(READ_ACCESS)
142141
assert not table.is_fresh(WRITE_ACCESS)
143142

144143

145144
class RoutingTableParseRoutingInfoTestCase(TestCase):
146-
147145
def test_should_return_routing_table_on_valid_record(self):
148146
table = RoutingTable.parse_routing_info([VALID_ROUTING_RECORD])
149147
assert table.routers == {('127.0.0.1', 9001), ('127.0.0.1', 9002), ('127.0.0.1', 9003)}
@@ -172,7 +170,6 @@ def test_should_fail_on_multiple_records(self):
172170

173171

174172
class RoutingTableFreshnessTestCase(TestCase):
175-
176173
def test_should_be_fresh_after_update(self):
177174
table = RoutingTable.parse_routing_info([VALID_ROUTING_RECORD])
178175
assert table.is_fresh(READ_ACCESS)
@@ -198,7 +195,6 @@ def test_should_become_stale_if_no_writers(self):
198195

199196

200197
class RoutingTableUpdateTestCase(TestCase):
201-
202198
def setUp(self):
203199
self.table = RoutingTable(
204200
[("192.168.1.1", 7687), ("192.168.1.2", 7687)], [("192.168.1.3", 7687)], [], 0)
@@ -224,9 +220,119 @@ def test_update_should_replace_ttl(self):
224220

225221

226222
class RoutingConnectionPoolConstructionTestCase(TestCase):
227-
228223
def test_should_populate_initial_router(self):
229224
initial_router = ("127.0.0.1", 9001)
230225
router = ("127.0.0.1", 9002)
231226
with RoutingConnectionPool(connector, initial_router, {}, router) as pool:
232227
assert pool.routing_table.routers == {("127.0.0.1", 9002)}
228+
229+
230+
class FakeConnectionPool(object):
231+
232+
def __init__(self, addresses):
233+
self._addresses = addresses
234+
235+
def in_use_connection_count(self, address):
236+
return self._addresses.get(address, 0)
237+
238+
239+
class RoundRobinLoadBalancingStrategyTestCase(TestCase):
240+
241+
def test_simple_reader_selection(self):
242+
strategy = RoundRobinLoadBalancingStrategy()
243+
self.assertEqual(strategy.select_reader(["0.0.0.0", "1.1.1.1", "2.2.2.2"]), "0.0.0.0")
244+
self.assertEqual(strategy.select_reader(["0.0.0.0", "1.1.1.1", "2.2.2.2"]), "1.1.1.1")
245+
self.assertEqual(strategy.select_reader(["0.0.0.0", "1.1.1.1", "2.2.2.2"]), "2.2.2.2")
246+
self.assertEqual(strategy.select_reader(["0.0.0.0", "1.1.1.1", "2.2.2.2"]), "0.0.0.0")
247+
248+
def test_empty_reader_selection(self):
249+
strategy = RoundRobinLoadBalancingStrategy()
250+
self.assertIsNone(strategy.select_reader([]))
251+
252+
def test_simple_writer_selection(self):
253+
strategy = RoundRobinLoadBalancingStrategy()
254+
self.assertEqual(strategy.select_writer(["0.0.0.0", "1.1.1.1", "2.2.2.2"]), "0.0.0.0")
255+
self.assertEqual(strategy.select_writer(["0.0.0.0", "1.1.1.1", "2.2.2.2"]), "1.1.1.1")
256+
self.assertEqual(strategy.select_writer(["0.0.0.0", "1.1.1.1", "2.2.2.2"]), "2.2.2.2")
257+
self.assertEqual(strategy.select_writer(["0.0.0.0", "1.1.1.1", "2.2.2.2"]), "0.0.0.0")
258+
259+
def test_empty_writer_selection(self):
260+
strategy = RoundRobinLoadBalancingStrategy()
261+
self.assertIsNone(strategy.select_writer([]))
262+
263+
264+
class LeastConnectedLoadBalancingStrategyTestCase(TestCase):
265+
266+
def test_simple_reader_selection(self):
267+
strategy = LeastConnectedLoadBalancingStrategy(FakeConnectionPool(OrderedDict([
268+
("0.0.0.0", 2),
269+
("1.1.1.1", 1),
270+
("2.2.2.2", 0),
271+
])))
272+
self.assertEqual(strategy.select_reader(["0.0.0.0", "1.1.1.1", "2.2.2.2"]), "2.2.2.2")
273+
274+
def test_reader_selection_with_clash(self):
275+
strategy = LeastConnectedLoadBalancingStrategy(FakeConnectionPool(OrderedDict([
276+
("0.0.0.0", 0),
277+
("0.0.0.1", 0),
278+
("1.1.1.1", 1),
279+
])))
280+
self.assertEqual(strategy.select_reader(["0.0.0.0", "0.0.0.1", "1.1.1.1"]), "0.0.0.0")
281+
self.assertEqual(strategy.select_reader(["0.0.0.0", "0.0.0.1", "1.1.1.1"]), "0.0.0.1")
282+
283+
def test_empty_reader_selection(self):
284+
strategy = LeastConnectedLoadBalancingStrategy(FakeConnectionPool(OrderedDict([
285+
])))
286+
self.assertIsNone(strategy.select_reader([]))
287+
288+
def test_not_in_pool_reader_selection(self):
289+
strategy = LeastConnectedLoadBalancingStrategy(FakeConnectionPool(OrderedDict([
290+
("1.1.1.1", 1),
291+
("2.2.2.2", 2),
292+
])))
293+
self.assertEqual(strategy.select_reader(["2.2.2.2", "3.3.3.3"]), "3.3.3.3")
294+
295+
def test_partially_in_pool_reader_selection(self):
296+
strategy = LeastConnectedLoadBalancingStrategy(FakeConnectionPool(OrderedDict([
297+
("1.1.1.1", 1),
298+
("2.2.2.2", 0),
299+
])))
300+
self.assertEqual(strategy.select_reader(["2.2.2.2", "3.3.3.3"]), "2.2.2.2")
301+
self.assertEqual(strategy.select_reader(["2.2.2.2", "3.3.3.3"]), "3.3.3.3")
302+
303+
def test_simple_writer_selection(self):
304+
strategy = LeastConnectedLoadBalancingStrategy(FakeConnectionPool(OrderedDict([
305+
("0.0.0.0", 2),
306+
("1.1.1.1", 1),
307+
("2.2.2.2", 0),
308+
])))
309+
self.assertEqual(strategy.select_writer(["0.0.0.0", "1.1.1.1", "2.2.2.2"]), "2.2.2.2")
310+
311+
def test_writer_selection_with_clash(self):
312+
strategy = LeastConnectedLoadBalancingStrategy(FakeConnectionPool(OrderedDict([
313+
("0.0.0.0", 0),
314+
("0.0.0.1", 0),
315+
("1.1.1.1", 1),
316+
])))
317+
self.assertEqual(strategy.select_writer(["0.0.0.0", "0.0.0.1", "1.1.1.1"]), "0.0.0.0")
318+
self.assertEqual(strategy.select_writer(["0.0.0.0", "0.0.0.1", "1.1.1.1"]), "0.0.0.1")
319+
320+
def test_empty_writer_selection(self):
321+
strategy = LeastConnectedLoadBalancingStrategy(FakeConnectionPool(OrderedDict([
322+
])))
323+
self.assertIsNone(strategy.select_writer([]))
324+
325+
def test_not_in_pool_writer_selection(self):
326+
strategy = LeastConnectedLoadBalancingStrategy(FakeConnectionPool(OrderedDict([
327+
("1.1.1.1", 1),
328+
("2.2.2.2", 2),
329+
])))
330+
self.assertEqual(strategy.select_writer(["2.2.2.2", "3.3.3.3"]), "3.3.3.3")
331+
332+
def test_partially_in_pool_writer_selection(self):
333+
strategy = LeastConnectedLoadBalancingStrategy(FakeConnectionPool(OrderedDict([
334+
("1.1.1.1", 1),
335+
("2.2.2.2", 0),
336+
])))
337+
self.assertEqual(strategy.select_writer(["2.2.2.2", "3.3.3.3"]), "2.2.2.2")
338+
self.assertEqual(strategy.select_writer(["2.2.2.2", "3.3.3.3"]), "3.3.3.3")

0 commit comments

Comments
 (0)