From aaeee66d4b4216b4aaf16aee285882e0cca354bf Mon Sep 17 00:00:00 2001 From: Rouven Bauer Date: Fri, 1 Oct 2021 14:10:55 +0200 Subject: [PATCH] Fix removing connection twice from pool. When trying to close a stale connection the driver count realize that the connection is dead on trying to send GOODBYE. This would cause the connection to make sure that all connections to the same address would get removed from the pool as well. Since this removal only happens as a side effect of `connection.close()` and does not always happen, the driver would still try to remove the (now already removed) connection form the pool after closure. Fixes: `ValueError: deque.remove(x): x not in deque` --- neo4j/io/__init__.py | 36 ++++++--- tests/unit/io/test_neo4j_pool.py | 115 ++++++++++++++++++++++++++++ tests/unit/work/__init__.py | 1 + tests/unit/work/_fake_connection.py | 92 ++++++++++++++++++++++ 4 files changed, 232 insertions(+), 12 deletions(-) create mode 100644 tests/unit/io/test_neo4j_pool.py create mode 100644 tests/unit/work/_fake_connection.py diff --git a/neo4j/io/__init__.py b/neo4j/io/__init__.py index 4b6a2b93..92b517e6 100644 --- a/neo4j/io/__init__.py +++ b/neo4j/io/__init__.py @@ -34,7 +34,10 @@ ] import abc -from collections import deque +from collections import ( + defaultdict, + deque, +) from logging import getLogger from random import choice from select import select @@ -610,7 +613,7 @@ def __init__(self, opener, pool_config, workspace_config): self.opener = opener self.pool_config = pool_config self.workspace_config = workspace_config - self.connections = {} + self.connections = defaultdict(deque) self.lock = RLock() self.cond = Condition(self.lock) @@ -632,18 +635,13 @@ def _acquire(self, address, timeout): timeout = self.workspace_config.connection_acquisition_timeout with self.lock: - try: - connections = self.connections[address] - except KeyError: - connections = self.connections[address] = deque() - def time_remaining(): t = timeout - (perf_counter() - t0) return t if t > 0 else 0 while True: # try to find a free connection in pool - for connection in list(connections): + for connection in list(self.connections.get(address, [])): if (connection.closed() or connection.defunct() or connection.stale()): # `close` is a noop on already closed connections. @@ -651,16 +649,30 @@ def time_remaining(): # closed, e.g. if it's just marked as `stale` but still # alive. connection.close() - connections.remove(connection) + try: + self.connections.get(address, []).remove(connection) + except ValueError: + # If closure fails (e.g. because the server went + # down), all connections to the same address will + # be removed. Therefore, we silently ignore if the + # connection isn't in the pool anymore. + pass continue if not connection.in_use: connection.in_use = True return connection # all connections in pool are in-use - infinite_pool_size = (self.pool_config.max_connection_pool_size < 0 or self.pool_config.max_connection_pool_size == float("inf")) - can_create_new_connection = infinite_pool_size or len(connections) < self.pool_config.max_connection_pool_size + connections = self.connections[address] + max_pool_size = self.pool_config.max_connection_pool_size + infinite_pool_size = (max_pool_size < 0 + or max_pool_size == float("inf")) + can_create_new_connection = ( + infinite_pool_size + or len(connections) < max_pool_size + ) if can_create_new_connection: - timeout = min(self.pool_config.connection_timeout, time_remaining()) + timeout = min(self.pool_config.connection_timeout, + time_remaining()) try: connection = self.opener(address, timeout) except ServiceUnavailable: diff --git a/tests/unit/io/test_neo4j_pool.py b/tests/unit/io/test_neo4j_pool.py new file mode 100644 index 00000000..aadf6300 --- /dev/null +++ b/tests/unit/io/test_neo4j_pool.py @@ -0,0 +1,115 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from unittest.mock import Mock + +import pytest + +from ..work import FakeConnection + +from neo4j import ( + READ_ACCESS, + WRITE_ACCESS, +) +from neo4j.addressing import ResolvedAddress +from neo4j.conf import ( + PoolConfig, + WorkspaceConfig +) +from neo4j.io import Neo4jPool + + +ROUTER_ADDRESS = ResolvedAddress(("1.2.3.1", 9001), host_name="host") +READER_ADDRESS = ResolvedAddress(("1.2.3.1", 9002), host_name="host") +WRITER_ADDRESS = ResolvedAddress(("1.2.3.1", 9003), host_name="host") + + +@pytest.fixture() +def opener(): + def open_(addr, timeout): + connection = FakeConnection() + connection.addr = addr + connection.timeout = timeout + route_mock = Mock() + route_mock.return_value = [{ + "ttl": 1000, + "servers": [ + {"addresses": [str(ROUTER_ADDRESS)], "role": "ROUTE"}, + {"addresses": [str(READER_ADDRESS)], "role": "READ"}, + {"addresses": [str(WRITER_ADDRESS)], "role": "WRITE"}, + ], + }] + connection.attach_mock(route_mock, "route") + opener_.connections.append(connection) + return connection + + opener_ = Mock() + opener_.connections = [] + opener_.side_effect = open_ + return opener_ + + +@pytest.mark.parametrize("type_", ("r", "w")) +def test_chooses_right_connection_type(opener, type_): + pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS) + cx1 = pool.acquire(READ_ACCESS if type_ == "r" else WRITE_ACCESS, + 30, "test_db", None) + pool.release(cx1) + if type_ == "r": + assert cx1.addr == READER_ADDRESS + else: + assert cx1.addr == WRITER_ADDRESS + + +def test_reuses_connection(opener): + pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS) + cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None) + pool.release(cx1) + cx2 = pool.acquire(READ_ACCESS, 30, "test_db", None) + assert cx1 is cx2 + + +@pytest.mark.parametrize("break_on_close", (True, False)) +def test_closes_stale_connections(opener, break_on_close): + def break_connection(): + pool.deactivate(cx1.addr) + + if cx_close_mock_side_effect: + cx_close_mock_side_effect() + + pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS) + cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None) + pool.release(cx1) + assert cx1 in pool.connections[cx1.addr] + # simulate connection going stale (e.g. exceeding) and than breaking when + # the pool tries to close the connection + cx1.stale.return_value = True + cx_close_mock = cx1.close + if break_on_close: + cx_close_mock_side_effect = cx_close_mock.side_effect + cx_close_mock.side_effect = break_connection + cx2 = pool.acquire(READ_ACCESS, 30, "test_db", None) + pool.release(cx2) + assert cx1.close.called_once() + assert cx2 is not cx1 + assert cx2.addr == cx1.addr + assert cx1 not in pool.connections[cx1.addr] + assert cx2 in pool.connections[cx2.addr] diff --git a/tests/unit/work/__init__.py b/tests/unit/work/__init__.py index e69de29b..1bc320c6 100644 --- a/tests/unit/work/__init__.py +++ b/tests/unit/work/__init__.py @@ -0,0 +1 @@ +from ._fake_connection import FakeConnection diff --git a/tests/unit/work/_fake_connection.py b/tests/unit/work/_fake_connection.py new file mode 100644 index 00000000..25b272fe --- /dev/null +++ b/tests/unit/work/_fake_connection.py @@ -0,0 +1,92 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import inspect +from unittest import mock + +import pytest + +from neo4j import ServerInfo + + +class FakeConnection(mock.NonCallableMagicMock): + callbacks = [] + server_info = ServerInfo("127.0.0.1", (4, 3)) + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.attach_mock(mock.PropertyMock(return_value=True), "is_reset") + self.attach_mock(mock.Mock(return_value=False), "defunct") + self.attach_mock(mock.Mock(return_value=False), "stale") + self.attach_mock(mock.Mock(return_value=False), "closed") + + def close_side_effect(): + self.closed.return_value = True + + self.attach_mock(mock.Mock(side_effect=close_side_effect), "close") + + def fetch_message(self, *args, **kwargs): + if self.callbacks: + cb = self.callbacks.pop(0) + cb() + return super().__getattr__("fetch_message")(*args, **kwargs) + + def fetch_all(self, *args, **kwargs): + while self.callbacks: + cb = self.callbacks.pop(0) + cb() + return super().__getattr__("fetch_all")(*args, **kwargs) + + def __getattr__(self, name): + parent = super() + + def build_message_handler(name): + def func(*args, **kwargs): + def callback(): + for cb_name, param_count in ( + ("on_success", 1), + ("on_summary", 0) + ): + cb = kwargs.get(cb_name, None) + if callable(cb): + try: + param_count = \ + len(inspect.signature(cb).parameters) + except ValueError: + # e.g. built-in method as cb + pass + if param_count == 1: + cb({}) + else: + cb() + self.callbacks.append(callback) + return parent.__getattr__(name)(*args, **kwargs) + + return func + + if name in ("run", "commit", "pull", "rollback", "discard"): + return build_message_handler(name) + return parent.__getattr__(name) + + +@pytest.fixture +def fake_connection(): + return FakeConnection()