diff --git a/neo4j/io/__init__.py b/neo4j/io/__init__.py index 4b6a2b93..92b517e6 100644 --- a/neo4j/io/__init__.py +++ b/neo4j/io/__init__.py @@ -34,7 +34,10 @@ ] import abc -from collections import deque +from collections import ( + defaultdict, + deque, +) from logging import getLogger from random import choice from select import select @@ -610,7 +613,7 @@ def __init__(self, opener, pool_config, workspace_config): self.opener = opener self.pool_config = pool_config self.workspace_config = workspace_config - self.connections = {} + self.connections = defaultdict(deque) self.lock = RLock() self.cond = Condition(self.lock) @@ -632,18 +635,13 @@ def _acquire(self, address, timeout): timeout = self.workspace_config.connection_acquisition_timeout with self.lock: - try: - connections = self.connections[address] - except KeyError: - connections = self.connections[address] = deque() - def time_remaining(): t = timeout - (perf_counter() - t0) return t if t > 0 else 0 while True: # try to find a free connection in pool - for connection in list(connections): + for connection in list(self.connections.get(address, [])): if (connection.closed() or connection.defunct() or connection.stale()): # `close` is a noop on already closed connections. @@ -651,16 +649,30 @@ def time_remaining(): # closed, e.g. if it's just marked as `stale` but still # alive. connection.close() - connections.remove(connection) + try: + self.connections.get(address, []).remove(connection) + except ValueError: + # If closure fails (e.g. because the server went + # down), all connections to the same address will + # be removed. Therefore, we silently ignore if the + # connection isn't in the pool anymore. + pass continue if not connection.in_use: connection.in_use = True return connection # all connections in pool are in-use - infinite_pool_size = (self.pool_config.max_connection_pool_size < 0 or self.pool_config.max_connection_pool_size == float("inf")) - can_create_new_connection = infinite_pool_size or len(connections) < self.pool_config.max_connection_pool_size + connections = self.connections[address] + max_pool_size = self.pool_config.max_connection_pool_size + infinite_pool_size = (max_pool_size < 0 + or max_pool_size == float("inf")) + can_create_new_connection = ( + infinite_pool_size + or len(connections) < max_pool_size + ) if can_create_new_connection: - timeout = min(self.pool_config.connection_timeout, time_remaining()) + timeout = min(self.pool_config.connection_timeout, + time_remaining()) try: connection = self.opener(address, timeout) except ServiceUnavailable: diff --git a/tests/unit/io/test_neo4j_pool.py b/tests/unit/io/test_neo4j_pool.py new file mode 100644 index 00000000..aadf6300 --- /dev/null +++ b/tests/unit/io/test_neo4j_pool.py @@ -0,0 +1,115 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from unittest.mock import Mock + +import pytest + +from ..work import FakeConnection + +from neo4j import ( + READ_ACCESS, + WRITE_ACCESS, +) +from neo4j.addressing import ResolvedAddress +from neo4j.conf import ( + PoolConfig, + WorkspaceConfig +) +from neo4j.io import Neo4jPool + + +ROUTER_ADDRESS = ResolvedAddress(("1.2.3.1", 9001), host_name="host") +READER_ADDRESS = ResolvedAddress(("1.2.3.1", 9002), host_name="host") +WRITER_ADDRESS = ResolvedAddress(("1.2.3.1", 9003), host_name="host") + + +@pytest.fixture() +def opener(): + def open_(addr, timeout): + connection = FakeConnection() + connection.addr = addr + connection.timeout = timeout + route_mock = Mock() + route_mock.return_value = [{ + "ttl": 1000, + "servers": [ + {"addresses": [str(ROUTER_ADDRESS)], "role": "ROUTE"}, + {"addresses": [str(READER_ADDRESS)], "role": "READ"}, + {"addresses": [str(WRITER_ADDRESS)], "role": "WRITE"}, + ], + }] + connection.attach_mock(route_mock, "route") + opener_.connections.append(connection) + return connection + + opener_ = Mock() + opener_.connections = [] + opener_.side_effect = open_ + return opener_ + + +@pytest.mark.parametrize("type_", ("r", "w")) +def test_chooses_right_connection_type(opener, type_): + pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS) + cx1 = pool.acquire(READ_ACCESS if type_ == "r" else WRITE_ACCESS, + 30, "test_db", None) + pool.release(cx1) + if type_ == "r": + assert cx1.addr == READER_ADDRESS + else: + assert cx1.addr == WRITER_ADDRESS + + +def test_reuses_connection(opener): + pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS) + cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None) + pool.release(cx1) + cx2 = pool.acquire(READ_ACCESS, 30, "test_db", None) + assert cx1 is cx2 + + +@pytest.mark.parametrize("break_on_close", (True, False)) +def test_closes_stale_connections(opener, break_on_close): + def break_connection(): + pool.deactivate(cx1.addr) + + if cx_close_mock_side_effect: + cx_close_mock_side_effect() + + pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS) + cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None) + pool.release(cx1) + assert cx1 in pool.connections[cx1.addr] + # simulate connection going stale (e.g. exceeding) and than breaking when + # the pool tries to close the connection + cx1.stale.return_value = True + cx_close_mock = cx1.close + if break_on_close: + cx_close_mock_side_effect = cx_close_mock.side_effect + cx_close_mock.side_effect = break_connection + cx2 = pool.acquire(READ_ACCESS, 30, "test_db", None) + pool.release(cx2) + assert cx1.close.called_once() + assert cx2 is not cx1 + assert cx2.addr == cx1.addr + assert cx1 not in pool.connections[cx1.addr] + assert cx2 in pool.connections[cx2.addr] diff --git a/tests/unit/work/__init__.py b/tests/unit/work/__init__.py index e69de29b..1bc320c6 100644 --- a/tests/unit/work/__init__.py +++ b/tests/unit/work/__init__.py @@ -0,0 +1 @@ +from ._fake_connection import FakeConnection diff --git a/tests/unit/work/_fake_connection.py b/tests/unit/work/_fake_connection.py new file mode 100644 index 00000000..25b272fe --- /dev/null +++ b/tests/unit/work/_fake_connection.py @@ -0,0 +1,92 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [http://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import inspect +from unittest import mock + +import pytest + +from neo4j import ServerInfo + + +class FakeConnection(mock.NonCallableMagicMock): + callbacks = [] + server_info = ServerInfo("127.0.0.1", (4, 3)) + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.attach_mock(mock.PropertyMock(return_value=True), "is_reset") + self.attach_mock(mock.Mock(return_value=False), "defunct") + self.attach_mock(mock.Mock(return_value=False), "stale") + self.attach_mock(mock.Mock(return_value=False), "closed") + + def close_side_effect(): + self.closed.return_value = True + + self.attach_mock(mock.Mock(side_effect=close_side_effect), "close") + + def fetch_message(self, *args, **kwargs): + if self.callbacks: + cb = self.callbacks.pop(0) + cb() + return super().__getattr__("fetch_message")(*args, **kwargs) + + def fetch_all(self, *args, **kwargs): + while self.callbacks: + cb = self.callbacks.pop(0) + cb() + return super().__getattr__("fetch_all")(*args, **kwargs) + + def __getattr__(self, name): + parent = super() + + def build_message_handler(name): + def func(*args, **kwargs): + def callback(): + for cb_name, param_count in ( + ("on_success", 1), + ("on_summary", 0) + ): + cb = kwargs.get(cb_name, None) + if callable(cb): + try: + param_count = \ + len(inspect.signature(cb).parameters) + except ValueError: + # e.g. built-in method as cb + pass + if param_count == 1: + cb({}) + else: + cb() + self.callbacks.append(callback) + return parent.__getattr__(name)(*args, **kwargs) + + return func + + if name in ("run", "commit", "pull", "rollback", "discard"): + return build_message_handler(name) + return parent.__getattr__(name) + + +@pytest.fixture +def fake_connection(): + return FakeConnection()