Skip to content

Commit 7318329

Browse files
committed
PYTHON-5087 - Convert test.test_load_balancer to async
1 parent 34ae214 commit 7318329

File tree

4 files changed

+377
-48
lines changed

4 files changed

+377
-48
lines changed
Lines changed: 244 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,244 @@
1+
# Copyright 2021-present MongoDB, Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Test the Load Balancer unified spec tests."""
16+
from __future__ import annotations
17+
18+
import asyncio
19+
import gc
20+
import os
21+
import pathlib
22+
import sys
23+
import threading
24+
25+
import pytest
26+
27+
sys.path[0:0] = [""]
28+
29+
from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest
30+
from test.asynchronous.unified_format import generate_test_classes
31+
from test.utils import (
32+
ExceptionCatchingTask,
33+
ExceptionCatchingThread,
34+
async_get_pool,
35+
async_wait_until,
36+
)
37+
38+
from pymongo.asynchronous.helpers import anext
39+
40+
_IS_SYNC = False
41+
42+
pytestmark = pytest.mark.load_balancer
43+
44+
# Location of JSON test specifications.
45+
if _IS_SYNC:
46+
_TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "load_balancer")
47+
else:
48+
_TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "load_balancer")
49+
50+
# Generate unified tests.
51+
globals().update(generate_test_classes(_TEST_PATH, module=__name__))
52+
53+
54+
class TestLB(AsyncIntegrationTest):
55+
RUN_ON_LOAD_BALANCER = True
56+
RUN_ON_SERVERLESS = True
57+
58+
async def test_connections_are_only_returned_once(self):
59+
if "PyPy" in sys.version:
60+
# Tracked in PYTHON-3011
61+
self.skipTest("Test is flaky on PyPy")
62+
pool = await async_get_pool(self.client)
63+
n_conns = len(pool.conns)
64+
await self.db.test.find_one({})
65+
self.assertEqual(len(pool.conns), n_conns)
66+
await (await self.db.test.aggregate([{"$limit": 1}])).to_list()
67+
self.assertEqual(len(pool.conns), n_conns)
68+
69+
@async_client_context.require_load_balancer
70+
async def test_unpin_committed_transaction(self):
71+
client = await self.async_rs_client()
72+
pool = await async_get_pool(client)
73+
coll = client[self.db.name].test
74+
async with client.start_session() as session:
75+
async with await session.start_transaction():
76+
self.assertEqual(pool.active_sockets, 0)
77+
await coll.insert_one({}, session=session)
78+
self.assertEqual(pool.active_sockets, 1) # Pinned.
79+
self.assertEqual(pool.active_sockets, 1) # Still pinned.
80+
self.assertEqual(pool.active_sockets, 0) # Unpinned.
81+
82+
@async_client_context.require_failCommand_fail_point
83+
async def test_cursor_gc(self):
84+
async def create_resource(coll):
85+
cursor = coll.find({}, batch_size=3)
86+
await anext(cursor)
87+
return cursor
88+
89+
await self._test_no_gc_deadlock(create_resource)
90+
91+
@async_client_context.require_failCommand_fail_point
92+
async def test_command_cursor_gc(self):
93+
async def create_resource(coll):
94+
cursor = await coll.aggregate([], batchSize=3)
95+
await anext(cursor)
96+
return cursor
97+
98+
await self._test_no_gc_deadlock(create_resource)
99+
100+
async def _test_no_gc_deadlock(self, create_resource):
101+
client = await self.async_rs_client()
102+
pool = await async_get_pool(client)
103+
coll = client[self.db.name].test
104+
await coll.insert_many([{} for _ in range(10)])
105+
self.assertEqual(pool.active_sockets, 0)
106+
# Cause the initial find attempt to fail to induce a reference cycle.
107+
args = {
108+
"mode": {"times": 1},
109+
"data": {
110+
"failCommands": ["find", "aggregate"],
111+
"closeConnection": True,
112+
},
113+
}
114+
async with self.fail_point(args):
115+
resource = await create_resource(coll)
116+
if async_client_context.load_balancer:
117+
self.assertEqual(pool.active_sockets, 1) # Pinned.
118+
119+
if _IS_SYNC:
120+
thread = PoolLocker(pool)
121+
thread.start()
122+
self.assertTrue(thread.locked.wait(5), "timed out")
123+
# Garbage collect the resource while the pool is locked to ensure we
124+
# don't deadlock.
125+
del resource
126+
# On PyPy it can take a few rounds to collect the cursor.
127+
for _ in range(3):
128+
gc.collect()
129+
thread.unlock.set()
130+
thread.join(5)
131+
self.assertFalse(thread.is_alive())
132+
self.assertIsNone(thread.exc)
133+
134+
else:
135+
task = PoolLocker(pool)
136+
self.assertTrue(await asyncio.wait_for(task.locked.wait(), timeout=5), "timed out") # type: ignore[arg-type]
137+
138+
# Garbage collect the resource while the pool is locked to ensure we
139+
# don't deadlock.
140+
del resource
141+
# On PyPy it can take a few rounds to collect the cursor.
142+
for _ in range(3):
143+
gc.collect()
144+
task.unlock.set()
145+
await task.run()
146+
self.assertFalse(task.is_alive())
147+
self.assertIsNone(task.exc)
148+
149+
await async_wait_until(lambda: pool.active_sockets == 0, "return socket")
150+
# Run another operation to ensure the socket still works.
151+
await coll.delete_many({})
152+
153+
@async_client_context.require_transactions
154+
async def test_session_gc(self):
155+
client = await self.async_rs_client()
156+
pool = await async_get_pool(client)
157+
session = client.start_session()
158+
await session.start_transaction()
159+
await client.test_session_gc.test.find_one({}, session=session)
160+
# Cleanup the transaction left open on the server unless we're
161+
# testing serverless which does not support killSessions.
162+
if not async_client_context.serverless:
163+
self.addAsyncCleanup(self.client.admin.command, "killSessions", [session.session_id])
164+
if async_client_context.load_balancer:
165+
self.assertEqual(pool.active_sockets, 1) # Pinned.
166+
167+
if _IS_SYNC:
168+
thread = PoolLocker(pool)
169+
thread.start()
170+
self.assertTrue(thread.locked.wait(5), "timed out")
171+
# Garbage collect the session while the pool is locked to ensure we
172+
# don't deadlock.
173+
del session
174+
# On PyPy it can take a few rounds to collect the session.
175+
for _ in range(3):
176+
gc.collect()
177+
thread.unlock.set()
178+
thread.join(5)
179+
self.assertFalse(thread.is_alive())
180+
self.assertIsNone(thread.exc)
181+
182+
else:
183+
task = PoolLocker(pool)
184+
self.assertTrue(await asyncio.wait_for(task.locked.wait(), timeout=5), "timed out") # type: ignore[arg-type]
185+
186+
# Garbage collect the session while the pool is locked to ensure we
187+
# don't deadlock.
188+
del session
189+
# On PyPy it can take a few rounds to collect the cursor.
190+
for _ in range(3):
191+
gc.collect()
192+
task.unlock.set()
193+
await task.run()
194+
self.assertFalse(task.is_alive())
195+
self.assertIsNone(task.exc)
196+
197+
await async_wait_until(lambda: pool.active_sockets == 0, "return socket")
198+
# Run another operation to ensure the socket still works.
199+
await client[self.db.name].test.delete_many({})
200+
201+
202+
if _IS_SYNC:
203+
204+
class PoolLocker(ExceptionCatchingThread):
205+
def __init__(self, pool):
206+
super().__init__(target=self.lock_pool)
207+
self.pool = pool
208+
self.daemon = True
209+
self.locked = threading.Event()
210+
self.unlock = threading.Event()
211+
212+
def lock_pool(self):
213+
with self.pool.lock:
214+
self.locked.set()
215+
# Wait for the unlock flag.
216+
unlock_pool = self.unlock.wait(10)
217+
if not unlock_pool:
218+
raise Exception("timed out waiting for unlock signal: deadlock?")
219+
220+
else:
221+
222+
class PoolLocker(ExceptionCatchingTask):
223+
def __init__(self, pool):
224+
super().__init__(self.lock_pool)
225+
self.pool = pool
226+
self.daemon = True
227+
self.locked = asyncio.Event()
228+
self.unlock = asyncio.Event()
229+
230+
async def lock_pool(self):
231+
async with self.pool.lock:
232+
self.locked.set()
233+
# Wait for the unlock flag.
234+
try:
235+
await asyncio.wait_for(self.unlock.wait(), timeout=10)
236+
except asyncio.TimeoutError:
237+
raise Exception("timed out waiting for unlock signal: deadlock?")
238+
239+
def is_alive(self):
240+
return not self.task.done()
241+
242+
243+
if __name__ == "__main__":
244+
unittest.main()

0 commit comments

Comments
 (0)