From 2e6260580df4ef0bed9826c5fbd805f53096bfe5 Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Wed, 29 Jan 2025 10:32:36 -0800 Subject: [PATCH 01/19] Convert test.test_discovery_and_monitoring to async --- test/asynchronous/pymongo_mocks.py | 2 +- .../test_discovery_and_monitoring.py | 466 ++++++++++++++++++ test/test_discovery_and_monitoring.py | 51 +- tools/synchro.py | 1 + 4 files changed, 500 insertions(+), 20 deletions(-) create mode 100644 test/asynchronous/test_discovery_and_monitoring.py diff --git a/test/asynchronous/pymongo_mocks.py b/test/asynchronous/pymongo_mocks.py index ed2395bc98..25316bd11f 100644 --- a/test/asynchronous/pymongo_mocks.py +++ b/test/asynchronous/pymongo_mocks.py @@ -75,7 +75,7 @@ def open(self): def request_check(self): pass - def close(self): + async def close(self): self.opened = False diff --git a/test/asynchronous/test_discovery_and_monitoring.py b/test/asynchronous/test_discovery_and_monitoring.py new file mode 100644 index 0000000000..e63fd9ec80 --- /dev/null +++ b/test/asynchronous/test_discovery_and_monitoring.py @@ -0,0 +1,466 @@ +# Copyright 2014-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test the topology module.""" +from __future__ import annotations + +import asyncio +import os +import pathlib +import socketserver +import sys +import threading + +sys.path[0:0] = [""] + +from test.asynchronous import AsyncIntegrationTest, AsyncPyMongoTestCase, unittest +from test.asynchronous.pymongo_mocks import DummyMonitor +from test.unified_format import generate_test_classes +from test.utils import ( + CMAPListener, + HeartbeatEventListener, + HeartbeatEventsListListener, + assertion_context, + async_client_context, + async_get_pool, + async_wait_until, + server_name_to_type, + wait_until, +) +from unittest.mock import patch + +from bson import Timestamp, json_util +from pymongo import AsyncMongoClient, common, monitoring +from pymongo.asynchronous.settings import TopologySettings +from pymongo.asynchronous.topology import Topology, _ErrorContext +from pymongo.errors import ( + AutoReconnect, + ConfigurationError, + NetworkTimeout, + NotPrimaryError, + OperationFailure, +) +from pymongo.hello import Hello, HelloCompat +from pymongo.helpers_shared import _check_command_response, _check_write_command_response +from pymongo.monitoring import ServerHeartbeatFailedEvent, ServerHeartbeatStartedEvent +from pymongo.server_description import SERVER_TYPE, ServerDescription +from pymongo.topology_description import TOPOLOGY_TYPE +from pymongo.uri_parser import parse_uri + +_IS_SYNC = False + +# Location of JSON test specifications. +if _IS_SYNC: + SDAM_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "discovery_and_monitoring") +else: + SDAM_PATH = os.path.join( + pathlib.Path(__file__).resolve().parent.parent, + "discovery_and_monitoring", + ) + + +async def create_mock_topology(uri, monitor_class=DummyMonitor): + parsed_uri = parse_uri(uri) + replica_set_name = None + direct_connection = None + load_balanced = None + if "replicaset" in parsed_uri["options"]: + replica_set_name = parsed_uri["options"]["replicaset"] + if "directConnection" in parsed_uri["options"]: + direct_connection = parsed_uri["options"]["directConnection"] + if "loadBalanced" in parsed_uri["options"]: + load_balanced = parsed_uri["options"]["loadBalanced"] + + topology_settings = TopologySettings( + parsed_uri["nodelist"], + replica_set_name=replica_set_name, + monitor_class=monitor_class, + direct_connection=direct_connection, + load_balanced=load_balanced, + ) + + c = Topology(topology_settings) + await c.open() + return c + + +async def got_hello(topology, server_address, hello_response): + server_description = ServerDescription(server_address, Hello(hello_response), 0) + await topology.on_change(server_description) + + +async def got_app_error(topology, app_error): + server_address = common.partition_node(app_error["address"]) + server = topology.get_server_by_address(server_address) + error_type = app_error["type"] + generation = app_error.get("generation", server.pool.gen.get_overall()) + when = app_error["when"] + max_wire_version = app_error["maxWireVersion"] + # XXX: We could get better test coverage by mocking the errors on the + # Pool/AsyncConnection. + try: + if error_type == "command": + _check_command_response(app_error["response"], max_wire_version) + _check_write_command_response(app_error["response"]) + elif error_type == "network": + raise AutoReconnect("mock non-timeout network error") + elif error_type == "timeout": + raise NetworkTimeout("mock network timeout error") + else: + raise AssertionError(f"unknown error type: {error_type}") + raise AssertionError + except (AutoReconnect, NotPrimaryError, OperationFailure) as e: + if when == "beforeHandshakeCompletes": + completed_handshake = False + elif when == "afterHandshakeCompletes": + completed_handshake = True + else: + raise AssertionError(f"Unknown when field {when}") + + await topology.handle_error( + server_address, + _ErrorContext(e, max_wire_version, generation, completed_handshake, None), + ) + + +def get_type(topology, hostname): + description = topology.get_server_by_address((hostname, 27017)).description + return description.server_type + + +class TestAllScenarios(unittest.IsolatedAsyncioTestCase): + pass + + +def topology_type_name(topology_type): + return TOPOLOGY_TYPE._fields[topology_type] + + +def server_type_name(server_type): + return SERVER_TYPE._fields[server_type] + + +def check_outcome(self, topology, outcome): + expected_servers = outcome["servers"] + + # Check weak equality before proceeding. + self.assertEqual(len(topology.description.server_descriptions()), len(expected_servers)) + + if outcome.get("compatible") is False: + with self.assertRaises(ConfigurationError): + topology.description.check_compatible() + else: + # No error. + topology.description.check_compatible() + + # Since lengths are equal, every actual server must have a corresponding + # expected server. + for expected_server_address, expected_server in expected_servers.items(): + node = common.partition_node(expected_server_address) + self.assertTrue(topology.has_server(node)) + actual_server = topology.get_server_by_address(node) + actual_server_description = actual_server.description + expected_server_type = server_name_to_type(expected_server["type"]) + + self.assertEqual( + server_type_name(expected_server_type), + server_type_name(actual_server_description.server_type), + ) + + self.assertEqual(expected_server.get("setName"), actual_server_description.replica_set_name) + + self.assertEqual(expected_server.get("setVersion"), actual_server_description.set_version) + + self.assertEqual(expected_server.get("electionId"), actual_server_description.election_id) + + self.assertEqual( + expected_server.get("topologyVersion"), actual_server_description.topology_version + ) + + expected_pool = expected_server.get("pool") + if expected_pool: + self.assertEqual(expected_pool.get("generation"), actual_server.pool.gen.get_overall()) + + self.assertEqual(outcome["setName"], topology.description.replica_set_name) + self.assertEqual( + outcome.get("logicalSessionTimeoutMinutes"), + topology.description.logical_session_timeout_minutes, + ) + + expected_topology_type = getattr(TOPOLOGY_TYPE, outcome["topologyType"]) + self.assertEqual( + topology_type_name(expected_topology_type), + topology_type_name(topology.description.topology_type), + ) + + self.assertEqual(outcome.get("maxSetVersion"), topology.description.max_set_version) + self.assertEqual(outcome.get("maxElectionId"), topology.description.max_election_id) + + +def create_test(scenario_def): + async def run_scenario(self): + c = await create_mock_topology(scenario_def["uri"]) + + for i, phase in enumerate(scenario_def["phases"]): + # Including the phase description makes failures easier to debug. + description = phase.get("description", str(i)) + with assertion_context(f"phase: {description}"): + for response in phase.get("responses", []): + await got_hello(c, common.partition_node(response[0]), response[1]) + + for app_error in phase.get("applicationErrors", []): + await got_app_error(c, app_error) + + check_outcome(self, c, phase["outcome"]) + + return run_scenario + + +def create_tests(): + for dirpath, _, filenames in os.walk(SDAM_PATH): + dirname = os.path.split(dirpath)[-1] + # SDAM unified tests are handled separately. + if dirname == "unified": + continue + + for filename in filenames: + if os.path.splitext(filename)[1] != ".json": + continue + with open(os.path.join(dirpath, filename)) as scenario_stream: + scenario_def = json_util.loads(scenario_stream.read()) + + # Construct test from scenario. + new_test = create_test(scenario_def) + test_name = f"test_{dirname}_{os.path.splitext(filename)[0]}" + + new_test.__name__ = test_name + setattr(TestAllScenarios, new_test.__name__, new_test) + + +create_tests() + + +class TestClusterTimeComparison(unittest.IsolatedAsyncioTestCase): + async def test_cluster_time_comparison(self): + t = await create_mock_topology("mongodb://host") + + async def send_cluster_time(time, inc, should_update): + old = t.max_cluster_time() + new = {"clusterTime": Timestamp(time, inc)} + await got_hello( + t, + ("host", 27017), + {"ok": 1, "minWireVersion": 0, "maxWireVersion": 6, "$clusterTime": new}, + ) + + actual = t.max_cluster_time() + if should_update: + self.assertEqual(actual, new) + else: + self.assertEqual(actual, old) + + await send_cluster_time(0, 1, True) + await send_cluster_time(2, 2, True) + await send_cluster_time(2, 1, False) + await send_cluster_time(1, 3, False) + await send_cluster_time(2, 3, True) + + +class TestIgnoreStaleErrors(AsyncIntegrationTest): + @async_client_context.require_sync + async def test_ignore_stale_connection_errors(self): + N_THREADS = 5 + barrier = threading.Barrier(N_THREADS, timeout=30) + client = await self.async_rs_or_single_client(minPoolSize=N_THREADS) + + # Wait for initial discovery. + await client.admin.command("ping") + pool = await async_get_pool(client) + starting_generation = pool.gen.get_overall() + await async_wait_until(lambda: len(pool.conns) == N_THREADS, "created conns") + + def mock_command(*args, **kwargs): + # Synchronize all threads to ensure they use the same generation. + barrier.wait() + raise AutoReconnect("mock AsyncConnection.command error") + + for conn in pool.conns: + conn.command = mock_command + + async def insert_command(i): + try: + await client.test.command("insert", "test", documents=[{"i": i}]) + except AutoReconnect: + pass + + threads = [] + for i in range(N_THREADS): + threads.append(threading.Thread(target=insert_command, args=(i,))) + for t in threads: + t.start() + for t in threads: + t.join() + + # Expect a single pool reset for the network error + self.assertEqual(starting_generation + 1, pool.gen.get_overall()) + + # Server should be selectable. + await client.admin.command("ping") + + +class CMAPHeartbeatListener(HeartbeatEventListener, CMAPListener): + pass + + +class TestPoolManagement(AsyncIntegrationTest): + @async_client_context.require_failCommand_appName + async def test_pool_unpause(self): + # This test implements the prose test "AsyncConnection Pool Management" + listener = CMAPHeartbeatListener() + _ = await self.async_single_client( + appName="SDAMPoolManagementTest", heartbeatFrequencyMS=500, event_listeners=[listener] + ) + # Assert that AsyncConnectionPoolReadyEvent occurs after the first + # ServerHeartbeatSucceededEvent. + await listener.async_wait_for_event(monitoring.PoolReadyEvent, 1) + pool_ready = listener.events_by_type(monitoring.PoolReadyEvent)[0] + hb_succeeded = listener.events_by_type(monitoring.ServerHeartbeatSucceededEvent)[0] + self.assertGreater(listener.events.index(pool_ready), listener.events.index(hb_succeeded)) + + listener.reset() + fail_hello = { + "mode": {"times": 2}, + "data": { + "failCommands": [HelloCompat.LEGACY_CMD, "hello"], + "errorCode": 1234, + "appName": "SDAMPoolManagementTest", + }, + } + async with self.fail_point(fail_hello): + await listener.async_wait_for_event(monitoring.ServerHeartbeatFailedEvent, 1) + await listener.async_wait_for_event(monitoring.PoolClearedEvent, 1) + await listener.async_wait_for_event(monitoring.ServerHeartbeatSucceededEvent, 1) + await listener.async_wait_for_event(monitoring.PoolReadyEvent, 1) + + +class TestServerMonitoringMode(AsyncIntegrationTest): + @async_client_context.require_no_serverless + @async_client_context.require_no_load_balancer + async def asyncSetUp(self): + await super().asyncSetUp() + + async def test_rtt_connection_is_enabled_stream(self): + client = await self.async_rs_or_single_client(serverMonitoringMode="stream") + await client.admin.command("ping") + + def predicate(): + for _, server in client._topology._servers.items(): + monitor = server._monitor + if not monitor._stream: + return False + if async_client_context.version >= (4, 4): + if _IS_SYNC: + if monitor._rtt_monitor._executor._thread is None: + return False + else: + if monitor._rtt_monitor._executor._task is None: + return False + else: + if _IS_SYNC: + if monitor._rtt_monitor._executor._thread is not None: + return False + else: + if monitor._rtt_monitor._executor._task is not None: + return False + return True + + await async_wait_until(predicate, "find all RTT monitors") + + async def test_rtt_connection_is_disabled_poll(self): + client = await self.async_rs_or_single_client(serverMonitoringMode="poll") + + await self.assert_rtt_connection_is_disabled(client) + + async def test_rtt_connection_is_disabled_auto(self): + envs = [ + {"AWS_EXECUTION_ENV": "AWS_Lambda_python3.9"}, + {"FUNCTIONS_WORKER_RUNTIME": "python"}, + {"K_SERVICE": "gcpservicename"}, + {"FUNCTION_NAME": "gcpfunctionname"}, + {"VERCEL": "1"}, + ] + for env in envs: + with patch.dict("os.environ", env): + client = await self.async_rs_or_single_client(serverMonitoringMode="auto") + await self.assert_rtt_connection_is_disabled(client) + + async def assert_rtt_connection_is_disabled(self, client): + await client.admin.command("ping") + for _, server in client._topology._servers.items(): + monitor = server._monitor + self.assertFalse(monitor._stream) + if _IS_SYNC: + self.assertIsNone(monitor._rtt_monitor._executor._thread) + else: + self.assertIsNone(monitor._rtt_monitor._executor._task) + + +class MockTCPHandler(socketserver.BaseRequestHandler): + def handle(self): + self.server.events.append("client connected") + if self.request.recv(1024).strip(): + self.server.events.append("client hello received") + self.request.close() + + +class TCPServer(socketserver.TCPServer): + allow_reuse_address = True + + def handle_request_and_shutdown(self): + self.handle_request() + self.server_close() + + +class TestHeartbeatStartOrdering(AsyncPyMongoTestCase): + @async_client_context.require_sync + async def test_heartbeat_start_ordering(self): + events = [] + listener = HeartbeatEventsListListener(events) + server = TCPServer(("localhost", 9999), MockTCPHandler) + server.events = events + server_thread = threading.Thread(target=server.handle_request_and_shutdown) + server_thread.start() + _c = await self.simple_client( + "mongodb://localhost:9999", serverSelectionTimeoutMS=500, event_listeners=(listener,) + ) + server_thread.join() + listener.wait_for_event(ServerHeartbeatStartedEvent, 1) + listener.wait_for_event(ServerHeartbeatFailedEvent, 1) + + self.assertEqual( + events, + [ + "serverHeartbeatStartedEvent", + "client connected", + "client hello received", + "serverHeartbeatFailedEvent", + ], + ) + + +# Generate unified tests. +globals().update(generate_test_classes(os.path.join(SDAM_PATH, "unified"), module=__name__)) + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_discovery_and_monitoring.py b/test/test_discovery_and_monitoring.py index ce7a52f1a0..114cf1b54e 100644 --- a/test/test_discovery_and_monitoring.py +++ b/test/test_discovery_and_monitoring.py @@ -15,7 +15,9 @@ """Test the topology module.""" from __future__ import annotations +import asyncio import os +import pathlib import socketserver import sys import threading @@ -55,8 +57,16 @@ from pymongo.topology_description import TOPOLOGY_TYPE from pymongo.uri_parser import parse_uri +_IS_SYNC = True + # Location of JSON test specifications. -SDAM_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "discovery_and_monitoring") +if _IS_SYNC: + SDAM_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "discovery_and_monitoring") +else: + SDAM_PATH = os.path.join( + pathlib.Path(__file__).resolve().parent.parent, + "discovery_and_monitoring", + ) def create_mock_topology(uri, monitor_class=DummyMonitor): @@ -250,12 +260,7 @@ def send_cluster_time(time, inc, should_update): got_hello( t, ("host", 27017), - { - "ok": 1, - "minWireVersion": 0, - "maxWireVersion": common.MIN_SUPPORTED_WIRE_VERSION, - "$clusterTime": new, - }, + {"ok": 1, "minWireVersion": 0, "maxWireVersion": 6, "$clusterTime": new}, ) actual = t.max_cluster_time() @@ -272,11 +277,11 @@ def send_cluster_time(time, inc, should_update): class TestIgnoreStaleErrors(IntegrationTest): + @client_context.require_sync def test_ignore_stale_connection_errors(self): N_THREADS = 5 barrier = threading.Barrier(N_THREADS, timeout=30) client = self.rs_or_single_client(minPoolSize=N_THREADS) - self.addCleanup(client.close) # Wait for initial discovery. client.admin.command("ping") @@ -322,10 +327,9 @@ class TestPoolManagement(IntegrationTest): def test_pool_unpause(self): # This test implements the prose test "Connection Pool Management" listener = CMAPHeartbeatListener() - client = self.single_client( + _ = self.single_client( appName="SDAMPoolManagementTest", heartbeatFrequencyMS=500, event_listeners=[listener] ) - self.addCleanup(client.close) # Assert that ConnectionPoolReadyEvent occurs after the first # ServerHeartbeatSucceededEvent. listener.wait_for_event(monitoring.PoolReadyEvent, 1) @@ -357,7 +361,6 @@ def setUp(self): def test_rtt_connection_is_enabled_stream(self): client = self.rs_or_single_client(serverMonitoringMode="stream") - self.addCleanup(client.close) client.admin.command("ping") def predicate(): @@ -366,18 +369,26 @@ def predicate(): if not monitor._stream: return False if client_context.version >= (4, 4): - if monitor._rtt_monitor._executor._thread is None: - return False + if _IS_SYNC: + if monitor._rtt_monitor._executor._thread is None: + return False + else: + if monitor._rtt_monitor._executor._task is None: + return False else: - if monitor._rtt_monitor._executor._thread is not None: - return False + if _IS_SYNC: + if monitor._rtt_monitor._executor._thread is not None: + return False + else: + if monitor._rtt_monitor._executor._task is not None: + return False return True wait_until(predicate, "find all RTT monitors") def test_rtt_connection_is_disabled_poll(self): client = self.rs_or_single_client(serverMonitoringMode="poll") - self.addCleanup(client.close) + self.assert_rtt_connection_is_disabled(client) def test_rtt_connection_is_disabled_auto(self): @@ -391,7 +402,6 @@ def test_rtt_connection_is_disabled_auto(self): for env in envs: with patch.dict("os.environ", env): client = self.rs_or_single_client(serverMonitoringMode="auto") - self.addCleanup(client.close) self.assert_rtt_connection_is_disabled(client) def assert_rtt_connection_is_disabled(self, client): @@ -399,7 +409,10 @@ def assert_rtt_connection_is_disabled(self, client): for _, server in client._topology._servers.items(): monitor = server._monitor self.assertFalse(monitor._stream) - self.assertIsNone(monitor._rtt_monitor._executor._thread) + if _IS_SYNC: + self.assertIsNone(monitor._rtt_monitor._executor._thread) + else: + self.assertIsNone(monitor._rtt_monitor._executor._task) class MockTCPHandler(socketserver.BaseRequestHandler): @@ -419,6 +432,7 @@ def handle_request_and_shutdown(self): class TestHeartbeatStartOrdering(PyMongoTestCase): + @client_context.require_sync def test_heartbeat_start_ordering(self): events = [] listener = HeartbeatEventsListListener(events) @@ -447,6 +461,5 @@ def test_heartbeat_start_ordering(self): # Generate unified tests. globals().update(generate_test_classes(os.path.join(SDAM_PATH, "unified"), module=__name__)) - if __name__ == "__main__": unittest.main() diff --git a/tools/synchro.py b/tools/synchro.py index dbcbbd1351..692cff7de5 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -203,6 +203,7 @@ def async_only_test(f: str) -> bool: "test_crud_unified.py", "test_cursor.py", "test_database.py", + "test_discovery_and_monitoring.py", "test_encryption.py", "test_grid_file.py", "test_logger.py", From 3b6384073ca34f0bd9dfe7632c89a28691a9870a Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Wed, 29 Jan 2025 11:27:32 -0800 Subject: [PATCH 02/19] modify path --- test/asynchronous/test_discovery_and_monitoring.py | 7 +++---- test/test_discovery_and_monitoring.py | 6 +++--- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/test/asynchronous/test_discovery_and_monitoring.py b/test/asynchronous/test_discovery_and_monitoring.py index e63fd9ec80..0fd55f909b 100644 --- a/test/asynchronous/test_discovery_and_monitoring.py +++ b/test/asynchronous/test_discovery_and_monitoring.py @@ -17,10 +17,10 @@ import asyncio import os -import pathlib import socketserver import sys import threading +from pathlib import Path sys.path[0:0] = [""] @@ -36,7 +36,6 @@ async_get_pool, async_wait_until, server_name_to_type, - wait_until, ) from unittest.mock import patch @@ -62,10 +61,10 @@ # Location of JSON test specifications. if _IS_SYNC: - SDAM_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "discovery_and_monitoring") + SDAM_PATH = os.path.join(Path(__file__).resolve().parent, "discovery_and_monitoring") else: SDAM_PATH = os.path.join( - pathlib.Path(__file__).resolve().parent.parent, + Path(__file__).resolve().parent.parent, "discovery_and_monitoring", ) diff --git a/test/test_discovery_and_monitoring.py b/test/test_discovery_and_monitoring.py index 114cf1b54e..092936efcb 100644 --- a/test/test_discovery_and_monitoring.py +++ b/test/test_discovery_and_monitoring.py @@ -17,10 +17,10 @@ import asyncio import os -import pathlib import socketserver import sys import threading +from pathlib import Path sys.path[0:0] = [""] @@ -61,10 +61,10 @@ # Location of JSON test specifications. if _IS_SYNC: - SDAM_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "discovery_and_monitoring") + SDAM_PATH = os.path.join(Path(__file__).resolve().parent, "discovery_and_monitoring") else: SDAM_PATH = os.path.join( - pathlib.Path(__file__).resolve().parent.parent, + Path(__file__).resolve().parent.parent, "discovery_and_monitoring", ) From 004f8f0e67929678afac1c1a1eb9ccb36fcaa61a Mon Sep 17 00:00:00 2001 From: Iris <58442094+sleepyStick@users.noreply.github.com> Date: Wed, 29 Jan 2025 13:48:47 -0800 Subject: [PATCH 03/19] Update test/asynchronous/test_discovery_and_monitoring.py Co-authored-by: Noah Stapp --- test/asynchronous/test_discovery_and_monitoring.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/asynchronous/test_discovery_and_monitoring.py b/test/asynchronous/test_discovery_and_monitoring.py index 0fd55f909b..edf0790c86 100644 --- a/test/asynchronous/test_discovery_and_monitoring.py +++ b/test/asynchronous/test_discovery_and_monitoring.py @@ -138,7 +138,7 @@ def get_type(topology, hostname): return description.server_type -class TestAllScenarios(unittest.IsolatedAsyncioTestCase): +class TestAllScenarios(AsyncUnitTest): pass From 70d09d3af1361c6ecc8c37b60998ce18511351ef Mon Sep 17 00:00:00 2001 From: Iris <58442094+sleepyStick@users.noreply.github.com> Date: Wed, 29 Jan 2025 13:48:55 -0800 Subject: [PATCH 04/19] Update test/asynchronous/test_discovery_and_monitoring.py Co-authored-by: Noah Stapp --- test/asynchronous/test_discovery_and_monitoring.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/asynchronous/test_discovery_and_monitoring.py b/test/asynchronous/test_discovery_and_monitoring.py index edf0790c86..e019f9586e 100644 --- a/test/asynchronous/test_discovery_and_monitoring.py +++ b/test/asynchronous/test_discovery_and_monitoring.py @@ -26,7 +26,7 @@ from test.asynchronous import AsyncIntegrationTest, AsyncPyMongoTestCase, unittest from test.asynchronous.pymongo_mocks import DummyMonitor -from test.unified_format import generate_test_classes +from test.asynchronous.unified_format import generate_test_classes from test.utils import ( CMAPListener, HeartbeatEventListener, From 568cc6409e6d4b03987550d548cf9492bf76fd48 Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Wed, 29 Jan 2025 18:32:15 -0800 Subject: [PATCH 05/19] debugging.. --- .../test_discovery_and_monitoring.py | 199 ++++++++++++------ test/test_discovery_and_monitoring.py | 199 ++++++++++++------ tools/synchro.py | 1 - 3 files changed, 268 insertions(+), 131 deletions(-) diff --git a/test/asynchronous/test_discovery_and_monitoring.py b/test/asynchronous/test_discovery_and_monitoring.py index e019f9586e..152d76d31d 100644 --- a/test/asynchronous/test_discovery_and_monitoring.py +++ b/test/asynchronous/test_discovery_and_monitoring.py @@ -20,11 +20,12 @@ import socketserver import sys import threading +from asyncio import StreamReader from pathlib import Path sys.path[0:0] = [""] -from test.asynchronous import AsyncIntegrationTest, AsyncPyMongoTestCase, unittest +from test.asynchronous import AsyncIntegrationTest, AsyncPyMongoTestCase, AsyncUnitTest, unittest from test.asynchronous.pymongo_mocks import DummyMonitor from test.asynchronous.unified_format import generate_test_classes from test.utils import ( @@ -226,7 +227,7 @@ async def run_scenario(self): return run_scenario -def create_tests(): +async def create_tests(): for dirpath, _, filenames in os.walk(SDAM_PATH): dirname = os.path.split(dirpath)[-1] # SDAM unified tests are handled separately. @@ -247,7 +248,6 @@ def create_tests(): setattr(TestAllScenarios, new_test.__name__, new_test) -create_tests() class TestClusterTimeComparison(unittest.IsolatedAsyncioTestCase): @@ -277,45 +277,82 @@ async def send_cluster_time(time, inc, should_update): class TestIgnoreStaleErrors(AsyncIntegrationTest): - @async_client_context.require_sync - async def test_ignore_stale_connection_errors(self): - N_THREADS = 5 - barrier = threading.Barrier(N_THREADS, timeout=30) - client = await self.async_rs_or_single_client(minPoolSize=N_THREADS) + if _IS_SYNC: + async def test_ignore_stale_connection_errors(self): + N_THREADS = 5 + barrier = threading.Barrier(N_THREADS, timeout=30) + client = await self.async_rs_or_single_client(minPoolSize=N_THREADS) + + # Wait for initial discovery. + await client.admin.command("ping") + pool = await async_get_pool(client) + starting_generation = pool.gen.get_overall() + await async_wait_until(lambda: len(pool.conns) == N_THREADS, "created conns") + + def mock_command(*args, **kwargs): + # Synchronize all threads to ensure they use the same generation. + barrier.wait() + raise AutoReconnect("mock AsyncConnection.command error") + + for conn in pool.conns: + conn.command = mock_command + + async def insert_command(i): + try: + await client.test.command("insert", "test", documents=[{"i": i}]) + except AutoReconnect: + pass + + threads = [] + for i in range(N_THREADS): + threads.append(threading.Thread(target=insert_command, args=(i,))) + for t in threads: + t.start() + for t in threads: + t.join() + + # Expect a single pool reset for the network error + self.assertEqual(starting_generation + 1, pool.gen.get_overall()) + + # Server should be selectable. + await client.admin.command("ping") + else: + async def test_ignore_stale_connection_errors(self): + N_TASKS = 5 + barrier = asyncio.Barrier(N_TASKS) + client = await self.async_rs_or_single_client(minPoolSize=N_TASKS) - # Wait for initial discovery. - await client.admin.command("ping") - pool = await async_get_pool(client) - starting_generation = pool.gen.get_overall() - await async_wait_until(lambda: len(pool.conns) == N_THREADS, "created conns") - - def mock_command(*args, **kwargs): - # Synchronize all threads to ensure they use the same generation. - barrier.wait() - raise AutoReconnect("mock AsyncConnection.command error") - - for conn in pool.conns: - conn.command = mock_command - - async def insert_command(i): - try: - await client.test.command("insert", "test", documents=[{"i": i}]) - except AutoReconnect: - pass - - threads = [] - for i in range(N_THREADS): - threads.append(threading.Thread(target=insert_command, args=(i,))) - for t in threads: - t.start() - for t in threads: - t.join() - - # Expect a single pool reset for the network error - self.assertEqual(starting_generation + 1, pool.gen.get_overall()) - - # Server should be selectable. - await client.admin.command("ping") + # Wait for initial discovery. + await client.admin.command("ping") + pool = await async_get_pool(client) + starting_generation = pool.gen.get_overall() + await async_wait_until(lambda: len(pool.conns) == N_TASKS, "created conns") + + async def mock_command(*args, **kwargs): + # Synchronize all threads to ensure they use the same generation. + await asyncio.wait_for(barrier.wait(), timeout=30) + raise AutoReconnect("mock AsyncConnection.command error") + + for conn in pool.conns: + conn.command = mock_command + + async def insert_command(i): + try: + await client.test.command("insert", "test", documents=[{"i": i}]) + except AutoReconnect: + pass + + tasks = [] + for i in range(N_TASKS): + tasks.append(asyncio.create_task(insert_command(i))) + for t in tasks: + await t + + # Expect a single pool reset for the network error + self.assertEqual(starting_generation + 1, pool.gen.get_overall()) + + # Server should be selectable. + await client.admin.command("ping") class CMAPHeartbeatListener(HeartbeatEventListener, CMAPListener): @@ -432,30 +469,62 @@ def handle_request_and_shutdown(self): class TestHeartbeatStartOrdering(AsyncPyMongoTestCase): - @async_client_context.require_sync - async def test_heartbeat_start_ordering(self): - events = [] - listener = HeartbeatEventsListListener(events) - server = TCPServer(("localhost", 9999), MockTCPHandler) - server.events = events - server_thread = threading.Thread(target=server.handle_request_and_shutdown) - server_thread.start() - _c = await self.simple_client( - "mongodb://localhost:9999", serverSelectionTimeoutMS=500, event_listeners=(listener,) - ) - server_thread.join() - listener.wait_for_event(ServerHeartbeatStartedEvent, 1) - listener.wait_for_event(ServerHeartbeatFailedEvent, 1) - - self.assertEqual( - events, - [ - "serverHeartbeatStartedEvent", - "client connected", - "client hello received", - "serverHeartbeatFailedEvent", - ], - ) + if _IS_SYNC: + async def test_heartbeat_start_ordering(self): + events = [] + listener = HeartbeatEventsListListener(events) + server = TCPServer(("localhost", 9999), MockTCPHandler) + server.events = events + server_thread = threading.Thread(target=server.handle_request_and_shutdown) + server_thread.start() + _c = await self.simple_client( + "mongodb://localhost:9999", serverSelectionTimeoutMS=500, event_listeners=(listener,) + ) + server_thread.join() + listener.wait_for_event(ServerHeartbeatStartedEvent, 1) + listener.wait_for_event(ServerHeartbeatFailedEvent, 1) + + self.assertEqual( + events, + [ + "serverHeartbeatStartedEvent", + "client connected", + "client hello received", + "serverHeartbeatFailedEvent", + ], + ) + else: + async def test_heartbeat_start_ordering(self): + events = [] + + async def handle_client(reader: StreamReader, writer): + server.events.append("client connected") + print("clent connected") + if (await reader.read(1024)).strip(): + server.events.append("client hello received") + print("client helllo recieved") + listener = HeartbeatEventsListListener(events) + server = await asyncio.start_server(handle_client, "localhost", 9999) + async with server: + server.events = events + _c = self.simple_client( + "mongodb://localhost:9999", serverSelectionTimeoutMS=500, event_listeners=(listener,) + ) + server.close() + server_task = asyncio.create_task(server.wait_closed()) + await server_task + await listener.async_wait_for_event(ServerHeartbeatStartedEvent, 1) + await listener.async_wait_for_event(ServerHeartbeatFailedEvent, 1) + + self.assertEqual( + events, + [ + "serverHeartbeatStartedEvent", + "client connected", + "client hello received", + "serverHeartbeatFailedEvent", + ], + ) # Generate unified tests. diff --git a/test/test_discovery_and_monitoring.py b/test/test_discovery_and_monitoring.py index 092936efcb..a4bbcbeb47 100644 --- a/test/test_discovery_and_monitoring.py +++ b/test/test_discovery_and_monitoring.py @@ -20,11 +20,12 @@ import socketserver import sys import threading +from asyncio import StreamReader from pathlib import Path sys.path[0:0] = [""] -from test import IntegrationTest, PyMongoTestCase, unittest +from test import IntegrationTest, PyMongoTestCase, UnitTest, unittest from test.pymongo_mocks import DummyMonitor from test.unified_format import generate_test_classes from test.utils import ( @@ -138,7 +139,7 @@ def get_type(topology, hostname): return description.server_type -class TestAllScenarios(unittest.TestCase): +class TestAllScenarios(UnitTest): pass @@ -247,7 +248,6 @@ def create_tests(): setattr(TestAllScenarios, new_test.__name__, new_test) -create_tests() class TestClusterTimeComparison(unittest.TestCase): @@ -277,45 +277,82 @@ def send_cluster_time(time, inc, should_update): class TestIgnoreStaleErrors(IntegrationTest): - @client_context.require_sync - def test_ignore_stale_connection_errors(self): - N_THREADS = 5 - barrier = threading.Barrier(N_THREADS, timeout=30) - client = self.rs_or_single_client(minPoolSize=N_THREADS) + if _IS_SYNC: + def test_ignore_stale_connection_errors(self): + N_THREADS = 5 + barrier = threading.Barrier(N_THREADS, timeout=30) + client = self.rs_or_single_client(minPoolSize=N_THREADS) + + # Wait for initial discovery. + client.admin.command("ping") + pool = get_pool(client) + starting_generation = pool.gen.get_overall() + wait_until(lambda: len(pool.conns) == N_THREADS, "created conns") + + def mock_command(*args, **kwargs): + # Synchronize all threads to ensure they use the same generation. + barrier.wait() + raise AutoReconnect("mock Connection.command error") + + for conn in pool.conns: + conn.command = mock_command + + def insert_command(i): + try: + client.test.command("insert", "test", documents=[{"i": i}]) + except AutoReconnect: + pass + + threads = [] + for i in range(N_THREADS): + threads.append(threading.Thread(target=insert_command, args=(i,))) + for t in threads: + t.start() + for t in threads: + t.join() + + # Expect a single pool reset for the network error + self.assertEqual(starting_generation + 1, pool.gen.get_overall()) + + # Server should be selectable. + client.admin.command("ping") + else: + def test_ignore_stale_connection_errors(self): + N_TASKS = 5 + barrier = asyncio.Barrier(N_TASKS) + client = self.rs_or_single_client(minPoolSize=N_TASKS) - # Wait for initial discovery. - client.admin.command("ping") - pool = get_pool(client) - starting_generation = pool.gen.get_overall() - wait_until(lambda: len(pool.conns) == N_THREADS, "created conns") - - def mock_command(*args, **kwargs): - # Synchronize all threads to ensure they use the same generation. - barrier.wait() - raise AutoReconnect("mock Connection.command error") - - for conn in pool.conns: - conn.command = mock_command - - def insert_command(i): - try: - client.test.command("insert", "test", documents=[{"i": i}]) - except AutoReconnect: - pass - - threads = [] - for i in range(N_THREADS): - threads.append(threading.Thread(target=insert_command, args=(i,))) - for t in threads: - t.start() - for t in threads: - t.join() - - # Expect a single pool reset for the network error - self.assertEqual(starting_generation + 1, pool.gen.get_overall()) - - # Server should be selectable. - client.admin.command("ping") + # Wait for initial discovery. + client.admin.command("ping") + pool = get_pool(client) + starting_generation = pool.gen.get_overall() + wait_until(lambda: len(pool.conns) == N_TASKS, "created conns") + + def mock_command(*args, **kwargs): + # Synchronize all threads to ensure they use the same generation. + asyncio.wait_for(barrier.wait(), timeout=30) + raise AutoReconnect("mock Connection.command error") + + for conn in pool.conns: + conn.command = mock_command + + def insert_command(i): + try: + client.test.command("insert", "test", documents=[{"i": i}]) + except AutoReconnect: + pass + + tasks = [] + for i in range(N_TASKS): + tasks.append(asyncio.create_task(insert_command(i))) + for t in tasks: + t + + # Expect a single pool reset for the network error + self.assertEqual(starting_generation + 1, pool.gen.get_overall()) + + # Server should be selectable. + client.admin.command("ping") class CMAPHeartbeatListener(HeartbeatEventListener, CMAPListener): @@ -432,30 +469,62 @@ def handle_request_and_shutdown(self): class TestHeartbeatStartOrdering(PyMongoTestCase): - @client_context.require_sync - def test_heartbeat_start_ordering(self): - events = [] - listener = HeartbeatEventsListListener(events) - server = TCPServer(("localhost", 9999), MockTCPHandler) - server.events = events - server_thread = threading.Thread(target=server.handle_request_and_shutdown) - server_thread.start() - _c = self.simple_client( - "mongodb://localhost:9999", serverSelectionTimeoutMS=500, event_listeners=(listener,) - ) - server_thread.join() - listener.wait_for_event(ServerHeartbeatStartedEvent, 1) - listener.wait_for_event(ServerHeartbeatFailedEvent, 1) - - self.assertEqual( - events, - [ - "serverHeartbeatStartedEvent", - "client connected", - "client hello received", - "serverHeartbeatFailedEvent", - ], - ) + if _IS_SYNC: + def test_heartbeat_start_ordering(self): + events = [] + listener = HeartbeatEventsListListener(events) + server = TCPServer(("localhost", 9999), MockTCPHandler) + server.events = events + server_thread = threading.Thread(target=server.handle_request_and_shutdown) + server_thread.start() + _c = self.simple_client( + "mongodb://localhost:9999", serverSelectionTimeoutMS=500, event_listeners=(listener,) + ) + server_thread.join() + listener.wait_for_event(ServerHeartbeatStartedEvent, 1) + listener.wait_for_event(ServerHeartbeatFailedEvent, 1) + + self.assertEqual( + events, + [ + "serverHeartbeatStartedEvent", + "client connected", + "client hello received", + "serverHeartbeatFailedEvent", + ], + ) + else: + def test_heartbeat_start_ordering(self): + events = [] + + def handle_client(reader: StreamReader, writer): + server.events.append("client connected") + print("clent connected") + if (reader.read(1024)).strip(): + server.events.append("client hello received") + print("client helllo recieved") + listener = HeartbeatEventsListListener(events) + server = asyncio.start_server(handle_client, "localhost", 9999) + with server: + server.events = events + _c = self.simple_client( + "mongodb://localhost:9999", serverSelectionTimeoutMS=500, event_listeners=(listener,) + ) + server.close() + server_task = asyncio.create_task(server.wait_closed()) + server_task + listener.wait_for_event(ServerHeartbeatStartedEvent, 1) + listener.wait_for_event(ServerHeartbeatFailedEvent, 1) + + self.assertEqual( + events, + [ + "serverHeartbeatStartedEvent", + "client connected", + "client hello received", + "serverHeartbeatFailedEvent", + ], + ) # Generate unified tests. diff --git a/tools/synchro.py b/tools/synchro.py index df611ee37b..79775f74e2 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -204,7 +204,6 @@ def async_only_test(f: str) -> bool: "test_cursor.py", "test_custom_types.py", "test_database.py", - "test_discovery_and_monitoring.py", "test_data_lake.py", "test_discovery_and_monitoring.py", "test_encryption.py", From 47d9ebd31ae8839f8dc29746b7c2798c9dc7f804 Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Thu, 6 Feb 2025 16:21:24 -0800 Subject: [PATCH 06/19] make test_heartbeat_start_ordering async --- .../test_discovery_and_monitoring.py | 93 +++++++++---------- test/asynchronous/unified_format.py | 8 +- test/test_discovery_and_monitoring.py | 91 +++++++++--------- 3 files changed, 95 insertions(+), 97 deletions(-) diff --git a/test/asynchronous/test_discovery_and_monitoring.py b/test/asynchronous/test_discovery_and_monitoring.py index 152d76d31d..9947d405f3 100644 --- a/test/asynchronous/test_discovery_and_monitoring.py +++ b/test/asynchronous/test_discovery_and_monitoring.py @@ -20,7 +20,7 @@ import socketserver import sys import threading -from asyncio import StreamReader +from asyncio import StreamReader, StreamWriter from pathlib import Path sys.path[0:0] = [""] @@ -227,7 +227,7 @@ async def run_scenario(self): return run_scenario -async def create_tests(): +def create_tests(): for dirpath, _, filenames in os.walk(SDAM_PATH): dirname = os.path.split(dirpath)[-1] # SDAM unified tests are handled separately. @@ -248,8 +248,6 @@ async def create_tests(): setattr(TestAllScenarios, new_test.__name__, new_test) - - class TestClusterTimeComparison(unittest.IsolatedAsyncioTestCase): async def test_cluster_time_comparison(self): t = await create_mock_topology("mongodb://host") @@ -278,6 +276,7 @@ async def send_cluster_time(time, inc, should_update): class TestIgnoreStaleErrors(AsyncIntegrationTest): if _IS_SYNC: + async def test_ignore_stale_connection_errors(self): N_THREADS = 5 barrier = threading.Barrier(N_THREADS, timeout=30) @@ -317,6 +316,7 @@ async def insert_command(i): # Server should be selectable. await client.admin.command("ping") else: + async def test_ignore_stale_connection_errors(self): N_TASKS = 5 barrier = asyncio.Barrier(N_TASKS) @@ -469,62 +469,61 @@ def handle_request_and_shutdown(self): class TestHeartbeatStartOrdering(AsyncPyMongoTestCase): - if _IS_SYNC: - async def test_heartbeat_start_ordering(self): - events = [] - listener = HeartbeatEventsListListener(events) + async def test_heartbeat_start_ordering(self): + events = [] + listener = HeartbeatEventsListListener(events) + + if _IS_SYNC: server = TCPServer(("localhost", 9999), MockTCPHandler) server.events = events server_thread = threading.Thread(target=server.handle_request_and_shutdown) server_thread.start() _c = await self.simple_client( - "mongodb://localhost:9999", serverSelectionTimeoutMS=500, event_listeners=(listener,) + "mongodb://localhost:9999", + serverSelectionTimeoutMS=500, + event_listeners=(listener,), ) server_thread.join() listener.wait_for_event(ServerHeartbeatStartedEvent, 1) listener.wait_for_event(ServerHeartbeatFailedEvent, 1) - self.assertEqual( - events, - [ - "serverHeartbeatStartedEvent", - "client connected", - "client hello received", - "serverHeartbeatFailedEvent", - ], - ) - else: - async def test_heartbeat_start_ordering(self): - events = [] + else: - async def handle_client(reader: StreamReader, writer): - server.events.append("client connected") - print("clent connected") + async def handle_client(reader: StreamReader, writer: StreamWriter): + events.append("client connected") if (await reader.read(1024)).strip(): - server.events.append("client hello received") - print("client helllo recieved") - listener = HeartbeatEventsListListener(events) + events.append("client hello received") + writer.close() + await writer.wait_closed() + server = await asyncio.start_server(handle_client, "localhost", 9999) - async with server: - server.events = events - _c = self.simple_client( - "mongodb://localhost:9999", serverSelectionTimeoutMS=500, event_listeners=(listener,) - ) - server.close() - server_task = asyncio.create_task(server.wait_closed()) - await server_task - await listener.async_wait_for_event(ServerHeartbeatStartedEvent, 1) - await listener.async_wait_for_event(ServerHeartbeatFailedEvent, 1) - - self.assertEqual( - events, - [ - "serverHeartbeatStartedEvent", - "client connected", - "client hello received", - "serverHeartbeatFailedEvent", - ], - ) + server.events = events + await server.start_serving() + print(server.is_serving()) + _c = self.simple_client( + "mongodb://localhost:9999", + serverSelectionTimeoutMS=500, + event_listeners=(listener,), + ) + if _c._options.connect: + await _c.aconnect() + + await listener.async_wait_for_event(ServerHeartbeatStartedEvent, 1) + await listener.async_wait_for_event(ServerHeartbeatFailedEvent, 1) + + server.close() + await server.wait_closed() + await _c.close() + + self.assertEqual( + events, + [ + "serverHeartbeatStartedEvent", + "client connected", + "client hello received", + "serverHeartbeatFailedEvent", + ], + ) # Generate unified tests. diff --git a/test/asynchronous/unified_format.py b/test/asynchronous/unified_format.py index 149aad9786..34ad2b3da9 100644 --- a/test/asynchronous/unified_format.py +++ b/test/asynchronous/unified_format.py @@ -1155,7 +1155,7 @@ def _testOperation_assertTopologyType(self, spec): self.assertIsInstance(description, TopologyDescription) self.assertEqual(description.topology_type_name, spec["topologyType"]) - def _testOperation_waitForPrimaryChange(self, spec: dict) -> None: + async def _testOperation_waitForPrimaryChange(self, spec: dict) -> None: """Run the waitForPrimaryChange test operation.""" client = self.entity_map[spec["client"]] old_description: TopologyDescription = self.entity_map[spec["priorTopologyDescription"]] @@ -1169,13 +1169,13 @@ def get_primary(td: TopologyDescription) -> Optional[_Address]: old_primary = get_primary(old_description) - def primary_changed() -> bool: - primary = client.primary + async def primary_changed() -> bool: + primary = await client.primary if primary is None: return False return primary != old_primary - wait_until(primary_changed, "change primary", timeout=timeout) + await async_wait_until(primary_changed, "change primary", timeout=timeout) async def _testOperation_runOnThread(self, spec): """Run the 'runOnThread' operation.""" diff --git a/test/test_discovery_and_monitoring.py b/test/test_discovery_and_monitoring.py index a4bbcbeb47..fdf1677cbe 100644 --- a/test/test_discovery_and_monitoring.py +++ b/test/test_discovery_and_monitoring.py @@ -20,7 +20,7 @@ import socketserver import sys import threading -from asyncio import StreamReader +from asyncio import StreamReader, StreamWriter from pathlib import Path sys.path[0:0] = [""] @@ -248,8 +248,6 @@ def create_tests(): setattr(TestAllScenarios, new_test.__name__, new_test) - - class TestClusterTimeComparison(unittest.TestCase): def test_cluster_time_comparison(self): t = create_mock_topology("mongodb://host") @@ -278,6 +276,7 @@ def send_cluster_time(time, inc, should_update): class TestIgnoreStaleErrors(IntegrationTest): if _IS_SYNC: + def test_ignore_stale_connection_errors(self): N_THREADS = 5 barrier = threading.Barrier(N_THREADS, timeout=30) @@ -317,6 +316,7 @@ def insert_command(i): # Server should be selectable. client.admin.command("ping") else: + def test_ignore_stale_connection_errors(self): N_TASKS = 5 barrier = asyncio.Barrier(N_TASKS) @@ -469,62 +469,61 @@ def handle_request_and_shutdown(self): class TestHeartbeatStartOrdering(PyMongoTestCase): - if _IS_SYNC: - def test_heartbeat_start_ordering(self): - events = [] - listener = HeartbeatEventsListListener(events) + def test_heartbeat_start_ordering(self): + events = [] + listener = HeartbeatEventsListListener(events) + + if _IS_SYNC: server = TCPServer(("localhost", 9999), MockTCPHandler) server.events = events server_thread = threading.Thread(target=server.handle_request_and_shutdown) server_thread.start() _c = self.simple_client( - "mongodb://localhost:9999", serverSelectionTimeoutMS=500, event_listeners=(listener,) + "mongodb://localhost:9999", + serverSelectionTimeoutMS=500, + event_listeners=(listener,), ) server_thread.join() listener.wait_for_event(ServerHeartbeatStartedEvent, 1) listener.wait_for_event(ServerHeartbeatFailedEvent, 1) - self.assertEqual( - events, - [ - "serverHeartbeatStartedEvent", - "client connected", - "client hello received", - "serverHeartbeatFailedEvent", - ], - ) - else: - def test_heartbeat_start_ordering(self): - events = [] + else: - def handle_client(reader: StreamReader, writer): - server.events.append("client connected") - print("clent connected") + def handle_client(reader: StreamReader, writer: StreamWriter): + events.append("client connected") if (reader.read(1024)).strip(): - server.events.append("client hello received") - print("client helllo recieved") - listener = HeartbeatEventsListListener(events) + events.append("client hello received") + writer.close() + writer.wait_closed() + server = asyncio.start_server(handle_client, "localhost", 9999) - with server: - server.events = events - _c = self.simple_client( - "mongodb://localhost:9999", serverSelectionTimeoutMS=500, event_listeners=(listener,) - ) - server.close() - server_task = asyncio.create_task(server.wait_closed()) - server_task - listener.wait_for_event(ServerHeartbeatStartedEvent, 1) - listener.wait_for_event(ServerHeartbeatFailedEvent, 1) - - self.assertEqual( - events, - [ - "serverHeartbeatStartedEvent", - "client connected", - "client hello received", - "serverHeartbeatFailedEvent", - ], - ) + server.events = events + server.start_serving() + print(server.is_serving()) + _c = self.simple_client( + "mongodb://localhost:9999", + serverSelectionTimeoutMS=500, + event_listeners=(listener,), + ) + if _c._options.connect: + _c._connect() + + listener.wait_for_event(ServerHeartbeatStartedEvent, 1) + listener.wait_for_event(ServerHeartbeatFailedEvent, 1) + + server.close() + server.wait_closed() + _c.close() + + self.assertEqual( + events, + [ + "serverHeartbeatStartedEvent", + "client connected", + "client hello received", + "serverHeartbeatFailedEvent", + ], + ) # Generate unified tests. From 24d30af0654e5f571796c2303fad27206c20dcf7 Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Tue, 11 Feb 2025 08:56:31 -0800 Subject: [PATCH 07/19] address review --- test/asynchronous/helpers.py | 16 +++ .../test_discovery_and_monitoring.py | 119 ++++++------------ test/helpers.py | 16 +++ test/test_discovery_and_monitoring.py | 119 ++++++------------ tools/synchro.py | 2 + 5 files changed, 112 insertions(+), 160 deletions(-) diff --git a/test/asynchronous/helpers.py b/test/asynchronous/helpers.py index a35c71b107..66b5fe9cc7 100644 --- a/test/asynchronous/helpers.py +++ b/test/asynchronous/helpers.py @@ -407,3 +407,19 @@ async def run(self): await self.target(*self.args) finally: self.stopped = True + + +def create_barrier(N_TASKS, timeout: float | None = None): + return threading.Barrier(N_TASKS, timeout) + + +def async_create_barrier(N_TASKS, timeout: float | None = None): + return asyncio.Barrier(N_TASKS) + + +def barrier_wait(barrier, timeout: float | None = None): + barrier.wait() + + +async def async_barrier_wait(barrier, timeout: float | None = None): + await asyncio.wait_for(barrier.wait(), timeout) diff --git a/test/asynchronous/test_discovery_and_monitoring.py b/test/asynchronous/test_discovery_and_monitoring.py index 9947d405f3..64ac400af8 100644 --- a/test/asynchronous/test_discovery_and_monitoring.py +++ b/test/asynchronous/test_discovery_and_monitoring.py @@ -22,6 +22,7 @@ import threading from asyncio import StreamReader, StreamWriter from pathlib import Path +from test.asynchronous.helpers import ConcurrentRunner, async_barrier_wait, async_create_barrier sys.path[0:0] = [""] @@ -275,84 +276,44 @@ async def send_cluster_time(time, inc, should_update): class TestIgnoreStaleErrors(AsyncIntegrationTest): - if _IS_SYNC: - - async def test_ignore_stale_connection_errors(self): - N_THREADS = 5 - barrier = threading.Barrier(N_THREADS, timeout=30) - client = await self.async_rs_or_single_client(minPoolSize=N_THREADS) - - # Wait for initial discovery. - await client.admin.command("ping") - pool = await async_get_pool(client) - starting_generation = pool.gen.get_overall() - await async_wait_until(lambda: len(pool.conns) == N_THREADS, "created conns") - - def mock_command(*args, **kwargs): - # Synchronize all threads to ensure they use the same generation. - barrier.wait() - raise AutoReconnect("mock AsyncConnection.command error") - - for conn in pool.conns: - conn.command = mock_command - - async def insert_command(i): - try: - await client.test.command("insert", "test", documents=[{"i": i}]) - except AutoReconnect: - pass - - threads = [] - for i in range(N_THREADS): - threads.append(threading.Thread(target=insert_command, args=(i,))) - for t in threads: - t.start() - for t in threads: - t.join() - - # Expect a single pool reset for the network error - self.assertEqual(starting_generation + 1, pool.gen.get_overall()) - - # Server should be selectable. - await client.admin.command("ping") - else: - - async def test_ignore_stale_connection_errors(self): - N_TASKS = 5 - barrier = asyncio.Barrier(N_TASKS) - client = await self.async_rs_or_single_client(minPoolSize=N_TASKS) - - # Wait for initial discovery. - await client.admin.command("ping") - pool = await async_get_pool(client) - starting_generation = pool.gen.get_overall() - await async_wait_until(lambda: len(pool.conns) == N_TASKS, "created conns") - - async def mock_command(*args, **kwargs): - # Synchronize all threads to ensure they use the same generation. - await asyncio.wait_for(barrier.wait(), timeout=30) - raise AutoReconnect("mock AsyncConnection.command error") - - for conn in pool.conns: - conn.command = mock_command + async def test_ignore_stale_connection_errors(self): + N_TASKS = 5 + barrier = async_create_barrier(N_TASKS, timeout=30) + client = await self.async_rs_or_single_client(minPoolSize=N_TASKS) - async def insert_command(i): - try: - await client.test.command("insert", "test", documents=[{"i": i}]) - except AutoReconnect: - pass - - tasks = [] - for i in range(N_TASKS): - tasks.append(asyncio.create_task(insert_command(i))) - for t in tasks: - await t - - # Expect a single pool reset for the network error - self.assertEqual(starting_generation + 1, pool.gen.get_overall()) - - # Server should be selectable. - await client.admin.command("ping") + # Wait for initial discovery. + await client.admin.command("ping") + pool = await async_get_pool(client) + starting_generation = pool.gen.get_overall() + await async_wait_until(lambda: len(pool.conns) == N_TASKS, "created conns") + + async def mock_command(*args, **kwargs): + # Synchronize all threads to ensure they use the same generation. + await async_barrier_wait(barrier, timeout=30) + raise AutoReconnect("mock AsyncConnection.command error") + + for conn in pool.conns: + conn.command = mock_command + + async def insert_command(i): + try: + await client.test.command("insert", "test", documents=[{"i": i}]) + except AutoReconnect: + pass + + tasks = [] + for i in range(N_TASKS): + tasks.append(ConcurrentRunner(target=insert_command, args=(i,))) + for t in tasks: + await t.start() + for t in tasks: + await t.join() + + # Expect a single pool reset for the network error + self.assertEqual(starting_generation + 1, pool.gen.get_overall()) + + # Server should be selectable. + await client.admin.command("ping") class CMAPHeartbeatListener(HeartbeatEventListener, CMAPListener): @@ -499,14 +460,12 @@ async def handle_client(reader: StreamReader, writer: StreamWriter): server = await asyncio.start_server(handle_client, "localhost", 9999) server.events = events await server.start_serving() - print(server.is_serving()) _c = self.simple_client( "mongodb://localhost:9999", serverSelectionTimeoutMS=500, event_listeners=(listener,), ) - if _c._options.connect: - await _c.aconnect() + await _c.aconnect() await listener.async_wait_for_event(ServerHeartbeatStartedEvent, 1) await listener.async_wait_for_event(ServerHeartbeatFailedEvent, 1) diff --git a/test/helpers.py b/test/helpers.py index 705843efcd..fb4e58a0c6 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -407,3 +407,19 @@ def run(self): self.target(*self.args) finally: self.stopped = True + + +def create_barrier(N_TASKS, timeout: float | None = None): + return threading.Barrier(N_TASKS, timeout) + + +def create_barrier(N_TASKS, timeout: float | None = None): + return asyncio.Barrier(N_TASKS) + + +def barrier_wait(barrier, timeout: float | None = None): + barrier.wait() + + +def barrier_wait(barrier, timeout: float | None = None): + asyncio.wait_for(barrier.wait(), timeout) diff --git a/test/test_discovery_and_monitoring.py b/test/test_discovery_and_monitoring.py index fdf1677cbe..1f3c48d6cd 100644 --- a/test/test_discovery_and_monitoring.py +++ b/test/test_discovery_and_monitoring.py @@ -22,6 +22,7 @@ import threading from asyncio import StreamReader, StreamWriter from pathlib import Path +from test.helpers import ConcurrentRunner, barrier_wait, create_barrier sys.path[0:0] = [""] @@ -275,84 +276,44 @@ def send_cluster_time(time, inc, should_update): class TestIgnoreStaleErrors(IntegrationTest): - if _IS_SYNC: - - def test_ignore_stale_connection_errors(self): - N_THREADS = 5 - barrier = threading.Barrier(N_THREADS, timeout=30) - client = self.rs_or_single_client(minPoolSize=N_THREADS) - - # Wait for initial discovery. - client.admin.command("ping") - pool = get_pool(client) - starting_generation = pool.gen.get_overall() - wait_until(lambda: len(pool.conns) == N_THREADS, "created conns") - - def mock_command(*args, **kwargs): - # Synchronize all threads to ensure they use the same generation. - barrier.wait() - raise AutoReconnect("mock Connection.command error") - - for conn in pool.conns: - conn.command = mock_command - - def insert_command(i): - try: - client.test.command("insert", "test", documents=[{"i": i}]) - except AutoReconnect: - pass - - threads = [] - for i in range(N_THREADS): - threads.append(threading.Thread(target=insert_command, args=(i,))) - for t in threads: - t.start() - for t in threads: - t.join() - - # Expect a single pool reset for the network error - self.assertEqual(starting_generation + 1, pool.gen.get_overall()) - - # Server should be selectable. - client.admin.command("ping") - else: - - def test_ignore_stale_connection_errors(self): - N_TASKS = 5 - barrier = asyncio.Barrier(N_TASKS) - client = self.rs_or_single_client(minPoolSize=N_TASKS) - - # Wait for initial discovery. - client.admin.command("ping") - pool = get_pool(client) - starting_generation = pool.gen.get_overall() - wait_until(lambda: len(pool.conns) == N_TASKS, "created conns") - - def mock_command(*args, **kwargs): - # Synchronize all threads to ensure they use the same generation. - asyncio.wait_for(barrier.wait(), timeout=30) - raise AutoReconnect("mock Connection.command error") - - for conn in pool.conns: - conn.command = mock_command + def test_ignore_stale_connection_errors(self): + N_TASKS = 5 + barrier = create_barrier(N_TASKS, timeout=30) + client = self.rs_or_single_client(minPoolSize=N_TASKS) - def insert_command(i): - try: - client.test.command("insert", "test", documents=[{"i": i}]) - except AutoReconnect: - pass - - tasks = [] - for i in range(N_TASKS): - tasks.append(asyncio.create_task(insert_command(i))) - for t in tasks: - t - - # Expect a single pool reset for the network error - self.assertEqual(starting_generation + 1, pool.gen.get_overall()) - - # Server should be selectable. - client.admin.command("ping") + # Wait for initial discovery. + client.admin.command("ping") + pool = get_pool(client) + starting_generation = pool.gen.get_overall() + wait_until(lambda: len(pool.conns) == N_TASKS, "created conns") + + def mock_command(*args, **kwargs): + # Synchronize all threads to ensure they use the same generation. + barrier_wait(barrier, timeout=30) + raise AutoReconnect("mock Connection.command error") + + for conn in pool.conns: + conn.command = mock_command + + def insert_command(i): + try: + client.test.command("insert", "test", documents=[{"i": i}]) + except AutoReconnect: + pass + + tasks = [] + for i in range(N_TASKS): + tasks.append(ConcurrentRunner(target=insert_command, args=(i,))) + for t in tasks: + t.start() + for t in tasks: + t.join() + + # Expect a single pool reset for the network error + self.assertEqual(starting_generation + 1, pool.gen.get_overall()) + + # Server should be selectable. + client.admin.command("ping") class CMAPHeartbeatListener(HeartbeatEventListener, CMAPListener): @@ -499,14 +460,12 @@ def handle_client(reader: StreamReader, writer: StreamWriter): server = asyncio.start_server(handle_client, "localhost", 9999) server.events = events server.start_serving() - print(server.is_serving()) _c = self.simple_client( "mongodb://localhost:9999", serverSelectionTimeoutMS=500, event_listeners=(listener,), ) - if _c._options.connect: - _c._connect() + _c._connect() listener.wait_for_event(ServerHeartbeatStartedEvent, 1) listener.wait_for_event(ServerHeartbeatFailedEvent, 1) diff --git a/tools/synchro.py b/tools/synchro.py index abf60f914e..9d22f6e2c2 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -122,6 +122,8 @@ "SpecRunnerTask": "SpecRunnerThread", "AsyncMockConnection": "MockConnection", "AsyncMockPool": "MockPool", + "async_create_barrier": "create_barrier", + "async_barrier_wait": "barrier_wait", } docstring_replacements: dict[tuple[str, str], str] = { From ffb07b39dd89fcad62f407867f10efbee8cf65c1 Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Tue, 11 Feb 2025 09:10:56 -0800 Subject: [PATCH 08/19] fix typing and lint --- test/asynchronous/helpers.py | 7 +++---- test/helpers.py | 5 ++--- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/test/asynchronous/helpers.py b/test/asynchronous/helpers.py index 9a73051bfb..06cd3f5c99 100644 --- a/test/asynchronous/helpers.py +++ b/test/asynchronous/helpers.py @@ -408,7 +408,7 @@ async def run(self): finally: self.stopped = True - + class ExceptionCatchingTask(ConcurrentRunner): """A Task that stores any exception encountered while running.""" @@ -425,7 +425,7 @@ async def run(self): def create_barrier(N_TASKS, timeout: float | None = None): - return threading.Barrier(N_TASKS, timeout) + return threading.Barrier(N_TASKS, timeout=timeout) def async_create_barrier(N_TASKS, timeout: float | None = None): @@ -437,5 +437,4 @@ def barrier_wait(barrier, timeout: float | None = None): async def async_barrier_wait(barrier, timeout: float | None = None): - await asyncio.wait_for(barrier.wait(), timeout) - + await asyncio.wait_for(barrier.wait(), timeout=timeout) diff --git a/test/helpers.py b/test/helpers.py index 7211a31e1d..1c5b8996c1 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -425,7 +425,7 @@ def run(self): def create_barrier(N_TASKS, timeout: float | None = None): - return threading.Barrier(N_TASKS, timeout) + return threading.Barrier(N_TASKS, timeout=timeout) def create_barrier(N_TASKS, timeout: float | None = None): @@ -437,5 +437,4 @@ def barrier_wait(barrier, timeout: float | None = None): def barrier_wait(barrier, timeout: float | None = None): - asyncio.wait_for(barrier.wait(), timeout) - + asyncio.wait_for(barrier.wait(), timeout=timeout) From 1c1ebef380abd8be3e18c4a28cf24e8dfa15bc4b Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Tue, 11 Feb 2025 09:49:35 -0800 Subject: [PATCH 09/19] fix typing --- test/asynchronous/helpers.py | 16 ---------------- test/helpers.py | 16 ---------------- test/utils.py | 16 ++++++++++++++++ 3 files changed, 16 insertions(+), 32 deletions(-) diff --git a/test/asynchronous/helpers.py b/test/asynchronous/helpers.py index 06cd3f5c99..28260d0a52 100644 --- a/test/asynchronous/helpers.py +++ b/test/asynchronous/helpers.py @@ -422,19 +422,3 @@ async def run(self): except BaseException as exc: self.exc = exc raise - - -def create_barrier(N_TASKS, timeout: float | None = None): - return threading.Barrier(N_TASKS, timeout=timeout) - - -def async_create_barrier(N_TASKS, timeout: float | None = None): - return asyncio.Barrier(N_TASKS) - - -def barrier_wait(barrier, timeout: float | None = None): - barrier.wait() - - -async def async_barrier_wait(barrier, timeout: float | None = None): - await asyncio.wait_for(barrier.wait(), timeout=timeout) diff --git a/test/helpers.py b/test/helpers.py index 1c5b8996c1..3f51fde08c 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -422,19 +422,3 @@ def run(self): except BaseException as exc: self.exc = exc raise - - -def create_barrier(N_TASKS, timeout: float | None = None): - return threading.Barrier(N_TASKS, timeout=timeout) - - -def create_barrier(N_TASKS, timeout: float | None = None): - return asyncio.Barrier(N_TASKS) - - -def barrier_wait(barrier, timeout: float | None = None): - barrier.wait() - - -def barrier_wait(barrier, timeout: float | None = None): - asyncio.wait_for(barrier.wait(), timeout=timeout) diff --git a/test/utils.py b/test/utils.py index 5c1e0bfb7c..0fe6d444dd 100644 --- a/test/utils.py +++ b/test/utils.py @@ -1073,3 +1073,19 @@ def create_async_event(): def create_event(): return threading.Event() + + +def async_create_barrier(N_TASKS, timeout: float | None = None): + return asyncio.Barrier(N_TASKS) + + +def create_barrier(N_TASKS, timeout: float | None = None): + return threading.Barrier(N_TASKS, timeout=timeout) + + +async def async_barrier_wait(barrier, timeout: float | None = None): + await asyncio.wait_for(barrier.wait(), timeout=timeout) + + +def barrier_wait(barrier, timeout: float | None = None): + barrier.wait() From d00eb4536240f36f20590d334705fd26ad31e442 Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Tue, 11 Feb 2025 09:54:05 -0800 Subject: [PATCH 10/19] test_ignore_stale_connection_errors should only run on python3.11 and newer for async --- test/asynchronous/test_discovery_and_monitoring.py | 2 ++ test/test_discovery_and_monitoring.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/test/asynchronous/test_discovery_and_monitoring.py b/test/asynchronous/test_discovery_and_monitoring.py index 64ac400af8..8be92802fd 100644 --- a/test/asynchronous/test_discovery_and_monitoring.py +++ b/test/asynchronous/test_discovery_and_monitoring.py @@ -277,6 +277,8 @@ async def send_cluster_time(time, inc, should_update): class TestIgnoreStaleErrors(AsyncIntegrationTest): async def test_ignore_stale_connection_errors(self): + if not _IS_SYNC and sys.version_info < (3, 11): + self.skipTest("Test requires asyncio.Barrier (added in Python 3.11)") N_TASKS = 5 barrier = async_create_barrier(N_TASKS, timeout=30) client = await self.async_rs_or_single_client(minPoolSize=N_TASKS) diff --git a/test/test_discovery_and_monitoring.py b/test/test_discovery_and_monitoring.py index 1f3c48d6cd..4e20af1b3c 100644 --- a/test/test_discovery_and_monitoring.py +++ b/test/test_discovery_and_monitoring.py @@ -277,6 +277,8 @@ def send_cluster_time(time, inc, should_update): class TestIgnoreStaleErrors(IntegrationTest): def test_ignore_stale_connection_errors(self): + if not _IS_SYNC and sys.version_info < (3, 11): + self.skipTest("Test requires asyncio.Barrier (added in Python 3.11)") N_TASKS = 5 barrier = create_barrier(N_TASKS, timeout=30) client = self.rs_or_single_client(minPoolSize=N_TASKS) From 8d0921b03c0405c85a5d9643a5b16705fdffbe49 Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Tue, 11 Feb 2025 12:02:34 -0800 Subject: [PATCH 11/19] fix imports --- test/asynchronous/test_discovery_and_monitoring.py | 4 +++- test/test_discovery_and_monitoring.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/test/asynchronous/test_discovery_and_monitoring.py b/test/asynchronous/test_discovery_and_monitoring.py index 8be92802fd..72efa31763 100644 --- a/test/asynchronous/test_discovery_and_monitoring.py +++ b/test/asynchronous/test_discovery_and_monitoring.py @@ -22,7 +22,7 @@ import threading from asyncio import StreamReader, StreamWriter from pathlib import Path -from test.asynchronous.helpers import ConcurrentRunner, async_barrier_wait, async_create_barrier +from test.asynchronous.helpers import ConcurrentRunner sys.path[0:0] = [""] @@ -34,7 +34,9 @@ HeartbeatEventListener, HeartbeatEventsListListener, assertion_context, + async_barrier_wait, async_client_context, + async_create_barrier, async_get_pool, async_wait_until, server_name_to_type, diff --git a/test/test_discovery_and_monitoring.py b/test/test_discovery_and_monitoring.py index 4e20af1b3c..f3319ae763 100644 --- a/test/test_discovery_and_monitoring.py +++ b/test/test_discovery_and_monitoring.py @@ -22,7 +22,7 @@ import threading from asyncio import StreamReader, StreamWriter from pathlib import Path -from test.helpers import ConcurrentRunner, barrier_wait, create_barrier +from test.helpers import ConcurrentRunner sys.path[0:0] = [""] @@ -34,7 +34,9 @@ HeartbeatEventListener, HeartbeatEventsListListener, assertion_context, + barrier_wait, client_context, + create_barrier, get_pool, server_name_to_type, wait_until, From a5c893024303e1aa3e53fa9ff4a25cdbb789e28e Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Tue, 11 Feb 2025 13:55:39 -0800 Subject: [PATCH 12/19] undo changes to hello response --- test/asynchronous/test_discovery_and_monitoring.py | 7 ++++++- test/test_discovery_and_monitoring.py | 7 ++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/test/asynchronous/test_discovery_and_monitoring.py b/test/asynchronous/test_discovery_and_monitoring.py index 72efa31763..fc1465489d 100644 --- a/test/asynchronous/test_discovery_and_monitoring.py +++ b/test/asynchronous/test_discovery_and_monitoring.py @@ -261,7 +261,12 @@ async def send_cluster_time(time, inc, should_update): await got_hello( t, ("host", 27017), - {"ok": 1, "minWireVersion": 0, "maxWireVersion": 6, "$clusterTime": new}, + { + "ok": 1, + "minWireVersion": 0, + "maxWireVersion": common.MIN_SUPPORTED_WIRE_VERSION, + "$clusterTime": new, + }, ) actual = t.max_cluster_time() diff --git a/test/test_discovery_and_monitoring.py b/test/test_discovery_and_monitoring.py index f3319ae763..be5508229c 100644 --- a/test/test_discovery_and_monitoring.py +++ b/test/test_discovery_and_monitoring.py @@ -261,7 +261,12 @@ def send_cluster_time(time, inc, should_update): got_hello( t, ("host", 27017), - {"ok": 1, "minWireVersion": 0, "maxWireVersion": 6, "$clusterTime": new}, + { + "ok": 1, + "minWireVersion": 0, + "maxWireVersion": common.MIN_SUPPORTED_WIRE_VERSION, + "$clusterTime": new, + }, ) actual = t.max_cluster_time() From 17d0e4c40454f046f7c9c488fa7ed994b2d1e95e Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Wed, 12 Feb 2025 11:49:18 -0800 Subject: [PATCH 13/19] address review -ish --- test/asynchronous/test_discovery_and_monitoring.py | 10 +++++----- test/test_discovery_and_monitoring.py | 6 +++--- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/test/asynchronous/test_discovery_and_monitoring.py b/test/asynchronous/test_discovery_and_monitoring.py index fc1465489d..fcc20a33aa 100644 --- a/test/asynchronous/test_discovery_and_monitoring.py +++ b/test/asynchronous/test_discovery_and_monitoring.py @@ -251,7 +251,7 @@ def create_tests(): setattr(TestAllScenarios, new_test.__name__, new_test) -class TestClusterTimeComparison(unittest.IsolatedAsyncioTestCase): +class TestClusterTimeComparison(AsyncPyMongoTestCase): async def test_cluster_time_comparison(self): t = await create_mock_topology("mongodb://host") @@ -297,7 +297,7 @@ async def test_ignore_stale_connection_errors(self): await async_wait_until(lambda: len(pool.conns) == N_TASKS, "created conns") async def mock_command(*args, **kwargs): - # Synchronize all threads to ensure they use the same generation. + # Synchronize all tasks to ensure they use the same generation. await async_barrier_wait(barrier, timeout=30) raise AutoReconnect("mock AsyncConnection.command error") @@ -446,14 +446,14 @@ async def test_heartbeat_start_ordering(self): if _IS_SYNC: server = TCPServer(("localhost", 9999), MockTCPHandler) server.events = events - server_thread = threading.Thread(target=server.handle_request_and_shutdown) - server_thread.start() + server_thread = ConcurrentRunner(target=server.handle_request_and_shutdown) + await server_thread.start() _c = await self.simple_client( "mongodb://localhost:9999", serverSelectionTimeoutMS=500, event_listeners=(listener,), ) - server_thread.join() + await server_thread.join() listener.wait_for_event(ServerHeartbeatStartedEvent, 1) listener.wait_for_event(ServerHeartbeatFailedEvent, 1) diff --git a/test/test_discovery_and_monitoring.py b/test/test_discovery_and_monitoring.py index be5508229c..dbdb3ef084 100644 --- a/test/test_discovery_and_monitoring.py +++ b/test/test_discovery_and_monitoring.py @@ -251,7 +251,7 @@ def create_tests(): setattr(TestAllScenarios, new_test.__name__, new_test) -class TestClusterTimeComparison(unittest.TestCase): +class TestClusterTimeComparison(PyMongoTestCase): def test_cluster_time_comparison(self): t = create_mock_topology("mongodb://host") @@ -297,7 +297,7 @@ def test_ignore_stale_connection_errors(self): wait_until(lambda: len(pool.conns) == N_TASKS, "created conns") def mock_command(*args, **kwargs): - # Synchronize all threads to ensure they use the same generation. + # Synchronize all tasks to ensure they use the same generation. barrier_wait(barrier, timeout=30) raise AutoReconnect("mock Connection.command error") @@ -446,7 +446,7 @@ def test_heartbeat_start_ordering(self): if _IS_SYNC: server = TCPServer(("localhost", 9999), MockTCPHandler) server.events = events - server_thread = threading.Thread(target=server.handle_request_and_shutdown) + server_thread = ConcurrentRunner(target=server.handle_request_and_shutdown) server_thread.start() _c = self.simple_client( "mongodb://localhost:9999", From c4c9e00b2476bee7ff9d6ee5625f1633f862288a Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Tue, 18 Feb 2025 11:12:09 -0800 Subject: [PATCH 14/19] undo accidental changes --- test/asynchronous/test_discovery_and_monitoring.py | 4 ++++ test/test_discovery_and_monitoring.py | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/test/asynchronous/test_discovery_and_monitoring.py b/test/asynchronous/test_discovery_and_monitoring.py index fcc20a33aa..b58f2a7474 100644 --- a/test/asynchronous/test_discovery_and_monitoring.py +++ b/test/asynchronous/test_discovery_and_monitoring.py @@ -251,6 +251,9 @@ def create_tests(): setattr(TestAllScenarios, new_test.__name__, new_test) +create_tests() + + class TestClusterTimeComparison(AsyncPyMongoTestCase): async def test_cluster_time_comparison(self): t = await create_mock_topology("mongodb://host") @@ -497,5 +500,6 @@ async def handle_client(reader: StreamReader, writer: StreamWriter): # Generate unified tests. globals().update(generate_test_classes(os.path.join(SDAM_PATH, "unified"), module=__name__)) + if __name__ == "__main__": unittest.main() diff --git a/test/test_discovery_and_monitoring.py b/test/test_discovery_and_monitoring.py index dbdb3ef084..50f1985b80 100644 --- a/test/test_discovery_and_monitoring.py +++ b/test/test_discovery_and_monitoring.py @@ -251,6 +251,9 @@ def create_tests(): setattr(TestAllScenarios, new_test.__name__, new_test) +create_tests() + + class TestClusterTimeComparison(PyMongoTestCase): def test_cluster_time_comparison(self): t = create_mock_topology("mongodb://host") @@ -497,5 +500,6 @@ def handle_client(reader: StreamReader, writer: StreamWriter): # Generate unified tests. globals().update(generate_test_classes(os.path.join(SDAM_PATH, "unified"), module=__name__)) + if __name__ == "__main__": unittest.main() From a3465e9aa4d69b4204f8168a692454fe05ef4686 Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Mon, 24 Feb 2025 09:00:52 -0800 Subject: [PATCH 15/19] retrying the tests because they're flakey --- test/asynchronous/unified_format.py | 8 ++++++-- test/unified_format.py | 8 ++++++-- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/test/asynchronous/unified_format.py b/test/asynchronous/unified_format.py index 27d1dffef3..b961eb424b 100644 --- a/test/asynchronous/unified_format.py +++ b/test/asynchronous/unified_format.py @@ -87,6 +87,7 @@ NotPrimaryError, OperationFailure, PyMongoError, + _OperationCancelled, ) from pymongo.monitoring import ( CommandStartedEvent, @@ -1388,17 +1389,20 @@ async def run_scenario(self, spec, uri=None): # operations during test set up and tear down. await self.kill_all_sessions() - if "csot" in self.id().lower(): + if "csot" in self.id().lower() or "discovery_and_monitoring" in self.id().lower(): # Retry CSOT tests up to 2 times to deal with flakey tests. + # discovery_and_monitoring tests on windows are also flakey attempts = 3 for i in range(attempts): try: return await self._run_scenario(spec, uri) - except (AssertionError, OperationFailure) as exc: + except (AssertionError, OperationFailure, _OperationCancelled) as exc: if isinstance(exc, OperationFailure) and ( _IS_SYNC or "failpoint" not in exc._message ): raise + if isinstance(exc, _OperationCancelled) and _IS_SYNC: + raise if i < attempts - 1: print( f"Retrying after attempt {i+1} of {self.id()} failed with:\n" diff --git a/test/unified_format.py b/test/unified_format.py index 73dee10ddf..acf26d84fa 100644 --- a/test/unified_format.py +++ b/test/unified_format.py @@ -79,6 +79,7 @@ NotPrimaryError, OperationFailure, PyMongoError, + _OperationCancelled, ) from pymongo.monitoring import ( CommandStartedEvent, @@ -1375,17 +1376,20 @@ def run_scenario(self, spec, uri=None): # operations during test set up and tear down. self.kill_all_sessions() - if "csot" in self.id().lower(): + if "csot" in self.id().lower() or "discovery_and_monitoring" in self.id().lower(): # Retry CSOT tests up to 2 times to deal with flakey tests. + # discovery_and_monitoring tests on windows are also flakey attempts = 3 for i in range(attempts): try: return self._run_scenario(spec, uri) - except (AssertionError, OperationFailure) as exc: + except (AssertionError, OperationFailure, _OperationCancelled) as exc: if isinstance(exc, OperationFailure) and ( _IS_SYNC or "failpoint" not in exc._message ): raise + if isinstance(exc, _OperationCancelled) and _IS_SYNC: + raise if i < attempts - 1: print( f"Retrying after attempt {i+1} of {self.id()} failed with:\n" From c4beba1f14771531fc72fd2c3ea4a4a7bb85743e Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Thu, 27 Feb 2025 08:17:52 -0800 Subject: [PATCH 16/19] update discovery and monitoring based on #1925 --- .../test_discovery_and_monitoring.py | 20 +++++++++---------- test/asynchronous/unified_format.py | 6 ++---- test/unified_format.py | 6 ++---- 3 files changed, 13 insertions(+), 19 deletions(-) diff --git a/test/asynchronous/test_discovery_and_monitoring.py b/test/asynchronous/test_discovery_and_monitoring.py index b58f2a7474..7c4095ebb8 100644 --- a/test/asynchronous/test_discovery_and_monitoring.py +++ b/test/asynchronous/test_discovery_and_monitoring.py @@ -258,7 +258,7 @@ class TestClusterTimeComparison(AsyncPyMongoTestCase): async def test_cluster_time_comparison(self): t = await create_mock_topology("mongodb://host") - async def send_cluster_time(time, inc, should_update): + async def send_cluster_time(time, inc): old = t.max_cluster_time() new = {"clusterTime": Timestamp(time, inc)} await got_hello( @@ -273,16 +273,14 @@ async def send_cluster_time(time, inc, should_update): ) actual = t.max_cluster_time() - if should_update: - self.assertEqual(actual, new) - else: - self.assertEqual(actual, old) - - await send_cluster_time(0, 1, True) - await send_cluster_time(2, 2, True) - await send_cluster_time(2, 1, False) - await send_cluster_time(1, 3, False) - await send_cluster_time(2, 3, True) + # We never update $clusterTime from monitoring connections. + self.assertEqual(actual, old) + + await send_cluster_time(0, 1) + await send_cluster_time(2, 2) + await send_cluster_time(2, 1) + await send_cluster_time(1, 3) + await send_cluster_time(2, 3) class TestIgnoreStaleErrors(AsyncIntegrationTest): diff --git a/test/asynchronous/unified_format.py b/test/asynchronous/unified_format.py index 384eb3d69b..908f2a6b75 100644 --- a/test/asynchronous/unified_format.py +++ b/test/asynchronous/unified_format.py @@ -1385,20 +1385,18 @@ async def run_scenario(self, spec, uri=None): # operations during test set up and tear down. await self.kill_all_sessions() - if "csot" in self.id().lower() or "discovery_and_monitoring" in self.id().lower(): + if "csot" in self.id().lower(): # Retry CSOT tests up to 2 times to deal with flakey tests. # discovery_and_monitoring tests on windows are also flakey attempts = 3 for i in range(attempts): try: return await self._run_scenario(spec, uri) - except (AssertionError, OperationFailure, _OperationCancelled) as exc: + except (AssertionError, OperationFailure) as exc: if isinstance(exc, OperationFailure) and ( _IS_SYNC or "failpoint" not in exc._message ): raise - if isinstance(exc, _OperationCancelled) and _IS_SYNC: - raise if i < attempts - 1: print( f"Retrying after attempt {i+1} of {self.id()} failed with:\n" diff --git a/test/unified_format.py b/test/unified_format.py index 86b5d1311a..13f8487d7b 100644 --- a/test/unified_format.py +++ b/test/unified_format.py @@ -1372,20 +1372,18 @@ def run_scenario(self, spec, uri=None): # operations during test set up and tear down. self.kill_all_sessions() - if "csot" in self.id().lower() or "discovery_and_monitoring" in self.id().lower(): + if "csot" in self.id().lower(): # Retry CSOT tests up to 2 times to deal with flakey tests. # discovery_and_monitoring tests on windows are also flakey attempts = 3 for i in range(attempts): try: return self._run_scenario(spec, uri) - except (AssertionError, OperationFailure, _OperationCancelled) as exc: + except (AssertionError, OperationFailure) as exc: if isinstance(exc, OperationFailure) and ( _IS_SYNC or "failpoint" not in exc._message ): raise - if isinstance(exc, _OperationCancelled) and _IS_SYNC: - raise if i < attempts - 1: print( f"Retrying after attempt {i+1} of {self.id()} failed with:\n" From 9b4fd4536bfe6c9f1551257ca47e8e1f481c63ae Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Fri, 28 Feb 2025 13:25:20 -0800 Subject: [PATCH 17/19] skip flakey test for now --- test/asynchronous/unified_format.py | 7 ++++++- test/unified_format.py | 7 ++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/test/asynchronous/unified_format.py b/test/asynchronous/unified_format.py index 908f2a6b75..973c019ba6 100644 --- a/test/asynchronous/unified_format.py +++ b/test/asynchronous/unified_format.py @@ -545,6 +545,12 @@ def maybe_skip_test(self, spec): self.skipTest("Implement PYTHON-1894") if "timeoutMS applied to entire download" in spec["description"]: self.skipTest("PyMongo's open_download_stream does not cap the stream's lifetime") + if ( + "Error returned from connection pool clear with interruptInUseConnections=true is retryable" + in spec["description"] + and not _IS_SYNC + ): + self.skipTest("PYTHON-5170 tests are flakey") class_name = self.__class__.__name__.lower() description = spec["description"].lower() @@ -1387,7 +1393,6 @@ async def run_scenario(self, spec, uri=None): if "csot" in self.id().lower(): # Retry CSOT tests up to 2 times to deal with flakey tests. - # discovery_and_monitoring tests on windows are also flakey attempts = 3 for i in range(attempts): try: diff --git a/test/unified_format.py b/test/unified_format.py index 13f8487d7b..b899c21bff 100644 --- a/test/unified_format.py +++ b/test/unified_format.py @@ -544,6 +544,12 @@ def maybe_skip_test(self, spec): self.skipTest("Implement PYTHON-1894") if "timeoutMS applied to entire download" in spec["description"]: self.skipTest("PyMongo's open_download_stream does not cap the stream's lifetime") + if ( + "Error returned from connection pool clear with interruptInUseConnections=true is retryable" + in spec["description"] + and not _IS_SYNC + ): + self.skipTest("PYTHON-5170 tests are flakey") class_name = self.__class__.__name__.lower() description = spec["description"].lower() @@ -1374,7 +1380,6 @@ def run_scenario(self, spec, uri=None): if "csot" in self.id().lower(): # Retry CSOT tests up to 2 times to deal with flakey tests. - # discovery_and_monitoring tests on windows are also flakey attempts = 3 for i in range(attempts): try: From c0e37e5f8755daae2fe3a656924f019af980b914 Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Fri, 28 Feb 2025 13:32:29 -0800 Subject: [PATCH 18/19] remove import --- test/asynchronous/unified_format.py | 1 - test/unified_format.py | 1 - 2 files changed, 2 deletions(-) diff --git a/test/asynchronous/unified_format.py b/test/asynchronous/unified_format.py index 973c019ba6..ba31c58d41 100644 --- a/test/asynchronous/unified_format.py +++ b/test/asynchronous/unified_format.py @@ -87,7 +87,6 @@ NotPrimaryError, OperationFailure, PyMongoError, - _OperationCancelled, ) from pymongo.monitoring import ( CommandStartedEvent, diff --git a/test/unified_format.py b/test/unified_format.py index b899c21bff..bdccb5c3e2 100644 --- a/test/unified_format.py +++ b/test/unified_format.py @@ -79,7 +79,6 @@ NotPrimaryError, OperationFailure, PyMongoError, - _OperationCancelled, ) from pymongo.monitoring import ( CommandStartedEvent, From 81f7a1b2fab3b142aa43cb0983e277a2f4ceb690 Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Fri, 28 Feb 2025 15:42:29 -0800 Subject: [PATCH 19/19] another flakey test to skip --- test/asynchronous/unified_format.py | 2 ++ test/unified_format.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/test/asynchronous/unified_format.py b/test/asynchronous/unified_format.py index ba31c58d41..c3931da936 100644 --- a/test/asynchronous/unified_format.py +++ b/test/asynchronous/unified_format.py @@ -550,6 +550,8 @@ def maybe_skip_test(self, spec): and not _IS_SYNC ): self.skipTest("PYTHON-5170 tests are flakey") + if "Driver extends timeout while streaming" in spec["description"] and not _IS_SYNC: + self.skipTest("PYTHON-5174 tests are flakey") class_name = self.__class__.__name__.lower() description = spec["description"].lower() diff --git a/test/unified_format.py b/test/unified_format.py index bdccb5c3e2..8ed9e214bb 100644 --- a/test/unified_format.py +++ b/test/unified_format.py @@ -549,6 +549,8 @@ def maybe_skip_test(self, spec): and not _IS_SYNC ): self.skipTest("PYTHON-5170 tests are flakey") + if "Driver extends timeout while streaming" in spec["description"] and not _IS_SYNC: + self.skipTest("PYTHON-5174 tests are flakey") class_name = self.__class__.__name__.lower() description = spec["description"].lower()