diff --git a/test/asynchronous/test_discovery_and_monitoring.py b/test/asynchronous/test_discovery_and_monitoring.py new file mode 100644 index 0000000000..7c4095ebb8 --- /dev/null +++ b/test/asynchronous/test_discovery_and_monitoring.py @@ -0,0 +1,503 @@ +# Copyright 2014-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test the topology module.""" +from __future__ import annotations + +import asyncio +import os +import socketserver +import sys +import threading +from asyncio import StreamReader, StreamWriter +from pathlib import Path +from test.asynchronous.helpers import ConcurrentRunner + +sys.path[0:0] = [""] + +from test.asynchronous import AsyncIntegrationTest, AsyncPyMongoTestCase, AsyncUnitTest, unittest +from test.asynchronous.pymongo_mocks import DummyMonitor +from test.asynchronous.unified_format import generate_test_classes +from test.utils import ( + CMAPListener, + HeartbeatEventListener, + HeartbeatEventsListListener, + assertion_context, + async_barrier_wait, + async_client_context, + async_create_barrier, + async_get_pool, + async_wait_until, + server_name_to_type, +) +from unittest.mock import patch + +from bson import Timestamp, json_util +from pymongo import AsyncMongoClient, common, monitoring +from pymongo.asynchronous.settings import TopologySettings +from pymongo.asynchronous.topology import Topology, _ErrorContext +from pymongo.errors import ( + AutoReconnect, + ConfigurationError, + NetworkTimeout, + NotPrimaryError, + OperationFailure, +) +from pymongo.hello import Hello, HelloCompat +from pymongo.helpers_shared import _check_command_response, _check_write_command_response +from pymongo.monitoring import ServerHeartbeatFailedEvent, ServerHeartbeatStartedEvent +from pymongo.server_description import SERVER_TYPE, ServerDescription +from pymongo.topology_description import TOPOLOGY_TYPE +from pymongo.uri_parser import parse_uri + +_IS_SYNC = False + +# Location of JSON test specifications. +if _IS_SYNC: + SDAM_PATH = os.path.join(Path(__file__).resolve().parent, "discovery_and_monitoring") +else: + SDAM_PATH = os.path.join( + Path(__file__).resolve().parent.parent, + "discovery_and_monitoring", + ) + + +async def create_mock_topology(uri, monitor_class=DummyMonitor): + parsed_uri = parse_uri(uri) + replica_set_name = None + direct_connection = None + load_balanced = None + if "replicaset" in parsed_uri["options"]: + replica_set_name = parsed_uri["options"]["replicaset"] + if "directConnection" in parsed_uri["options"]: + direct_connection = parsed_uri["options"]["directConnection"] + if "loadBalanced" in parsed_uri["options"]: + load_balanced = parsed_uri["options"]["loadBalanced"] + + topology_settings = TopologySettings( + parsed_uri["nodelist"], + replica_set_name=replica_set_name, + monitor_class=monitor_class, + direct_connection=direct_connection, + load_balanced=load_balanced, + ) + + c = Topology(topology_settings) + await c.open() + return c + + +async def got_hello(topology, server_address, hello_response): + server_description = ServerDescription(server_address, Hello(hello_response), 0) + await topology.on_change(server_description) + + +async def got_app_error(topology, app_error): + server_address = common.partition_node(app_error["address"]) + server = topology.get_server_by_address(server_address) + error_type = app_error["type"] + generation = app_error.get("generation", server.pool.gen.get_overall()) + when = app_error["when"] + max_wire_version = app_error["maxWireVersion"] + # XXX: We could get better test coverage by mocking the errors on the + # Pool/AsyncConnection. + try: + if error_type == "command": + _check_command_response(app_error["response"], max_wire_version) + _check_write_command_response(app_error["response"]) + elif error_type == "network": + raise AutoReconnect("mock non-timeout network error") + elif error_type == "timeout": + raise NetworkTimeout("mock network timeout error") + else: + raise AssertionError(f"unknown error type: {error_type}") + raise AssertionError + except (AutoReconnect, NotPrimaryError, OperationFailure) as e: + if when == "beforeHandshakeCompletes": + completed_handshake = False + elif when == "afterHandshakeCompletes": + completed_handshake = True + else: + raise AssertionError(f"Unknown when field {when}") + + await topology.handle_error( + server_address, + _ErrorContext(e, max_wire_version, generation, completed_handshake, None), + ) + + +def get_type(topology, hostname): + description = topology.get_server_by_address((hostname, 27017)).description + return description.server_type + + +class TestAllScenarios(AsyncUnitTest): + pass + + +def topology_type_name(topology_type): + return TOPOLOGY_TYPE._fields[topology_type] + + +def server_type_name(server_type): + return SERVER_TYPE._fields[server_type] + + +def check_outcome(self, topology, outcome): + expected_servers = outcome["servers"] + + # Check weak equality before proceeding. + self.assertEqual(len(topology.description.server_descriptions()), len(expected_servers)) + + if outcome.get("compatible") is False: + with self.assertRaises(ConfigurationError): + topology.description.check_compatible() + else: + # No error. + topology.description.check_compatible() + + # Since lengths are equal, every actual server must have a corresponding + # expected server. + for expected_server_address, expected_server in expected_servers.items(): + node = common.partition_node(expected_server_address) + self.assertTrue(topology.has_server(node)) + actual_server = topology.get_server_by_address(node) + actual_server_description = actual_server.description + expected_server_type = server_name_to_type(expected_server["type"]) + + self.assertEqual( + server_type_name(expected_server_type), + server_type_name(actual_server_description.server_type), + ) + + self.assertEqual(expected_server.get("setName"), actual_server_description.replica_set_name) + + self.assertEqual(expected_server.get("setVersion"), actual_server_description.set_version) + + self.assertEqual(expected_server.get("electionId"), actual_server_description.election_id) + + self.assertEqual( + expected_server.get("topologyVersion"), actual_server_description.topology_version + ) + + expected_pool = expected_server.get("pool") + if expected_pool: + self.assertEqual(expected_pool.get("generation"), actual_server.pool.gen.get_overall()) + + self.assertEqual(outcome["setName"], topology.description.replica_set_name) + self.assertEqual( + outcome.get("logicalSessionTimeoutMinutes"), + topology.description.logical_session_timeout_minutes, + ) + + expected_topology_type = getattr(TOPOLOGY_TYPE, outcome["topologyType"]) + self.assertEqual( + topology_type_name(expected_topology_type), + topology_type_name(topology.description.topology_type), + ) + + self.assertEqual(outcome.get("maxSetVersion"), topology.description.max_set_version) + self.assertEqual(outcome.get("maxElectionId"), topology.description.max_election_id) + + +def create_test(scenario_def): + async def run_scenario(self): + c = await create_mock_topology(scenario_def["uri"]) + + for i, phase in enumerate(scenario_def["phases"]): + # Including the phase description makes failures easier to debug. + description = phase.get("description", str(i)) + with assertion_context(f"phase: {description}"): + for response in phase.get("responses", []): + await got_hello(c, common.partition_node(response[0]), response[1]) + + for app_error in phase.get("applicationErrors", []): + await got_app_error(c, app_error) + + check_outcome(self, c, phase["outcome"]) + + return run_scenario + + +def create_tests(): + for dirpath, _, filenames in os.walk(SDAM_PATH): + dirname = os.path.split(dirpath)[-1] + # SDAM unified tests are handled separately. + if dirname == "unified": + continue + + for filename in filenames: + if os.path.splitext(filename)[1] != ".json": + continue + with open(os.path.join(dirpath, filename)) as scenario_stream: + scenario_def = json_util.loads(scenario_stream.read()) + + # Construct test from scenario. + new_test = create_test(scenario_def) + test_name = f"test_{dirname}_{os.path.splitext(filename)[0]}" + + new_test.__name__ = test_name + setattr(TestAllScenarios, new_test.__name__, new_test) + + +create_tests() + + +class TestClusterTimeComparison(AsyncPyMongoTestCase): + async def test_cluster_time_comparison(self): + t = await create_mock_topology("mongodb://host") + + async def send_cluster_time(time, inc): + old = t.max_cluster_time() + new = {"clusterTime": Timestamp(time, inc)} + await got_hello( + t, + ("host", 27017), + { + "ok": 1, + "minWireVersion": 0, + "maxWireVersion": common.MIN_SUPPORTED_WIRE_VERSION, + "$clusterTime": new, + }, + ) + + actual = t.max_cluster_time() + # We never update $clusterTime from monitoring connections. + self.assertEqual(actual, old) + + await send_cluster_time(0, 1) + await send_cluster_time(2, 2) + await send_cluster_time(2, 1) + await send_cluster_time(1, 3) + await send_cluster_time(2, 3) + + +class TestIgnoreStaleErrors(AsyncIntegrationTest): + async def test_ignore_stale_connection_errors(self): + if not _IS_SYNC and sys.version_info < (3, 11): + self.skipTest("Test requires asyncio.Barrier (added in Python 3.11)") + N_TASKS = 5 + barrier = async_create_barrier(N_TASKS, timeout=30) + client = await self.async_rs_or_single_client(minPoolSize=N_TASKS) + + # Wait for initial discovery. + await client.admin.command("ping") + pool = await async_get_pool(client) + starting_generation = pool.gen.get_overall() + await async_wait_until(lambda: len(pool.conns) == N_TASKS, "created conns") + + async def mock_command(*args, **kwargs): + # Synchronize all tasks to ensure they use the same generation. + await async_barrier_wait(barrier, timeout=30) + raise AutoReconnect("mock AsyncConnection.command error") + + for conn in pool.conns: + conn.command = mock_command + + async def insert_command(i): + try: + await client.test.command("insert", "test", documents=[{"i": i}]) + except AutoReconnect: + pass + + tasks = [] + for i in range(N_TASKS): + tasks.append(ConcurrentRunner(target=insert_command, args=(i,))) + for t in tasks: + await t.start() + for t in tasks: + await t.join() + + # Expect a single pool reset for the network error + self.assertEqual(starting_generation + 1, pool.gen.get_overall()) + + # Server should be selectable. + await client.admin.command("ping") + + +class CMAPHeartbeatListener(HeartbeatEventListener, CMAPListener): + pass + + +class TestPoolManagement(AsyncIntegrationTest): + @async_client_context.require_failCommand_appName + async def test_pool_unpause(self): + # This test implements the prose test "AsyncConnection Pool Management" + listener = CMAPHeartbeatListener() + _ = await self.async_single_client( + appName="SDAMPoolManagementTest", heartbeatFrequencyMS=500, event_listeners=[listener] + ) + # Assert that AsyncConnectionPoolReadyEvent occurs after the first + # ServerHeartbeatSucceededEvent. + await listener.async_wait_for_event(monitoring.PoolReadyEvent, 1) + pool_ready = listener.events_by_type(monitoring.PoolReadyEvent)[0] + hb_succeeded = listener.events_by_type(monitoring.ServerHeartbeatSucceededEvent)[0] + self.assertGreater(listener.events.index(pool_ready), listener.events.index(hb_succeeded)) + + listener.reset() + fail_hello = { + "mode": {"times": 2}, + "data": { + "failCommands": [HelloCompat.LEGACY_CMD, "hello"], + "errorCode": 1234, + "appName": "SDAMPoolManagementTest", + }, + } + async with self.fail_point(fail_hello): + await listener.async_wait_for_event(monitoring.ServerHeartbeatFailedEvent, 1) + await listener.async_wait_for_event(monitoring.PoolClearedEvent, 1) + await listener.async_wait_for_event(monitoring.ServerHeartbeatSucceededEvent, 1) + await listener.async_wait_for_event(monitoring.PoolReadyEvent, 1) + + +class TestServerMonitoringMode(AsyncIntegrationTest): + @async_client_context.require_no_serverless + @async_client_context.require_no_load_balancer + async def asyncSetUp(self): + await super().asyncSetUp() + + async def test_rtt_connection_is_enabled_stream(self): + client = await self.async_rs_or_single_client(serverMonitoringMode="stream") + await client.admin.command("ping") + + def predicate(): + for _, server in client._topology._servers.items(): + monitor = server._monitor + if not monitor._stream: + return False + if async_client_context.version >= (4, 4): + if _IS_SYNC: + if monitor._rtt_monitor._executor._thread is None: + return False + else: + if monitor._rtt_monitor._executor._task is None: + return False + else: + if _IS_SYNC: + if monitor._rtt_monitor._executor._thread is not None: + return False + else: + if monitor._rtt_monitor._executor._task is not None: + return False + return True + + await async_wait_until(predicate, "find all RTT monitors") + + async def test_rtt_connection_is_disabled_poll(self): + client = await self.async_rs_or_single_client(serverMonitoringMode="poll") + + await self.assert_rtt_connection_is_disabled(client) + + async def test_rtt_connection_is_disabled_auto(self): + envs = [ + {"AWS_EXECUTION_ENV": "AWS_Lambda_python3.9"}, + {"FUNCTIONS_WORKER_RUNTIME": "python"}, + {"K_SERVICE": "gcpservicename"}, + {"FUNCTION_NAME": "gcpfunctionname"}, + {"VERCEL": "1"}, + ] + for env in envs: + with patch.dict("os.environ", env): + client = await self.async_rs_or_single_client(serverMonitoringMode="auto") + await self.assert_rtt_connection_is_disabled(client) + + async def assert_rtt_connection_is_disabled(self, client): + await client.admin.command("ping") + for _, server in client._topology._servers.items(): + monitor = server._monitor + self.assertFalse(monitor._stream) + if _IS_SYNC: + self.assertIsNone(monitor._rtt_monitor._executor._thread) + else: + self.assertIsNone(monitor._rtt_monitor._executor._task) + + +class MockTCPHandler(socketserver.BaseRequestHandler): + def handle(self): + self.server.events.append("client connected") + if self.request.recv(1024).strip(): + self.server.events.append("client hello received") + self.request.close() + + +class TCPServer(socketserver.TCPServer): + allow_reuse_address = True + + def handle_request_and_shutdown(self): + self.handle_request() + self.server_close() + + +class TestHeartbeatStartOrdering(AsyncPyMongoTestCase): + async def test_heartbeat_start_ordering(self): + events = [] + listener = HeartbeatEventsListListener(events) + + if _IS_SYNC: + server = TCPServer(("localhost", 9999), MockTCPHandler) + server.events = events + server_thread = ConcurrentRunner(target=server.handle_request_and_shutdown) + await server_thread.start() + _c = await self.simple_client( + "mongodb://localhost:9999", + serverSelectionTimeoutMS=500, + event_listeners=(listener,), + ) + await server_thread.join() + listener.wait_for_event(ServerHeartbeatStartedEvent, 1) + listener.wait_for_event(ServerHeartbeatFailedEvent, 1) + + else: + + async def handle_client(reader: StreamReader, writer: StreamWriter): + events.append("client connected") + if (await reader.read(1024)).strip(): + events.append("client hello received") + writer.close() + await writer.wait_closed() + + server = await asyncio.start_server(handle_client, "localhost", 9999) + server.events = events + await server.start_serving() + _c = self.simple_client( + "mongodb://localhost:9999", + serverSelectionTimeoutMS=500, + event_listeners=(listener,), + ) + await _c.aconnect() + + await listener.async_wait_for_event(ServerHeartbeatStartedEvent, 1) + await listener.async_wait_for_event(ServerHeartbeatFailedEvent, 1) + + server.close() + await server.wait_closed() + await _c.close() + + self.assertEqual( + events, + [ + "serverHeartbeatStartedEvent", + "client connected", + "client hello received", + "serverHeartbeatFailedEvent", + ], + ) + + +# Generate unified tests. +globals().update(generate_test_classes(os.path.join(SDAM_PATH, "unified"), module=__name__)) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/asynchronous/unified_format.py b/test/asynchronous/unified_format.py index c315e86945..c3931da936 100644 --- a/test/asynchronous/unified_format.py +++ b/test/asynchronous/unified_format.py @@ -544,6 +544,14 @@ def maybe_skip_test(self, spec): self.skipTest("Implement PYTHON-1894") if "timeoutMS applied to entire download" in spec["description"]: self.skipTest("PyMongo's open_download_stream does not cap the stream's lifetime") + if ( + "Error returned from connection pool clear with interruptInUseConnections=true is retryable" + in spec["description"] + and not _IS_SYNC + ): + self.skipTest("PYTHON-5170 tests are flakey") + if "Driver extends timeout while streaming" in spec["description"] and not _IS_SYNC: + self.skipTest("PYTHON-5174 tests are flakey") class_name = self.__class__.__name__.lower() description = spec["description"].lower() @@ -1151,7 +1159,7 @@ def _testOperation_assertTopologyType(self, spec): self.assertIsInstance(description, TopologyDescription) self.assertEqual(description.topology_type_name, spec["topologyType"]) - def _testOperation_waitForPrimaryChange(self, spec: dict) -> None: + async def _testOperation_waitForPrimaryChange(self, spec: dict) -> None: """Run the waitForPrimaryChange test operation.""" client = self.entity_map[spec["client"]] old_description: TopologyDescription = self.entity_map[spec["priorTopologyDescription"]] @@ -1165,13 +1173,13 @@ def get_primary(td: TopologyDescription) -> Optional[_Address]: old_primary = get_primary(old_description) - def primary_changed() -> bool: - primary = client.primary + async def primary_changed() -> bool: + primary = await client.primary if primary is None: return False return primary != old_primary - wait_until(primary_changed, "change primary", timeout=timeout) + await async_wait_until(primary_changed, "change primary", timeout=timeout) async def _testOperation_runOnThread(self, spec): """Run the 'runOnThread' operation.""" diff --git a/test/test_discovery_and_monitoring.py b/test/test_discovery_and_monitoring.py index 70dcfc5b48..2eb278383c 100644 --- a/test/test_discovery_and_monitoring.py +++ b/test/test_discovery_and_monitoring.py @@ -15,14 +15,18 @@ """Test the topology module.""" from __future__ import annotations +import asyncio import os import socketserver import sys import threading +from asyncio import StreamReader, StreamWriter +from pathlib import Path +from test.helpers import ConcurrentRunner sys.path[0:0] = [""] -from test import IntegrationTest, PyMongoTestCase, unittest +from test import IntegrationTest, PyMongoTestCase, UnitTest, unittest from test.pymongo_mocks import DummyMonitor from test.unified_format import generate_test_classes from test.utils import ( @@ -30,7 +34,9 @@ HeartbeatEventListener, HeartbeatEventsListListener, assertion_context, + barrier_wait, client_context, + create_barrier, get_pool, server_name_to_type, wait_until, @@ -55,8 +61,16 @@ from pymongo.topology_description import TOPOLOGY_TYPE from pymongo.uri_parser import parse_uri +_IS_SYNC = True + # Location of JSON test specifications. -SDAM_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "discovery_and_monitoring") +if _IS_SYNC: + SDAM_PATH = os.path.join(Path(__file__).resolve().parent, "discovery_and_monitoring") +else: + SDAM_PATH = os.path.join( + Path(__file__).resolve().parent.parent, + "discovery_and_monitoring", + ) def create_mock_topology(uri, monitor_class=DummyMonitor): @@ -128,7 +142,7 @@ def get_type(topology, hostname): return description.server_type -class TestAllScenarios(unittest.TestCase): +class TestAllScenarios(UnitTest): pass @@ -240,7 +254,7 @@ def create_tests(): create_tests() -class TestClusterTimeComparison(unittest.TestCase): +class TestClusterTimeComparison(PyMongoTestCase): def test_cluster_time_comparison(self): t = create_mock_topology("mongodb://host") @@ -271,20 +285,21 @@ def send_cluster_time(time, inc): class TestIgnoreStaleErrors(IntegrationTest): def test_ignore_stale_connection_errors(self): - N_THREADS = 5 - barrier = threading.Barrier(N_THREADS, timeout=30) - client = self.rs_or_single_client(minPoolSize=N_THREADS) - self.addCleanup(client.close) + if not _IS_SYNC and sys.version_info < (3, 11): + self.skipTest("Test requires asyncio.Barrier (added in Python 3.11)") + N_TASKS = 5 + barrier = create_barrier(N_TASKS, timeout=30) + client = self.rs_or_single_client(minPoolSize=N_TASKS) # Wait for initial discovery. client.admin.command("ping") pool = get_pool(client) starting_generation = pool.gen.get_overall() - wait_until(lambda: len(pool.conns) == N_THREADS, "created conns") + wait_until(lambda: len(pool.conns) == N_TASKS, "created conns") def mock_command(*args, **kwargs): - # Synchronize all threads to ensure they use the same generation. - barrier.wait() + # Synchronize all tasks to ensure they use the same generation. + barrier_wait(barrier, timeout=30) raise AutoReconnect("mock Connection.command error") for conn in pool.conns: @@ -296,12 +311,12 @@ def insert_command(i): except AutoReconnect: pass - threads = [] - for i in range(N_THREADS): - threads.append(threading.Thread(target=insert_command, args=(i,))) - for t in threads: + tasks = [] + for i in range(N_TASKS): + tasks.append(ConcurrentRunner(target=insert_command, args=(i,))) + for t in tasks: t.start() - for t in threads: + for t in tasks: t.join() # Expect a single pool reset for the network error @@ -320,10 +335,9 @@ class TestPoolManagement(IntegrationTest): def test_pool_unpause(self): # This test implements the prose test "Connection Pool Management" listener = CMAPHeartbeatListener() - client = self.single_client( + _ = self.single_client( appName="SDAMPoolManagementTest", heartbeatFrequencyMS=500, event_listeners=[listener] ) - self.addCleanup(client.close) # Assert that ConnectionPoolReadyEvent occurs after the first # ServerHeartbeatSucceededEvent. listener.wait_for_event(monitoring.PoolReadyEvent, 1) @@ -355,7 +369,6 @@ def setUp(self): def test_rtt_connection_is_enabled_stream(self): client = self.rs_or_single_client(serverMonitoringMode="stream") - self.addCleanup(client.close) client.admin.command("ping") def predicate(): @@ -364,18 +377,26 @@ def predicate(): if not monitor._stream: return False if client_context.version >= (4, 4): - if monitor._rtt_monitor._executor._thread is None: - return False + if _IS_SYNC: + if monitor._rtt_monitor._executor._thread is None: + return False + else: + if monitor._rtt_monitor._executor._task is None: + return False else: - if monitor._rtt_monitor._executor._thread is not None: - return False + if _IS_SYNC: + if monitor._rtt_monitor._executor._thread is not None: + return False + else: + if monitor._rtt_monitor._executor._task is not None: + return False return True wait_until(predicate, "find all RTT monitors") def test_rtt_connection_is_disabled_poll(self): client = self.rs_or_single_client(serverMonitoringMode="poll") - self.addCleanup(client.close) + self.assert_rtt_connection_is_disabled(client) def test_rtt_connection_is_disabled_auto(self): @@ -389,7 +410,6 @@ def test_rtt_connection_is_disabled_auto(self): for env in envs: with patch.dict("os.environ", env): client = self.rs_or_single_client(serverMonitoringMode="auto") - self.addCleanup(client.close) self.assert_rtt_connection_is_disabled(client) def assert_rtt_connection_is_disabled(self, client): @@ -397,7 +417,10 @@ def assert_rtt_connection_is_disabled(self, client): for _, server in client._topology._servers.items(): monitor = server._monitor self.assertFalse(monitor._stream) - self.assertIsNone(monitor._rtt_monitor._executor._thread) + if _IS_SYNC: + self.assertIsNone(monitor._rtt_monitor._executor._thread) + else: + self.assertIsNone(monitor._rtt_monitor._executor._task) class MockTCPHandler(socketserver.BaseRequestHandler): @@ -420,16 +443,46 @@ class TestHeartbeatStartOrdering(PyMongoTestCase): def test_heartbeat_start_ordering(self): events = [] listener = HeartbeatEventsListListener(events) - server = TCPServer(("localhost", 9999), MockTCPHandler) - server.events = events - server_thread = threading.Thread(target=server.handle_request_and_shutdown) - server_thread.start() - _c = self.simple_client( - "mongodb://localhost:9999", serverSelectionTimeoutMS=500, event_listeners=(listener,) - ) - server_thread.join() - listener.wait_for_event(ServerHeartbeatStartedEvent, 1) - listener.wait_for_event(ServerHeartbeatFailedEvent, 1) + + if _IS_SYNC: + server = TCPServer(("localhost", 9999), MockTCPHandler) + server.events = events + server_thread = ConcurrentRunner(target=server.handle_request_and_shutdown) + server_thread.start() + _c = self.simple_client( + "mongodb://localhost:9999", + serverSelectionTimeoutMS=500, + event_listeners=(listener,), + ) + server_thread.join() + listener.wait_for_event(ServerHeartbeatStartedEvent, 1) + listener.wait_for_event(ServerHeartbeatFailedEvent, 1) + + else: + + def handle_client(reader: StreamReader, writer: StreamWriter): + events.append("client connected") + if (reader.read(1024)).strip(): + events.append("client hello received") + writer.close() + writer.wait_closed() + + server = asyncio.start_server(handle_client, "localhost", 9999) + server.events = events + server.start_serving() + _c = self.simple_client( + "mongodb://localhost:9999", + serverSelectionTimeoutMS=500, + event_listeners=(listener,), + ) + _c._connect() + + listener.wait_for_event(ServerHeartbeatStartedEvent, 1) + listener.wait_for_event(ServerHeartbeatFailedEvent, 1) + + server.close() + server.wait_closed() + _c.close() self.assertEqual( events, diff --git a/test/unified_format.py b/test/unified_format.py index d5698f5a77..8ed9e214bb 100644 --- a/test/unified_format.py +++ b/test/unified_format.py @@ -543,6 +543,14 @@ def maybe_skip_test(self, spec): self.skipTest("Implement PYTHON-1894") if "timeoutMS applied to entire download" in spec["description"]: self.skipTest("PyMongo's open_download_stream does not cap the stream's lifetime") + if ( + "Error returned from connection pool clear with interruptInUseConnections=true is retryable" + in spec["description"] + and not _IS_SYNC + ): + self.skipTest("PYTHON-5170 tests are flakey") + if "Driver extends timeout while streaming" in spec["description"] and not _IS_SYNC: + self.skipTest("PYTHON-5174 tests are flakey") class_name = self.__class__.__name__.lower() description = spec["description"].lower() diff --git a/test/utils.py b/test/utils.py index e089b3fc2f..ae316d0387 100644 --- a/test/utils.py +++ b/test/utils.py @@ -1078,3 +1078,19 @@ def create_async_event(): def create_event(): return threading.Event() + + +def async_create_barrier(N_TASKS, timeout: float | None = None): + return asyncio.Barrier(N_TASKS) + + +def create_barrier(N_TASKS, timeout: float | None = None): + return threading.Barrier(N_TASKS, timeout=timeout) + + +async def async_barrier_wait(barrier, timeout: float | None = None): + await asyncio.wait_for(barrier.wait(), timeout=timeout) + + +def barrier_wait(barrier, timeout: float | None = None): + barrier.wait() diff --git a/tools/synchro.py b/tools/synchro.py index 877a683531..42d5694f47 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -124,6 +124,8 @@ "AsyncMockPool": "MockPool", "StopAsyncIteration": "StopIteration", "create_async_event": "create_event", + "async_create_barrier": "create_barrier", + "async_barrier_wait": "barrier_wait", "async_joinall": "joinall", } @@ -213,6 +215,7 @@ def async_only_test(f: str) -> bool: "test_custom_types.py", "test_database.py", "test_data_lake.py", + "test_discovery_and_monitoring.py", "test_dns.py", "test_encryption.py", "test_examples.py",