Skip to content

Fix flaky tests #541

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion neo4j/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -893,7 +893,7 @@ def fetch_routing_info(self, address, database, bookmarks, timeout):
# routing is broken.
log.debug("Routing is broken (%s)", error)
raise ServiceUnavailable(*error.args)
except ServiceUnavailable as error:
except (ServiceUnavailable, SessionExpired) as error:
# The routing table request suffered a connection
# failure. This should return a null routing table,
# signalling to the caller to retry the request
Expand Down
74 changes: 66 additions & 8 deletions tests/unit/io/test_direct.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,13 @@
from unittest import TestCase
import pytest
from threading import (
Thread,
Condition,
Event,
Lock,
Thread,
)
import time

from neo4j import (
Config,
PoolConfig,
Expand Down Expand Up @@ -120,6 +124,52 @@ def test_ping_timeout(self):
assert protocol_version is None


class MultiEvent:
# Adopted from threading.Event

def __init__(self):
super().__init__()
self._cond = Condition(Lock())
self._counter = 0

def _reset_internal_locks(self):
# private! called by Thread._reset_internal_locks by _after_fork()
self._cond.__init__(Lock())

def counter(self):
return self._counter

def increment(self):
with self._cond:
self._counter += 1
self._cond.notify_all()

def decrement(self):
with self._cond:
self._counter -= 1
self._cond.notify_all()

def clear(self):
with self._cond:
self._counter = 0
self._cond.notify_all()

def wait(self, value=0, timeout=None):
with self._cond:
t_start = time.time()
while True:
if value == self._counter:
return True
if timeout is None:
time_left = None
else:
time_left = timeout - (time.time() - t_start)
if time_left <= 0:
return False
if not self._cond.wait(time_left):
return False


class ConnectionPoolTestCase(TestCase):

def setUp(self):
Expand Down Expand Up @@ -200,28 +250,36 @@ def test_max_conn_pool_size(self):
def test_multithread(self):
with FakeBoltPool((), max_connection_pool_size=5) as pool:
address = ("127.0.0.1", 7687)
releasing_event = Event()
acquired_counter = MultiEvent()
release_event = Event()

# We start 10 threads to compete connections from pool with size of 5
# start 10 threads competing for connections from a pool of size 5
threads = []
for i in range(10):
t = Thread(target=acquire_release_conn, args=(pool, address, releasing_event))
t = Thread(
target=acquire_release_conn,
args=(pool, address, acquired_counter, release_event),
daemon=True
)
t.start()
threads.append(t)

if not acquired_counter.wait(5, timeout=1):
raise RuntimeError("Acquire threads not fast enough")
# The pool size should be 5, all are in-use
self.assert_pool_size(address, 5, 0, pool)
# Now we allow thread to release connections they obtained from pool
releasing_event.set()
release_event.set()

# wait for all threads to release connections back to pool
for t in threads:
t.join()
t.join(timeout=1)
# The pool size is still 5, but all are free
self.assert_pool_size(address, 0, 5, pool)


def acquire_release_conn(pool, address, releasing_event):
def acquire_release_conn(pool, address, acquired_counter, release_event):
conn = pool._acquire(address, timeout=3)
releasing_event.wait()
acquired_counter.increment()
release_event.wait()
pool.release(conn)