diff --git a/tests/unit/async_/io/test_direct.py b/tests/unit/async_/io/test_direct.py index 711ec711..21bdd8d6 100644 --- a/tests/unit/async_/io/test_direct.py +++ b/tests/unit/async_/io/test_direct.py @@ -83,6 +83,9 @@ def defunct(self): def timedout(self): return False + def assert_re_auth_support(self): + pass + class AsyncFakeBoltPool(AsyncIOPool): is_direct_pool = False diff --git a/tests/unit/mixed/io/test_direct.py b/tests/unit/mixed/io/test_direct.py index 63fba1e7..04e6b98d 100644 --- a/tests/unit/mixed/io/test_direct.py +++ b/tests/unit/mixed/io/test_direct.py @@ -17,6 +17,8 @@ import asyncio +import threading +import time from asyncio import Event as AsyncEvent from threading import ( Event, @@ -26,10 +28,20 @@ import pytest +from neo4j._async.io._pool import AcquireAuth as AsyncAcquireAuth from neo4j._deadline import Deadline +from neo4j._sync.io._pool import AcquireAuth +from neo4j.auth_management import ( + AsyncAuthManagers, + AuthManagers, +) from ...async_.io.test_direct import AsyncFakeBoltPool +from ...async_.test_auth_manager import ( + static_auth_manager as static_async_auth_manager, +) from ...sync.io.test_direct import FakeBoltPool +from ...sync.test_auth_manager import static_auth_manager from ._common import ( AsyncMultiEvent, MultiEvent, @@ -111,6 +123,49 @@ def acquire_release_conn(pool_, address_, acquired_counter_, # The pool size is still 5, but all are free self.assert_pool_size(address, 0, 5, pool) + def test_full_pool_re_auth(self, mocker): + address = ("127.0.0.1", 7687) + acquire_auth1 = AcquireAuth(auth=static_auth_manager( + ("user1", "pass1")) + ) + auth2 = ("user2", "pass2") + acquire_auth2 = AcquireAuth(auth=static_auth_manager(auth2)) + acquire1_event = threading.Event() + cx1 = None + + def acquire1(pool_): + nonlocal cx1 + cx = pool_._acquire(address, acquire_auth1, Deadline(0), None) + acquire1_event.set() + cx1 = cx + while True: + with pool_.cond: + # _waiters is an internal attribute of threading.Condition + # this might break in the future, but we couldn't come up + # with a better way of waiting for the other thread to block. + waiters = len(pool_.cond._waiters) + if waiters: + break + time.sleep(0.001) + cx.re_auth = mocker.Mock(spec=cx.re_auth) + pool_.release(cx) + + def acquire2(pool_): + acquire1_event.wait(timeout=10) + cx = pool_._acquire(address, acquire_auth2, Deadline(10), None) + assert cx is cx1 + cx.re_auth.assert_called_once() + assert auth2 in cx.re_auth.call_args.args + pool_.release(cx) + + with FakeBoltPool((), max_connection_pool_size=1) as pool: + t1 = threading.Thread(target=acquire1, args=(pool,), daemon=True) + t2 = threading.Thread(target=acquire2, args=(pool,), daemon=True) + t1.start() + t2.start() + t1.join() + t2.join() + @pytest.mark.parametrize("pre_populated", (0, 3, 5)) @pytest.mark.asyncio async def test_multi_coroutine(self, pre_populated): @@ -172,3 +227,35 @@ async def waiter(pool_, acquired_counter_, release_event_): waiter(pool, acquired_counter, release_event), *coroutines ) + + @pytest.mark.asyncio + async def test_full_pool_re_auth_async(self, mocker): + address = ("127.0.0.1", 7687) + acquire_auth1 = AsyncAcquireAuth(auth=static_async_auth_manager( + ("user1", "pass1")) + ) + auth2 = ("user2", "pass2") + acquire_auth2 = AsyncAcquireAuth(auth=static_async_auth_manager(auth2)) + cx1 = None + + async def acquire1(pool_): + nonlocal cx1 + cx = await pool_._acquire(address, acquire_auth1, Deadline(0), None) + cx1 = cx + while len(pool_.cond._waiters) == 0: + await asyncio.sleep(0) + cx.re_auth = mocker.Mock(spec=cx.re_auth) + await pool_.release(cx) + + async def acquire2(pool_): + while cx1 is None: + await asyncio.sleep(0) + cx = await pool_._acquire(address, acquire_auth2, + Deadline(float("inf")), None) + assert cx is cx1 + cx.re_auth.assert_called_once() + assert auth2 in cx.re_auth.call_args.args + await pool_.release(cx) + + async with AsyncFakeBoltPool((), max_connection_pool_size=1) as pool: + await asyncio.gather(acquire1(pool), acquire2(pool)) diff --git a/tests/unit/sync/io/test_direct.py b/tests/unit/sync/io/test_direct.py index 4a4ef6a7..e75aab94 100644 --- a/tests/unit/sync/io/test_direct.py +++ b/tests/unit/sync/io/test_direct.py @@ -83,6 +83,9 @@ def defunct(self): def timedout(self): return False + def assert_re_auth_support(self): + pass + class FakeBoltPool(IOPool): is_direct_pool = False