diff --git a/test/asynchronous/pymongo_mocks.py b/test/asynchronous/pymongo_mocks.py index ed2395bc98..40beb3c0dc 100644 --- a/test/asynchronous/pymongo_mocks.py +++ b/test/asynchronous/pymongo_mocks.py @@ -66,7 +66,7 @@ def __init__(self, server_description, topology, pool, topology_settings): def cancel_check(self): pass - def join(self): + async def join(self): pass def open(self): @@ -75,7 +75,7 @@ def open(self): def request_check(self): pass - def close(self): + async def close(self): self.opened = False diff --git a/test/asynchronous/test_connection_monitoring.py b/test/asynchronous/test_connection_monitoring.py new file mode 100644 index 0000000000..a68b2a90cb --- /dev/null +++ b/test/asynchronous/test_connection_monitoring.py @@ -0,0 +1,479 @@ +# Copyright 2019-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Execute Transactions Spec tests.""" +from __future__ import annotations + +import asyncio +import os +import sys +import time +from pathlib import Path + +sys.path[0:0] = [""] + +from test.asynchronous import AsyncIntegrationTest, client_knobs, unittest +from test.asynchronous.pymongo_mocks import DummyMonitor +from test.asynchronous.utils_spec_runner import AsyncSpecTestCreator, SpecRunnerTask +from test.utils import ( + CMAPListener, + async_client_context, + async_get_pool, + async_get_pools, + async_wait_until, + camel_to_snake, +) + +from bson.objectid import ObjectId +from bson.son import SON +from pymongo.asynchronous.pool import PoolState, _PoolClosedError +from pymongo.errors import ( + ConnectionFailure, + OperationFailure, + PyMongoError, + WaitQueueTimeoutError, +) +from pymongo.monitoring import ( + ConnectionCheckedInEvent, + ConnectionCheckedOutEvent, + ConnectionCheckOutFailedEvent, + ConnectionCheckOutFailedReason, + ConnectionCheckOutStartedEvent, + ConnectionClosedEvent, + ConnectionClosedReason, + ConnectionCreatedEvent, + ConnectionReadyEvent, + PoolClearedEvent, + PoolClosedEvent, + PoolCreatedEvent, + PoolReadyEvent, +) +from pymongo.read_preferences import ReadPreference +from pymongo.topology_description import updated_topology_description + +_IS_SYNC = False + +OBJECT_TYPES = { + # Event types. + "ConnectionCheckedIn": ConnectionCheckedInEvent, + "ConnectionCheckedOut": ConnectionCheckedOutEvent, + "ConnectionCheckOutFailed": ConnectionCheckOutFailedEvent, + "ConnectionClosed": ConnectionClosedEvent, + "ConnectionCreated": ConnectionCreatedEvent, + "ConnectionReady": ConnectionReadyEvent, + "ConnectionCheckOutStarted": ConnectionCheckOutStartedEvent, + "ConnectionPoolCreated": PoolCreatedEvent, + "ConnectionPoolReady": PoolReadyEvent, + "ConnectionPoolCleared": PoolClearedEvent, + "ConnectionPoolClosed": PoolClosedEvent, + # Error types. + "PoolClosedError": _PoolClosedError, + "WaitQueueTimeoutError": WaitQueueTimeoutError, +} + + +class AsyncTestCMAP(AsyncIntegrationTest): + # Location of JSON test specifications. + if _IS_SYNC: + TEST_PATH = os.path.join(Path(__file__).resolve().parent, "connection_monitoring") + else: + TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "connection_monitoring") + + # Test operations: + + async def start(self, op): + """Run the 'start' thread operation.""" + target = op["target"] + thread = SpecRunnerTask(target) + await thread.start() + self.targets[target] = thread + + async def wait(self, op): + """Run the 'wait' operation.""" + await asyncio.sleep(op["ms"] / 1000.0) + + async def wait_for_thread(self, op): + """Run the 'waitForThread' operation.""" + target = op["target"] + thread = self.targets[target] + await thread.stop() + await thread.join() + if thread.exc: + raise thread.exc + self.assertFalse(thread.ops) + + async def wait_for_event(self, op): + """Run the 'waitForEvent' operation.""" + event = OBJECT_TYPES[op["event"]] + count = op["count"] + timeout = op.get("timeout", 10000) / 1000.0 + await async_wait_until( + lambda: self.listener.event_count(event) >= count, + f"find {count} {event} event(s)", + timeout=timeout, + ) + + async def check_out(self, op): + """Run the 'checkOut' operation.""" + label = op["label"] + async with self.pool.checkout() as conn: + # Call 'pin_cursor' so we can hold the socket. + conn.pin_cursor() + if label: + self.labels[label] = conn + else: + self.addAsyncCleanup(conn.close_conn, None) + + async def check_in(self, op): + """Run the 'checkIn' operation.""" + label = op["connection"] + conn = self.labels[label] + await self.pool.checkin(conn) + + async def ready(self, op): + """Run the 'ready' operation.""" + await self.pool.ready() + + async def clear(self, op): + """Run the 'clear' operation.""" + if "interruptInUseConnections" in op: + await self.pool.reset(interrupt_connections=op["interruptInUseConnections"]) + else: + await self.pool.reset() + + async def close(self, op): + """Run the 'close' operation.""" + await self.pool.close() + + async def run_operation(self, op): + """Run a single operation in a test.""" + op_name = camel_to_snake(op["name"]) + thread = op["thread"] + meth = getattr(self, op_name) + if thread: + await self.targets[thread].schedule(lambda: meth(op)) + else: + await meth(op) + + async def run_operations(self, ops): + """Run a test's operations.""" + for op in ops: + self._ops.append(op) + await self.run_operation(op) + + def check_object(self, actual, expected): + """Assert that the actual object matches the expected object.""" + self.assertEqual(type(actual), OBJECT_TYPES[expected["type"]]) + for attr, expected_val in expected.items(): + if attr == "type": + continue + c2s = camel_to_snake(attr) + if c2s == "interrupt_in_use_connections": + c2s = "interrupt_connections" + actual_val = getattr(actual, c2s) + if expected_val == 42: + self.assertIsNotNone(actual_val) + else: + self.assertEqual(actual_val, expected_val) + + def check_event(self, actual, expected): + """Assert that the actual event matches the expected event.""" + self.check_object(actual, expected) + + def actual_events(self, ignore): + """Return all the non-ignored events.""" + ignore = tuple(OBJECT_TYPES[name] for name in ignore) + return [event for event in self.listener.events if not isinstance(event, ignore)] + + def check_events(self, events, ignore): + """Check the events of a test.""" + actual_events = self.actual_events(ignore) + for actual, expected in zip(actual_events, events): + self.logs.append(f"Checking event actual: {actual!r} vs expected: {expected!r}") + self.check_event(actual, expected) + + if len(events) > len(actual_events): + self.fail(f"missing events: {events[len(actual_events) :]!r}") + + def check_error(self, actual, expected): + message = expected.pop("message") + self.check_object(actual, expected) + self.assertIn(message, str(actual)) + + async def _set_fail_point(self, client, command_args): + cmd = SON([("configureFailPoint", "failCommand")]) + cmd.update(command_args) + await client.admin.command(cmd) + + async def set_fail_point(self, command_args): + if not async_client_context.supports_failCommand_fail_point: + self.skipTest("failCommand fail point must be supported") + await self._set_fail_point(self.client, command_args) + + async def run_scenario(self, scenario_def, test): + """Run a CMAP spec test.""" + self.logs: list = [] + self.assertEqual(scenario_def["version"], 1) + self.assertIn(scenario_def["style"], ["unit", "integration"]) + self.listener = CMAPListener() + self._ops: list = [] + + # Configure the fail point before creating the client. + if "failPoint" in test: + fp = test["failPoint"] + await self.set_fail_point(fp) + self.addAsyncCleanup( + self.set_fail_point, {"configureFailPoint": fp["configureFailPoint"], "mode": "off"} + ) + + opts = test["poolOptions"].copy() + opts["event_listeners"] = [self.listener] + opts["_monitor_class"] = DummyMonitor + opts["connect"] = False + # Support backgroundThreadIntervalMS, default to 50ms. + interval = opts.pop("backgroundThreadIntervalMS", 50) + if interval < 0: + kill_cursor_frequency = 99999999 + else: + kill_cursor_frequency = interval / 1000.0 + with client_knobs(kill_cursor_frequency=kill_cursor_frequency, min_heartbeat_interval=0.05): + client = await self.async_single_client(**opts) + # Update the SD to a known type because the DummyMonitor will not. + # Note we cannot simply call topology.on_change because that would + # internally call pool.ready() which introduces unexpected + # PoolReadyEvents. Instead, update the initial state before + # opening the Topology. + td = async_client_context.client._topology.description + sd = td.server_descriptions()[ + (await async_client_context.host, await async_client_context.port) + ] + client._topology._description = updated_topology_description( + client._topology._description, sd + ) + # When backgroundThreadIntervalMS is negative we do not start the + # background thread to ensure it never runs. + if interval < 0: + await client._topology.open() + else: + await client._get_topology() + self.pool = list(client._topology._servers.values())[0].pool + + # Map of target names to Thread objects. + self.targets: dict = {} + # Map of label names to AsyncConnection objects + self.labels: dict = {} + + async def cleanup(): + for t in self.targets.values(): + await t.stop() + for t in self.targets.values(): + await t.join(5) + for conn in self.labels.values(): + conn.close_conn(None) + + self.addAsyncCleanup(cleanup) + + try: + if test["error"]: + with self.assertRaises(PyMongoError) as ctx: + await self.run_operations(test["operations"]) + self.check_error(ctx.exception, test["error"]) + else: + await self.run_operations(test["operations"]) + + self.check_events(test["events"], test["ignore"]) + except Exception: + # Print the events after a test failure. + print("\nFailed test: {!r}".format(test["description"])) + print("Operations:") + for op in self._ops: + print(op) + print("Threads:") + print(self.targets) + print("AsyncConnections:") + print(self.labels) + print("Events:") + for event in self.listener.events: + print(event) + print("Log:") + for log in self.logs: + print(log) + raise + + POOL_OPTIONS = { + "maxPoolSize": 50, + "minPoolSize": 1, + "maxIdleTimeMS": 10000, + "waitQueueTimeoutMS": 10000, + } + + # + # Prose tests. Numbers correspond to the prose test number in the spec. + # + async def test_1_client_connection_pool_options(self): + client = await self.async_rs_or_single_client(**self.POOL_OPTIONS) + pool_opts = (await async_get_pool(client)).opts + self.assertEqual(pool_opts.non_default_options, self.POOL_OPTIONS) + + async def test_2_all_client_pools_have_same_options(self): + client = await self.async_rs_or_single_client(**self.POOL_OPTIONS) + await client.admin.command("ping") + # Discover at least one secondary. + if await async_client_context.has_secondaries: + await client.admin.command("ping", read_preference=ReadPreference.SECONDARY) + pools = await async_get_pools(client) + pool_opts = pools[0].opts + + self.assertEqual(pool_opts.non_default_options, self.POOL_OPTIONS) + for pool in pools[1:]: + self.assertEqual(pool.opts, pool_opts) + + async def test_3_uri_connection_pool_options(self): + opts = "&".join([f"{k}={v}" for k, v in self.POOL_OPTIONS.items()]) + uri = f"mongodb://{await async_client_context.pair}/?{opts}" + client = await self.async_rs_or_single_client(uri) + pool_opts = (await async_get_pool(client)).opts + self.assertEqual(pool_opts.non_default_options, self.POOL_OPTIONS) + + async def test_4_subscribe_to_events(self): + listener = CMAPListener() + client = await self.async_single_client(event_listeners=[listener]) + self.assertEqual(listener.event_count(PoolCreatedEvent), 1) + + # Creates a new connection. + await client.admin.command("ping") + self.assertEqual(listener.event_count(ConnectionCheckOutStartedEvent), 1) + self.assertEqual(listener.event_count(ConnectionCreatedEvent), 1) + self.assertEqual(listener.event_count(ConnectionReadyEvent), 1) + self.assertEqual(listener.event_count(ConnectionCheckedOutEvent), 1) + self.assertEqual(listener.event_count(ConnectionCheckedInEvent), 1) + + # Uses the existing connection. + await client.admin.command("ping") + self.assertEqual(listener.event_count(ConnectionCheckOutStartedEvent), 2) + self.assertEqual(listener.event_count(ConnectionCheckedOutEvent), 2) + self.assertEqual(listener.event_count(ConnectionCheckedInEvent), 2) + + await client.close() + self.assertEqual(listener.event_count(PoolClosedEvent), 1) + self.assertEqual(listener.event_count(ConnectionClosedEvent), 1) + + async def test_5_check_out_fails_connection_error(self): + listener = CMAPListener() + client = await self.async_single_client(event_listeners=[listener]) + pool = await async_get_pool(client) + + def mock_connect(*args, **kwargs): + raise ConnectionFailure("connect failed") + + pool.connect = mock_connect + # Un-patch Pool.connect to break the cyclic reference. + self.addCleanup(delattr, pool, "connect") + + # Attempt to create a new connection. + with self.assertRaisesRegex(ConnectionFailure, "connect failed"): + await client.admin.command("ping") + + self.assertIsInstance(listener.events[0], PoolCreatedEvent) + self.assertIsInstance(listener.events[1], PoolReadyEvent) + self.assertIsInstance(listener.events[2], ConnectionCheckOutStartedEvent) + self.assertIsInstance(listener.events[3], ConnectionCheckOutFailedEvent) + self.assertIsInstance(listener.events[4], PoolClearedEvent) + + failed_event = listener.events[3] + self.assertEqual(failed_event.reason, ConnectionCheckOutFailedReason.CONN_ERROR) + + @async_client_context.require_no_fips + async def test_5_check_out_fails_auth_error(self): + listener = CMAPListener() + client = await self.async_single_client_noauth( + username="notauser", password="fail", event_listeners=[listener] + ) + + # Attempt to create a new connection. + with self.assertRaisesRegex(OperationFailure, "failed"): + await client.admin.command("ping") + + self.assertIsInstance(listener.events[0], PoolCreatedEvent) + self.assertIsInstance(listener.events[1], PoolReadyEvent) + self.assertIsInstance(listener.events[2], ConnectionCheckOutStartedEvent) + self.assertIsInstance(listener.events[3], ConnectionCreatedEvent) + # Error happens here. + self.assertIsInstance(listener.events[4], ConnectionClosedEvent) + self.assertIsInstance(listener.events[5], ConnectionCheckOutFailedEvent) + self.assertEqual(listener.events[5].reason, ConnectionCheckOutFailedReason.CONN_ERROR) + + # + # Extra non-spec tests + # + def assertRepr(self, obj): + new_obj = eval(repr(obj)) + self.assertEqual(type(new_obj), type(obj)) + self.assertEqual(repr(new_obj), repr(obj)) + + async def test_events_repr(self): + host = ("localhost", 27017) + self.assertRepr(ConnectionCheckedInEvent(host, 1)) + self.assertRepr(ConnectionCheckedOutEvent(host, 1, time.monotonic())) + self.assertRepr( + ConnectionCheckOutFailedEvent( + host, ConnectionCheckOutFailedReason.POOL_CLOSED, time.monotonic() + ) + ) + self.assertRepr(ConnectionClosedEvent(host, 1, ConnectionClosedReason.POOL_CLOSED)) + self.assertRepr(ConnectionCreatedEvent(host, 1)) + self.assertRepr(ConnectionReadyEvent(host, 1, time.monotonic())) + self.assertRepr(ConnectionCheckOutStartedEvent(host)) + self.assertRepr(PoolCreatedEvent(host, {})) + self.assertRepr(PoolClearedEvent(host)) + self.assertRepr(PoolClearedEvent(host, service_id=ObjectId())) + self.assertRepr(PoolClosedEvent(host)) + + async def test_close_leaves_pool_unpaused(self): + listener = CMAPListener() + client = await self.async_single_client(event_listeners=[listener]) + await client.admin.command("ping") + pool = await async_get_pool(client) + await client.close() + self.assertEqual(1, listener.event_count(PoolClosedEvent)) + self.assertEqual(PoolState.CLOSED, pool.state) + # Checking out a connection should fail + with self.assertRaises(_PoolClosedError): + async with pool.checkout(): + pass + + +def create_test(scenario_def, test, name): + async def run_scenario(self): + await self.run_scenario(scenario_def, test) + + return run_scenario + + +class CMAPSpecTestCreator(AsyncSpecTestCreator): + def tests(self, scenario_def): + """Extract the tests from a spec file. + + CMAP tests do not have a 'tests' field. The whole file represents + a single test case. + """ + return [scenario_def] + + +test_creator = CMAPSpecTestCreator(create_test, AsyncTestCMAP, AsyncTestCMAP.TEST_PATH) +test_creator.create_tests() + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_connection_monitoring.py b/test/test_connection_monitoring.py index 05411d17ba..810d440932 100644 --- a/test/test_connection_monitoring.py +++ b/test/test_connection_monitoring.py @@ -15,9 +15,11 @@ """Execute Transactions Spec tests.""" from __future__ import annotations +import asyncio import os import sys import time +from pathlib import Path sys.path[0:0] = [""] @@ -60,6 +62,8 @@ from pymongo.synchronous.pool import PoolState, _PoolClosedError from pymongo.topology_description import updated_topology_description +_IS_SYNC = True + OBJECT_TYPES = { # Event types. "ConnectionCheckedIn": ConnectionCheckedInEvent, @@ -81,7 +85,10 @@ class TestCMAP(IntegrationTest): # Location of JSON test specifications. - TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "connection_monitoring") + if _IS_SYNC: + TEST_PATH = os.path.join(Path(__file__).resolve().parent, "connection_monitoring") + else: + TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "connection_monitoring") # Test operations: @@ -258,7 +265,6 @@ def run_scenario(self, scenario_def, test): client._topology.open() else: client._get_topology() - self.addCleanup(client.close) self.pool = list(client._topology._servers.values())[0].pool # Map of target names to Thread objects. @@ -315,13 +321,11 @@ def cleanup(): # def test_1_client_connection_pool_options(self): client = self.rs_or_single_client(**self.POOL_OPTIONS) - self.addCleanup(client.close) - pool_opts = get_pool(client).opts + pool_opts = (get_pool(client)).opts self.assertEqual(pool_opts.non_default_options, self.POOL_OPTIONS) def test_2_all_client_pools_have_same_options(self): client = self.rs_or_single_client(**self.POOL_OPTIONS) - self.addCleanup(client.close) client.admin.command("ping") # Discover at least one secondary. if client_context.has_secondaries: @@ -337,14 +341,12 @@ def test_3_uri_connection_pool_options(self): opts = "&".join([f"{k}={v}" for k, v in self.POOL_OPTIONS.items()]) uri = f"mongodb://{client_context.pair}/?{opts}" client = self.rs_or_single_client(uri) - self.addCleanup(client.close) - pool_opts = get_pool(client).opts + pool_opts = (get_pool(client)).opts self.assertEqual(pool_opts.non_default_options, self.POOL_OPTIONS) def test_4_subscribe_to_events(self): listener = CMAPListener() client = self.single_client(event_listeners=[listener]) - self.addCleanup(client.close) self.assertEqual(listener.event_count(PoolCreatedEvent), 1) # Creates a new connection. @@ -368,7 +370,6 @@ def test_4_subscribe_to_events(self): def test_5_check_out_fails_connection_error(self): listener = CMAPListener() client = self.single_client(event_listeners=[listener]) - self.addCleanup(client.close) pool = get_pool(client) def mock_connect(*args, **kwargs): @@ -397,7 +398,6 @@ def test_5_check_out_fails_auth_error(self): client = self.single_client_noauth( username="notauser", password="fail", event_listeners=[listener] ) - self.addCleanup(client.close) # Attempt to create a new connection. with self.assertRaisesRegex(OperationFailure, "failed"): diff --git a/test/utils.py b/test/utils.py index 40eec01cb4..e089b3fc2f 100644 --- a/test/utils.py +++ b/test/utils.py @@ -832,7 +832,7 @@ async def async_get_pools(client): """Get all pools.""" return [ server.pool - async for server in await (await client._get_topology()).select_servers( + for server in await (await client._get_topology()).select_servers( any_server_selector, _Op.TEST ) ] diff --git a/tools/synchro.py b/tools/synchro.py index bc1fcd7869..aa681df54e 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -204,6 +204,7 @@ def async_only_test(f: str) -> bool: "test_comment.py", "test_common.py", "test_connection_logging.py", + "test_connection_monitoring.py", "test_connections_survive_primary_stepdown_spec.py", "test_create_entities.py", "test_crud_unified.py",