Skip to content

Commit b47143c

Browse files
authored
PYTHON-4864 - Create async version of SpecRunnerThread (#2094)
1 parent 1fda6a2 commit b47143c

File tree

7 files changed

+104
-31
lines changed

7 files changed

+104
-31
lines changed

test/asynchronous/helpers.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""Shared constants and helper methods for pymongo, bson, and gridfs test suites."""
1616
from __future__ import annotations
1717

18+
import asyncio
1819
import base64
1920
import gc
2021
import multiprocessing
@@ -30,6 +31,8 @@
3031
import warnings
3132
from asyncio import iscoroutinefunction
3233

34+
from pymongo._asyncio_task import create_task
35+
3336
try:
3437
import ipaddress
3538

@@ -369,3 +372,37 @@ def disable(self):
369372
os.environ.pop("SSL_CERT_FILE")
370373
else:
371374
os.environ["SSL_CERT_FILE"] = self.original_certs
375+
376+
377+
if _IS_SYNC:
378+
PARENT = threading.Thread
379+
else:
380+
PARENT = object
381+
382+
383+
class ConcurrentRunner(PARENT):
384+
def __init__(self, name, *args, **kwargs):
385+
if _IS_SYNC:
386+
super().__init__(*args, **kwargs)
387+
self.name = name
388+
self.stopped = False
389+
self.task = None
390+
if "target" in kwargs:
391+
self.target = kwargs["target"]
392+
393+
if not _IS_SYNC:
394+
395+
async def start(self):
396+
self.task = create_task(self.run(), name=self.name)
397+
398+
async def join(self, timeout: float | None = 0): # type: ignore[override]
399+
if self.task is not None:
400+
await asyncio.wait([self.task], timeout=timeout)
401+
402+
def is_alive(self):
403+
return not self.stopped
404+
405+
async def run(self):
406+
if self.target:
407+
await self.target()
408+
self.stopped = True

test/asynchronous/unified_format.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
client_knobs,
3636
unittest,
3737
)
38+
from test.asynchronous.utils_spec_runner import SpecRunnerTask
3839
from test.unified_format_shared import (
3940
KMS_TLS_OPTS,
4041
PLACEHOLDER_MAP,
@@ -58,7 +59,6 @@
5859
snake_to_camel,
5960
wait_until,
6061
)
61-
from test.utils_spec_runner import SpecRunnerThread
6262
from test.version import Version
6363
from typing import Any, Dict, List, Mapping, Optional
6464

@@ -382,8 +382,8 @@ async def drop(self: AsyncGridFSBucket, *args: Any, **kwargs: Any) -> None:
382382
return
383383
elif entity_type == "thread":
384384
name = spec["id"]
385-
thread = SpecRunnerThread(name)
386-
thread.start()
385+
thread = SpecRunnerTask(name)
386+
await thread.start()
387387
self[name] = thread
388388
return
389389

@@ -1177,16 +1177,16 @@ def primary_changed() -> bool:
11771177

11781178
wait_until(primary_changed, "change primary", timeout=timeout)
11791179

1180-
def _testOperation_runOnThread(self, spec):
1180+
async def _testOperation_runOnThread(self, spec):
11811181
"""Run the 'runOnThread' operation."""
11821182
thread = self.entity_map[spec["thread"]]
1183-
thread.schedule(lambda: self.run_entity_operation(spec["operation"]))
1183+
await thread.schedule(functools.partial(self.run_entity_operation, spec["operation"]))
11841184

1185-
def _testOperation_waitForThread(self, spec):
1185+
async def _testOperation_waitForThread(self, spec):
11861186
"""Run the 'waitForThread' operation."""
11871187
thread = self.entity_map[spec["thread"]]
1188-
thread.stop()
1189-
thread.join(10)
1188+
await thread.stop()
1189+
await thread.join(10)
11901190
if thread.exc:
11911191
raise thread.exc
11921192
self.assertFalse(thread.is_alive(), "Thread {} is still running".format(spec["thread"]))

test/asynchronous/utils_spec_runner.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,11 @@
1818
import asyncio
1919
import functools
2020
import os
21-
import threading
2221
import unittest
2322
from asyncio import iscoroutinefunction
2423
from collections import abc
2524
from test.asynchronous import AsyncIntegrationTest, async_client_context, client_knobs
25+
from test.asynchronous.helpers import ConcurrentRunner
2626
from test.utils import (
2727
CMAPListener,
2828
CompareType,
@@ -47,6 +47,7 @@
4747
from pymongo.asynchronous.command_cursor import AsyncCommandCursor
4848
from pymongo.asynchronous.cursor import AsyncCursor
4949
from pymongo.errors import AutoReconnect, BulkWriteError, OperationFailure, PyMongoError
50+
from pymongo.lock import _async_cond_wait, _async_create_condition, _async_create_lock
5051
from pymongo.read_concern import ReadConcern
5152
from pymongo.read_preferences import ReadPreference
5253
from pymongo.results import BulkWriteResult, _WriteResult
@@ -55,38 +56,36 @@
5556
_IS_SYNC = False
5657

5758

58-
class SpecRunnerThread(threading.Thread):
59+
class SpecRunnerTask(ConcurrentRunner):
5960
def __init__(self, name):
60-
super().__init__()
61-
self.name = name
61+
super().__init__(name)
6262
self.exc = None
6363
self.daemon = True
64-
self.cond = threading.Condition()
64+
self.cond = _async_create_condition(_async_create_lock())
6565
self.ops = []
66-
self.stopped = False
6766

68-
def schedule(self, work):
67+
async def schedule(self, work):
6968
self.ops.append(work)
70-
with self.cond:
69+
async with self.cond:
7170
self.cond.notify()
7271

73-
def stop(self):
72+
async def stop(self):
7473
self.stopped = True
75-
with self.cond:
74+
async with self.cond:
7675
self.cond.notify()
7776

78-
def run(self):
77+
async def run(self):
7978
while not self.stopped or self.ops:
8079
if not self.ops:
81-
with self.cond:
82-
self.cond.wait(10)
80+
async with self.cond:
81+
await _async_cond_wait(self.cond, 10)
8382
if self.ops:
8483
try:
8584
work = self.ops.pop(0)
86-
work()
85+
await work()
8786
except Exception as exc:
8887
self.exc = exc
89-
self.stop()
88+
await self.stop()
9089

9190

9291
class AsyncSpecTestCreator:

test/helpers.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""Shared constants and helper methods for pymongo, bson, and gridfs test suites."""
1616
from __future__ import annotations
1717

18+
import asyncio
1819
import base64
1920
import gc
2021
import multiprocessing
@@ -30,6 +31,8 @@
3031
import warnings
3132
from asyncio import iscoroutinefunction
3233

34+
from pymongo._asyncio_task import create_task
35+
3336
try:
3437
import ipaddress
3538

@@ -369,3 +372,37 @@ def disable(self):
369372
os.environ.pop("SSL_CERT_FILE")
370373
else:
371374
os.environ["SSL_CERT_FILE"] = self.original_certs
375+
376+
377+
if _IS_SYNC:
378+
PARENT = threading.Thread
379+
else:
380+
PARENT = object
381+
382+
383+
class ConcurrentRunner(PARENT):
384+
def __init__(self, name, *args, **kwargs):
385+
if _IS_SYNC:
386+
super().__init__(*args, **kwargs)
387+
self.name = name
388+
self.stopped = False
389+
self.task = None
390+
if "target" in kwargs:
391+
self.target = kwargs["target"]
392+
393+
if not _IS_SYNC:
394+
395+
def start(self):
396+
self.task = create_task(self.run(), name=self.name)
397+
398+
def join(self, timeout: float | None = 0): # type: ignore[override]
399+
if self.task is not None:
400+
asyncio.wait([self.task], timeout=timeout)
401+
402+
def is_alive(self):
403+
return not self.stopped
404+
405+
def run(self):
406+
if self.target:
407+
self.target()
408+
self.stopped = True

test/unified_format.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1167,7 +1167,7 @@ def primary_changed() -> bool:
11671167
def _testOperation_runOnThread(self, spec):
11681168
"""Run the 'runOnThread' operation."""
11691169
thread = self.entity_map[spec["thread"]]
1170-
thread.schedule(lambda: self.run_entity_operation(spec["operation"]))
1170+
thread.schedule(functools.partial(self.run_entity_operation, spec["operation"]))
11711171

11721172
def _testOperation_waitForThread(self, spec):
11731173
"""Run the 'waitForThread' operation."""

test/utils_spec_runner.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,11 @@
1818
import asyncio
1919
import functools
2020
import os
21-
import threading
2221
import unittest
2322
from asyncio import iscoroutinefunction
2423
from collections import abc
2524
from test import IntegrationTest, client_context, client_knobs
25+
from test.helpers import ConcurrentRunner
2626
from test.utils import (
2727
CMAPListener,
2828
CompareType,
@@ -44,6 +44,7 @@
4444
from gridfs import GridFSBucket
4545
from gridfs.synchronous.grid_file import GridFSBucket
4646
from pymongo.errors import AutoReconnect, BulkWriteError, OperationFailure, PyMongoError
47+
from pymongo.lock import _cond_wait, _create_condition, _create_lock
4748
from pymongo.read_concern import ReadConcern
4849
from pymongo.read_preferences import ReadPreference
4950
from pymongo.results import BulkWriteResult, _WriteResult
@@ -55,15 +56,13 @@
5556
_IS_SYNC = True
5657

5758

58-
class SpecRunnerThread(threading.Thread):
59+
class SpecRunnerThread(ConcurrentRunner):
5960
def __init__(self, name):
60-
super().__init__()
61-
self.name = name
61+
super().__init__(name)
6262
self.exc = None
6363
self.daemon = True
64-
self.cond = threading.Condition()
64+
self.cond = _create_condition(_create_lock())
6565
self.ops = []
66-
self.stopped = False
6766

6867
def schedule(self, work):
6968
self.ops.append(work)
@@ -79,7 +78,7 @@ def run(self):
7978
while not self.stopped or self.ops:
8079
if not self.ops:
8180
with self.cond:
82-
self.cond.wait(10)
81+
_cond_wait(self.cond, 10)
8382
if self.ops:
8483
try:
8584
work = self.ops.pop(0)

tools/synchro.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@
119119
"_async_create_lock": "_create_lock",
120120
"_async_create_condition": "_create_condition",
121121
"_async_cond_wait": "_cond_wait",
122+
"SpecRunnerTask": "SpecRunnerThread",
122123
"AsyncMockConnection": "MockConnection",
123124
"AsyncMockPool": "MockPool",
124125
}

0 commit comments

Comments
 (0)