diff --git a/neo4j/io/__init__.py b/neo4j/io/__init__.py index d7ed2323..dd47538f 100644 --- a/neo4j/io/__init__.py +++ b/neo4j/io/__init__.py @@ -695,7 +695,7 @@ def time_remaining(): try: connection = self.opener(address, timeout) except ServiceUnavailable: - self.remove(address) + self.deactivate(address) raise else: connection.pool = self @@ -772,30 +772,22 @@ def deactivate(self, address): connections = self.connections[address] except KeyError: # already removed from the connection pool return - for conn in list(connections): - if not conn.in_use: - connections.remove(conn) - try: - conn.close() - except OSError: - pass - if not connections: - self.remove(address) + closable_connections = [ + conn for conn in connections if not conn.in_use + ] + # First remove all connections in question, then try to close them. + # If closing of a connection fails, we will end up in this method + # again. + for conn in closable_connections: + connections.remove(conn) + for conn in closable_connections: + conn.close() + if not self.connections[address]: + del self.connections[address] def on_write_failure(self, address): raise WriteServiceUnavailable("No write service available for pool {}".format(self)) - def remove(self, address): - """ Remove an address from the connection pool, if present, closing - all connections to that address. - """ - with self.lock: - for connection in self.connections.pop(address, ()): - try: - connection.close() - except OSError: - pass - def close(self): """ Close all connections and empty the pool. This method is thread safe. @@ -803,7 +795,11 @@ def close(self): try: with self.lock: for address in list(self.connections): - self.remove(address) + for connection in self.connections.pop(address, ()): + try: + connection.close() + except OSError: + pass except TypeError: pass diff --git a/tests/unit/io/test_neo4j_pool.py b/tests/unit/io/test_neo4j_pool.py index a5df0a90..5aba030a 100644 --- a/tests/unit/io/test_neo4j_pool.py +++ b/tests/unit/io/test_neo4j_pool.py @@ -35,6 +35,10 @@ RoutingConfig, WorkspaceConfig ) +from neo4j.exceptions import ( + ServiceUnavailable, + SessionExpired +) from neo4j.io import Neo4jPool @@ -226,3 +230,44 @@ def test_release_does_not_resets_defunct_connections(opener): cx1.defunct.assert_called_once() cx1.is_reset_mock.asset_not_called() cx1.reset.asset_not_called() + + +def test_multiple_broken_connections_on_close(opener): + def mock_connection_breaks_on_close(cx): + def close_side_effect(): + cx.closed.return_value = True + cx.defunct.return_value = True + pool.deactivate(READER_ADDRESS) + + cx.attach_mock(Mock(side_effect=close_side_effect), "close") + + # create pool with 2 idle connections + pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS) + cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None) + cx2 = pool.acquire(READ_ACCESS, 30, "test_db", None) + pool.release(cx1) + pool.release(cx2) + + # both will loose connection + mock_connection_breaks_on_close(cx1) + mock_connection_breaks_on_close(cx2) + + # force pool to close cx1, which will make it realize that the server is + # unreachable + cx1.stale.return_value = True + + cx3 = pool.acquire(READ_ACCESS, 30, "test_db", None) + + assert cx3 is not cx1 + assert cx3 is not cx2 + + +def test_failing_opener_leaves_connections_in_use_alone(opener): + pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS) + cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None) + + opener.side_effect = ServiceUnavailable("Server overloaded") + with pytest.raises((ServiceUnavailable, SessionExpired)): + pool.acquire(READ_ACCESS, 30, "test_db", None) + + assert not cx1.closed()