From 73183295dcd808c3075caae5439e286f34f8d675 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Thu, 30 Jan 2025 13:34:42 -0500 Subject: [PATCH 1/5] PYTHON-5087 - Convert test.test_load_balancer to async --- test/asynchronous/test_load_balancer.py | 244 ++++++++++++++++++++++++ test/test_load_balancer.py | 163 +++++++++++----- test/utils.py | 17 ++ tools/synchro.py | 1 + 4 files changed, 377 insertions(+), 48 deletions(-) create mode 100644 test/asynchronous/test_load_balancer.py diff --git a/test/asynchronous/test_load_balancer.py b/test/asynchronous/test_load_balancer.py new file mode 100644 index 0000000000..0e75d75734 --- /dev/null +++ b/test/asynchronous/test_load_balancer.py @@ -0,0 +1,244 @@ +# Copyright 2021-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test the Load Balancer unified spec tests.""" +from __future__ import annotations + +import asyncio +import gc +import os +import pathlib +import sys +import threading + +import pytest + +sys.path[0:0] = [""] + +from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest +from test.asynchronous.unified_format import generate_test_classes +from test.utils import ( + ExceptionCatchingTask, + ExceptionCatchingThread, + async_get_pool, + async_wait_until, +) + +from pymongo.asynchronous.helpers import anext + +_IS_SYNC = False + +pytestmark = pytest.mark.load_balancer + +# Location of JSON test specifications. +if _IS_SYNC: + _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "load_balancer") +else: + _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "load_balancer") + +# Generate unified tests. +globals().update(generate_test_classes(_TEST_PATH, module=__name__)) + + +class TestLB(AsyncIntegrationTest): + RUN_ON_LOAD_BALANCER = True + RUN_ON_SERVERLESS = True + + async def test_connections_are_only_returned_once(self): + if "PyPy" in sys.version: + # Tracked in PYTHON-3011 + self.skipTest("Test is flaky on PyPy") + pool = await async_get_pool(self.client) + n_conns = len(pool.conns) + await self.db.test.find_one({}) + self.assertEqual(len(pool.conns), n_conns) + await (await self.db.test.aggregate([{"$limit": 1}])).to_list() + self.assertEqual(len(pool.conns), n_conns) + + @async_client_context.require_load_balancer + async def test_unpin_committed_transaction(self): + client = await self.async_rs_client() + pool = await async_get_pool(client) + coll = client[self.db.name].test + async with client.start_session() as session: + async with await session.start_transaction(): + self.assertEqual(pool.active_sockets, 0) + await coll.insert_one({}, session=session) + self.assertEqual(pool.active_sockets, 1) # Pinned. + self.assertEqual(pool.active_sockets, 1) # Still pinned. + self.assertEqual(pool.active_sockets, 0) # Unpinned. + + @async_client_context.require_failCommand_fail_point + async def test_cursor_gc(self): + async def create_resource(coll): + cursor = coll.find({}, batch_size=3) + await anext(cursor) + return cursor + + await self._test_no_gc_deadlock(create_resource) + + @async_client_context.require_failCommand_fail_point + async def test_command_cursor_gc(self): + async def create_resource(coll): + cursor = await coll.aggregate([], batchSize=3) + await anext(cursor) + return cursor + + await self._test_no_gc_deadlock(create_resource) + + async def _test_no_gc_deadlock(self, create_resource): + client = await self.async_rs_client() + pool = await async_get_pool(client) + coll = client[self.db.name].test + await coll.insert_many([{} for _ in range(10)]) + self.assertEqual(pool.active_sockets, 0) + # Cause the initial find attempt to fail to induce a reference cycle. + args = { + "mode": {"times": 1}, + "data": { + "failCommands": ["find", "aggregate"], + "closeConnection": True, + }, + } + async with self.fail_point(args): + resource = await create_resource(coll) + if async_client_context.load_balancer: + self.assertEqual(pool.active_sockets, 1) # Pinned. + + if _IS_SYNC: + thread = PoolLocker(pool) + thread.start() + self.assertTrue(thread.locked.wait(5), "timed out") + # Garbage collect the resource while the pool is locked to ensure we + # don't deadlock. + del resource + # On PyPy it can take a few rounds to collect the cursor. + for _ in range(3): + gc.collect() + thread.unlock.set() + thread.join(5) + self.assertFalse(thread.is_alive()) + self.assertIsNone(thread.exc) + + else: + task = PoolLocker(pool) + self.assertTrue(await asyncio.wait_for(task.locked.wait(), timeout=5), "timed out") # type: ignore[arg-type] + + # Garbage collect the resource while the pool is locked to ensure we + # don't deadlock. + del resource + # On PyPy it can take a few rounds to collect the cursor. + for _ in range(3): + gc.collect() + task.unlock.set() + await task.run() + self.assertFalse(task.is_alive()) + self.assertIsNone(task.exc) + + await async_wait_until(lambda: pool.active_sockets == 0, "return socket") + # Run another operation to ensure the socket still works. + await coll.delete_many({}) + + @async_client_context.require_transactions + async def test_session_gc(self): + client = await self.async_rs_client() + pool = await async_get_pool(client) + session = client.start_session() + await session.start_transaction() + await client.test_session_gc.test.find_one({}, session=session) + # Cleanup the transaction left open on the server unless we're + # testing serverless which does not support killSessions. + if not async_client_context.serverless: + self.addAsyncCleanup(self.client.admin.command, "killSessions", [session.session_id]) + if async_client_context.load_balancer: + self.assertEqual(pool.active_sockets, 1) # Pinned. + + if _IS_SYNC: + thread = PoolLocker(pool) + thread.start() + self.assertTrue(thread.locked.wait(5), "timed out") + # Garbage collect the session while the pool is locked to ensure we + # don't deadlock. + del session + # On PyPy it can take a few rounds to collect the session. + for _ in range(3): + gc.collect() + thread.unlock.set() + thread.join(5) + self.assertFalse(thread.is_alive()) + self.assertIsNone(thread.exc) + + else: + task = PoolLocker(pool) + self.assertTrue(await asyncio.wait_for(task.locked.wait(), timeout=5), "timed out") # type: ignore[arg-type] + + # Garbage collect the session while the pool is locked to ensure we + # don't deadlock. + del session + # On PyPy it can take a few rounds to collect the cursor. + for _ in range(3): + gc.collect() + task.unlock.set() + await task.run() + self.assertFalse(task.is_alive()) + self.assertIsNone(task.exc) + + await async_wait_until(lambda: pool.active_sockets == 0, "return socket") + # Run another operation to ensure the socket still works. + await client[self.db.name].test.delete_many({}) + + +if _IS_SYNC: + + class PoolLocker(ExceptionCatchingThread): + def __init__(self, pool): + super().__init__(target=self.lock_pool) + self.pool = pool + self.daemon = True + self.locked = threading.Event() + self.unlock = threading.Event() + + def lock_pool(self): + with self.pool.lock: + self.locked.set() + # Wait for the unlock flag. + unlock_pool = self.unlock.wait(10) + if not unlock_pool: + raise Exception("timed out waiting for unlock signal: deadlock?") + +else: + + class PoolLocker(ExceptionCatchingTask): + def __init__(self, pool): + super().__init__(self.lock_pool) + self.pool = pool + self.daemon = True + self.locked = asyncio.Event() + self.unlock = asyncio.Event() + + async def lock_pool(self): + async with self.pool.lock: + self.locked.set() + # Wait for the unlock flag. + try: + await asyncio.wait_for(self.unlock.wait(), timeout=10) + except asyncio.TimeoutError: + raise Exception("timed out waiting for unlock signal: deadlock?") + + def is_alive(self): + return not self.task.done() + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_load_balancer.py b/test/test_load_balancer.py index 23bea4d984..042fbea33e 100644 --- a/test/test_load_balancer.py +++ b/test/test_load_balancer.py @@ -15,8 +15,10 @@ """Test the Load Balancer unified spec tests.""" from __future__ import annotations +import asyncio import gc import os +import pathlib import sys import threading @@ -26,15 +28,27 @@ from test import IntegrationTest, client_context, unittest from test.unified_format import generate_test_classes -from test.utils import ExceptionCatchingThread, get_pool, wait_until +from test.utils import ( + ExceptionCatchingTask, + ExceptionCatchingThread, + get_pool, + wait_until, +) + +from pymongo.synchronous.helpers import next + +_IS_SYNC = True pytestmark = pytest.mark.load_balancer # Location of JSON test specifications. -TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "load_balancer") +if _IS_SYNC: + _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "load_balancer") +else: + _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "load_balancer") # Generate unified tests. -globals().update(generate_test_classes(TEST_PATH, module=__name__)) +globals().update(generate_test_classes(_TEST_PATH, module=__name__)) class TestLB(IntegrationTest): @@ -49,13 +63,12 @@ def test_connections_are_only_returned_once(self): n_conns = len(pool.conns) self.db.test.find_one({}) self.assertEqual(len(pool.conns), n_conns) - list(self.db.test.aggregate([{"$limit": 1}])) + (self.db.test.aggregate([{"$limit": 1}])).to_list() self.assertEqual(len(pool.conns), n_conns) @client_context.require_load_balancer def test_unpin_committed_transaction(self): client = self.rs_client() - self.addCleanup(client.close) pool = get_pool(client) coll = client[self.db.name].test with client.start_session() as session: @@ -86,7 +99,6 @@ def create_resource(coll): def _test_no_gc_deadlock(self, create_resource): client = self.rs_client() - self.addCleanup(client.close) pool = get_pool(client) coll = client[self.db.name].test coll.insert_many([{} for _ in range(10)]) @@ -104,19 +116,35 @@ def _test_no_gc_deadlock(self, create_resource): if client_context.load_balancer: self.assertEqual(pool.active_sockets, 1) # Pinned. - thread = PoolLocker(pool) - thread.start() - self.assertTrue(thread.locked.wait(5), "timed out") - # Garbage collect the resource while the pool is locked to ensure we - # don't deadlock. - del resource - # On PyPy it can take a few rounds to collect the cursor. - for _ in range(3): - gc.collect() - thread.unlock.set() - thread.join(5) - self.assertFalse(thread.is_alive()) - self.assertIsNone(thread.exc) + if _IS_SYNC: + thread = PoolLocker(pool) + thread.start() + self.assertTrue(thread.locked.wait(5), "timed out") + # Garbage collect the resource while the pool is locked to ensure we + # don't deadlock. + del resource + # On PyPy it can take a few rounds to collect the cursor. + for _ in range(3): + gc.collect() + thread.unlock.set() + thread.join(5) + self.assertFalse(thread.is_alive()) + self.assertIsNone(thread.exc) + + else: + task = PoolLocker(pool) + self.assertTrue(asyncio.wait_for(task.locked.wait(), timeout=5), "timed out") # type: ignore[arg-type] + + # Garbage collect the resource while the pool is locked to ensure we + # don't deadlock. + del resource + # On PyPy it can take a few rounds to collect the cursor. + for _ in range(3): + gc.collect() + task.unlock.set() + task.run() + self.assertFalse(task.is_alive()) + self.assertIsNone(task.exc) wait_until(lambda: pool.active_sockets == 0, "return socket") # Run another operation to ensure the socket still works. @@ -125,7 +153,6 @@ def _test_no_gc_deadlock(self, create_resource): @client_context.require_transactions def test_session_gc(self): client = self.rs_client() - self.addCleanup(client.close) pool = get_pool(client) session = client.start_session() session.start_transaction() @@ -137,40 +164,80 @@ def test_session_gc(self): if client_context.load_balancer: self.assertEqual(pool.active_sockets, 1) # Pinned. - thread = PoolLocker(pool) - thread.start() - self.assertTrue(thread.locked.wait(5), "timed out") - # Garbage collect the session while the pool is locked to ensure we - # don't deadlock. - del session - # On PyPy it can take a few rounds to collect the session. - for _ in range(3): - gc.collect() - thread.unlock.set() - thread.join(5) - self.assertFalse(thread.is_alive()) - self.assertIsNone(thread.exc) + if _IS_SYNC: + thread = PoolLocker(pool) + thread.start() + self.assertTrue(thread.locked.wait(5), "timed out") + # Garbage collect the session while the pool is locked to ensure we + # don't deadlock. + del session + # On PyPy it can take a few rounds to collect the session. + for _ in range(3): + gc.collect() + thread.unlock.set() + thread.join(5) + self.assertFalse(thread.is_alive()) + self.assertIsNone(thread.exc) + + else: + task = PoolLocker(pool) + self.assertTrue(asyncio.wait_for(task.locked.wait(), timeout=5), "timed out") # type: ignore[arg-type] + + # Garbage collect the session while the pool is locked to ensure we + # don't deadlock. + del session + # On PyPy it can take a few rounds to collect the cursor. + for _ in range(3): + gc.collect() + task.unlock.set() + task.run() + self.assertFalse(task.is_alive()) + self.assertIsNone(task.exc) wait_until(lambda: pool.active_sockets == 0, "return socket") # Run another operation to ensure the socket still works. client[self.db.name].test.delete_many({}) -class PoolLocker(ExceptionCatchingThread): - def __init__(self, pool): - super().__init__(target=self.lock_pool) - self.pool = pool - self.daemon = True - self.locked = threading.Event() - self.unlock = threading.Event() - - def lock_pool(self): - with self.pool.lock: - self.locked.set() - # Wait for the unlock flag. - unlock_pool = self.unlock.wait(10) - if not unlock_pool: - raise Exception("timed out waiting for unlock signal: deadlock?") +if _IS_SYNC: + + class PoolLocker(ExceptionCatchingThread): + def __init__(self, pool): + super().__init__(target=self.lock_pool) + self.pool = pool + self.daemon = True + self.locked = threading.Event() + self.unlock = threading.Event() + + def lock_pool(self): + with self.pool.lock: + self.locked.set() + # Wait for the unlock flag. + unlock_pool = self.unlock.wait(10) + if not unlock_pool: + raise Exception("timed out waiting for unlock signal: deadlock?") + +else: + + class PoolLocker(ExceptionCatchingTask): + def __init__(self, pool): + super().__init__(self.lock_pool) + self.pool = pool + self.daemon = True + self.locked = asyncio.Event() + self.unlock = asyncio.Event() + + def lock_pool(self): + with self.pool.lock: + self.locked.set() + # Wait for the unlock flag. + try: + asyncio.wait_for(self.unlock.wait(), timeout=10) + except asyncio.TimeoutError: + raise Exception("timed out waiting for unlock signal: deadlock?") + + def is_alive(self): + return not self.task.done() if __name__ == "__main__": diff --git a/test/utils.py b/test/utils.py index 69154bc63b..5973db0cd0 100644 --- a/test/utils.py +++ b/test/utils.py @@ -39,6 +39,7 @@ from bson.objectid import ObjectId from bson.son import SON from pymongo import AsyncMongoClient, monitoring, operations, read_preferences +from pymongo._asyncio_task import create_task from pymongo.cursor_shared import CursorType from pymongo.errors import ConfigurationError, OperationFailure from pymongo.hello import HelloCompat @@ -870,6 +871,22 @@ def run(self): raise +class ExceptionCatchingTask: + """A Task that stores any exception encountered from its task.""" + + def __init__(self, target): + self.exc = None + self.target = target + self.task = create_task(self.run()) + + async def run(self): + try: + await self.target() + except BaseException as exc: + self.exc = exc + raise + + def parse_read_preference(pref): # Make first letter lowercase to match read_pref's modes. mode_string = pref.get("mode", "primary") diff --git a/tools/synchro.py b/tools/synchro.py index 897e5e8018..16aede8ed6 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -207,6 +207,7 @@ def async_only_test(f: str) -> bool: "test_data_lake.py", "test_encryption.py", "test_grid_file.py", + "test_load_balancer.py", "test_logger.py", "test_monitoring.py", "test_raw_bson.py", From 51a69f5d6a8e50f2085ad19a311dbb9276a8fbd5 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Tue, 4 Feb 2025 10:34:27 -0500 Subject: [PATCH 2/5] Refactor --- test/asynchronous/helpers.py | 25 +++- test/asynchronous/test_load_balancer.py | 149 +++++++++--------------- test/helpers.py | 25 +++- test/test_load_balancer.py | 149 +++++++++--------------- test/utils.py | 39 ++----- tools/synchro.py | 1 + 6 files changed, 161 insertions(+), 227 deletions(-) diff --git a/test/asynchronous/helpers.py b/test/asynchronous/helpers.py index 7758f281e1..dc05271cf3 100644 --- a/test/asynchronous/helpers.py +++ b/test/asynchronous/helpers.py @@ -404,5 +404,28 @@ def is_alive(self): async def run(self): if self.target: - await self.target() + if _IS_SYNC: + super().run() + else: + await self.target() self.stopped = True + + +class ExceptionCatchingTask(ConcurrentRunner): + """A Task that stores any exception encountered while running.""" + + def __init__(self, *args, **kwargs): + super().__init__("ExceptionCatchingTask", *args, **kwargs) + self.exc = None + + async def run(self): + try: + if _IS_SYNC: + await super().run() + else: + await self.target() + except BaseException as exc: + self.exc = exc + raise + finally: + self.stopped = True diff --git a/test/asynchronous/test_load_balancer.py b/test/asynchronous/test_load_balancer.py index 0e75d75734..13b4035852 100644 --- a/test/asynchronous/test_load_balancer.py +++ b/test/asynchronous/test_load_balancer.py @@ -21,6 +21,8 @@ import pathlib import sys import threading +from asyncio import Event +from test.asynchronous.helpers import ConcurrentRunner, ExceptionCatchingTask import pytest @@ -29,10 +31,9 @@ from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest from test.asynchronous.unified_format import generate_test_classes from test.utils import ( - ExceptionCatchingTask, - ExceptionCatchingThread, async_get_pool, async_wait_until, + create_async_event, ) from pymongo.asynchronous.helpers import anext @@ -116,35 +117,19 @@ async def _test_no_gc_deadlock(self, create_resource): if async_client_context.load_balancer: self.assertEqual(pool.active_sockets, 1) # Pinned. - if _IS_SYNC: - thread = PoolLocker(pool) - thread.start() - self.assertTrue(thread.locked.wait(5), "timed out") - # Garbage collect the resource while the pool is locked to ensure we - # don't deadlock. - del resource - # On PyPy it can take a few rounds to collect the cursor. - for _ in range(3): - gc.collect() - thread.unlock.set() - thread.join(5) - self.assertFalse(thread.is_alive()) - self.assertIsNone(thread.exc) - - else: - task = PoolLocker(pool) - self.assertTrue(await asyncio.wait_for(task.locked.wait(), timeout=5), "timed out") # type: ignore[arg-type] - - # Garbage collect the resource while the pool is locked to ensure we - # don't deadlock. - del resource - # On PyPy it can take a few rounds to collect the cursor. - for _ in range(3): - gc.collect() - task.unlock.set() - await task.run() - self.assertFalse(task.is_alive()) - self.assertIsNone(task.exc) + task = PoolLocker(pool) + await task.start() + self.assertTrue(await task.wait(task.locked, 5), "timed out") + # Garbage collect the resource while the pool is locked to ensure we + # don't deadlock. + del resource + # On PyPy it can take a few rounds to collect the cursor. + for _ in range(3): + gc.collect() + task.unlock.set() + await task.join(5) + self.assertFalse(task.is_alive()) + self.assertIsNone(task.exc) await async_wait_until(lambda: pool.active_sockets == 0, "return socket") # Run another operation to ensure the socket still works. @@ -164,80 +149,50 @@ async def test_session_gc(self): if async_client_context.load_balancer: self.assertEqual(pool.active_sockets, 1) # Pinned. - if _IS_SYNC: - thread = PoolLocker(pool) - thread.start() - self.assertTrue(thread.locked.wait(5), "timed out") - # Garbage collect the session while the pool is locked to ensure we - # don't deadlock. - del session - # On PyPy it can take a few rounds to collect the session. - for _ in range(3): - gc.collect() - thread.unlock.set() - thread.join(5) - self.assertFalse(thread.is_alive()) - self.assertIsNone(thread.exc) - - else: - task = PoolLocker(pool) - self.assertTrue(await asyncio.wait_for(task.locked.wait(), timeout=5), "timed out") # type: ignore[arg-type] - - # Garbage collect the session while the pool is locked to ensure we - # don't deadlock. - del session - # On PyPy it can take a few rounds to collect the cursor. - for _ in range(3): - gc.collect() - task.unlock.set() - await task.run() - self.assertFalse(task.is_alive()) - self.assertIsNone(task.exc) + task = PoolLocker(pool) + await task.start() + self.assertTrue(await task.wait(task.locked, 5), "timed out") + # Garbage collect the session while the pool is locked to ensure we + # don't deadlock. + del session + # On PyPy it can take a few rounds to collect the session. + for _ in range(3): + gc.collect() + task.unlock.set() + await task.join(5) + self.assertFalse(task.is_alive()) + self.assertIsNone(task.exc) await async_wait_until(lambda: pool.active_sockets == 0, "return socket") # Run another operation to ensure the socket still works. await client[self.db.name].test.delete_many({}) -if _IS_SYNC: - - class PoolLocker(ExceptionCatchingThread): - def __init__(self, pool): - super().__init__(target=self.lock_pool) - self.pool = pool - self.daemon = True - self.locked = threading.Event() - self.unlock = threading.Event() - - def lock_pool(self): - with self.pool.lock: - self.locked.set() - # Wait for the unlock flag. - unlock_pool = self.unlock.wait(10) - if not unlock_pool: - raise Exception("timed out waiting for unlock signal: deadlock?") +class PoolLocker(ExceptionCatchingTask): + def __init__(self, pool): + super().__init__(target=self.lock_pool) + self.pool = pool + self.daemon = True + self.locked = create_async_event() + self.unlock = create_async_event() -else: + async def lock_pool(self): + async with self.pool.lock: + self.locked.set() + # Wait for the unlock flag. + unlock_pool = await self.wait(self.unlock, 10) + if not unlock_pool: + raise Exception("timed out waiting for unlock signal: deadlock?") - class PoolLocker(ExceptionCatchingTask): - def __init__(self, pool): - super().__init__(self.lock_pool) - self.pool = pool - self.daemon = True - self.locked = asyncio.Event() - self.unlock = asyncio.Event() - - async def lock_pool(self): - async with self.pool.lock: - self.locked.set() - # Wait for the unlock flag. - try: - await asyncio.wait_for(self.unlock.wait(), timeout=10) - except asyncio.TimeoutError: - raise Exception("timed out waiting for unlock signal: deadlock?") - - def is_alive(self): - return not self.task.done() + async def wait(self, event: Event, timeout: int): + if _IS_SYNC: + return event.wait(timeout) + else: + try: + await asyncio.wait_for(event.wait(), timeout=timeout) + except asyncio.TimeoutError: + return False + return True if __name__ == "__main__": diff --git a/test/helpers.py b/test/helpers.py index bd9e23bba4..d6d1202893 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -404,5 +404,28 @@ def is_alive(self): def run(self): if self.target: - self.target() + if _IS_SYNC: + super().run() + else: + self.target() self.stopped = True + + +class ExceptionCatchingTask(ConcurrentRunner): + """A Task that stores any exception encountered while running.""" + + def __init__(self, *args, **kwargs): + super().__init__("ExceptionCatchingTask", *args, **kwargs) + self.exc = None + + def run(self): + try: + if _IS_SYNC: + super().run() + else: + self.target() + except BaseException as exc: + self.exc = exc + raise + finally: + self.stopped = True diff --git a/test/test_load_balancer.py b/test/test_load_balancer.py index 042fbea33e..d0afab35a1 100644 --- a/test/test_load_balancer.py +++ b/test/test_load_balancer.py @@ -21,6 +21,8 @@ import pathlib import sys import threading +from asyncio import Event +from test.helpers import ConcurrentRunner, ExceptionCatchingTask import pytest @@ -29,8 +31,7 @@ from test import IntegrationTest, client_context, unittest from test.unified_format import generate_test_classes from test.utils import ( - ExceptionCatchingTask, - ExceptionCatchingThread, + create_event, get_pool, wait_until, ) @@ -116,35 +117,19 @@ def _test_no_gc_deadlock(self, create_resource): if client_context.load_balancer: self.assertEqual(pool.active_sockets, 1) # Pinned. - if _IS_SYNC: - thread = PoolLocker(pool) - thread.start() - self.assertTrue(thread.locked.wait(5), "timed out") - # Garbage collect the resource while the pool is locked to ensure we - # don't deadlock. - del resource - # On PyPy it can take a few rounds to collect the cursor. - for _ in range(3): - gc.collect() - thread.unlock.set() - thread.join(5) - self.assertFalse(thread.is_alive()) - self.assertIsNone(thread.exc) - - else: - task = PoolLocker(pool) - self.assertTrue(asyncio.wait_for(task.locked.wait(), timeout=5), "timed out") # type: ignore[arg-type] - - # Garbage collect the resource while the pool is locked to ensure we - # don't deadlock. - del resource - # On PyPy it can take a few rounds to collect the cursor. - for _ in range(3): - gc.collect() - task.unlock.set() - task.run() - self.assertFalse(task.is_alive()) - self.assertIsNone(task.exc) + task = PoolLocker(pool) + task.start() + self.assertTrue(task.wait(task.locked, 5), "timed out") + # Garbage collect the resource while the pool is locked to ensure we + # don't deadlock. + del resource + # On PyPy it can take a few rounds to collect the cursor. + for _ in range(3): + gc.collect() + task.unlock.set() + task.join(5) + self.assertFalse(task.is_alive()) + self.assertIsNone(task.exc) wait_until(lambda: pool.active_sockets == 0, "return socket") # Run another operation to ensure the socket still works. @@ -164,80 +149,50 @@ def test_session_gc(self): if client_context.load_balancer: self.assertEqual(pool.active_sockets, 1) # Pinned. - if _IS_SYNC: - thread = PoolLocker(pool) - thread.start() - self.assertTrue(thread.locked.wait(5), "timed out") - # Garbage collect the session while the pool is locked to ensure we - # don't deadlock. - del session - # On PyPy it can take a few rounds to collect the session. - for _ in range(3): - gc.collect() - thread.unlock.set() - thread.join(5) - self.assertFalse(thread.is_alive()) - self.assertIsNone(thread.exc) - - else: - task = PoolLocker(pool) - self.assertTrue(asyncio.wait_for(task.locked.wait(), timeout=5), "timed out") # type: ignore[arg-type] - - # Garbage collect the session while the pool is locked to ensure we - # don't deadlock. - del session - # On PyPy it can take a few rounds to collect the cursor. - for _ in range(3): - gc.collect() - task.unlock.set() - task.run() - self.assertFalse(task.is_alive()) - self.assertIsNone(task.exc) + task = PoolLocker(pool) + task.start() + self.assertTrue(task.wait(task.locked, 5), "timed out") + # Garbage collect the session while the pool is locked to ensure we + # don't deadlock. + del session + # On PyPy it can take a few rounds to collect the session. + for _ in range(3): + gc.collect() + task.unlock.set() + task.join(5) + self.assertFalse(task.is_alive()) + self.assertIsNone(task.exc) wait_until(lambda: pool.active_sockets == 0, "return socket") # Run another operation to ensure the socket still works. client[self.db.name].test.delete_many({}) -if _IS_SYNC: - - class PoolLocker(ExceptionCatchingThread): - def __init__(self, pool): - super().__init__(target=self.lock_pool) - self.pool = pool - self.daemon = True - self.locked = threading.Event() - self.unlock = threading.Event() - - def lock_pool(self): - with self.pool.lock: - self.locked.set() - # Wait for the unlock flag. - unlock_pool = self.unlock.wait(10) - if not unlock_pool: - raise Exception("timed out waiting for unlock signal: deadlock?") +class PoolLocker(ExceptionCatchingTask): + def __init__(self, pool): + super().__init__(target=self.lock_pool) + self.pool = pool + self.daemon = True + self.locked = create_event() + self.unlock = create_event() -else: + def lock_pool(self): + with self.pool.lock: + self.locked.set() + # Wait for the unlock flag. + unlock_pool = self.wait(self.unlock, 10) + if not unlock_pool: + raise Exception("timed out waiting for unlock signal: deadlock?") - class PoolLocker(ExceptionCatchingTask): - def __init__(self, pool): - super().__init__(self.lock_pool) - self.pool = pool - self.daemon = True - self.locked = asyncio.Event() - self.unlock = asyncio.Event() - - def lock_pool(self): - with self.pool.lock: - self.locked.set() - # Wait for the unlock flag. - try: - asyncio.wait_for(self.unlock.wait(), timeout=10) - except asyncio.TimeoutError: - raise Exception("timed out waiting for unlock signal: deadlock?") - - def is_alive(self): - return not self.task.done() + def wait(self, event: Event, timeout: int): + if _IS_SYNC: + return event.wait(timeout) + else: + try: + asyncio.wait_for(event.wait(), timeout=timeout) + except asyncio.TimeoutError: + return False + return True if __name__ == "__main__": diff --git a/test/utils.py b/test/utils.py index b2897cff36..073fa45da4 100644 --- a/test/utils.py +++ b/test/utils.py @@ -913,37 +913,6 @@ def is_greenthread_patched(): return gevent_monkey_patched() or eventlet_monkey_patched() -class ExceptionCatchingThread(threading.Thread): - """A thread that stores any exception encountered from run().""" - - def __init__(self, *args, **kwargs): - self.exc = None - super().__init__(*args, **kwargs) - - def run(self): - try: - super().run() - except BaseException as exc: - self.exc = exc - raise - - -class ExceptionCatchingTask: - """A Task that stores any exception encountered from its task.""" - - def __init__(self, target): - self.exc = None - self.target = target - self.task = create_task(self.run()) - - async def run(self): - try: - await self.target() - except BaseException as exc: - self.exc = exc - raise - - def parse_read_preference(pref): # Make first letter lowercase to match read_pref's modes. mode_string = pref.get("mode", "primary") @@ -1096,3 +1065,11 @@ async def async_set_fail_point(client, command_args): cmd = SON([("configureFailPoint", "failCommand")]) cmd.update(command_args) await client.admin.command(cmd) + + +def create_async_event(): + return asyncio.Event + + +def create_event(): + return threading.Event() diff --git a/tools/synchro.py b/tools/synchro.py index 37f6ba8ef3..9d5a77e241 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -122,6 +122,7 @@ "SpecRunnerTask": "SpecRunnerThread", "AsyncMockConnection": "MockConnection", "AsyncMockPool": "MockPool", + "create_async_event": "create_event", } docstring_replacements: dict[tuple[str, str], str] = { From 12e2e39fccdc191f6a1381203102c4b5e5f90545 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Tue, 4 Feb 2025 11:04:13 -0500 Subject: [PATCH 3/5] WIP --- test/asynchronous/helpers.py | 22 ++++++++++-------- test/asynchronous/test_session.py | 38 ++++++++++++++++--------------- test/helpers.py | 22 ++++++++++-------- test/test_bson.py | 6 ++--- test/test_session.py | 32 ++++++++++++++------------ 5 files changed, 64 insertions(+), 56 deletions(-) diff --git a/test/asynchronous/helpers.py b/test/asynchronous/helpers.py index dc05271cf3..692969ac98 100644 --- a/test/asynchronous/helpers.py +++ b/test/asynchronous/helpers.py @@ -381,14 +381,16 @@ def disable(self): class ConcurrentRunner(PARENT): - def __init__(self, name, *args, **kwargs): + def __init__(self, **kwargs): if _IS_SYNC: - super().__init__(*args, **kwargs) - self.name = name + super().__init__(**kwargs) + self.name = kwargs.get("name", "ConcurrentRunner") self.stopped = False self.task = None if "target" in kwargs: self.target = kwargs["target"] + if "args" in kwargs: + self.args = kwargs["args"] if not _IS_SYNC: @@ -407,23 +409,23 @@ async def run(self): if _IS_SYNC: super().run() else: - await self.target() + if self.args: + await self.target(*self.args) + else: + await self.target() self.stopped = True class ExceptionCatchingTask(ConcurrentRunner): """A Task that stores any exception encountered while running.""" - def __init__(self, *args, **kwargs): - super().__init__("ExceptionCatchingTask", *args, **kwargs) + def __init__(self, **kwargs): + super().__init__(**kwargs) self.exc = None async def run(self): try: - if _IS_SYNC: - await super().run() - else: - await self.target() + await super().run() except BaseException as exc: self.exc = exc raise diff --git a/test/asynchronous/test_session.py b/test/asynchronous/test_session.py index 42bc253b56..03d1032b5b 100644 --- a/test/asynchronous/test_session.py +++ b/test/asynchronous/test_session.py @@ -15,10 +15,13 @@ """Test the client_session module.""" from __future__ import annotations +import asyncio import copy import sys import time +from asyncio import iscoroutinefunction from io import BytesIO +from test.asynchronous.helpers import ExceptionCatchingTask from typing import Any, Callable, List, Set, Tuple from pymongo.synchronous.mongo_client import MongoClient @@ -35,7 +38,6 @@ ) from test.utils import ( EventListener, - ExceptionCatchingThread, OvertCommandListener, async_wait_until, ) @@ -184,8 +186,7 @@ async def _test_ops(self, client, *ops): f"{f.__name__} did not return implicit session to pool", ) - @async_client_context.require_sync - def test_implicit_sessions_checkout(self): + async def test_implicit_sessions_checkout(self): # "To confirm that implicit sessions only allocate their server session after a # successful connection checkout" test from Driver Sessions Spec. succeeded = False @@ -193,7 +194,7 @@ def test_implicit_sessions_checkout(self): failures = 0 for _ in range(5): listener = OvertCommandListener() - client = self.async_rs_or_single_client(event_listeners=[listener], maxPoolSize=1) + client = await self.async_rs_or_single_client(event_listeners=[listener], maxPoolSize=1) cursor = client.db.test.find({}) ops: List[Tuple[Callable, List[Any]]] = [ (client.db.test.find_one, [{"_id": 1}]), @@ -210,26 +211,27 @@ def test_implicit_sessions_checkout(self): (cursor.distinct, ["_id"]), (client.db.list_collections, []), ] - threads = [] + tasks = [] listener.reset() - def thread_target(op, *args): - res = op(*args) + async def target(op, *args): + if iscoroutinefunction(op): + res = await op(*args) + else: + res = op(*args) if isinstance(res, (AsyncCursor, AsyncCommandCursor)): - list(res) # type: ignore[call-overload] + await res.to_list() for op, args in ops: - threads.append( - ExceptionCatchingThread( - target=thread_target, args=[op, *args], name=op.__name__ - ) + tasks.append( + ExceptionCatchingTask(target=target, args=[op, *args], name=op.__name__) ) - threads[-1].start() - self.assertEqual(len(threads), len(ops)) - for thread in threads: - thread.join() - self.assertIsNone(thread.exc) - client.close() + await tasks[-1].start() + self.assertEqual(len(tasks), len(ops)) + for t in tasks: + await t.join() + self.assertIsNone(t.exc) + await client.close() lsid_set.clear() for i in listener.started_events: if i.command.get("lsid"): diff --git a/test/helpers.py b/test/helpers.py index d6d1202893..cfb73ccc96 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -381,14 +381,16 @@ def disable(self): class ConcurrentRunner(PARENT): - def __init__(self, name, *args, **kwargs): + def __init__(self, **kwargs): if _IS_SYNC: - super().__init__(*args, **kwargs) - self.name = name + super().__init__(**kwargs) + self.name = kwargs.get("name", "ConcurrentRunner") self.stopped = False self.task = None if "target" in kwargs: self.target = kwargs["target"] + if "args" in kwargs: + self.args = kwargs["args"] if not _IS_SYNC: @@ -407,23 +409,23 @@ def run(self): if _IS_SYNC: super().run() else: - self.target() + if self.args: + self.target(*self.args) + else: + self.target() self.stopped = True class ExceptionCatchingTask(ConcurrentRunner): """A Task that stores any exception encountered while running.""" - def __init__(self, *args, **kwargs): - super().__init__("ExceptionCatchingTask", *args, **kwargs) + def __init__(self, **kwargs): + super().__init__(**kwargs) self.exc = None def run(self): try: - if _IS_SYNC: - super().run() - else: - self.target() + super().run() except BaseException as exc: self.exc = exc raise diff --git a/test/test_bson.py b/test/test_bson.py index e601be4915..e704efe451 100644 --- a/test/test_bson.py +++ b/test/test_bson.py @@ -33,7 +33,7 @@ sys.path[0:0] = [""] from test import qcheck, unittest -from test.utils import ExceptionCatchingThread +from test.helpers import ExceptionCatchingTask import bson from bson import ( @@ -1075,7 +1075,7 @@ def target(i): my_int = type(f"MyInt_{i}_{j}", (int,), {}) bson.encode({"my_int": my_int()}) - threads = [ExceptionCatchingThread(target=target, args=(i,)) for i in range(3)] + threads = [ExceptionCatchingTask(target=target, args=(i,)) for i in range(3)] for t in threads: t.start() @@ -1114,7 +1114,7 @@ def __repr__(self): def test_doc_in_invalid_document_error_message_mapping(self): class MyMapping(abc.Mapping): - def keys(): + def keys(self): return ["t"] def __getitem__(self, name): diff --git a/test/test_session.py b/test/test_session.py index 634efa11c0..175a282495 100644 --- a/test/test_session.py +++ b/test/test_session.py @@ -15,10 +15,13 @@ """Test the client_session module.""" from __future__ import annotations +import asyncio import copy import sys import time +from asyncio import iscoroutinefunction from io import BytesIO +from test.helpers import ExceptionCatchingTask from typing import Any, Callable, List, Set, Tuple from pymongo.synchronous.mongo_client import MongoClient @@ -35,7 +38,6 @@ ) from test.utils import ( EventListener, - ExceptionCatchingThread, OvertCommandListener, wait_until, ) @@ -184,7 +186,6 @@ def _test_ops(self, client, *ops): f"{f.__name__} did not return implicit session to pool", ) - @client_context.require_sync def test_implicit_sessions_checkout(self): # "To confirm that implicit sessions only allocate their server session after a # successful connection checkout" test from Driver Sessions Spec. @@ -210,25 +211,26 @@ def test_implicit_sessions_checkout(self): (cursor.distinct, ["_id"]), (client.db.list_collections, []), ] - threads = [] + tasks = [] listener.reset() - def thread_target(op, *args): - res = op(*args) + def target(op, *args): + if iscoroutinefunction(op): + res = op(*args) + else: + res = op(*args) if isinstance(res, (Cursor, CommandCursor)): - list(res) # type: ignore[call-overload] + res.to_list() for op, args in ops: - threads.append( - ExceptionCatchingThread( - target=thread_target, args=[op, *args], name=op.__name__ - ) + tasks.append( + ExceptionCatchingTask(target=target, args=[op, *args], name=op.__name__) ) - threads[-1].start() - self.assertEqual(len(threads), len(ops)) - for thread in threads: - thread.join() - self.assertIsNone(thread.exc) + tasks[-1].start() + self.assertEqual(len(tasks), len(ops)) + for t in tasks: + t.join() + self.assertIsNone(t.exc) client.close() lsid_set.clear() for i in listener.started_events: From 8c5dbc390a5b1aa5f467d01d74b9828a00637bdf Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Wed, 5 Feb 2025 15:37:58 -0500 Subject: [PATCH 4/5] Fix asyncio.Event creation --- test/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/utils.py b/test/utils.py index 073fa45da4..5c1e0bfb7c 100644 --- a/test/utils.py +++ b/test/utils.py @@ -1068,7 +1068,7 @@ async def async_set_fail_point(client, command_args): def create_async_event(): - return asyncio.Event + return asyncio.Event() def create_event(): From 9a7568a89519d0f58bc2e088df53a4c18e30b0a5 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Wed, 5 Feb 2025 16:24:44 -0500 Subject: [PATCH 5/5] ConcurrentRunner.join timeout defaults to None, not 0 --- test/asynchronous/helpers.py | 4 ++-- test/asynchronous/test_load_balancer.py | 2 +- test/helpers.py | 4 ++-- test/test_load_balancer.py | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/test/asynchronous/helpers.py b/test/asynchronous/helpers.py index 748965d870..28260d0a52 100644 --- a/test/asynchronous/helpers.py +++ b/test/asynchronous/helpers.py @@ -40,7 +40,7 @@ except ImportError: HAVE_IPADDRESS = False from functools import wraps -from typing import Any, Callable, Dict, Generator, no_type_check +from typing import Any, Callable, Dict, Generator, Optional, no_type_check from unittest import SkipTest from bson.son import SON @@ -395,7 +395,7 @@ def __init__(self, **kwargs): async def start(self): self.task = create_task(self.run(), name=self.name) - async def join(self, timeout: float | None = 0): # type: ignore[override] + async def join(self, timeout: Optional[float] = None): # type: ignore[override] if self.task is not None: await asyncio.wait([self.task], timeout=timeout) diff --git a/test/asynchronous/test_load_balancer.py b/test/asynchronous/test_load_balancer.py index 13b4035852..fd50841c87 100644 --- a/test/asynchronous/test_load_balancer.py +++ b/test/asynchronous/test_load_balancer.py @@ -186,7 +186,7 @@ async def lock_pool(self): async def wait(self, event: Event, timeout: int): if _IS_SYNC: - return event.wait(timeout) + return event.wait(timeout) # type: ignore[call-arg] else: try: await asyncio.wait_for(event.wait(), timeout=timeout) diff --git a/test/helpers.py b/test/helpers.py index b19777e699..3f51fde08c 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -40,7 +40,7 @@ except ImportError: HAVE_IPADDRESS = False from functools import wraps -from typing import Any, Callable, Dict, Generator, no_type_check +from typing import Any, Callable, Dict, Generator, Optional, no_type_check from unittest import SkipTest from bson.son import SON @@ -395,7 +395,7 @@ def __init__(self, **kwargs): def start(self): self.task = create_task(self.run(), name=self.name) - def join(self, timeout: float | None = 0): # type: ignore[override] + def join(self, timeout: Optional[float] = None): # type: ignore[override] if self.task is not None: asyncio.wait([self.task], timeout=timeout) diff --git a/test/test_load_balancer.py b/test/test_load_balancer.py index d0afab35a1..7db19b46b5 100644 --- a/test/test_load_balancer.py +++ b/test/test_load_balancer.py @@ -186,7 +186,7 @@ def lock_pool(self): def wait(self, event: Event, timeout: int): if _IS_SYNC: - return event.wait(timeout) + return event.wait(timeout) # type: ignore[call-arg] else: try: asyncio.wait_for(event.wait(), timeout=timeout)