Skip to content

Commit 91fc210

Browse files
authored
Merge branch 'master' into PYTHON-5109
2 parents a3d14cc + 68237f7 commit 91fc210

11 files changed

+284
-40
lines changed

test/asynchronous/helpers.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""Shared constants and helper methods for pymongo, bson, and gridfs test suites."""
1616
from __future__ import annotations
1717

18+
import asyncio
1819
import base64
1920
import gc
2021
import multiprocessing
@@ -30,6 +31,8 @@
3031
import warnings
3132
from asyncio import iscoroutinefunction
3233

34+
from pymongo._asyncio_task import create_task
35+
3336
try:
3437
import ipaddress
3538

@@ -369,3 +372,37 @@ def disable(self):
369372
os.environ.pop("SSL_CERT_FILE")
370373
else:
371374
os.environ["SSL_CERT_FILE"] = self.original_certs
375+
376+
377+
if _IS_SYNC:
378+
PARENT = threading.Thread
379+
else:
380+
PARENT = object
381+
382+
383+
class ConcurrentRunner(PARENT):
384+
def __init__(self, name, *args, **kwargs):
385+
if _IS_SYNC:
386+
super().__init__(*args, **kwargs)
387+
self.name = name
388+
self.stopped = False
389+
self.task = None
390+
if "target" in kwargs:
391+
self.target = kwargs["target"]
392+
393+
if not _IS_SYNC:
394+
395+
async def start(self):
396+
self.task = create_task(self.run(), name=self.name)
397+
398+
async def join(self, timeout: float | None = 0): # type: ignore[override]
399+
if self.task is not None:
400+
await asyncio.wait([self.task], timeout=timeout)
401+
402+
def is_alive(self):
403+
return not self.stopped
404+
405+
async def run(self):
406+
if self.target:
407+
await self.target()
408+
self.stopped = True

test/asynchronous/test_run_command.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Copyright 2024-Present MongoDB, Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Run Command unified tests."""
16+
from __future__ import annotations
17+
18+
import os
19+
import unittest
20+
from pathlib import Path
21+
from test.asynchronous.unified_format import generate_test_classes
22+
23+
_IS_SYNC = False
24+
25+
# Location of JSON test specifications.
26+
if _IS_SYNC:
27+
TEST_PATH = os.path.join(Path(__file__).resolve().parent, "run_command")
28+
else:
29+
TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "run_command")
30+
31+
32+
globals().update(
33+
generate_test_classes(
34+
os.path.join(TEST_PATH, "unified"),
35+
module=__name__,
36+
)
37+
)
38+
39+
40+
if __name__ == "__main__":
41+
unittest.main()
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
# Copyright 2020-present MongoDB, Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from __future__ import annotations
15+
16+
import os
17+
import sys
18+
from pathlib import Path
19+
from typing import Any
20+
21+
sys.path[0:0] = [""]
22+
23+
from test import UnitTest, unittest
24+
from test.asynchronous.unified_format import MatchEvaluatorUtil, generate_test_classes
25+
26+
from bson import ObjectId
27+
28+
_IS_SYNC = False
29+
30+
# Location of JSON test specifications.
31+
if _IS_SYNC:
32+
TEST_PATH = os.path.join(Path(__file__).resolve().parent, "unified-test-format")
33+
else:
34+
TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "unified-test-format")
35+
36+
37+
globals().update(
38+
generate_test_classes(
39+
os.path.join(TEST_PATH, "valid-pass"),
40+
module=__name__,
41+
class_name_prefix="UnifiedTestFormat",
42+
expected_failures=[
43+
"Client side error in command starting transaction", # PYTHON-1894
44+
],
45+
RUN_ON_SERVERLESS=False,
46+
)
47+
)
48+
49+
50+
globals().update(
51+
generate_test_classes(
52+
os.path.join(TEST_PATH, "valid-fail"),
53+
module=__name__,
54+
class_name_prefix="UnifiedTestFormat",
55+
bypass_test_generation_errors=True,
56+
expected_failures=[
57+
".*", # All tests expected to fail
58+
],
59+
RUN_ON_SERVERLESS=False,
60+
)
61+
)
62+
63+
64+
class TestMatchEvaluatorUtil(UnitTest):
65+
def setUp(self):
66+
self.match_evaluator = MatchEvaluatorUtil(self)
67+
68+
def test_unsetOrMatches(self):
69+
spec: dict[str, Any] = {"$$unsetOrMatches": {"y": {"$$unsetOrMatches": 2}}}
70+
for actual in [{}, {"y": 2}, None]:
71+
self.match_evaluator.match_result(spec, actual)
72+
73+
spec = {"x": {"$$unsetOrMatches": {"y": {"$$unsetOrMatches": 2}}}}
74+
for actual in [{}, {"x": {}}, {"x": {"y": 2}}]:
75+
self.match_evaluator.match_result(spec, actual)
76+
77+
spec = {"y": {"$$unsetOrMatches": {"$$exists": True}}}
78+
self.match_evaluator.match_result(spec, {})
79+
self.match_evaluator.match_result(spec, {"y": 2})
80+
self.match_evaluator.match_result(spec, {"x": 1})
81+
self.match_evaluator.match_result(spec, {"y": {}})
82+
83+
def test_type(self):
84+
self.match_evaluator.match_result(
85+
{
86+
"operationType": "insert",
87+
"ns": {"db": "change-stream-tests", "coll": "test"},
88+
"fullDocument": {"_id": {"$$type": "objectId"}, "x": 1},
89+
},
90+
{
91+
"operationType": "insert",
92+
"fullDocument": {"_id": ObjectId("5fc93511ac93941052098f0c"), "x": 1},
93+
"ns": {"db": "change-stream-tests", "coll": "test"},
94+
},
95+
)
96+
97+
98+
if __name__ == "__main__":
99+
unittest.main()

