diff --git a/neo4j/io/__init__.py b/neo4j/io/__init__.py index 4804df0b..348bbce8 100644 --- a/neo4j/io/__init__.py +++ b/neo4j/io/__init__.py @@ -39,6 +39,7 @@ defaultdict, deque, ) +import logging from logging import getLogger from random import choice import selectors @@ -652,11 +653,19 @@ def time_remaining(): # try to find a free connection in pool for connection in list(self.connections.get(address, [])): if (connection.closed() or connection.defunct() - or connection.stale()): + or (connection.stale() and not connection.in_use)): # `close` is a noop on already closed connections. # This is to make sure that the connection is gracefully # closed, e.g. if it's just marked as `stale` but still # alive. + if log.isEnabledFor(logging.DEBUG): + log.debug( + "[#%04X] C: removing old connection " + "(closed=%s, defunct=%s, stale=%s, in_use=%s)", + connection.local_port, + connection.closed(), connection.defunct(), + connection.stale(), connection.in_use + ) connection.close() try: self.connections.get(address, []).remove(connection) diff --git a/tests/unit/io/test_neo4j_pool.py b/tests/unit/io/test_neo4j_pool.py index 5dbab9b4..a5df0a90 100644 --- a/tests/unit/io/test_neo4j_pool.py +++ b/tests/unit/io/test_neo4j_pool.py @@ -18,7 +18,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import inspect + from unittest.mock import Mock import pytest @@ -121,7 +121,7 @@ def test_chooses_right_connection_type(opener, type_): cx1 = pool.acquire(READ_ACCESS if type_ == "r" else WRITE_ACCESS, 30, "test_db", None) pool.release(cx1) - if type_ == "r": + if type_ == "r": assert cx1.addr == READER_ADDRESS else: assert cx1.addr == WRITER_ADDRESS @@ -147,7 +147,7 @@ def break_connection(): cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None) pool.release(cx1) assert cx1 in pool.connections[cx1.addr] - # simulate connection going stale (e.g. exceeding) and than breaking when + # simulate connection going stale (e.g. exceeding) and then breaking when # the pool tries to close the connection cx1.stale.return_value = True cx_close_mock = cx1.close @@ -156,13 +156,44 @@ def break_connection(): cx_close_mock.side_effect = break_connection cx2 = pool.acquire(READ_ACCESS, 30, "test_db", None) pool.release(cx2) - assert cx1.close.called_once() + if break_on_close: + cx1.close.assert_called() + else: + cx1.close.assert_called_once() assert cx2 is not cx1 assert cx2.addr == cx1.addr assert cx1 not in pool.connections[cx1.addr] assert cx2 in pool.connections[cx2.addr] +def test_does_not_close_stale_connections_in_use(opener): + pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS) + cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None) + assert cx1 in pool.connections[cx1.addr] + # simulate connection going stale (e.g. exceeding) while being in use + cx1.stale.return_value = True + cx2 = pool.acquire(READ_ACCESS, 30, "test_db", None) + pool.release(cx2) + cx1.close.assert_not_called() + assert cx2 is not cx1 + assert cx2.addr == cx1.addr + assert cx1 in pool.connections[cx1.addr] + assert cx2 in pool.connections[cx2.addr] + + pool.release(cx1) + # now that cx1 is back in the pool and still stale, + # it should be closed when trying to acquire the next connection + cx1.close.assert_not_called() + + cx3 = pool.acquire(READ_ACCESS, 30, "test_db", None) + pool.release(cx3) + cx1.close.assert_called_once() + assert cx2 is cx3 + assert cx3.addr == cx1.addr + assert cx1 not in pool.connections[cx1.addr] + assert cx3 in pool.connections[cx2.addr] + + def test_release_resets_connections(opener): pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS) cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None)