diff --git a/test/asynchronous/helpers.py b/test/asynchronous/helpers.py index b5fc5d8ac4..7758f281e1 100644 --- a/test/asynchronous/helpers.py +++ b/test/asynchronous/helpers.py @@ -15,6 +15,7 @@ """Shared constants and helper methods for pymongo, bson, and gridfs test suites.""" from __future__ import annotations +import asyncio import base64 import gc import multiprocessing @@ -30,6 +31,8 @@ import warnings from asyncio import iscoroutinefunction +from pymongo._asyncio_task import create_task + try: import ipaddress @@ -369,3 +372,37 @@ def disable(self): os.environ.pop("SSL_CERT_FILE") else: os.environ["SSL_CERT_FILE"] = self.original_certs + + +if _IS_SYNC: + PARENT = threading.Thread +else: + PARENT = object + + +class ConcurrentRunner(PARENT): + def __init__(self, name, *args, **kwargs): + if _IS_SYNC: + super().__init__(*args, **kwargs) + self.name = name + self.stopped = False + self.task = None + if "target" in kwargs: + self.target = kwargs["target"] + + if not _IS_SYNC: + + async def start(self): + self.task = create_task(self.run(), name=self.name) + + async def join(self, timeout: float | None = 0): # type: ignore[override] + if self.task is not None: + await asyncio.wait([self.task], timeout=timeout) + + def is_alive(self): + return not self.stopped + + async def run(self): + if self.target: + await self.target() + self.stopped = True diff --git a/test/asynchronous/unified_format.py b/test/asynchronous/unified_format.py index 52d964eb3e..a7a6364497 100644 --- a/test/asynchronous/unified_format.py +++ b/test/asynchronous/unified_format.py @@ -35,6 +35,7 @@ client_knobs, unittest, ) +from test.asynchronous.utils_spec_runner import SpecRunnerTask from test.unified_format_shared import ( KMS_TLS_OPTS, PLACEHOLDER_MAP, @@ -58,7 +59,6 @@ snake_to_camel, wait_until, ) -from test.utils_spec_runner import SpecRunnerThread from test.version import Version from typing import Any, Dict, List, Mapping, Optional @@ -382,8 +382,8 @@ async def drop(self: AsyncGridFSBucket, *args: Any, **kwargs: Any) -> None: return elif entity_type == "thread": name = spec["id"] - thread = SpecRunnerThread(name) - thread.start() + thread = SpecRunnerTask(name) + await thread.start() self[name] = thread return @@ -1177,16 +1177,16 @@ def primary_changed() -> bool: wait_until(primary_changed, "change primary", timeout=timeout) - def _testOperation_runOnThread(self, spec): + async def _testOperation_runOnThread(self, spec): """Run the 'runOnThread' operation.""" thread = self.entity_map[spec["thread"]] - thread.schedule(lambda: self.run_entity_operation(spec["operation"])) + await thread.schedule(functools.partial(self.run_entity_operation, spec["operation"])) - def _testOperation_waitForThread(self, spec): + async def _testOperation_waitForThread(self, spec): """Run the 'waitForThread' operation.""" thread = self.entity_map[spec["thread"]] - thread.stop() - thread.join(10) + await thread.stop() + await thread.join(10) if thread.exc: raise thread.exc self.assertFalse(thread.is_alive(), "Thread {} is still running".format(spec["thread"])) diff --git a/test/asynchronous/utils_spec_runner.py b/test/asynchronous/utils_spec_runner.py index b79e5258b5..d103374313 100644 --- a/test/asynchronous/utils_spec_runner.py +++ b/test/asynchronous/utils_spec_runner.py @@ -18,11 +18,11 @@ import asyncio import functools import os -import threading import unittest from asyncio import iscoroutinefunction from collections import abc from test.asynchronous import AsyncIntegrationTest, async_client_context, client_knobs +from test.asynchronous.helpers import ConcurrentRunner from test.utils import ( CMAPListener, CompareType, @@ -47,6 +47,7 @@ from pymongo.asynchronous.command_cursor import AsyncCommandCursor from pymongo.asynchronous.cursor import AsyncCursor from pymongo.errors import AutoReconnect, BulkWriteError, OperationFailure, PyMongoError +from pymongo.lock import _async_cond_wait, _async_create_condition, _async_create_lock from pymongo.read_concern import ReadConcern from pymongo.read_preferences import ReadPreference from pymongo.results import BulkWriteResult, _WriteResult @@ -55,38 +56,36 @@ _IS_SYNC = False -class SpecRunnerThread(threading.Thread): +class SpecRunnerTask(ConcurrentRunner): def __init__(self, name): - super().__init__() - self.name = name + super().__init__(name) self.exc = None self.daemon = True - self.cond = threading.Condition() + self.cond = _async_create_condition(_async_create_lock()) self.ops = [] - self.stopped = False - def schedule(self, work): + async def schedule(self, work): self.ops.append(work) - with self.cond: + async with self.cond: self.cond.notify() - def stop(self): + async def stop(self): self.stopped = True - with self.cond: + async with self.cond: self.cond.notify() - def run(self): + async def run(self): while not self.stopped or self.ops: if not self.ops: - with self.cond: - self.cond.wait(10) + async with self.cond: + await _async_cond_wait(self.cond, 10) if self.ops: try: work = self.ops.pop(0) - work() + await work() except Exception as exc: self.exc = exc - self.stop() + await self.stop() class AsyncSpecTestCreator: diff --git a/test/helpers.py b/test/helpers.py index 11d5ab0374..bd9e23bba4 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -15,6 +15,7 @@ """Shared constants and helper methods for pymongo, bson, and gridfs test suites.""" from __future__ import annotations +import asyncio import base64 import gc import multiprocessing @@ -30,6 +31,8 @@ import warnings from asyncio import iscoroutinefunction +from pymongo._asyncio_task import create_task + try: import ipaddress @@ -369,3 +372,37 @@ def disable(self): os.environ.pop("SSL_CERT_FILE") else: os.environ["SSL_CERT_FILE"] = self.original_certs + + +if _IS_SYNC: + PARENT = threading.Thread +else: + PARENT = object + + +class ConcurrentRunner(PARENT): + def __init__(self, name, *args, **kwargs): + if _IS_SYNC: + super().__init__(*args, **kwargs) + self.name = name + self.stopped = False + self.task = None + if "target" in kwargs: + self.target = kwargs["target"] + + if not _IS_SYNC: + + def start(self): + self.task = create_task(self.run(), name=self.name) + + def join(self, timeout: float | None = 0): # type: ignore[override] + if self.task is not None: + asyncio.wait([self.task], timeout=timeout) + + def is_alive(self): + return not self.stopped + + def run(self): + if self.target: + self.target() + self.stopped = True diff --git a/test/unified_format.py b/test/unified_format.py index 372eb8abba..84f1553c53 100644 --- a/test/unified_format.py +++ b/test/unified_format.py @@ -1167,7 +1167,7 @@ def primary_changed() -> bool: def _testOperation_runOnThread(self, spec): """Run the 'runOnThread' operation.""" thread = self.entity_map[spec["thread"]] - thread.schedule(lambda: self.run_entity_operation(spec["operation"])) + thread.schedule(functools.partial(self.run_entity_operation, spec["operation"])) def _testOperation_waitForThread(self, spec): """Run the 'waitForThread' operation.""" diff --git a/test/utils_spec_runner.py b/test/utils_spec_runner.py index 4508502cd0..6a62112afb 100644 --- a/test/utils_spec_runner.py +++ b/test/utils_spec_runner.py @@ -18,11 +18,11 @@ import asyncio import functools import os -import threading import unittest from asyncio import iscoroutinefunction from collections import abc from test import IntegrationTest, client_context, client_knobs +from test.helpers import ConcurrentRunner from test.utils import ( CMAPListener, CompareType, @@ -44,6 +44,7 @@ from gridfs import GridFSBucket from gridfs.synchronous.grid_file import GridFSBucket from pymongo.errors import AutoReconnect, BulkWriteError, OperationFailure, PyMongoError +from pymongo.lock import _cond_wait, _create_condition, _create_lock from pymongo.read_concern import ReadConcern from pymongo.read_preferences import ReadPreference from pymongo.results import BulkWriteResult, _WriteResult @@ -55,15 +56,13 @@ _IS_SYNC = True -class SpecRunnerThread(threading.Thread): +class SpecRunnerThread(ConcurrentRunner): def __init__(self, name): - super().__init__() - self.name = name + super().__init__(name) self.exc = None self.daemon = True - self.cond = threading.Condition() + self.cond = _create_condition(_create_lock()) self.ops = [] - self.stopped = False def schedule(self, work): self.ops.append(work) @@ -79,7 +78,7 @@ def run(self): while not self.stopped or self.ops: if not self.ops: with self.cond: - self.cond.wait(10) + _cond_wait(self.cond, 10) if self.ops: try: work = self.ops.pop(0) diff --git a/tools/synchro.py b/tools/synchro.py index ef82db756d..4584c41a6c 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -119,6 +119,7 @@ "_async_create_lock": "_create_lock", "_async_create_condition": "_create_condition", "_async_cond_wait": "_cond_wait", + "SpecRunnerTask": "SpecRunnerThread", "AsyncMockConnection": "MockConnection", "AsyncMockPool": "MockPool", }