Skip to content

Commit c660463

Browse files
lutovichtechnige
authored andcommitted
Initial impl of least connected
1 parent 871009c commit c660463

File tree

2 files changed

+122
-5
lines changed

2 files changed

+122
-5
lines changed

neo4j/bolt/connection.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,20 @@ def release(self, connection):
418418
with self.lock:
419419
connection.in_use = False
420420

421+
def in_use_connection_count(self, address):
422+
try:
423+
connections = self.connections[address]
424+
except KeyError:
425+
return 0
426+
else:
427+
in_use_count = 0
428+
429+
for connection in list(connections):
430+
if connection.in_use:
431+
in_use_count += 1
432+
433+
return in_use_count
434+
421435
def remove(self, address):
422436
""" Remove an address from the connection pool, if present, closing
423437
all connections to that address.

neo4j/v1/routing.py

Lines changed: 108 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,24 @@
1919
# limitations under the License.
2020

2121

22+
from sys import maxsize
2223
from threading import Lock
2324
from time import clock
2425

2526
from neo4j.addressing import SocketAddress, resolve
2627
from neo4j.bolt import ConnectionPool, ServiceUnavailable, ProtocolError, DEFAULT_PORT, connect
2728
from neo4j.compat.collections import MutableSet, OrderedDict
2829
from neo4j.exceptions import CypherError
30+
from neo4j.util import ServerVersion
2931
from neo4j.v1.api import Driver, READ_ACCESS, WRITE_ACCESS, fix_statement, fix_parameters
3032
from neo4j.v1.exceptions import SessionExpired
3133
from neo4j.v1.security import SecurityPlan
3234
from neo4j.v1.session import BoltSession
33-
from neo4j.util import ServerVersion
35+
36+
37+
LOAD_BALANCING_STRATEGY_LEAST_CONNECTED = 0
38+
LOAD_BALANCING_STRATEGY_ROUND_ROBIN = 1
39+
LOAD_BALANCING_STRATEGY_DEFAULT = LOAD_BALANCING_STRATEGY_LEAST_CONNECTED
3440

3541

3642
class RoundRobinSet(MutableSet):
@@ -52,7 +58,7 @@ def __next__(self):
5258
self._current = 0
5359
else:
5460
self._current = (self._current + 1) % len(self._elements)
55-
current = list(self._elements.keys())[self._current]
61+
current = self.get(self._current)
5662
return current
5763

5864
def __iter__(self):
@@ -90,6 +96,9 @@ def replace(self, elements=()):
9096
e.clear()
9197
e.update(OrderedDict.fromkeys(elements))
9298

99+
def get(self, index):
100+
return list(self._elements.keys())[index]
101+
93102

94103
class RoutingTable(object):
95104

@@ -168,17 +177,109 @@ def __run__(self, ignored, routing_context):
168177
return self._run(fix_statement(statement), fix_parameters(parameters))
169178

170179

180+
class LoadBalancingStrategy(object):
181+
182+
@classmethod
183+
def build(cls, connection_pool, **config):
184+
load_balancing_strategy = config.get("load_balancing_strategy", LOAD_BALANCING_STRATEGY_DEFAULT)
185+
if load_balancing_strategy == LOAD_BALANCING_STRATEGY_LEAST_CONNECTED:
186+
return LeastConnectedLoadBalancingStrategy(connection_pool)
187+
elif load_balancing_strategy == LOAD_BALANCING_STRATEGY_ROUND_ROBIN:
188+
return RoundRobinLoadBalancingStrategy()
189+
else:
190+
raise ValueError("Unknown load balancing strategy '%s'" % load_balancing_strategy)
191+
pass
192+
193+
def select_reader(self, known_readers):
194+
raise NotImplementedError()
195+
196+
def select_writer(self, known_writers):
197+
raise NotImplementedError()
198+
199+
200+
class RoundRobinLoadBalancingStrategy(LoadBalancingStrategy):
201+
202+
_readers_offset = 0
203+
_writers_offset = 0
204+
205+
def select_reader(self, known_readers):
206+
address = self.select(self._readers_offset, known_readers)
207+
self._readers_offset += 1
208+
return address
209+
210+
def select_writer(self, known_writers):
211+
address = self.select(self._writers_offset, known_writers)
212+
self._writers_offset += 1
213+
return address
214+
215+
def select(self, offset, addresses):
216+
length = len(addresses)
217+
if length == 0:
218+
return None
219+
else:
220+
index = offset % length
221+
return addresses.get(index)
222+
223+
224+
class LeastConnectedLoadBalancingStrategy(LoadBalancingStrategy):
225+
226+
def __init__(self, connection_pool):
227+
self._readers_offset = 0
228+
self._writers_offset = 0
229+
self._connection_pool = connection_pool
230+
231+
def select_reader(self, known_readers):
232+
address = self.select(self._readers_offset, known_readers)
233+
self._readers_offset += 1
234+
return address
235+
236+
def select_writer(self, known_writers):
237+
address = self.select(self._writers_offset, known_writers)
238+
self._writers_offset += 1
239+
return address
240+
241+
def select(self, offset, addresses):
242+
length = len(addresses)
243+
if length == 0:
244+
return None
245+
else:
246+
start_index = offset % length
247+
index = start_index
248+
249+
least_connected_address = None
250+
least_in_use_connections = maxsize
251+
252+
while True:
253+
address = addresses.get(index)
254+
in_use_connections = self._connection_pool.in_use_connection_count(address)
255+
256+
if in_use_connections < least_in_use_connections:
257+
least_connected_address = address
258+
least_in_use_connections = in_use_connections
259+
260+
if index == length - 1:
261+
index = 0
262+
else:
263+
index += 1
264+
265+
if index == start_index:
266+
break
267+
268+
return least_connected_address
269+
270+
171271
class RoutingConnectionPool(ConnectionPool):
172272
""" Connection pool with routing table.
173273
"""
174274

175-
def __init__(self, connector, initial_address, routing_context, *routers):
275+
def __init__(self, connector, initial_address, routing_context, *routers, **config):
176276
super(RoutingConnectionPool, self).__init__(connector)
177277
self.initial_address = initial_address
178278
self.routing_context = routing_context
179279
self.routing_table = RoutingTable(routers)
180280
self.missing_writer = False
181281
self.refresh_lock = Lock()
282+
self.load_balancing_strategy = LoadBalancingStrategy.build(self, **config)
182283

183284
def fetch_routing_info(self, address):
184285
""" Fetch raw routing info from a given router address.
@@ -304,14 +405,16 @@ def acquire(self, access_mode=None):
304405
access_mode = WRITE_ACCESS
305406
if access_mode == READ_ACCESS:
306407
server_list = self.routing_table.readers
408+
server_selector = self.load_balancing_strategy.select_reader
307409
elif access_mode == WRITE_ACCESS:
308410
server_list = self.routing_table.writers
411+
server_selector = self.load_balancing_strategy.select_writer
309412
else:
310413
raise ValueError("Unsupported access mode {}".format(access_mode))
311414

312415
self.ensure_routing_table_is_fresh(access_mode)
313416
while True:
314-
address = next(server_list)
417+
address = server_selector(server_list)
315418
if address is None:
316419
break
317420
try:
@@ -354,7 +457,7 @@ def __init__(self, uri, **config):
354457
def connector(a):
355458
return connect(a, security_plan.ssl_context, **config)
356459

357-
pool = RoutingConnectionPool(connector, initial_address, routing_context, *resolve(initial_address))
460+
pool = RoutingConnectionPool(connector, initial_address, routing_context, *resolve(initial_address), **config)
358461
try:
359462
pool.update_routing_table()
360463
except:

0 commit comments

Comments
 (0)