19
19
# limitations under the License.
20
20
21
21
22
+ from sys import maxsize
22
23
from threading import Lock
23
24
from time import clock
24
25
25
26
from neo4j .addressing import SocketAddress , resolve
26
27
from neo4j .bolt import ConnectionPool , ServiceUnavailable , ProtocolError , DEFAULT_PORT , connect
27
28
from neo4j .compat .collections import MutableSet , OrderedDict
28
29
from neo4j .exceptions import CypherError
30
+ from neo4j .util import ServerVersion
29
31
from neo4j .v1 .api import Driver , READ_ACCESS , WRITE_ACCESS , fix_statement , fix_parameters
30
32
from neo4j .v1 .exceptions import SessionExpired
31
33
from neo4j .v1 .security import SecurityPlan
32
34
from neo4j .v1 .session import BoltSession
33
- from neo4j .util import ServerVersion
35
+
36
+
37
+ LOAD_BALANCING_STRATEGY_LEAST_CONNECTED = 0
38
+ LOAD_BALANCING_STRATEGY_ROUND_ROBIN = 1
39
+ LOAD_BALANCING_STRATEGY_DEFAULT = LOAD_BALANCING_STRATEGY_LEAST_CONNECTED
34
40
35
41
36
42
class RoundRobinSet (MutableSet ):
@@ -52,7 +58,7 @@ def __next__(self):
52
58
self ._current = 0
53
59
else :
54
60
self ._current = (self ._current + 1 ) % len (self ._elements )
55
- current = list ( self ._elements . keys ())[ self ._current ]
61
+ current = self .get ( self ._current )
56
62
return current
57
63
58
64
def __iter__ (self ):
@@ -90,6 +96,9 @@ def replace(self, elements=()):
90
96
e .clear ()
91
97
e .update (OrderedDict .fromkeys (elements ))
92
98
99
+ def get (self , index ):
100
+ return list (self ._elements .keys ())[index ]
101
+
93
102
94
103
class RoutingTable (object ):
95
104
@@ -168,17 +177,109 @@ def __run__(self, ignored, routing_context):
168
177
return self ._run (fix_statement (statement ), fix_parameters (parameters ))
169
178
170
179
180
+ class LoadBalancingStrategy (object ):
181
+
182
+ @classmethod
183
+ def build (cls , connection_pool , ** config ):
184
+ load_balancing_strategy = config .get ("load_balancing_strategy" , LOAD_BALANCING_STRATEGY_DEFAULT )
185
+ if load_balancing_strategy == LOAD_BALANCING_STRATEGY_LEAST_CONNECTED :
186
+ return LeastConnectedLoadBalancingStrategy (connection_pool )
187
+ elif load_balancing_strategy == LOAD_BALANCING_STRATEGY_ROUND_ROBIN :
188
+ return RoundRobinLoadBalancingStrategy ()
189
+ else :
190
+ raise ValueError ("Unknown load balancing strategy '%s'" % load_balancing_strategy )
191
+ pass
192
+
193
+ def select_reader (self , known_readers ):
194
+ raise NotImplementedError ()
195
+
196
+ def select_writer (self , known_writers ):
197
+ raise NotImplementedError ()
198
+
199
+
200
+ class RoundRobinLoadBalancingStrategy (LoadBalancingStrategy ):
201
+
202
+ _readers_offset = 0
203
+ _writers_offset = 0
204
+
205
+ def select_reader (self , known_readers ):
206
+ address = self .select (self ._readers_offset , known_readers )
207
+ self ._readers_offset += 1
208
+ return address
209
+
210
+ def select_writer (self , known_writers ):
211
+ address = self .select (self ._writers_offset , known_writers )
212
+ self ._writers_offset += 1
213
+ return address
214
+
215
+ def select (self , offset , addresses ):
216
+ length = len (addresses )
217
+ if length == 0 :
218
+ return None
219
+ else :
220
+ index = offset % length
221
+ return addresses .get (index )
222
+
223
+
224
+ class LeastConnectedLoadBalancingStrategy (LoadBalancingStrategy ):
225
+
226
+ def __init__ (self , connection_pool ):
227
+ self ._readers_offset = 0
228
+ self ._writers_offset = 0
229
+ self ._connection_pool = connection_pool
230
+
231
+ def select_reader (self , known_readers ):
232
+ address = self .select (self ._readers_offset , known_readers )
233
+ self ._readers_offset += 1
234
+ return address
235
+
236
+ def select_writer (self , known_writers ):
237
+ address = self .select (self ._writers_offset , known_writers )
238
+ self ._writers_offset += 1
239
+ return address
240
+
241
+ def select (self , offset , addresses ):
242
+ length = len (addresses )
243
+ if length == 0 :
244
+ return None
245
+ else :
246
+ start_index = offset % length
247
+ index = start_index
248
+
249
+ least_connected_address = None
250
+ least_in_use_connections = maxsize
251
+
252
+ while True :
253
+ address = addresses .get (index )
254
+ in_use_connections = self ._connection_pool .in_use_connection_count (address )
255
+
256
+ if in_use_connections < least_in_use_connections :
257
+ least_connected_address = address
258
+ least_in_use_connections = in_use_connections
259
+
260
+ if index == length - 1 :
261
+ index = 0
262
+ else :
263
+ index += 1
264
+
265
+ if index == start_index :
266
+ break
267
+
268
+ return least_connected_address
269
+
270
+
171
271
class RoutingConnectionPool (ConnectionPool ):
172
272
""" Connection pool with routing table.
173
273
"""
174
274
175
- def __init__ (self , connector , initial_address , routing_context , * routers ):
275
+ def __init__ (self , connector , initial_address , routing_context , * routers , ** config ):
176
276
super (RoutingConnectionPool , self ).__init__ (connector )
177
277
self .initial_address = initial_address
178
278
self .routing_context = routing_context
179
279
self .routing_table = RoutingTable (routers )
180
280
self .missing_writer = False
181
281
self .refresh_lock = Lock ()
282
+ self .load_balancing_strategy = LoadBalancingStrategy .build (self , ** config )
182
283
183
284
def fetch_routing_info (self , address ):
184
285
""" Fetch raw routing info from a given router address.
@@ -304,14 +405,16 @@ def acquire(self, access_mode=None):
304
405
access_mode = WRITE_ACCESS
305
406
if access_mode == READ_ACCESS :
306
407
server_list = self .routing_table .readers
408
+ server_selector = self .load_balancing_strategy .select_reader
307
409
elif access_mode == WRITE_ACCESS :
308
410
server_list = self .routing_table .writers
411
+ server_selector = self .load_balancing_strategy .select_writer
309
412
else :
310
413
raise ValueError ("Unsupported access mode {}" .format (access_mode ))
311
414
312
415
self .ensure_routing_table_is_fresh (access_mode )
313
416
while True :
314
- address = next (server_list )
417
+ address = server_selector (server_list )
315
418
if address is None :
316
419
break
317
420
try :
@@ -354,7 +457,7 @@ def __init__(self, uri, **config):
354
457
def connector (a ):
355
458
return connect (a , security_plan .ssl_context , ** config )
356
459
357
- pool = RoutingConnectionPool (connector , initial_address , routing_context , * resolve (initial_address ))
460
+ pool = RoutingConnectionPool (connector , initial_address , routing_context , * resolve (initial_address ), ** config )
358
461
try :
359
462
pool .update_routing_table ()
360
463
except :
0 commit comments