Skip to content

PYTHON-4864 - Create async version of SpecRunnerThread #2094

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Feb 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions test/asynchronous/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Shared constants and helper methods for pymongo, bson, and gridfs test suites."""
from __future__ import annotations

import asyncio
import base64
import gc
import multiprocessing
Expand All @@ -30,6 +31,8 @@
import warnings
from asyncio import iscoroutinefunction

from pymongo._asyncio_task import create_task

try:
import ipaddress

Expand Down Expand Up @@ -369,3 +372,37 @@ def disable(self):
os.environ.pop("SSL_CERT_FILE")
else:
os.environ["SSL_CERT_FILE"] = self.original_certs


if _IS_SYNC:
PARENT = threading.Thread
else:
PARENT = object


class ConcurrentRunner(PARENT):
def __init__(self, name, *args, **kwargs):
if _IS_SYNC:
super().__init__(*args, **kwargs)
self.name = name
self.stopped = False
self.task = None
if "target" in kwargs:
self.target = kwargs["target"]

if not _IS_SYNC:

async def start(self):
self.task = create_task(self.run(), name=self.name)

async def join(self, timeout: float | None = 0): # type: ignore[override]
if self.task is not None:
await asyncio.wait([self.task], timeout=timeout)

def is_alive(self):
return not self.stopped

async def run(self):
if self.target:
await self.target()
self.stopped = True
16 changes: 8 additions & 8 deletions test/asynchronous/unified_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
client_knobs,
unittest,
)
from test.asynchronous.utils_spec_runner import SpecRunnerTask
from test.unified_format_shared import (
KMS_TLS_OPTS,
PLACEHOLDER_MAP,
Expand All @@ -58,7 +59,6 @@
snake_to_camel,
wait_until,
)
from test.utils_spec_runner import SpecRunnerThread
from test.version import Version
from typing import Any, Dict, List, Mapping, Optional

Expand Down Expand Up @@ -382,8 +382,8 @@ async def drop(self: AsyncGridFSBucket, *args: Any, **kwargs: Any) -> None:
return
elif entity_type == "thread":
name = spec["id"]
thread = SpecRunnerThread(name)
thread.start()
thread = SpecRunnerTask(name)
await thread.start()
self[name] = thread
return

Expand Down Expand Up @@ -1177,16 +1177,16 @@ def primary_changed() -> bool:

wait_until(primary_changed, "change primary", timeout=timeout)

def _testOperation_runOnThread(self, spec):
async def _testOperation_runOnThread(self, spec):
"""Run the 'runOnThread' operation."""
thread = self.entity_map[spec["thread"]]
thread.schedule(lambda: self.run_entity_operation(spec["operation"]))
await thread.schedule(functools.partial(self.run_entity_operation, spec["operation"]))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we running any async unified tests that use runOnThread?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The SDAM unified tests are the only ones that use runOnThread. Those are currently slated to be converted to async, yes.


def _testOperation_waitForThread(self, spec):
async def _testOperation_waitForThread(self, spec):
"""Run the 'waitForThread' operation."""
thread = self.entity_map[spec["thread"]]
thread.stop()
thread.join(10)
await thread.stop()
await thread.join(10)
if thread.exc:
raise thread.exc
self.assertFalse(thread.is_alive(), "Thread {} is still running".format(spec["thread"]))
Expand Down
29 changes: 14 additions & 15 deletions test/asynchronous/utils_spec_runner.py
Copy link
Contributor

@sleepyStick sleepyStick Jan 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not loving the duplicated code between sync and async but i'm guessing its because the async version needs some more methods? If so, then i understand and can live with it >.<

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The async version doesn't implement threading.Thread, but it still needs to match the same API as the synchronous version. Let me see if I can reduce some of the duplication though.

Copy link
Contributor Author

@NoahStapp NoahStapp Jan 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did some refactoring, much less duplication now. Great call-out!

Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@
import asyncio
import functools
import os
import threading
import unittest
from asyncio import iscoroutinefunction
from collections import abc
from test.asynchronous import AsyncIntegrationTest, async_client_context, client_knobs
from test.asynchronous.helpers import ConcurrentRunner
from test.utils import (
CMAPListener,
CompareType,
Expand All @@ -47,6 +47,7 @@
from pymongo.asynchronous.command_cursor import AsyncCommandCursor
from pymongo.asynchronous.cursor import AsyncCursor
from pymongo.errors import AutoReconnect, BulkWriteError, OperationFailure, PyMongoError
from pymongo.lock import _async_cond_wait, _async_create_condition, _async_create_lock
from pymongo.read_concern import ReadConcern
from pymongo.read_preferences import ReadPreference
from pymongo.results import BulkWriteResult, _WriteResult
Expand All @@ -55,38 +56,36 @@
_IS_SYNC = False


class SpecRunnerThread(threading.Thread):
class SpecRunnerTask(ConcurrentRunner):
def __init__(self, name):
super().__init__()
self.name = name
super().__init__(name)
self.exc = None
self.daemon = True
self.cond = threading.Condition()
self.cond = _async_create_condition(_async_create_lock())
self.ops = []
self.stopped = False

def schedule(self, work):
async def schedule(self, work):
self.ops.append(work)
with self.cond:
async with self.cond:
self.cond.notify()

def stop(self):
async def stop(self):
self.stopped = True
with self.cond:
async with self.cond:
self.cond.notify()

def run(self):
async def run(self):
while not self.stopped or self.ops:
if not self.ops:
with self.cond:
self.cond.wait(10)
async with self.cond:
await _async_cond_wait(self.cond, 10)
if self.ops:
try:
work = self.ops.pop(0)
work()
await work()
except Exception as exc:
self.exc = exc
self.stop()
await self.stop()


class AsyncSpecTestCreator:
Expand Down
37 changes: 37 additions & 0 deletions test/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Shared constants and helper methods for pymongo, bson, and gridfs test suites."""
from __future__ import annotations

