Skip to content

Commit 8a00a78

Browse files
author
Zhen
committed
Adding max_connection_pool_size as well as connection_acquisition_timeout
Moved connection tests from integration test to unit test as no server is needed
1 parent 2c0b708 commit 8a00a78

File tree

5 files changed

+200
-155
lines changed

5 files changed

+200
-155
lines changed

neo4j/bolt/connection.py

Lines changed: 47 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,9 @@
3232
from select import select
3333
from socket import socket, SOL_SOCKET, SO_KEEPALIVE, SHUT_RDWR, error as SocketError, timeout as SocketTimeout, AF_INET, AF_INET6
3434
from struct import pack as struct_pack, unpack as struct_unpack
35-
from threading import RLock
35+
from threading import RLock, Condition
3636

37+
from neo4j.v1 import ClientError
3738
from neo4j.addressing import SocketAddress, is_ip_address
3839
from neo4j.bolt.cert import KNOWN_HOSTS
3940
from neo4j.bolt.response import InitResponse, AckFailureResponse, ResetResponse
@@ -48,9 +49,11 @@
4849
ChunkedOutputBuffer = _import_best("neo4j.bolt._io", "neo4j.bolt.io").ChunkedOutputBuffer
4950

5051

51-
INFINITE_CONNECTION_LIFETIME = -1
52-
DEFAULT_MAX_CONNECTION_LIFETIME = INFINITE_CONNECTION_LIFETIME
52+
INFINITE = -1
53+
DEFAULT_MAX_CONNECTION_LIFETIME = INFINITE
54+
DEFAULT_MAX_CONNECTION_POOL_SIZE = INFINITE
5355
DEFAULT_CONNECTION_TIMEOUT = 5.0
56+
DEFAULT_CONNECTION_ACQUISITION_TIMEOUT = 60
5457
DEFAULT_PORT = 7687
5558
DEFAULT_USER_AGENT = "neo4j-python/%s" % version
5659

@@ -405,11 +408,14 @@ class ConnectionPool(object):
405408

406409
_closed = False
407410

408-
def __init__(self, connector, connection_error_handler):
411+
def __init__(self, connector, connection_error_handler, **config):
409412
self.connector = connector
410413
self.connection_error_handler = connection_error_handler
411414
self.connections = {}
412415
self.lock = RLock()
416+
self.cond = Condition(self.lock)
417+
self._max_connection_pool_size = config.get("max_connection_pool_size", DEFAULT_MAX_CONNECTION_POOL_SIZE)
418+
self._connection_acquisition_timeout = config.get("connection_acquisition_timeout", DEFAULT_CONNECTION_ACQUISITION_TIMEOUT)
413419

414420
def __enter__(self):
415421
return self
@@ -433,23 +439,42 @@ def acquire_direct(self, address):
433439
connections = self.connections[address]
434440
except KeyError:
435441
connections = self.connections[address] = deque()
436-
for connection in list(connections):
437-
if connection.closed() or connection.defunct() or connection.timedout():
438-
connections.remove(connection)
439-
continue
440-
if not connection.in_use:
441-
connection.in_use = True
442-
return connection
443-
try:
444-
connection = self.connector(address, self.connection_error_handler)
445-
except ServiceUnavailable:
446-
self.remove(address)
447-
raise
448-
else:
449-
connection.pool = self
450-
connection.in_use = True
451-
connections.append(connection)
452-
return connection
442+
443+
connection_acquisition_start_timestamp = clock()
444+
while True:
445+
# try to find a free connection in pool
446+
for connection in list(connections):
447+
if connection.closed() or connection.defunct() or connection.timedout():
448+
connections.remove(connection)
449+
continue
450+
if not connection.in_use:
451+
connection.in_use = True
452+
return connection
453+
# all connections in pool are in-use
454+
can_create_new_connection = self._max_connection_pool_size == INFINITE or len(connections) < self._max_connection_pool_size
455+
if can_create_new_connection:
456+
try:
457+
connection = self.connector(address, self.connection_error_handler)
458+
except ServiceUnavailable:
459+
self.remove(address)
460+
raise
461+
else:
462+
connection.pool = self
463+
connection.in_use = True
464+
connections.append(connection)
465+
return connection
466+
467+
# failed to obtain a connection from pool because the pool is full and no free connection in the pool
468+
span_timeout = self._connection_acquisition_timeout - (clock() - connection_acquisition_start_timestamp)
469+
if span_timeout > 0:
470+
self.cond.wait(span_timeout)
471+
# if timed out, then we throw error. This time computation is needed, as with python 2.7, we cannot
472+
# tell if the condition is notified or timed out when we come to this line
473+
if self._connection_acquisition_timeout <= (clock() - connection_acquisition_start_timestamp):
474+
raise ClientError("Failed to obtain a connection from pool within {!r}s".format(
475+
self._connection_acquisition_timeout))
476+
else:
477+
raise ClientError("Failed to obtain a connection from pool within {!r}s".format(self._connection_acquisition_timeout))
453478

