From 19896a88e164b5c3b782e380cc05cb4c5b18e338 Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Tue, 28 Jan 2025 15:13:11 -0800 Subject: [PATCH 1/7] Convert test.test_connection_monitoring to async --- .../test_connection_monitoring.py | 479 ++++++++++++++++++ test/test_connection_monitoring.py | 6 +- test/utils.py | 2 +- tools/synchro.py | 1 + 4 files changed, 485 insertions(+), 3 deletions(-) create mode 100644 test/asynchronous/test_connection_monitoring.py diff --git a/test/asynchronous/test_connection_monitoring.py b/test/asynchronous/test_connection_monitoring.py new file mode 100644 index 0000000000..f1d0346d92 --- /dev/null +++ b/test/asynchronous/test_connection_monitoring.py @@ -0,0 +1,479 @@ +# Copyright 2019-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Execute Transactions Spec tests.""" +from __future__ import annotations + +import os +import sys +import time + +sys.path[0:0] = [""] + +from test.asynchronous import AsyncIntegrationTest, client_knobs, unittest +from test.asynchronous.utils_spec_runner import AsyncSpecTestCreator, SpecRunnerThread +from test.pymongo_mocks import DummyMonitor +from test.utils import ( + CMAPListener, + async_client_context, + async_get_pool, + async_get_pools, + async_wait_until, + camel_to_snake, +) + +from bson.objectid import ObjectId +from bson.son import SON +from pymongo.asynchronous.pool import PoolState, _PoolClosedError +from pymongo.errors import ( + ConnectionFailure, + OperationFailure, + PyMongoError, + WaitQueueTimeoutError, +) +from pymongo.monitoring import ( + ConnectionCheckedInEvent, + ConnectionCheckedOutEvent, + ConnectionCheckOutFailedEvent, + ConnectionCheckOutFailedReason, + ConnectionCheckOutStartedEvent, + ConnectionClosedEvent, + ConnectionClosedReason, + ConnectionCreatedEvent, + ConnectionReadyEvent, + PoolClearedEvent, + PoolClosedEvent, + PoolCreatedEvent, + PoolReadyEvent, +) +from pymongo.read_preferences import ReadPreference +from pymongo.topology_description import updated_topology_description + +_IS_SYNC = False + +OBJECT_TYPES = { + # Event types. + "ConnectionCheckedIn": ConnectionCheckedInEvent, + "ConnectionCheckedOut": ConnectionCheckedOutEvent, + "ConnectionCheckOutFailed": ConnectionCheckOutFailedEvent, + "ConnectionClosed": ConnectionClosedEvent, + "ConnectionCreated": ConnectionCreatedEvent, + "ConnectionReady": ConnectionReadyEvent, + "ConnectionCheckOutStarted": ConnectionCheckOutStartedEvent, + "ConnectionPoolCreated": PoolCreatedEvent, + "ConnectionPoolReady": PoolReadyEvent, + "ConnectionPoolCleared": PoolClearedEvent, + "ConnectionPoolClosed": PoolClosedEvent, + # Error types. + "PoolClosedError": _PoolClosedError, + "WaitQueueTimeoutError": WaitQueueTimeoutError, +} + + +class AsyncTestCMAP(AsyncIntegrationTest): + # Location of JSON test specifications. + TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "connection_monitoring") + + # Test operations: + + def start(self, op): + """Run the 'start' thread operation.""" + target = op["target"] + thread = SpecRunnerThread(target) + thread.start() + self.targets[target] = thread + + def wait(self, op): + """Run the 'wait' operation.""" + time.sleep(op["ms"] / 1000.0) + + def wait_for_thread(self, op): + """Run the 'waitForThread' operation.""" + target = op["target"] + thread = self.targets[target] + thread.stop() + thread.join() + if thread.exc: + raise thread.exc + self.assertFalse(thread.ops) + + async def wait_for_event(self, op): + """Run the 'waitForEvent' operation.""" + event = OBJECT_TYPES[op["event"]] + count = op["count"] + timeout = op.get("timeout", 10000) / 1000.0 + await async_wait_until( + lambda: self.listener.event_count(event) >= count, + f"find {count} {event} event(s)", + timeout=timeout, + ) + + def check_out(self, op): + """Run the 'checkOut' operation.""" + label = op["label"] + with self.pool.checkout() as conn: + # Call 'pin_cursor' so we can hold the socket. + conn.pin_cursor() + if label: + self.labels[label] = conn + else: + self.addAsyncCleanup(conn.aclose_conn, None) + + def check_in(self, op): + """Run the 'checkIn' operation.""" + label = op["connection"] + conn = self.labels[label] + self.pool.checkin(conn) + + def ready(self, op): + """Run the 'ready' operation.""" + self.pool.ready() + + def clear(self, op): + """Run the 'clear' operation.""" + if "interruptInUseAsyncConnections" in op: + self.pool.reset(interrupt_connections=op["interruptInUseAsyncConnections"]) + else: + self.pool.reset() + + async def close(self, op): + """Run the 'aclose' operation.""" + await self.pool.aclose() + + def run_operation(self, op): + """Run a single operation in a test.""" + op_name = camel_to_snake(op["name"]) + thread = op["thread"] + meth = getattr(self, op_name) + if thread: + self.targets[thread].schedule(lambda: meth(op)) + else: + meth(op) + + def run_operations(self, ops): + """Run a test's operations.""" + for op in ops: + self._ops.append(op) + self.run_operation(op) + + def check_object(self, actual, expected): + """Assert that the actual object matches the expected object.""" + self.assertEqual(type(actual), OBJECT_TYPES[expected["type"]]) + for attr, expected_val in expected.items(): + if attr == "type": + continue + c2s = camel_to_snake(attr) + if c2s == "interrupt_in_use_connections": + c2s = "interrupt_connections" + actual_val = getattr(actual, c2s) + if expected_val == 42: + self.assertIsNotNone(actual_val) + else: + self.assertEqual(actual_val, expected_val) + + def check_event(self, actual, expected): + """Assert that the actual event matches the expected event.""" + self.check_object(actual, expected) + + def actual_events(self, ignore): + """Return all the non-ignored events.""" + ignore = tuple(OBJECT_TYPES[name] for name in ignore) + return [event for event in self.listener.events if not isinstance(event, ignore)] + + def check_events(self, events, ignore): + """Check the events of a test.""" + actual_events = self.actual_events(ignore) + for actual, expected in zip(actual_events, events): + self.logs.append(f"Checking event actual: {actual!r} vs expected: {expected!r}") + self.check_event(actual, expected) + + if len(events) > len(actual_events): + self.fail(f"missing events: {events[len(actual_events) :]!r}") + + def check_error(self, actual, expected): + message = expected.pop("message") + self.check_object(actual, expected) + self.assertIn(message, str(actual)) + + async def _set_fail_point(self, client, command_args): + cmd = SON([("configureFailPoint", "failCommand")]) + cmd.update(command_args) + await client.admin.command(cmd) + + def set_fail_point(self, command_args): + if not async_client_context.supports_failCommand_fail_point: + self.skipTest("failCommand fail point must be supported") + self._set_fail_point(self.client, command_args) + + async def run_scenario(self, scenario_def, test): + """Run a CMAP spec test.""" + self.logs: list = [] + self.assertEqual(scenario_def["version"], 1) + self.assertIn(scenario_def["style"], ["unit", "integration"]) + self.listener = CMAPListener() + self._ops: list = [] + + # Configure the fail point before creating the client. + if "failPoint" in test: + fp = test["failPoint"] + self.set_fail_point(fp) + self.addAsyncCleanup( + self.set_fail_point, {"configureFailPoint": fp["configureFailPoint"], "mode": "off"} + ) + + opts = test["poolOptions"].copy() + opts["event_listeners"] = [self.listener] + opts["_monitor_class"] = DummyMonitor + opts["connect"] = False + # Support backgroundThreadIntervalMS, default to 50ms. + interval = opts.pop("backgroundThreadIntervalMS", 50) + if interval < 0: + kill_cursor_frequency = 99999999 + else: + kill_cursor_frequency = interval / 1000.0 + with client_knobs(kill_cursor_frequency=kill_cursor_frequency, min_heartbeat_interval=0.05): + client = await self.async_single_client(**opts) + # Update the SD to a known type because the DummyMonitor will not. + # Note we cannot simply call topology.on_change because that would + # internally call pool.ready() which introduces unexpected + # PoolReadyEvents. Instead, update the initial state before + # opening the Topology. + td = async_client_context.client._topology.description + sd = td.server_descriptions()[(async_client_context.host, async_client_context.port)] + client._topology._description = updated_topology_description( + client._topology._description, sd + ) + # When backgroundThreadIntervalMS is negative we do not start the + # background thread to ensure it never runs. + if interval < 0: + client._topology.open() + else: + client._get_topology() + self.addAsyncCleanup(client.close) + self.pool = list(client._topology._servers.values())[0].pool + + # Map of target names to Thread objects. + self.targets: dict = {} + # Map of label names to AsyncConnection objects + self.labels: dict = {} + + async def cleanup(): + for t in self.targets.values(): + t.stop() + for t in self.targets.values(): + t.join(5) + for conn in self.labels.values(): + await conn.aclose_conn(None) + + self.addAsyncCleanup(cleanup) + + try: + if test["error"]: + with self.assertRaises(PyMongoError) as ctx: + self.run_operations(test["operations"]) + self.check_error(ctx.exception, test["error"]) + else: + self.run_operations(test["operations"]) + + self.check_events(test["events"], test["ignore"]) + except Exception: + # Print the events after a test failure. + print("\nFailed test: {!r}".format(test["description"])) + print("Operations:") + for op in self._ops: + print(op) + print("Threads:") + print(self.targets) + print("AsyncConnections:") + print(self.labels) + print("Events:") + for event in self.listener.events: + print(event) + print("Log:") + for log in self.logs: + print(log) + raise + + POOL_OPTIONS = { + "maxPoolSize": 50, + "minPoolSize": 1, + "maxIdleTimeMS": 10000, + "waitQueueTimeoutMS": 10000, + } + + # + # Prose tests. Numbers correspond to the prose test number in the spec. + # + async def test_1_client_connection_pool_options(self): + client = await self.async_rs_or_single_client(**self.POOL_OPTIONS) + self.addAsyncCleanup(client.close) + pool_opts = (await async_get_pool(client)).opts + self.assertEqual(pool_opts.non_default_options, self.POOL_OPTIONS) + + async def test_2_all_client_pools_have_same_options(self): + client = await self.async_rs_or_single_client(**self.POOL_OPTIONS) + self.addAsyncCleanup(client.close) + await client.admin.command("ping") + # Discover at least one secondary. + if await async_client_context.has_secondaries: + await client.admin.command("ping", read_preference=ReadPreference.SECONDARY) + pools = await async_get_pools(client) + pool_opts = pools[0].opts + + self.assertEqual(pool_opts.non_default_options, self.POOL_OPTIONS) + for pool in pools[1:]: + self.assertEqual(pool.opts, pool_opts) + + async def test_3_uri_connection_pool_options(self): + opts = "&".join([f"{k}={v}" for k, v in self.POOL_OPTIONS.items()]) + uri = f"mongodb://{async_client_context.pair}/?{opts}" + client = await self.async_rs_or_single_client(uri) + self.addAsyncCleanup(client.close) + pool_opts = (await async_get_pool(client)).opts + self.assertEqual(pool_opts.non_default_options, self.POOL_OPTIONS) + + async def test_4_subscribe_to_events(self): + listener = CMAPListener() + client = await self.async_single_client(event_listeners=[listener]) + self.addAsyncCleanup(client.close) + self.assertEqual(listener.event_count(PoolCreatedEvent), 1) + + # Creates a new connection. + await client.admin.command("ping") + self.assertEqual(listener.event_count(ConnectionCheckOutStartedEvent), 1) + self.assertEqual(listener.event_count(ConnectionCreatedEvent), 1) + self.assertEqual(listener.event_count(ConnectionReadyEvent), 1) + self.assertEqual(listener.event_count(ConnectionCheckedOutEvent), 1) + self.assertEqual(listener.event_count(ConnectionCheckedInEvent), 1) + + # Uses the existing connection. + await client.admin.command("ping") + self.assertEqual(listener.event_count(ConnectionCheckOutStartedEvent), 2) + self.assertEqual(listener.event_count(ConnectionCheckedOutEvent), 2) + self.assertEqual(listener.event_count(ConnectionCheckedInEvent), 2) + + await client.close() + self.assertEqual(listener.event_count(PoolClosedEvent), 1) + self.assertEqual(listener.event_count(ConnectionClosedEvent), 1) + + async def test_5_check_out_fails_connection_error(self): + listener = CMAPListener() + client = await self.async_single_client(event_listeners=[listener]) + self.addAsyncCleanup(client.close) + pool = await async_get_pool(client) + + def mock_connect(*args, **kwargs): + raise ConnectionFailure("connect failed") + + pool.connect = mock_connect + # Un-patch Pool.connect to break the cyclic reference. + self.addAsyncCleanup(delattr, pool, "connect") + + # Attempt to create a new connection. + with self.assertRaisesRegex(ConnectionFailure, "connect failed"): + await client.admin.command("ping") + + self.assertIsInstance(listener.events[0], PoolCreatedEvent) + self.assertIsInstance(listener.events[1], PoolReadyEvent) + self.assertIsInstance(listener.events[2], ConnectionCheckOutStartedEvent) + self.assertIsInstance(listener.events[3], ConnectionCheckOutFailedEvent) + self.assertIsInstance(listener.events[4], PoolClearedEvent) + + failed_event = listener.events[3] + self.assertEqual(failed_event.reason, ConnectionCheckOutFailedReason.CONN_ERROR) + + @async_client_context.require_no_fips + async def test_5_check_out_fails_auth_error(self): + listener = CMAPListener() + client = await self.async_single_client_noauth( + username="notauser", password="fail", event_listeners=[listener] + ) + self.addAsyncCleanup(client.close) + + # Attempt to create a new connection. + with self.assertRaisesRegex(OperationFailure, "failed"): + await client.admin.command("ping") + + self.assertIsInstance(listener.events[0], PoolCreatedEvent) + self.assertIsInstance(listener.events[1], PoolReadyEvent) + self.assertIsInstance(listener.events[2], ConnectionCheckOutStartedEvent) + self.assertIsInstance(listener.events[3], ConnectionCreatedEvent) + # Error happens here. + self.assertIsInstance(listener.events[4], ConnectionClosedEvent) + self.assertIsInstance(listener.events[5], ConnectionCheckOutFailedEvent) + self.assertEqual(listener.events[5].reason, ConnectionCheckOutFailedReason.CONN_ERROR) + + # + # Extra non-spec tests + # + def assertRepr(self, obj): + new_obj = eval(repr(obj)) + self.assertEqual(type(new_obj), type(obj)) + self.assertEqual(repr(new_obj), repr(obj)) + + async def test_events_repr(self): + host = ("localhost", 27017) + self.assertRepr(ConnectionCheckedInEvent(host, 1)) + self.assertRepr(ConnectionCheckedOutEvent(host, 1, time.monotonic())) + self.assertRepr( + ConnectionCheckOutFailedEvent( + host, ConnectionCheckOutFailedReason.POOL_CLOSED, time.monotonic() + ) + ) + self.assertRepr(ConnectionClosedEvent(host, 1, ConnectionClosedReason.POOL_CLOSED)) + self.assertRepr(ConnectionCreatedEvent(host, 1)) + self.assertRepr(ConnectionReadyEvent(host, 1, time.monotonic())) + self.assertRepr(ConnectionCheckOutStartedEvent(host)) + self.assertRepr(PoolCreatedEvent(host, {})) + self.assertRepr(PoolClearedEvent(host)) + self.assertRepr(PoolClearedEvent(host, service_id=ObjectId())) + self.assertRepr(PoolClosedEvent(host)) + + async def test_close_leaves_pool_unpaused(self): + listener = CMAPListener() + client = await self.async_single_client(event_listeners=[listener]) + await client.admin.command("ping") + pool = await async_get_pool(client) + await client.close() + self.assertEqual(1, listener.event_count(PoolClosedEvent)) + self.assertEqual(PoolState.CLOSED, pool.state) + # Checking out a connection should fail + with self.assertRaises(_PoolClosedError): + async with pool.checkout(): + pass + + +def create_test(scenario_def, test, name): + def run_scenario(self): + self.run_scenario(scenario_def, test) + + return run_scenario + + +class CMAPSpecTestCreator(AsyncSpecTestCreator): + async def tests(self, scenario_def): + """Extract the tests from a spec file. + + CMAP tests do not have a 'tests' field. The whole file represents + a single test case. + """ + return [scenario_def] + + +test_creator = CMAPSpecTestCreator(create_test, AsyncTestCMAP, AsyncTestCMAP.TEST_PATH) +test_creator.create_tests() + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_connection_monitoring.py b/test/test_connection_monitoring.py index 05411d17ba..7b95bb2615 100644 --- a/test/test_connection_monitoring.py +++ b/test/test_connection_monitoring.py @@ -60,6 +60,8 @@ from pymongo.synchronous.pool import PoolState, _PoolClosedError from pymongo.topology_description import updated_topology_description +_IS_SYNC = True + OBJECT_TYPES = { # Event types. "ConnectionCheckedIn": ConnectionCheckedInEvent, @@ -316,7 +318,7 @@ def cleanup(): def test_1_client_connection_pool_options(self): client = self.rs_or_single_client(**self.POOL_OPTIONS) self.addCleanup(client.close) - pool_opts = get_pool(client).opts + pool_opts = (get_pool(client)).opts self.assertEqual(pool_opts.non_default_options, self.POOL_OPTIONS) def test_2_all_client_pools_have_same_options(self): @@ -338,7 +340,7 @@ def test_3_uri_connection_pool_options(self): uri = f"mongodb://{client_context.pair}/?{opts}" client = self.rs_or_single_client(uri) self.addCleanup(client.close) - pool_opts = get_pool(client).opts + pool_opts = (get_pool(client)).opts self.assertEqual(pool_opts.non_default_options, self.POOL_OPTIONS) def test_4_subscribe_to_events(self): diff --git a/test/utils.py b/test/utils.py index 69154bc63b..9901575d57 100644 --- a/test/utils.py +++ b/test/utils.py @@ -769,7 +769,7 @@ async def async_get_pools(client): """Get all pools.""" return [ server.pool - async for server in await (await client._get_topology()).select_servers( + for server in await (await client._get_topology()).select_servers( any_server_selector, _Op.TEST ) ] diff --git a/tools/synchro.py b/tools/synchro.py index dbcbbd1351..971992257e 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -198,6 +198,7 @@ def async_only_test(f: str) -> bool: "test_comment.py", "test_common.py", "test_connection_logging.py", + "test_connection_monitoring.py", "test_connections_survive_primary_stepdown_spec.py", "test_create_entities.py", "test_crud_unified.py", From fadef89c771df33fbbd29c3f738f7644a67970eb Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Tue, 28 Jan 2025 15:39:54 -0800 Subject: [PATCH 2/7] fix --- test/asynchronous/test_connection_monitoring.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/asynchronous/test_connection_monitoring.py b/test/asynchronous/test_connection_monitoring.py index f1d0346d92..85e322a11c 100644 --- a/test/asynchronous/test_connection_monitoring.py +++ b/test/asynchronous/test_connection_monitoring.py @@ -337,7 +337,7 @@ async def test_2_all_client_pools_have_same_options(self): async def test_3_uri_connection_pool_options(self): opts = "&".join([f"{k}={v}" for k, v in self.POOL_OPTIONS.items()]) - uri = f"mongodb://{async_client_context.pair}/?{opts}" + uri = f"mongodb://{await async_client_context.pair}/?{opts}" client = await self.async_rs_or_single_client(uri) self.addAsyncCleanup(client.close) pool_opts = (await async_get_pool(client)).opts @@ -378,7 +378,7 @@ def mock_connect(*args, **kwargs): pool.connect = mock_connect # Un-patch Pool.connect to break the cyclic reference. - self.addAsyncCleanup(delattr, pool, "connect") + self.addCleanup(delattr, pool, "connect") # Attempt to create a new connection. with self.assertRaisesRegex(ConnectionFailure, "connect failed"): From f0bb7c8cccb45883f973bb6f82168f191b6fe03b Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Wed, 29 Jan 2025 09:34:24 -0800 Subject: [PATCH 3/7] address review comments --- .../test_connection_monitoring.py | 22 +++++++++---------- test/test_connection_monitoring.py | 20 ++++++++--------- 2 files changed, 21 insertions(+), 21 deletions(-) diff --git a/test/asynchronous/test_connection_monitoring.py b/test/asynchronous/test_connection_monitoring.py index 85e322a11c..c90871b8e1 100644 --- a/test/asynchronous/test_connection_monitoring.py +++ b/test/asynchronous/test_connection_monitoring.py @@ -16,6 +16,7 @@ from __future__ import annotations import os +import pathlib import sys import time @@ -83,7 +84,12 @@ class AsyncTestCMAP(AsyncIntegrationTest): # Location of JSON test specifications. - TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "connection_monitoring") + if _IS_SYNC: + _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "connection_monitoring") + else: + _TEST_PATH = os.path.join( + pathlib.Path(__file__).resolve().parent.parent, "connection_monitoring" + ) # Test operations: @@ -128,7 +134,7 @@ def check_out(self, op): if label: self.labels[label] = conn else: - self.addAsyncCleanup(conn.aclose_conn, None) + self.addAsyncCleanup(conn.close_conn, None) def check_in(self, op): """Run the 'checkIn' operation.""" @@ -260,7 +266,6 @@ async def run_scenario(self, scenario_def, test): client._topology.open() else: client._get_topology() - self.addAsyncCleanup(client.close) self.pool = list(client._topology._servers.values())[0].pool # Map of target names to Thread objects. @@ -317,13 +322,11 @@ async def cleanup(): # async def test_1_client_connection_pool_options(self): client = await self.async_rs_or_single_client(**self.POOL_OPTIONS) - self.addAsyncCleanup(client.close) pool_opts = (await async_get_pool(client)).opts self.assertEqual(pool_opts.non_default_options, self.POOL_OPTIONS) async def test_2_all_client_pools_have_same_options(self): client = await self.async_rs_or_single_client(**self.POOL_OPTIONS) - self.addAsyncCleanup(client.close) await client.admin.command("ping") # Discover at least one secondary. if await async_client_context.has_secondaries: @@ -339,14 +342,12 @@ async def test_3_uri_connection_pool_options(self): opts = "&".join([f"{k}={v}" for k, v in self.POOL_OPTIONS.items()]) uri = f"mongodb://{await async_client_context.pair}/?{opts}" client = await self.async_rs_or_single_client(uri) - self.addAsyncCleanup(client.close) pool_opts = (await async_get_pool(client)).opts self.assertEqual(pool_opts.non_default_options, self.POOL_OPTIONS) async def test_4_subscribe_to_events(self): listener = CMAPListener() client = await self.async_single_client(event_listeners=[listener]) - self.addAsyncCleanup(client.close) self.assertEqual(listener.event_count(PoolCreatedEvent), 1) # Creates a new connection. @@ -370,7 +371,6 @@ async def test_4_subscribe_to_events(self): async def test_5_check_out_fails_connection_error(self): listener = CMAPListener() client = await self.async_single_client(event_listeners=[listener]) - self.addAsyncCleanup(client.close) pool = await async_get_pool(client) def mock_connect(*args, **kwargs): @@ -399,7 +399,6 @@ async def test_5_check_out_fails_auth_error(self): client = await self.async_single_client_noauth( username="notauser", password="fail", event_listeners=[listener] ) - self.addAsyncCleanup(client.close) # Attempt to create a new connection. with self.assertRaisesRegex(OperationFailure, "failed"): @@ -471,8 +470,9 @@ async def tests(self, scenario_def): return [scenario_def] -test_creator = CMAPSpecTestCreator(create_test, AsyncTestCMAP, AsyncTestCMAP.TEST_PATH) -test_creator.create_tests() +if _IS_SYNC: + test_creator = CMAPSpecTestCreator(create_test, AsyncTestCMAP, AsyncTestCMAP.TEST_PATH) + test_creator.create_tests() if __name__ == "__main__": diff --git a/test/test_connection_monitoring.py b/test/test_connection_monitoring.py index 7b95bb2615..f057ad2889 100644 --- a/test/test_connection_monitoring.py +++ b/test/test_connection_monitoring.py @@ -16,6 +16,7 @@ from __future__ import annotations import os +import pathlib import sys import time @@ -83,7 +84,12 @@ class TestCMAP(IntegrationTest): # Location of JSON test specifications. - TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "connection_monitoring") + if _IS_SYNC: + _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "connection_monitoring") + else: + _TEST_PATH = os.path.join( + pathlib.Path(__file__).resolve().parent.parent, "connection_monitoring" + ) # Test operations: @@ -260,7 +266,6 @@ def run_scenario(self, scenario_def, test): client._topology.open() else: client._get_topology() - self.addCleanup(client.close) self.pool = list(client._topology._servers.values())[0].pool # Map of target names to Thread objects. @@ -317,13 +322,11 @@ def cleanup(): # def test_1_client_connection_pool_options(self): client = self.rs_or_single_client(**self.POOL_OPTIONS) - self.addCleanup(client.close) pool_opts = (get_pool(client)).opts self.assertEqual(pool_opts.non_default_options, self.POOL_OPTIONS) def test_2_all_client_pools_have_same_options(self): client = self.rs_or_single_client(**self.POOL_OPTIONS) - self.addCleanup(client.close) client.admin.command("ping") # Discover at least one secondary. if client_context.has_secondaries: @@ -339,14 +342,12 @@ def test_3_uri_connection_pool_options(self): opts = "&".join([f"{k}={v}" for k, v in self.POOL_OPTIONS.items()]) uri = f"mongodb://{client_context.pair}/?{opts}" client = self.rs_or_single_client(uri) - self.addCleanup(client.close) pool_opts = (get_pool(client)).opts self.assertEqual(pool_opts.non_default_options, self.POOL_OPTIONS) def test_4_subscribe_to_events(self): listener = CMAPListener() client = self.single_client(event_listeners=[listener]) - self.addCleanup(client.close) self.assertEqual(listener.event_count(PoolCreatedEvent), 1) # Creates a new connection. @@ -370,7 +371,6 @@ def test_4_subscribe_to_events(self): def test_5_check_out_fails_connection_error(self): listener = CMAPListener() client = self.single_client(event_listeners=[listener]) - self.addCleanup(client.close) pool = get_pool(client) def mock_connect(*args, **kwargs): @@ -399,7 +399,6 @@ def test_5_check_out_fails_auth_error(self): client = self.single_client_noauth( username="notauser", password="fail", event_listeners=[listener] ) - self.addCleanup(client.close) # Attempt to create a new connection. with self.assertRaisesRegex(OperationFailure, "failed"): @@ -471,8 +470,9 @@ def tests(self, scenario_def): return [scenario_def] -test_creator = CMAPSpecTestCreator(create_test, TestCMAP, TestCMAP.TEST_PATH) -test_creator.create_tests() +if _IS_SYNC: + test_creator = CMAPSpecTestCreator(create_test, TestCMAP, TestCMAP.TEST_PATH) + test_creator.create_tests() if __name__ == "__main__": From 6a6577a40a4c10a172b76573a2702be9df32ceaf Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Wed, 29 Jan 2025 09:47:54 -0800 Subject: [PATCH 4/7] don't mind me being silly --- test/asynchronous/test_connection_monitoring.py | 4 ++-- test/test_connection_monitoring.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/test/asynchronous/test_connection_monitoring.py b/test/asynchronous/test_connection_monitoring.py index c90871b8e1..f973725932 100644 --- a/test/asynchronous/test_connection_monitoring.py +++ b/test/asynchronous/test_connection_monitoring.py @@ -85,9 +85,9 @@ class AsyncTestCMAP(AsyncIntegrationTest): # Location of JSON test specifications. if _IS_SYNC: - _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "connection_monitoring") + TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "connection_monitoring") else: - _TEST_PATH = os.path.join( + TEST_PATH = os.path.join( pathlib.Path(__file__).resolve().parent.parent, "connection_monitoring" ) diff --git a/test/test_connection_monitoring.py b/test/test_connection_monitoring.py index f057ad2889..88b5119d23 100644 --- a/test/test_connection_monitoring.py +++ b/test/test_connection_monitoring.py @@ -85,9 +85,9 @@ class TestCMAP(IntegrationTest): # Location of JSON test specifications. if _IS_SYNC: - _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "connection_monitoring") + TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "connection_monitoring") else: - _TEST_PATH = os.path.join( + TEST_PATH = os.path.join( pathlib.Path(__file__).resolve().parent.parent, "connection_monitoring" ) From 1dbe1f9862df793af0d9aedbb765eaefdb2f5d70 Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Wed, 29 Jan 2025 11:24:07 -0800 Subject: [PATCH 5/7] change import --- test/asynchronous/test_connection_monitoring.py | 8 +++----- test/test_connection_monitoring.py | 8 +++----- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/test/asynchronous/test_connection_monitoring.py b/test/asynchronous/test_connection_monitoring.py index f973725932..7234c6a2b7 100644 --- a/test/asynchronous/test_connection_monitoring.py +++ b/test/asynchronous/test_connection_monitoring.py @@ -16,9 +16,9 @@ from __future__ import annotations import os -import pathlib import sys import time +from pathlib import Path sys.path[0:0] = [""] @@ -85,11 +85,9 @@ class AsyncTestCMAP(AsyncIntegrationTest): # Location of JSON test specifications. if _IS_SYNC: - TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "connection_monitoring") + TEST_PATH = os.path.join(Path(__file__).resolve().parent, "connection_monitoring") else: - TEST_PATH = os.path.join( - pathlib.Path(__file__).resolve().parent.parent, "connection_monitoring" - ) + TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "connection_monitoring") # Test operations: diff --git a/test/test_connection_monitoring.py b/test/test_connection_monitoring.py index 88b5119d23..7ca843a0da 100644 --- a/test/test_connection_monitoring.py +++ b/test/test_connection_monitoring.py @@ -16,9 +16,9 @@ from __future__ import annotations import os -import pathlib import sys import time +from pathlib import Path sys.path[0:0] = [""] @@ -85,11 +85,9 @@ class TestCMAP(IntegrationTest): # Location of JSON test specifications. if _IS_SYNC: - TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "connection_monitoring") + TEST_PATH = os.path.join(Path(__file__).resolve().parent, "connection_monitoring") else: - TEST_PATH = os.path.join( - pathlib.Path(__file__).resolve().parent.parent, "connection_monitoring" - ) + TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "connection_monitoring") # Test operations: From 0e5df26bc542cc2d08c23205bd3171502782d1c0 Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Wed, 5 Feb 2025 11:22:12 -0800 Subject: [PATCH 6/7] run all tests with SpecRunnerTask --- test/asynchronous/helpers.py | 2 +- test/asynchronous/pymongo_mocks.py | 4 +- .../test_connection_monitoring.py | 88 ++++++++++--------- test/asynchronous/utils_spec_runner.py | 4 +- test/helpers.py | 2 +- test/test_connection_monitoring.py | 6 +- 6 files changed, 54 insertions(+), 52 deletions(-) diff --git a/test/asynchronous/helpers.py b/test/asynchronous/helpers.py index 7758f281e1..2ef16adb67 100644 --- a/test/asynchronous/helpers.py +++ b/test/asynchronous/helpers.py @@ -395,7 +395,7 @@ def __init__(self, name, *args, **kwargs): async def start(self): self.task = create_task(self.run(), name=self.name) - async def join(self, timeout: float | None = 0): # type: ignore[override] + async def join(self, timeout: float | None = None): # type: ignore[override] if self.task is not None: await asyncio.wait([self.task], timeout=timeout) diff --git a/test/asynchronous/pymongo_mocks.py b/test/asynchronous/pymongo_mocks.py index ed2395bc98..40beb3c0dc 100644 --- a/test/asynchronous/pymongo_mocks.py +++ b/test/asynchronous/pymongo_mocks.py @@ -66,7 +66,7 @@ def __init__(self, server_description, topology, pool, topology_settings): def cancel_check(self): pass - def join(self): + async def join(self): pass def open(self): @@ -75,7 +75,7 @@ def open(self): def request_check(self): pass - def close(self): + async def close(self): self.opened = False diff --git a/test/asynchronous/test_connection_monitoring.py b/test/asynchronous/test_connection_monitoring.py index 7234c6a2b7..289617a3e4 100644 --- a/test/asynchronous/test_connection_monitoring.py +++ b/test/asynchronous/test_connection_monitoring.py @@ -15,6 +15,7 @@ """Execute Transactions Spec tests.""" from __future__ import annotations +import asyncio import os import sys import time @@ -23,8 +24,8 @@ sys.path[0:0] = [""] from test.asynchronous import AsyncIntegrationTest, client_knobs, unittest -from test.asynchronous.utils_spec_runner import AsyncSpecTestCreator, SpecRunnerThread -from test.pymongo_mocks import DummyMonitor +from test.asynchronous.pymongo_mocks import DummyMonitor +from test.asynchronous.utils_spec_runner import AsyncSpecTestCreator, SpecRunnerTask from test.utils import ( CMAPListener, async_client_context, @@ -91,23 +92,23 @@ class AsyncTestCMAP(AsyncIntegrationTest): # Test operations: - def start(self, op): + async def start(self, op): """Run the 'start' thread operation.""" target = op["target"] - thread = SpecRunnerThread(target) - thread.start() + thread = SpecRunnerTask(target) + await thread.start() self.targets[target] = thread - def wait(self, op): + async def wait(self, op): """Run the 'wait' operation.""" - time.sleep(op["ms"] / 1000.0) + await asyncio.sleep(op["ms"] / 1000.0) - def wait_for_thread(self, op): + async def wait_for_thread(self, op): """Run the 'waitForThread' operation.""" target = op["target"] thread = self.targets[target] - thread.stop() - thread.join() + await thread.stop() + await thread.join() if thread.exc: raise thread.exc self.assertFalse(thread.ops) @@ -123,10 +124,10 @@ async def wait_for_event(self, op): timeout=timeout, ) - def check_out(self, op): + async def check_out(self, op): """Run the 'checkOut' operation.""" label = op["label"] - with self.pool.checkout() as conn: + async with self.pool.checkout() as conn: # Call 'pin_cursor' so we can hold the socket. conn.pin_cursor() if label: @@ -134,42 +135,42 @@ def check_out(self, op): else: self.addAsyncCleanup(conn.close_conn, None) - def check_in(self, op): + async def check_in(self, op): """Run the 'checkIn' operation.""" label = op["connection"] conn = self.labels[label] - self.pool.checkin(conn) + await self.pool.checkin(conn) - def ready(self, op): + async def ready(self, op): """Run the 'ready' operation.""" - self.pool.ready() + await self.pool.ready() - def clear(self, op): + async def clear(self, op): """Run the 'clear' operation.""" - if "interruptInUseAsyncConnections" in op: - self.pool.reset(interrupt_connections=op["interruptInUseAsyncConnections"]) + if "interruptInUseConnections" in op: + await self.pool.reset(interrupt_connections=op["interruptInUseConnections"]) else: - self.pool.reset() + await self.pool.reset() async def close(self, op): - """Run the 'aclose' operation.""" - await self.pool.aclose() + """Run the 'close' operation.""" + await self.pool.close() - def run_operation(self, op): + async def run_operation(self, op): """Run a single operation in a test.""" op_name = camel_to_snake(op["name"]) thread = op["thread"] meth = getattr(self, op_name) if thread: - self.targets[thread].schedule(lambda: meth(op)) + await self.targets[thread].schedule(lambda: meth(op)) else: - meth(op) + await meth(op) - def run_operations(self, ops): + async def run_operations(self, ops): """Run a test's operations.""" for op in ops: self._ops.append(op) - self.run_operation(op) + await self.run_operation(op) def check_object(self, actual, expected): """Assert that the actual object matches the expected object.""" @@ -215,10 +216,10 @@ async def _set_fail_point(self, client, command_args): cmd.update(command_args) await client.admin.command(cmd) - def set_fail_point(self, command_args): + async def set_fail_point(self, command_args): if not async_client_context.supports_failCommand_fail_point: self.skipTest("failCommand fail point must be supported") - self._set_fail_point(self.client, command_args) + await self._set_fail_point(self.client, command_args) async def run_scenario(self, scenario_def, test): """Run a CMAP spec test.""" @@ -231,7 +232,7 @@ async def run_scenario(self, scenario_def, test): # Configure the fail point before creating the client. if "failPoint" in test: fp = test["failPoint"] - self.set_fail_point(fp) + await self.set_fail_point(fp) self.addAsyncCleanup( self.set_fail_point, {"configureFailPoint": fp["configureFailPoint"], "mode": "off"} ) @@ -254,16 +255,18 @@ async def run_scenario(self, scenario_def, test): # PoolReadyEvents. Instead, update the initial state before # opening the Topology. td = async_client_context.client._topology.description - sd = td.server_descriptions()[(async_client_context.host, async_client_context.port)] + sd = td.server_descriptions()[ + (await async_client_context.host, await async_client_context.port) + ] client._topology._description = updated_topology_description( client._topology._description, sd ) # When backgroundThreadIntervalMS is negative we do not start the # background thread to ensure it never runs. if interval < 0: - client._topology.open() + await client._topology.open() else: - client._get_topology() + await client._get_topology() self.pool = list(client._topology._servers.values())[0].pool # Map of target names to Thread objects. @@ -273,21 +276,21 @@ async def run_scenario(self, scenario_def, test): async def cleanup(): for t in self.targets.values(): - t.stop() + await t.stop() for t in self.targets.values(): - t.join(5) + await t.join(5) for conn in self.labels.values(): - await conn.aclose_conn(None) + conn.close_conn(None) self.addAsyncCleanup(cleanup) try: if test["error"]: with self.assertRaises(PyMongoError) as ctx: - self.run_operations(test["operations"]) + await self.run_operations(test["operations"]) self.check_error(ctx.exception, test["error"]) else: - self.run_operations(test["operations"]) + await self.run_operations(test["operations"]) self.check_events(test["events"], test["ignore"]) except Exception: @@ -452,8 +455,8 @@ async def test_close_leaves_pool_unpaused(self): def create_test(scenario_def, test, name): - def run_scenario(self): - self.run_scenario(scenario_def, test) + async def run_scenario(self): + await self.run_scenario(scenario_def, test) return run_scenario @@ -468,9 +471,8 @@ async def tests(self, scenario_def): return [scenario_def] -if _IS_SYNC: - test_creator = CMAPSpecTestCreator(create_test, AsyncTestCMAP, AsyncTestCMAP.TEST_PATH) - test_creator.create_tests() +test_creator = CMAPSpecTestCreator(create_test, AsyncTestCMAP, AsyncTestCMAP.TEST_PATH) +test_creator.create_tests() if __name__ == "__main__": diff --git a/test/asynchronous/utils_spec_runner.py b/test/asynchronous/utils_spec_runner.py index d103374313..8d043285d0 100644 --- a/test/asynchronous/utils_spec_runner.py +++ b/test/asynchronous/utils_spec_runner.py @@ -222,14 +222,14 @@ async def _create_tests(self): test_type = os.path.splitext(filename)[0] # Construct test from scenario. - for test_def in self.tests(scenario_def): + for test_def in await self.tests(scenario_def): test_name = "test_{}_{}_{}".format( dirname, test_type.replace("-", "_").replace(".", "_"), str(test_def["description"].replace(" ", "_").replace(".", "_")), ) - new_test = await self._create_test(scenario_def, test_def, test_name) + new_test = self._create_test(scenario_def, test_def, test_name) new_test = self._ensure_min_max_server_version(scenario_def, new_test) new_test = self.ensure_run_on(scenario_def, new_test) diff --git a/test/helpers.py b/test/helpers.py index bd9e23bba4..770da539d2 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -395,7 +395,7 @@ def __init__(self, name, *args, **kwargs): def start(self): self.task = create_task(self.run(), name=self.name) - def join(self, timeout: float | None = 0): # type: ignore[override] + def join(self, timeout: float | None = None): # type: ignore[override] if self.task is not None: asyncio.wait([self.task], timeout=timeout) diff --git a/test/test_connection_monitoring.py b/test/test_connection_monitoring.py index 7ca843a0da..810d440932 100644 --- a/test/test_connection_monitoring.py +++ b/test/test_connection_monitoring.py @@ -15,6 +15,7 @@ """Execute Transactions Spec tests.""" from __future__ import annotations +import asyncio import os import sys import time @@ -468,9 +469,8 @@ def tests(self, scenario_def): return [scenario_def] -if _IS_SYNC: - test_creator = CMAPSpecTestCreator(create_test, TestCMAP, TestCMAP.TEST_PATH) - test_creator.create_tests() +test_creator = CMAPSpecTestCreator(create_test, TestCMAP, TestCMAP.TEST_PATH) +test_creator.create_tests() if __name__ == "__main__": From 8cd9f1ae34e0d36ade6fabcd9e98db7e7f18dc9b Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Thu, 6 Feb 2025 10:15:29 -0800 Subject: [PATCH 7/7] fix tests --- test/asynchronous/test_connection_monitoring.py | 2 +- test/asynchronous/utils_spec_runner.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/asynchronous/test_connection_monitoring.py b/test/asynchronous/test_connection_monitoring.py index 289617a3e4..a68b2a90cb 100644 --- a/test/asynchronous/test_connection_monitoring.py +++ b/test/asynchronous/test_connection_monitoring.py @@ -462,7 +462,7 @@ async def run_scenario(self): class CMAPSpecTestCreator(AsyncSpecTestCreator): - async def tests(self, scenario_def): + def tests(self, scenario_def): """Extract the tests from a spec file. CMAP tests do not have a 'tests' field. The whole file represents diff --git a/test/asynchronous/utils_spec_runner.py b/test/asynchronous/utils_spec_runner.py index 5ef1f8b255..11d88850fc 100644 --- a/test/asynchronous/utils_spec_runner.py +++ b/test/asynchronous/utils_spec_runner.py @@ -222,7 +222,7 @@ async def _create_tests(self): test_type = os.path.splitext(filename)[0] # Construct test from scenario. - for test_def in await self.tests(scenario_def): + for test_def in self.tests(scenario_def): test_name = "test_{}_{}_{}".format( dirname, test_type.replace("-", "_").replace(".", "_"),