test/asynchronous/unified_format.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
client_knobs,
3636
unittest,
3737
)
38+
from test.asynchronous.utils_spec_runner import SpecRunnerTask
3839
from test.unified_format_shared import (
3940
KMS_TLS_OPTS,
4041
PLACEHOLDER_MAP,
@@ -58,7 +59,6 @@
5859
snake_to_camel,
5960
wait_until,
6061
)
61-
from test.utils_spec_runner import SpecRunnerThread
6262
from test.version import Version
6363
from typing import Any, Dict, List, Mapping, Optional
6464

@@ -382,8 +382,8 @@ async def drop(self: AsyncGridFSBucket, *args: Any, **kwargs: Any) -> None:
382382
return
383383
elif entity_type == "thread":
384384
name = spec["id"]
385-
thread = SpecRunnerThread(name)
386-
thread.start()
385+
thread = SpecRunnerTask(name)
386+
await thread.start()
387387
self[name] = thread
388388
return
389389

@@ -711,7 +711,7 @@ async def _databaseOperation_runCommand(self, target, **kwargs):
711711
return await target.command(**kwargs)
712712

713713
async def _databaseOperation_runCursorCommand(self, target, **kwargs):
714-
return list(await self._databaseOperation_createCommandCursor(target, **kwargs))
714+
return await (await self._databaseOperation_createCommandCursor(target, **kwargs)).to_list()
715715

716716
async def _databaseOperation_createCommandCursor(self, target, **kwargs):
717717
self.__raise_if_unsupported("createCommandCursor", target, AsyncDatabase)
@@ -1177,16 +1177,16 @@ def primary_changed() -> bool:
11771177

11781178
wait_until(primary_changed, "change primary", timeout=timeout)
11791179

1180-
def _testOperation_runOnThread(self, spec):
1180+
async def _testOperation_runOnThread(self, spec):
11811181
"""Run the 'runOnThread' operation."""
11821182
thread = self.entity_map[spec["thread"]]
1183-
thread.schedule(lambda: self.run_entity_operation(spec["operation"]))
1183+
await thread.schedule(functools.partial(self.run_entity_operation, spec["operation"]))
11841184

1185-
def _testOperation_waitForThread(self, spec):
1185+
async def _testOperation_waitForThread(self, spec):
11861186
"""Run the 'waitForThread' operation."""
11871187
thread = self.entity_map[spec["thread"]]
1188-
thread.stop()
1189-
thread.join(10)
1188+
await thread.stop()
1189+
await thread.join(10)
11901190
if thread.exc:
11911191
raise thread.exc
11921192
self.assertFalse(thread.is_alive(), "Thread {} is still running".format(spec["thread"]))

