From bcfc64a2b1047e133565a8fcecb5988c74967f12 Mon Sep 17 00:00:00 2001 From: Iris Date: Wed, 2 Oct 2024 13:15:57 -0700 Subject: [PATCH 1/8] Migrate test_discovery_and_monitoring.py to async --- .../test_discovery_and_monitoring.py | 448 ++++++++++++++++++ test/test_discovery_and_monitoring.py | 10 +- tools/synchro.py | 1 + 3 files changed, 454 insertions(+), 5 deletions(-) create mode 100644 test/asynchronous/test_discovery_and_monitoring.py diff --git a/test/asynchronous/test_discovery_and_monitoring.py b/test/asynchronous/test_discovery_and_monitoring.py new file mode 100644 index 0000000000..068dec6983 --- /dev/null +++ b/test/asynchronous/test_discovery_and_monitoring.py @@ -0,0 +1,448 @@ +# Copyright 2014-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test the topology module.""" +from __future__ import annotations + +import os +import socketserver +import sys +import threading + +sys.path[0:0] = [""] + +from test.asynchronous import AsyncIntegrationTest, AsyncPyMongoTestCase, unittest +from test.pymongo_mocks import DummyMonitor +from test.unified_format import generate_test_classes +from test.utils import ( + CMAPListener, + HeartbeatEventListener, + HeartbeatEventsListListener, + assertion_context, + async_client_context, + async_get_pool, + async_wait_until, + server_name_to_type, + wait_until, +) +from unittest.mock import patch + +from bson import Timestamp, json_util +from pymongo import AsyncMongoClient, common, monitoring +from pymongo.asynchronous.settings import TopologySettings +from pymongo.asynchronous.topology import Topology, _ErrorContext +from pymongo.errors import ( + AutoReconnect, + ConfigurationError, + NetworkTimeout, + NotPrimaryError, + OperationFailure, +) +from pymongo.hello import Hello, HelloCompat +from pymongo.helpers_shared import _check_command_response, _check_write_command_response +from pymongo.monitoring import ServerHeartbeatFailedEvent, ServerHeartbeatStartedEvent +from pymongo.server_description import SERVER_TYPE, ServerDescription +from pymongo.topology_description import TOPOLOGY_TYPE +from pymongo.uri_parser import parse_uri + +_IS_SYNC = False + +# Location of JSON test specifications. +SDAM_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "discovery_and_monitoring") + + +async def create_mock_topology(uri, monitor_class=DummyMonitor): + parsed_uri = parse_uri(uri) + replica_set_name = None + direct_connection = None + load_balanced = None + if "replicaset" in parsed_uri["options"]: + replica_set_name = parsed_uri["options"]["replicaset"] + if "directAsyncConnection" in parsed_uri["options"]: + direct_connection = parsed_uri["options"]["directAsyncConnection"] + if "loadBalanced" in parsed_uri["options"]: + load_balanced = parsed_uri["options"]["loadBalanced"] + + topology_settings = TopologySettings( + parsed_uri["nodelist"], + replica_set_name=replica_set_name, + monitor_class=monitor_class, + direct_connection=direct_connection, + load_balanced=load_balanced, + ) + + c = Topology(topology_settings) + await c.open() + return c + + +async def got_hello(topology, server_address, hello_response): + server_description = ServerDescription(server_address, Hello(hello_response), 0) + await topology.on_change(server_description) + + +def got_app_error(topology, app_error): + server_address = common.partition_node(app_error["address"]) + server = topology.get_server_by_address(server_address) + error_type = app_error["type"] + generation = app_error.get("generation", server.pool.gen.get_overall()) + when = app_error["when"] + max_wire_version = app_error["maxWireVersion"] + # XXX: We could get better test coverage by mocking the errors on the + # Pool/AsyncConnection. + try: + if error_type == "command": + _check_command_response(app_error["response"], max_wire_version) + _check_write_command_response(app_error["response"]) + elif error_type == "network": + raise AutoReconnect("mock non-timeout network error") + elif error_type == "timeout": + raise NetworkTimeout("mock network timeout error") + else: + raise AssertionError(f"unknown error type: {error_type}") + raise AssertionError + except (AutoReconnect, NotPrimaryError, OperationFailure) as e: + if when == "beforeHandshakeCompletes": + completed_handshake = False + elif when == "afterHandshakeCompletes": + completed_handshake = True + else: + raise AssertionError(f"Unknown when field {when}") + + topology.handle_error( + server_address, + _ErrorContext(e, max_wire_version, generation, completed_handshake, None), + ) + + +def get_type(topology, hostname): + description = topology.get_server_by_address((hostname, 27017)).description + return description.server_type + + +class TestAllScenarios(unittest.IsolatedAsyncioTestCase): + pass + + +def topology_type_name(topology_type): + return TOPOLOGY_TYPE._fields[topology_type] + + +def server_type_name(server_type): + return SERVER_TYPE._fields[server_type] + + +def check_outcome(self, topology, outcome): + expected_servers = outcome["servers"] + + # Check weak equality before proceeding. + self.assertEqual(len(topology.description.server_descriptions()), len(expected_servers)) + + if outcome.get("compatible") is False: + with self.assertRaises(ConfigurationError): + topology.description.check_compatible() + else: + # No error. + topology.description.check_compatible() + + # Since lengths are equal, every actual server must have a corresponding + # expected server. + for expected_server_address, expected_server in expected_servers.items(): + node = common.partition_node(expected_server_address) + self.assertTrue(topology.has_server(node)) + actual_server = topology.get_server_by_address(node) + actual_server_description = actual_server.description + expected_server_type = server_name_to_type(expected_server["type"]) + + self.assertEqual( + server_type_name(expected_server_type), + server_type_name(actual_server_description.server_type), + ) + + self.assertEqual(expected_server.get("setName"), actual_server_description.replica_set_name) + + self.assertEqual(expected_server.get("setVersion"), actual_server_description.set_version) + + self.assertEqual(expected_server.get("electionId"), actual_server_description.election_id) + + self.assertEqual( + expected_server.get("topologyVersion"), actual_server_description.topology_version + ) + + expected_pool = expected_server.get("pool") + if expected_pool: + self.assertEqual(expected_pool.get("generation"), actual_server.pool.gen.get_overall()) + + self.assertEqual(outcome["setName"], topology.description.replica_set_name) + self.assertEqual( + outcome.get("logicalSessionTimeoutMinutes"), + topology.description.logical_session_timeout_minutes, + ) + + expected_topology_type = getattr(TOPOLOGY_TYPE, outcome["topologyType"]) + self.assertEqual( + topology_type_name(expected_topology_type), + topology_type_name(topology.description.topology_type), + ) + + self.assertEqual(outcome.get("maxSetVersion"), topology.description.max_set_version) + self.assertEqual(outcome.get("maxElectionId"), topology.description.max_election_id) + + +def create_test(scenario_def): + async def run_scenario(self): + c = await create_mock_topology(scenario_def["uri"]) + + for i, phase in enumerate(scenario_def["phases"]): + # Including the phase description makes failures easier to debug. + description = phase.get("description", str(i)) + with assertion_context(f"phase: {description}"): + for response in phase.get("responses", []): + got_hello(c, common.partition_node(response[0]), response[1]) + + for app_error in phase.get("applicationErrors", []): + got_app_error(c, app_error) + + check_outcome(self, c, phase["outcome"]) + + return run_scenario + + +def create_tests(): + for dirpath, _, filenames in os.walk(SDAM_PATH): + dirname = os.path.split(dirpath)[-1] + # SDAM unified tests are handled separately. + if dirname == "unified": + continue + + for filename in filenames: + if os.path.splitext(filename)[1] != ".json": + continue + with open(os.path.join(dirpath, filename)) as scenario_stream: + scenario_def = json_util.loads(scenario_stream.read()) + + # Construct test from scenario. + new_test = create_test(scenario_def) + test_name = f"test_{dirname}_{os.path.splitext(filename)[0]}" + + new_test.__name__ = test_name + setattr(TestAllScenarios, new_test.__name__, new_test) + + +create_tests() + + +class TestClusterTimeComparison(unittest.IsolatedAsyncioTestCase): + async def test_cluster_time_comparison(self): + t = await create_mock_topology("mongodb://host") + + async def send_cluster_time(time, inc, should_update): + old = t.max_cluster_time() + new = {"clusterTime": Timestamp(time, inc)} + await got_hello( + t, + ("host", 27017), + {"ok": 1, "minWireVersion": 0, "maxWireVersion": 6, "$clusterTime": new}, + ) + + actual = t.max_cluster_time() + if should_update: + self.assertEqual(actual, new) + else: + self.assertEqual(actual, old) + + await send_cluster_time(0, 1, True) + await send_cluster_time(2, 2, True) + await send_cluster_time(2, 1, False) + await send_cluster_time(1, 3, False) + await send_cluster_time(2, 3, True) + + +class TestIgnoreStaleErrors(AsyncIntegrationTest): + @async_client_context.require_sync + async def test_ignore_stale_connection_errors(self): + N_THREADS = 5 + barrier = threading.Barrier(N_THREADS, timeout=30) + client = await self.async_rs_or_single_client(minPoolSize=N_THREADS) + + # Wait for initial discovery. + await client.admin.command("ping") + pool = await async_get_pool(client) + starting_generation = pool.gen.get_overall() + await async_wait_until(lambda: len(pool.conns) == N_THREADS, "created conns") + + def mock_command(*args, **kwargs): + # Synchronize all threads to ensure they use the same generation. + barrier.wait() + raise AutoReconnect("mock AsyncConnection.command error") + + for conn in pool.conns: + conn.command = mock_command + + async def insert_command(i): + try: + await client.test.command("insert", "test", documents=[{"i": i}]) + except AutoReconnect: + pass + + threads = [] + for i in range(N_THREADS): + threads.append(threading.Thread(target=insert_command, args=(i,))) + for t in threads: + t.start() + for t in threads: + t.join() + + # Expect a single pool reset for the network error + self.assertEqual(starting_generation + 1, pool.gen.get_overall()) + + # Server should be selectable. + await client.admin.command("ping") + + +class CMAPHeartbeatListener(HeartbeatEventListener, CMAPListener): + pass + + +class TestPoolManagement(AsyncIntegrationTest): + @async_client_context.require_failCommand_appName + async def test_pool_unpause(self): + # This test implements the prose test "AsyncConnection Pool Management" + listener = CMAPHeartbeatListener() + client = await self.async_single_client( + appName="SDAMPoolManagementTest", heartbeatFrequencyMS=500, event_listeners=[listener] + ) + # Assert that AsyncConnectionPoolReadyEvent occurs after the first + # ServerHeartbeatSucceededEvent. + listener.wait_for_event(monitoring.PoolReadyEvent, 1) + pool_ready = listener.events_by_type(monitoring.PoolReadyEvent)[0] + hb_succeeded = listener.events_by_type(monitoring.ServerHeartbeatSucceededEvent)[0] + self.assertGreater(listener.events.index(pool_ready), listener.events.index(hb_succeeded)) + + listener.reset() + fail_hello = { + "mode": {"times": 2}, + "data": { + "failCommands": [HelloCompat.LEGACY_CMD, "hello"], + "errorCode": 1234, + "appName": "SDAMPoolManagementTest", + }, + } + async with self.fail_point(fail_hello): + listener.wait_for_event(monitoring.ServerHeartbeatFailedEvent, 1) + listener.wait_for_event(monitoring.PoolClearedEvent, 1) + listener.wait_for_event(monitoring.ServerHeartbeatSucceededEvent, 1) + listener.wait_for_event(monitoring.PoolReadyEvent, 1) + + +class TestServerMonitoringMode(AsyncIntegrationTest): + @async_client_context.require_no_serverless + @async_client_context.require_no_load_balancer + async def asyncSetUp(self): + await super().asyncSetUp() + + async def test_rtt_connection_is_enabled_stream(self): + client = await self.async_rs_or_single_client(serverMonitoringMode="stream") + await client.admin.command("ping") + + def predicate(): + for _, server in client._topology._servers.items(): + monitor = server._monitor + if not monitor._stream: + return False + if async_client_context.version >= (4, 4): + if monitor._rtt_monitor._executor._thread is None: + return False + else: + if monitor._rtt_monitor._executor._thread is not None: + return False + return True + + wait_until(predicate, "find all RTT monitors") + + async def test_rtt_connection_is_disabled_poll(self): + client = await self.async_rs_or_single_client(serverMonitoringMode="poll") + + await self.assert_rtt_connection_is_disabled(client) + + async def test_rtt_connection_is_disabled_auto(self): + envs = [ + {"AWS_EXECUTION_ENV": "AWS_Lambda_python3.9"}, + {"FUNCTIONS_WORKER_RUNTIME": "python"}, + {"K_SERVICE": "gcpservicename"}, + {"FUNCTION_NAME": "gcpfunctionname"}, + {"VERCEL": "1"}, + ] + for env in envs: + with patch.dict("os.environ", env): + client = await self.async_rs_or_single_client(serverMonitoringMode="auto") + await self.assert_rtt_connection_is_disabled(client) + + async def assert_rtt_connection_is_disabled(self, client): + await client.admin.command("ping") + for _, server in client._topology._servers.items(): + monitor = server._monitor + self.assertFalse(monitor._stream) + self.assertIsNone(monitor._rtt_monitor._executor._thread) + + +class MockTCPHandler(socketserver.BaseRequestHandler): + def handle(self): + self.server.events.append("client connected") + if self.request.recv(1024).strip(): + self.server.events.append("client hello received") + self.request.close() + + +class TCPServer(socketserver.TCPServer): + allow_reuse_address = True + + def handle_request_and_shutdown(self): + self.handle_request() + self.server_close() + + +class TestHeartbeatStartOrdering(AsyncPyMongoTestCase): + @async_client_context.require_sync + async def test_heartbeat_start_ordering(self): + events = [] + listener = HeartbeatEventsListListener(events) + server = TCPServer(("localhost", 9999), MockTCPHandler) + server.events = events + server_thread = threading.Thread(target=server.handle_request_and_shutdown) + server_thread.start() + _c = self.simple_client( + "mongodb://localhost:9999", serverSelectionTimeoutMS=500, event_listeners=(listener,) + ) + server_thread.join() + listener.wait_for_event(ServerHeartbeatStartedEvent, 1) + listener.wait_for_event(ServerHeartbeatFailedEvent, 1) + + self.assertEqual( + events, + [ + "serverHeartbeatStartedEvent", + "client connected", + "client hello received", + "serverHeartbeatFailedEvent", + ], + ) + + +# Generate unified tests. +globals().update(generate_test_classes(os.path.join(SDAM_PATH, "unified"), module=__name__)) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_discovery_and_monitoring.py b/test/test_discovery_and_monitoring.py index 3554619f12..4ff94c2e37 100644 --- a/test/test_discovery_and_monitoring.py +++ b/test/test_discovery_and_monitoring.py @@ -55,6 +55,8 @@ from pymongo.topology_description import TOPOLOGY_TYPE from pymongo.uri_parser import parse_uri +_IS_SYNC = True + # Location of JSON test specifications. SDAM_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "discovery_and_monitoring") @@ -267,11 +269,11 @@ def send_cluster_time(time, inc, should_update): class TestIgnoreStaleErrors(IntegrationTest): + @client_context.require_sync def test_ignore_stale_connection_errors(self): N_THREADS = 5 barrier = threading.Barrier(N_THREADS, timeout=30) client = self.rs_or_single_client(minPoolSize=N_THREADS) - self.addCleanup(client.close) # Wait for initial discovery. client.admin.command("ping") @@ -320,7 +322,6 @@ def test_pool_unpause(self): client = self.single_client( appName="SDAMPoolManagementTest", heartbeatFrequencyMS=500, event_listeners=[listener] ) - self.addCleanup(client.close) # Assert that ConnectionPoolReadyEvent occurs after the first # ServerHeartbeatSucceededEvent. listener.wait_for_event(monitoring.PoolReadyEvent, 1) @@ -352,7 +353,6 @@ def setUp(self): def test_rtt_connection_is_enabled_stream(self): client = self.rs_or_single_client(serverMonitoringMode="stream") - self.addCleanup(client.close) client.admin.command("ping") def predicate(): @@ -372,7 +372,7 @@ def predicate(): def test_rtt_connection_is_disabled_poll(self): client = self.rs_or_single_client(serverMonitoringMode="poll") - self.addCleanup(client.close) + self.assert_rtt_connection_is_disabled(client) def test_rtt_connection_is_disabled_auto(self): @@ -386,7 +386,6 @@ def test_rtt_connection_is_disabled_auto(self): for env in envs: with patch.dict("os.environ", env): client = self.rs_or_single_client(serverMonitoringMode="auto") - self.addCleanup(client.close) self.assert_rtt_connection_is_disabled(client) def assert_rtt_connection_is_disabled(self, client): @@ -414,6 +413,7 @@ def handle_request_and_shutdown(self): class TestHeartbeatStartOrdering(PyMongoTestCase): + @client_context.require_sync def test_heartbeat_start_ordering(self): events = [] listener = HeartbeatEventsListListener(events) diff --git a/tools/synchro.py b/tools/synchro.py index 3333b0de2e..7ce7b79047 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -188,6 +188,7 @@ def async_only_test(f: str) -> bool: "test_collection.py", "test_cursor.py", "test_database.py", + "test_discovery_and_monitoring.py", "test_encryption.py", "test_grid_file.py", "test_logger.py", From 037afe69de4ce366c614bf42a6c60eecc7a3964a Mon Sep 17 00:00:00 2001 From: Iris Date: Thu, 3 Oct 2024 08:39:39 -0700 Subject: [PATCH 2/8] attempt to fix failing test --- pymongo/asynchronous/monitor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymongo/asynchronous/monitor.py b/pymongo/asynchronous/monitor.py index f9e912b084..42d21493ff 100644 --- a/pymongo/asynchronous/monitor.py +++ b/pymongo/asynchronous/monitor.py @@ -250,7 +250,7 @@ async def _check_server(self) -> ServerDescription: except (OperationFailure, NotPrimaryError) as exc: # Update max cluster time even when hello fails. details = cast(Mapping[str, Any], exc.details) - self._topology.receive_cluster_time(details.get("$clusterTime")) + await self._topology.receive_cluster_time(details.get("$clusterTime")) raise except ReferenceError: raise From 926c1765da080f3cde0f7a0251ea65691dc6e799 Mon Sep 17 00:00:00 2001 From: Iris Date: Thu, 3 Oct 2024 09:31:00 -0700 Subject: [PATCH 3/8] fix unused local var --- test/asynchronous/test_discovery_and_monitoring.py | 2 +- test/test_discovery_and_monitoring.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/asynchronous/test_discovery_and_monitoring.py b/test/asynchronous/test_discovery_and_monitoring.py index 068dec6983..c28e11165a 100644 --- a/test/asynchronous/test_discovery_and_monitoring.py +++ b/test/asynchronous/test_discovery_and_monitoring.py @@ -320,7 +320,7 @@ class TestPoolManagement(AsyncIntegrationTest): async def test_pool_unpause(self): # This test implements the prose test "AsyncConnection Pool Management" listener = CMAPHeartbeatListener() - client = await self.async_single_client( + _ = await self.async_single_client( appName="SDAMPoolManagementTest", heartbeatFrequencyMS=500, event_listeners=[listener] ) # Assert that AsyncConnectionPoolReadyEvent occurs after the first diff --git a/test/test_discovery_and_monitoring.py b/test/test_discovery_and_monitoring.py index 4ff94c2e37..e4bcbe0046 100644 --- a/test/test_discovery_and_monitoring.py +++ b/test/test_discovery_and_monitoring.py @@ -319,7 +319,7 @@ class TestPoolManagement(IntegrationTest): def test_pool_unpause(self): # This test implements the prose test "Connection Pool Management" listener = CMAPHeartbeatListener() - client = self.single_client( + _ = self.single_client( appName="SDAMPoolManagementTest", heartbeatFrequencyMS=500, event_listeners=[listener] ) # Assert that ConnectionPoolReadyEvent occurs after the first From 7a17e8f3137f0f4117eb459290d95bb2dc61ca67 Mon Sep 17 00:00:00 2001 From: Iris Date: Mon, 14 Oct 2024 09:32:39 -0700 Subject: [PATCH 4/8] add missing awaits and changed path; some tests are failing but pushing what I have --- test/asynchronous/pymongo_mocks.py | 2 +- .../test_discovery_and_monitoring.py | 25 +++++++++++++------ test/test_discovery_and_monitoring.py | 13 +++++++++- test/utils.py | 5 +++- 4 files changed, 35 insertions(+), 10 deletions(-) diff --git a/test/asynchronous/pymongo_mocks.py b/test/asynchronous/pymongo_mocks.py index ed2395bc98..25316bd11f 100644 --- a/test/asynchronous/pymongo_mocks.py +++ b/test/asynchronous/pymongo_mocks.py @@ -75,7 +75,7 @@ def open(self): def request_check(self): pass - def close(self): + async def close(self): self.opened = False diff --git a/test/asynchronous/test_discovery_and_monitoring.py b/test/asynchronous/test_discovery_and_monitoring.py index c28e11165a..a0085c144a 100644 --- a/test/asynchronous/test_discovery_and_monitoring.py +++ b/test/asynchronous/test_discovery_and_monitoring.py @@ -15,6 +15,7 @@ """Test the topology module.""" from __future__ import annotations +import asyncio import os import socketserver import sys @@ -23,7 +24,7 @@ sys.path[0:0] = [""] from test.asynchronous import AsyncIntegrationTest, AsyncPyMongoTestCase, unittest -from test.pymongo_mocks import DummyMonitor +from test.asynchronous.pymongo_mocks import DummyMonitor from test.unified_format import generate_test_classes from test.utils import ( CMAPListener, @@ -59,7 +60,15 @@ _IS_SYNC = False # Location of JSON test specifications. -SDAM_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "discovery_and_monitoring") +if _IS_SYNC: + SDAM_PATH = os.path.join( + os.path.dirname(os.path.realpath(__file__)), "discovery_and_monitoring" + ) +else: + SDAM_PATH = os.path.join( + os.path.abspath(os.path.join(os.path.dirname(os.path.realpath(__file__)), os.pardir)), + "discovery_and_monitoring", + ) async def create_mock_topology(uri, monitor_class=DummyMonitor): @@ -92,7 +101,7 @@ async def got_hello(topology, server_address, hello_response): await topology.on_change(server_description) -def got_app_error(topology, app_error): +async def got_app_error(topology, app_error): server_address = common.partition_node(app_error["address"]) server = topology.get_server_by_address(server_address) error_type = app_error["type"] @@ -120,7 +129,7 @@ def got_app_error(topology, app_error): else: raise AssertionError(f"Unknown when field {when}") - topology.handle_error( + await topology.handle_error( server_address, _ErrorContext(e, max_wire_version, generation, completed_handshake, None), ) @@ -207,12 +216,14 @@ async def run_scenario(self): for i, phase in enumerate(scenario_def["phases"]): # Including the phase description makes failures easier to debug. description = phase.get("description", str(i)) + if self._testMethodName == "test_single_direct_connection_external_ip": + print("here") with assertion_context(f"phase: {description}"): for response in phase.get("responses", []): - got_hello(c, common.partition_node(response[0]), response[1]) + await got_hello(c, common.partition_node(response[0]), response[1]) for app_error in phase.get("applicationErrors", []): - got_app_error(c, app_error) + await got_app_error(c, app_error) check_outcome(self, c, phase["outcome"]) @@ -369,7 +380,7 @@ def predicate(): return False return True - wait_until(predicate, "find all RTT monitors") + await async_wait_until(predicate, "find all RTT monitors") async def test_rtt_connection_is_disabled_poll(self): client = await self.async_rs_or_single_client(serverMonitoringMode="poll") diff --git a/test/test_discovery_and_monitoring.py b/test/test_discovery_and_monitoring.py index e4bcbe0046..a6b8dca662 100644 --- a/test/test_discovery_and_monitoring.py +++ b/test/test_discovery_and_monitoring.py @@ -15,6 +15,7 @@ """Test the topology module.""" from __future__ import annotations +import asyncio import os import socketserver import sys @@ -58,7 +59,15 @@ _IS_SYNC = True # Location of JSON test specifications. -SDAM_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "discovery_and_monitoring") +if _IS_SYNC: + SDAM_PATH = os.path.join( + os.path.dirname(os.path.realpath(__file__)), "discovery_and_monitoring" + ) +else: + SDAM_PATH = os.path.join( + os.path.abspath(os.path.join(os.path.dirname(os.path.realpath(__file__)), os.pardir)), + "discovery_and_monitoring", + ) def create_mock_topology(uri, monitor_class=DummyMonitor): @@ -206,6 +215,8 @@ def run_scenario(self): for i, phase in enumerate(scenario_def["phases"]): # Including the phase description makes failures easier to debug. description = phase.get("description", str(i)) + if self._testMethodName == "test_single_direct_connection_external_ip": + print("here") with assertion_context(f"phase: {description}"): for response in phase.get("responses", []): got_hello(c, common.partition_node(response[0]), response[1]) diff --git a/test/utils.py b/test/utils.py index 9615034899..c1d03f50dd 100644 --- a/test/utils.py +++ b/test/utils.py @@ -751,7 +751,10 @@ async def async_wait_until(predicate, success_description, timeout=10): start = time.time() interval = min(float(timeout) / 100, 0.1) while True: - retval = await predicate() + if iscoroutinefunction(predicate): + retval = await predicate() + else: + retval = predicate() if retval: return retval From 997260fe7bd0cb6270daf3367398f90fc607030c Mon Sep 17 00:00:00 2001 From: Iris Date: Mon, 14 Oct 2024 09:42:41 -0700 Subject: [PATCH 5/8] remove debugging statements --- test/asynchronous/test_discovery_and_monitoring.py | 2 -- test/test_discovery_and_monitoring.py | 2 -- 2 files changed, 4 deletions(-) diff --git a/test/asynchronous/test_discovery_and_monitoring.py b/test/asynchronous/test_discovery_and_monitoring.py index a0085c144a..d44db28aa6 100644 --- a/test/asynchronous/test_discovery_and_monitoring.py +++ b/test/asynchronous/test_discovery_and_monitoring.py @@ -216,8 +216,6 @@ async def run_scenario(self): for i, phase in enumerate(scenario_def["phases"]): # Including the phase description makes failures easier to debug. description = phase.get("description", str(i)) - if self._testMethodName == "test_single_direct_connection_external_ip": - print("here") with assertion_context(f"phase: {description}"): for response in phase.get("responses", []): await got_hello(c, common.partition_node(response[0]), response[1]) diff --git a/test/test_discovery_and_monitoring.py b/test/test_discovery_and_monitoring.py index a6b8dca662..633163e853 100644 --- a/test/test_discovery_and_monitoring.py +++ b/test/test_discovery_and_monitoring.py @@ -215,8 +215,6 @@ def run_scenario(self): for i, phase in enumerate(scenario_def["phases"]): # Including the phase description makes failures easier to debug. description = phase.get("description", str(i)) - if self._testMethodName == "test_single_direct_connection_external_ip": - print("here") with assertion_context(f"phase: {description}"): for response in phase.get("responses", []): got_hello(c, common.partition_node(response[0]), response[1]) From 7f447c853c9e950ee1eed3dc6eb768ead41db61a Mon Sep 17 00:00:00 2001 From: Iris Date: Mon, 14 Oct 2024 09:53:01 -0700 Subject: [PATCH 6/8] fix tests --- test/asynchronous/test_discovery_and_monitoring.py | 6 ++++-- test/test_discovery_and_monitoring.py | 2 ++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/test/asynchronous/test_discovery_and_monitoring.py b/test/asynchronous/test_discovery_and_monitoring.py index d44db28aa6..9ccca8f8ff 100644 --- a/test/asynchronous/test_discovery_and_monitoring.py +++ b/test/asynchronous/test_discovery_and_monitoring.py @@ -78,8 +78,8 @@ async def create_mock_topology(uri, monitor_class=DummyMonitor): load_balanced = None if "replicaset" in parsed_uri["options"]: replica_set_name = parsed_uri["options"]["replicaset"] - if "directAsyncConnection" in parsed_uri["options"]: - direct_connection = parsed_uri["options"]["directAsyncConnection"] + if "directConnection" in parsed_uri["options"]: + direct_connection = parsed_uri["options"]["directConnection"] if "loadBalanced" in parsed_uri["options"]: load_balanced = parsed_uri["options"]["loadBalanced"] @@ -216,6 +216,8 @@ async def run_scenario(self): for i, phase in enumerate(scenario_def["phases"]): # Including the phase description makes failures easier to debug. description = phase.get("description", str(i)) + if self._testMethodName == "test_single_direct_connection_external_ip": + print("here") with assertion_context(f"phase: {description}"): for response in phase.get("responses", []): await got_hello(c, common.partition_node(response[0]), response[1]) diff --git a/test/test_discovery_and_monitoring.py b/test/test_discovery_and_monitoring.py index 633163e853..a6b8dca662 100644 --- a/test/test_discovery_and_monitoring.py +++ b/test/test_discovery_and_monitoring.py @@ -215,6 +215,8 @@ def run_scenario(self): for i, phase in enumerate(scenario_def["phases"]): # Including the phase description makes failures easier to debug. description = phase.get("description", str(i)) + if self._testMethodName == "test_single_direct_connection_external_ip": + print("here") with assertion_context(f"phase: {description}"): for response in phase.get("responses", []): got_hello(c, common.partition_node(response[0]), response[1]) From 4f357bb16351d861920f18da8f4c8935fe6ce9d1 Mon Sep 17 00:00:00 2001 From: Iris Date: Mon, 14 Oct 2024 10:02:36 -0700 Subject: [PATCH 7/8] remove debugging statement again --- test/asynchronous/test_discovery_and_monitoring.py | 2 -- test/test_discovery_and_monitoring.py | 9 +-------- 2 files changed, 1 insertion(+), 10 deletions(-) diff --git a/test/asynchronous/test_discovery_and_monitoring.py b/test/asynchronous/test_discovery_and_monitoring.py index 9ccca8f8ff..11793a08e0 100644 --- a/test/asynchronous/test_discovery_and_monitoring.py +++ b/test/asynchronous/test_discovery_and_monitoring.py @@ -216,8 +216,6 @@ async def run_scenario(self): for i, phase in enumerate(scenario_def["phases"]): # Including the phase description makes failures easier to debug. description = phase.get("description", str(i)) - if self._testMethodName == "test_single_direct_connection_external_ip": - print("here") with assertion_context(f"phase: {description}"): for response in phase.get("responses", []): await got_hello(c, common.partition_node(response[0]), response[1]) diff --git a/test/test_discovery_and_monitoring.py b/test/test_discovery_and_monitoring.py index 29fa1fc14e..633163e853 100644 --- a/test/test_discovery_and_monitoring.py +++ b/test/test_discovery_and_monitoring.py @@ -215,8 +215,6 @@ def run_scenario(self): for i, phase in enumerate(scenario_def["phases"]): # Including the phase description makes failures easier to debug. description = phase.get("description", str(i)) - if self._testMethodName == "test_single_direct_connection_external_ip": - print("here") with assertion_context(f"phase: {description}"): for response in phase.get("responses", []): got_hello(c, common.partition_node(response[0]), response[1]) @@ -263,12 +261,7 @@ def send_cluster_time(time, inc, should_update): got_hello( t, ("host", 27017), - { - "ok": 1, - "minWireVersion": 0, - "maxWireVersion": common.MIN_SUPPORTED_WIRE_VERSION, - "$clusterTime": new, - }, + {"ok": 1, "minWireVersion": 0, "maxWireVersion": 6, "$clusterTime": new}, ) actual = t.max_cluster_time() From e39192e9a454f9ae271cccc8c7d80df54d851cbe Mon Sep 17 00:00:00 2001 From: Iris Date: Mon, 14 Oct 2024 10:19:23 -0700 Subject: [PATCH 8/8] fix lambda in test_collection --- test/asynchronous/test_collection.py | 5 ++++- test/test_collection.py | 5 ++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/test/asynchronous/test_collection.py b/test/asynchronous/test_collection.py index 612090b69f..cc6ff2f05e 100644 --- a/test/asynchronous/test_collection.py +++ b/test/asynchronous/test_collection.py @@ -1022,7 +1022,10 @@ async def test_replace_bypass_document_validation(self): await db.test.insert_one({"y": 1}, bypass_document_validation=True) await db_w0.test.replace_one({"y": 1}, {"x": 1}, bypass_document_validation=True) - await async_wait_until(lambda: db_w0.test.find_one({"x": 1}), "find w:0 replaced document") + async def async_lambda(): + await db_w0.test.find_one({"x": 1}) + + await async_wait_until(async_lambda, "find w:0 replaced document") async def test_update_bypass_document_validation(self): db = self.db diff --git a/test/test_collection.py b/test/test_collection.py index a2c3b0b0b6..012123c01b 100644 --- a/test/test_collection.py +++ b/test/test_collection.py @@ -1009,7 +1009,10 @@ def test_replace_bypass_document_validation(self): db.test.insert_one({"y": 1}, bypass_document_validation=True) db_w0.test.replace_one({"y": 1}, {"x": 1}, bypass_document_validation=True) - wait_until(lambda: db_w0.test.find_one({"x": 1}), "find w:0 replaced document") + def async_lambda(): + db_w0.test.find_one({"x": 1}) + + wait_until(async_lambda, "find w:0 replaced document") def test_update_bypass_document_validation(self): db = self.db