454479
def acquire(self, access_mode=None):
455480
""" Acquire a connection to a server that can satisfy a set of parameters.
@@ -463,6 +488,7 @@ def release(self, connection):
463488
"""
464489
with self.lock:
465490
connection.in_use = False
491+
self.cond.notify_all()
466492

467493
def in_use_connection_count(self, address):
468494
""" Count the number of connections currently in use to a given

neo4j/v1/direct.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@ def __init__(self):
3737

3838
class DirectConnectionPool(ConnectionPool):
3939

40-
def __init__(self, connector, address):
41-
super(DirectConnectionPool, self).__init__(connector, DirectConnectionErrorHandler())
40+
def __init__(self, connector, address, **config):
41+
super(DirectConnectionPool, self).__init__(connector, DirectConnectionErrorHandler(), **config)
4242
self.address = address
4343

4444
def acquire(self, access_mode=None):
@@ -73,7 +73,7 @@ def __init__(self, uri, **config):
7373
def connector(address, error_handler):
7474
return connect(address, security_plan.ssl_context, error_handler, **config)
7575

76-
pool = DirectConnectionPool(connector, self.address)
76+
pool = DirectConnectionPool(connector, self.address, **config)
7777
pool.release(pool.acquire())
7878
Driver.__init__(self, pool, **config)
7979

neo4j/v1/routing.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434

3535
LOAD_BALANCING_STRATEGY_LEAST_CONNECTED = 0
3636
LOAD_BALANCING_STRATEGY_ROUND_ROBIN = 1
37-
LOAD_BALANCING_STRATEGY_DEFAULT = LOAD_BALANCING_STRATEGY_LEAST_CONNECTED
37+
DEFAULT_LOAD_BALANCING_STRATEGY = LOAD_BALANCING_STRATEGY_LEAST_CONNECTED
3838

3939

4040
class OrderedSet(MutableSet):
@@ -166,7 +166,7 @@ class LoadBalancingStrategy(object):
166166

167167
@classmethod
168168
def build(cls, connection_pool, **config):
169-
load_balancing_strategy = config.get("load_balancing_strategy", LOAD_BALANCING_STRATEGY_DEFAULT)
169+
load_balancing_strategy = config.get("load_balancing_strategy", DEFAULT_LOAD_BALANCING_STRATEGY)
170170
if load_balancing_strategy == LOAD_BALANCING_STRATEGY_LEAST_CONNECTED:
171171
return LeastConnectedLoadBalancingStrategy(connection_pool)
172172
elif load_balancing_strategy == LOAD_BALANCING_STRATEGY_ROUND_ROBIN:
@@ -265,7 +265,7 @@ class RoutingConnectionPool(ConnectionPool):
265265
"""
266266

267267
def __init__(self, connector, initial_address, routing_context, *routers, **config):
268-
super(RoutingConnectionPool, self).__init__(connector, RoutingConnectionErrorHandler(self))
268+
super(RoutingConnectionPool, self).__init__(connector, RoutingConnectionErrorHandler(self), **config)
269269
self.initial_address = initial_address
270270
self.routing_context = routing_context
271271
self.routing_table = RoutingTable(routers)

test/integration/test_connection.py

Lines changed: 0 additions & 124 deletions
This file was deleted.

0 commit comments

Comments
 (0)