test/asynchronous/utils_spec_runner.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,11 @@
1818
import asyncio
1919
import functools
2020
import os
21-
import threading
2221
import unittest
2322
from asyncio import iscoroutinefunction
2423
from collections import abc
2524
from test.asynchronous import AsyncIntegrationTest, async_client_context, client_knobs
25+
from test.asynchronous.helpers import ConcurrentRunner
2626
from test.utils import (
2727
CMAPListener,
2828
CompareType,
@@ -47,6 +47,7 @@
4747
from pymongo.asynchronous.command_cursor import AsyncCommandCursor
4848
from pymongo.asynchronous.cursor import AsyncCursor
4949
from pymongo.errors import AutoReconnect, BulkWriteError, OperationFailure, PyMongoError
50+
from pymongo.lock import _async_cond_wait, _async_create_condition, _async_create_lock
5051
from pymongo.read_concern import ReadConcern
5152
from pymongo.read_preferences import ReadPreference
5253
from pymongo.results import BulkWriteResult, _WriteResult
@@ -55,38 +56,36 @@
5556
_IS_SYNC = False
5657

5758

58-
class SpecRunnerThread(threading.Thread):
59+
class SpecRunnerTask(ConcurrentRunner):
5960
def __init__(self, name):
60-
super().__init__()
61-
self.name = name
61+
super().__init__(name)
6262
self.exc = None
6363
self.daemon = True
64-
self.cond = threading.Condition()
64+
self.cond = _async_create_condition(_async_create_lock())
6565
self.ops = []
66-
self.stopped = False
6766

68-
def schedule(self, work):
67+
async def schedule(self, work):
6968
self.ops.append(work)
70-
with self.cond:
69+
async with self.cond:
7170
self.cond.notify()
7271

73-
def stop(self):
72+
async def stop(self):
7473
self.stopped = True
75-
with self.cond:
74+
async with self.cond:
7675
self.cond.notify()
7776

78-
def run(self):
77+
async def run(self):
7978
while not self.stopped or self.ops:
8079
if not self.ops:
81-
with self.cond:
82-
self.cond.wait(10)
80+
async with self.cond:
81+
await _async_cond_wait(self.cond, 10)
8382
if self.ops:
8483
try:
8584
work = self.ops.pop(0)
86-
work()
85+
await work()
8786
except Exception as exc:
8887
self.exc = exc
89-
self.stop()
88+
await self.stop()
9089

9190

9291
class AsyncSpecTestCreator:

test/helpers.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""Shared constants and helper methods for pymongo, bson, and gridfs test suites."""
1616
from __future__ import annotations
1717

18+
import asyncio
1819
import base64
1920
import gc
2021
import multiprocessing
@@ -30,6 +31,8 @@
3031
import warnings
3132
from asyncio import iscoroutinefunction
3233

34+
from pymongo._asyncio_task import create_task
35+
3336
try:
3437
import ipaddress
3538

@@ -369,3 +372,37 @@ def disable(self):
369372
os.environ.pop("SSL_CERT_FILE")
370373
else:
371374
os.environ["SSL_CERT_FILE"] = self.original_certs
375+
376+
377+
if _IS_SYNC:
378+
PARENT = threading.Thread
379+
else:
380+
PARENT = object
381+
382+
383+
class ConcurrentRunner(PARENT):
384+
def __init__(self, name, *args, **kwargs):
385+
if _IS_SYNC:
386+
super().__init__(*args, **kwargs)
387+
self.name = name
388+
self.stopped = False
389+
self.task = None
390+
if "target" in kwargs:
391+
self.target = kwargs["target"]
392+
393+
if not _IS_SYNC:
394+
395+
def start(self):
396+
self.task = create_task(self.run(), name=self.name)
397+
398+
def join(self, timeout: float | None = 0): # type: ignore[override]
399+
if self.task is not None:
400+
asyncio.wait([self.task], timeout=timeout)
401+
402+
def is_alive(self):
403+
return not self.stopped
404+
405+
def run(self):
406+
if self.target:
407+
self.target()
408+
self.stopped = True

0 commit comments

Comments
 (0)