import asyncio
import base64
import gc
import multiprocessing
Expand All @@ -30,6 +31,8 @@
import warnings
from asyncio import iscoroutinefunction

from pymongo._asyncio_task import create_task

try:
import ipaddress

Expand Down Expand Up @@ -369,3 +372,37 @@ def disable(self):
os.environ.pop("SSL_CERT_FILE")
else:
os.environ["SSL_CERT_FILE"] = self.original_certs


if _IS_SYNC:
PARENT = threading.Thread
else:
PARENT = object


class ConcurrentRunner(PARENT):
def __init__(self, name, *args, **kwargs):
if _IS_SYNC:
super().__init__(*args, **kwargs)
self.name = name
self.stopped = False
self.task = None
if "target" in kwargs:
self.target = kwargs["target"]

if not _IS_SYNC:

def start(self):
self.task = create_task(self.run(), name=self.name)

def join(self, timeout: float | None = 0): # type: ignore[override]
if self.task is not None:
asyncio.wait([self.task], timeout=timeout)

def is_alive(self):
return not self.stopped

def run(self):
if self.target:
self.target()
self.stopped = True
2 changes: 1 addition & 1 deletion test/unified_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -1167,7 +1167,7 @@ def primary_changed() -> bool:
def _testOperation_runOnThread(self, spec):
"""Run the 'runOnThread' operation."""
thread = self.entity_map[spec["thread"]]
thread.schedule(lambda: self.run_entity_operation(spec["operation"]))
thread.schedule(functools.partial(self.run_entity_operation, spec["operation"]))

def _testOperation_waitForThread(self, spec):
"""Run the 'waitForThread' operation."""
Expand Down
13 changes: 6 additions & 7 deletions test/utils_spec_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@
import asyncio
import functools
import os
import threading
import unittest
from asyncio import iscoroutinefunction
from collections import abc
from test import IntegrationTest, client_context, client_knobs
from test.helpers import ConcurrentRunner
from test.utils import (
CMAPListener,
CompareType,
Expand All @@ -44,6 +44,7 @@
from gridfs import GridFSBucket
from gridfs.synchronous.grid_file import GridFSBucket
from pymongo.errors import AutoReconnect, BulkWriteError, OperationFailure, PyMongoError
from pymongo.lock import _cond_wait, _create_condition, _create_lock
from pymongo.read_concern import ReadConcern
from pymongo.read_preferences import ReadPreference
from pymongo.results import BulkWriteResult, _WriteResult
Expand All @@ -55,15 +56,13 @@
_IS_SYNC = True


class SpecRunnerThread(threading.Thread):
class SpecRunnerThread(ConcurrentRunner):
def __init__(self, name):
super().__init__()
self.name = name
super().__init__(name)
self.exc = None
self.daemon = True
self.cond = threading.Condition()
self.cond = _create_condition(_create_lock())
self.ops = []
self.stopped = False

def schedule(self, work):
self.ops.append(work)
Expand All @@ -79,7 +78,7 @@ def run(self):
while not self.stopped or self.ops:
if not self.ops:
with self.cond:
self.cond.wait(10)
_cond_wait(self.cond, 10)
if self.ops:
try:
work = self.ops.pop(0)
Expand Down
1 change: 1 addition & 0 deletions tools/synchro.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@
"_async_create_lock": "_create_lock",
"_async_create_condition": "_create_condition",
"_async_cond_wait": "_cond_wait",
"SpecRunnerTask": "SpecRunnerThread",
"AsyncMockConnection": "MockConnection",
"AsyncMockPool": "MockPool",
}
Expand Down
Loading