|
22 | 22 | from unittest import TestCase
|
23 | 23 | import pytest
|
24 | 24 | from threading import (
|
25 |
| - Thread, |
| 25 | + Condition, |
26 | 26 | Event,
|
| 27 | + Lock, |
| 28 | + Thread, |
27 | 29 | )
|
| 30 | +import time |
| 31 | + |
28 | 32 | from neo4j import (
|
29 | 33 | Config,
|
30 | 34 | PoolConfig,
|
@@ -120,6 +124,52 @@ def test_ping_timeout(self):
|
120 | 124 | assert protocol_version is None
|
121 | 125 |
|
122 | 126 |
|
| 127 | +class MultiEvent: |
| 128 | + # Adopted from threading.Event |
| 129 | + |
| 130 | + def __init__(self): |
| 131 | + super().__init__() |
| 132 | + self._cond = Condition(Lock()) |
| 133 | + self._counter = 0 |
| 134 | + |
| 135 | + def _reset_internal_locks(self): |
| 136 | + # private! called by Thread._reset_internal_locks by _after_fork() |
| 137 | + self._cond.__init__(Lock()) |
| 138 | + |
| 139 | + def counter(self): |
| 140 | + return self._counter |
| 141 | + |
| 142 | + def increment(self): |
| 143 | + with self._cond: |
| 144 | + self._counter += 1 |
| 145 | + self._cond.notify_all() |
| 146 | + |
| 147 | + def decrement(self): |
| 148 | + with self._cond: |
| 149 | + self._counter -= 1 |
| 150 | + self._cond.notify_all() |
| 151 | + |
| 152 | + def clear(self): |
| 153 | + with self._cond: |
| 154 | + self._counter = 0 |
| 155 | + self._cond.notify_all() |
| 156 | + |
| 157 | + def wait(self, value=0, timeout=None): |
| 158 | + with self._cond: |
| 159 | + t_start = time.time() |
| 160 | + while True: |
| 161 | + if value == self._counter: |
| 162 | + return True |
| 163 | + if timeout is None: |
| 164 | + time_left = None |
| 165 | + else: |
| 166 | + time_left = timeout - (time.time() - t_start) |
| 167 | + if time_left <= 0: |
| 168 | + return False |
| 169 | + if not self._cond.wait(time_left): |
| 170 | + return False |
| 171 | + |
| 172 | + |
123 | 173 | class ConnectionPoolTestCase(TestCase):
|
124 | 174 |
|
125 | 175 | def setUp(self):
|
@@ -200,28 +250,36 @@ def test_max_conn_pool_size(self):
|
200 | 250 | def test_multithread(self):
|
201 | 251 | with FakeBoltPool((), max_connection_pool_size=5) as pool:
|
202 | 252 | address = ("127.0.0.1", 7687)
|
203 |
| - releasing_event = Event() |
| 253 | + acquired_counter = MultiEvent() |
| 254 | + release_event = Event() |
204 | 255 |
|
205 |
| - # We start 10 threads to compete connections from pool with size of 5 |
| 256 | + # start 10 threads competing for connections from a pool of size 5 |
206 | 257 | threads = []
|
207 | 258 | for i in range(10):
|
208 |
| - t = Thread(target=acquire_release_conn, args=(pool, address, releasing_event)) |
| 259 | + t = Thread( |
| 260 | + target=acquire_release_conn, |
| 261 | + args=(pool, address, acquired_counter, release_event), |
| 262 | + daemon=True |
| 263 | + ) |
209 | 264 | t.start()
|
210 | 265 | threads.append(t)
|
211 | 266 |
|
| 267 | + if not acquired_counter.wait(5, timeout=1): |
| 268 | + raise RuntimeError("Acquire threads not fast enough") |
212 | 269 | # The pool size should be 5, all are in-use
|
213 | 270 | self.assert_pool_size(address, 5, 0, pool)
|
214 | 271 | # Now we allow thread to release connections they obtained from pool
|
215 |
| - releasing_event.set() |
| 272 | + release_event.set() |
216 | 273 |
|
217 | 274 | # wait for all threads to release connections back to pool
|
218 | 275 | for t in threads:
|
219 |
| - t.join() |
| 276 | + t.join(timeout=1) |
220 | 277 | # The pool size is still 5, but all are free
|
221 | 278 | self.assert_pool_size(address, 0, 5, pool)
|
222 | 279 |
|
223 | 280 |
|
224 |
| -def acquire_release_conn(pool, address, releasing_event): |
| 281 | +def acquire_release_conn(pool, address, acquired_counter, release_event): |
225 | 282 | conn = pool._acquire(address, timeout=3)
|
226 |
| - releasing_event.wait() |
| 283 | + acquired_counter.increment() |
| 284 | + release_event.wait() |
227 | 285 | pool.release(conn)
|
0 commit comments