From f73e0d25c5b476dd335c23542fc3cbc1d5586f63 Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Fri, 31 Jan 2025 19:28:14 -0800 Subject: [PATCH 01/14] Convert test.test_server_selection_in_window to async --- .../test_server_selection_in_window.py | 196 ++++++++++++++++++ test/test_server_selection_in_window.py | 50 +++-- tools/synchro.py | 1 + 3 files changed, 234 insertions(+), 13 deletions(-) create mode 100644 test/asynchronous/test_server_selection_in_window.py diff --git a/test/asynchronous/test_server_selection_in_window.py b/test/asynchronous/test_server_selection_in_window.py new file mode 100644 index 0000000000..5c2210b101 --- /dev/null +++ b/test/asynchronous/test_server_selection_in_window.py @@ -0,0 +1,196 @@ +# Copyright 2020-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test the topology module's Server Selection Spec implementation.""" +from __future__ import annotations + +import asyncio +import os +import threading +from pathlib import Path +from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest +from test.utils import ( + CMAPListener, + OvertCommandListener, + async_get_pool, + async_wait_until, +) +from test.utils_selection_tests import create_topology +from test.utils_spec_runner import SpecTestCreator + +from pymongo.common import clean_node +from pymongo.monitoring import ConnectionReadyEvent +from pymongo.operations import _Op +from pymongo.read_preferences import ReadPreference + +_IS_SYNC = False +# Location of JSON test specifications. +if _IS_SYNC: + TEST_PATH = os.path.join(Path(__file__).resolve().parent, "server_selection", "in_window") +else: + TEST_PATH = os.path.join( + Path(__file__).resolve().parent.parent, "server_selection", "in_window" + ) + + +class TestAllScenarios(unittest.IsolatedAsyncioTestCase): + def run_scenario(self, scenario_def): + topology = create_topology(scenario_def) + + # Update mock operation_count state: + for mock in scenario_def["mocked_topology_state"]: + address = clean_node(mock["address"]) + server = topology.get_server_by_address(address) + server.pool.operation_count = mock["operation_count"] + + pref = ReadPreference.NEAREST + counts = {address: 0 for address in topology._description.server_descriptions()} + + # Number of times to repeat server selection + iterations = scenario_def["iterations"] + for _ in range(iterations): + server = topology.select_server(pref, _Op.TEST, server_selection_timeout=0) + counts[server.description.address] += 1 + + # Verify expected_frequencies + outcome = scenario_def["outcome"] + tolerance = outcome["tolerance"] + expected_frequencies = outcome["expected_frequencies"] + for host_str, freq in expected_frequencies.items(): + address = clean_node(host_str) + actual_freq = float(counts[address]) / iterations + if freq == 0: + # Should be exactly 0. + self.assertEqual(actual_freq, 0) + else: + # Should be within 'tolerance'. + self.assertAlmostEqual(actual_freq, freq, delta=tolerance) + + +def create_test(scenario_def, test, name): + def run_scenario(self): + self.run_scenario(scenario_def) + + return run_scenario + + +class CustomSpecTestCreator(SpecTestCreator): + def tests(self, scenario_def): + """Extract the tests from a spec file. + + Server selection in_window tests do not have a 'tests' field. + The whole file represents a single test case. + """ + return [scenario_def] + + +CustomSpecTestCreator(create_test, TestAllScenarios, TEST_PATH).create_tests() + +if _IS_SYNC: + PARENT = threading.Thread +else: + PARENT = object + + +class FinderThread(PARENT): + def __init__(self, collection, iterations): + super().__init__() + self.daemon = True + self.collection = collection + self.iterations = iterations + self.passed = False + self.task = None + + async def run(self): + for _ in range(self.iterations): + await self.collection.find_one({}) + self.passed = True + + def start(self): + if _IS_SYNC: + super().start() + else: + self.task = asyncio.create_task(self.run()) + + async def join(self): + if _IS_SYNC: + super().join() + else: + await self.task + + +class TestProse(AsyncIntegrationTest): + def frequencies(self, client, listener, n_finds=10): + coll = client.test.test + N_TASKS = 10 + tasks = [FinderThread(coll, n_finds) for _ in range(N_TASKS)] + for task in tasks: + task.start() + for task in tasks: + task.join() + for task in tasks: + self.assertTrue(task.passed) + + events = listener.started_events + self.assertEqual(len(events), n_finds * N_TASKS) + nodes = client.nodes + self.assertEqual(len(nodes), 2) + freqs = {address: 0.0 for address in nodes} + for event in events: + freqs[event.connection_id] += 1 + for address in freqs: + freqs[address] = freqs[address] / float(len(events)) + return freqs + + @async_client_context.require_failCommand_appName + @async_client_context.require_multiple_mongoses + async def test_load_balancing(self): + listener = OvertCommandListener() + cmap_listener = CMAPListener() + # PYTHON-2584: Use a large localThresholdMS to avoid the impact of + # varying RTTs. + client = await self.async_rs_client( + async_client_context.mongos_seeds(), + appName="loadBalancingTest", + event_listeners=[listener, cmap_listener], + localThresholdMS=30000, + minPoolSize=10, + ) + await async_wait_until(lambda: len(client.nodes) == 2, "discover both nodes") + # Wait for both pools to be populated. + cmap_listener.wait_for_event(ConnectionReadyEvent, 20) + # Delay find commands on only one mongos. + delay_finds = { + "configureFailPoint": "failCommand", + "mode": {"times": 10000}, + "data": { + "failCommands": ["find"], + "blockAsyncConnection": True, + "blockTimeMS": 500, + "appName": "loadBalancingTest", + }, + } + with self.fail_point(delay_finds): + nodes = async_client_context.client.nodes + self.assertEqual(len(nodes), 1) + delayed_server = await anext(iter(nodes)) + freqs = self.frequencies(client, listener) + self.assertLessEqual(freqs[delayed_server], 0.25) + listener.reset() + freqs = self.frequencies(client, listener, n_finds=150) + self.assertAlmostEqual(freqs[delayed_server], 0.50, delta=0.15) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_server_selection_in_window.py b/test/test_server_selection_in_window.py index 05772fa385..a531dbafb6 100644 --- a/test/test_server_selection_in_window.py +++ b/test/test_server_selection_in_window.py @@ -15,8 +15,10 @@ """Test the topology module's Server Selection Spec implementation.""" from __future__ import annotations +import asyncio import os import threading +from pathlib import Path from test import IntegrationTest, client_context, unittest from test.utils import ( CMAPListener, @@ -32,10 +34,14 @@ from pymongo.operations import _Op from pymongo.read_preferences import ReadPreference +_IS_SYNC = True # Location of JSON test specifications. -TEST_PATH = os.path.join( - os.path.dirname(os.path.realpath(__file__)), os.path.join("server_selection", "in_window") -) +if _IS_SYNC: + TEST_PATH = os.path.join(Path(__file__).resolve().parent, "server_selection", "in_window") +else: + TEST_PATH = os.path.join( + Path(__file__).resolve().parent.parent, "server_selection", "in_window" + ) class TestAllScenarios(unittest.TestCase): @@ -91,35 +97,53 @@ def tests(self, scenario_def): CustomSpecTestCreator(create_test, TestAllScenarios, TEST_PATH).create_tests() +if _IS_SYNC: + PARENT = threading.Thread +else: + PARENT = object + -class FinderThread(threading.Thread): +class FinderThread(PARENT): def __init__(self, collection, iterations): super().__init__() self.daemon = True self.collection = collection self.iterations = iterations self.passed = False + self.task = None def run(self): for _ in range(self.iterations): self.collection.find_one({}) self.passed = True + def start(self): + if _IS_SYNC: + super().start() + else: + self.task = asyncio.create_task(self.run()) + + def join(self): + if _IS_SYNC: + super().join() + else: + self.task + class TestProse(IntegrationTest): def frequencies(self, client, listener, n_finds=10): coll = client.test.test - N_THREADS = 10 - threads = [FinderThread(coll, n_finds) for _ in range(N_THREADS)] - for thread in threads: - thread.start() - for thread in threads: - thread.join() - for thread in threads: - self.assertTrue(thread.passed) + N_TASKS = 10 + tasks = [FinderThread(coll, n_finds) for _ in range(N_TASKS)] + for task in tasks: + task.start() + for task in tasks: + task.join() + for task in tasks: + self.assertTrue(task.passed) events = listener.started_events - self.assertEqual(len(events), n_finds * N_THREADS) + self.assertEqual(len(events), n_finds * N_TASKS) nodes = client.nodes self.assertEqual(len(nodes), 2) freqs = {address: 0.0 for address in nodes} diff --git a/tools/synchro.py b/tools/synchro.py index eb44ef4ac0..16b28fffa7 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -224,6 +224,7 @@ def async_only_test(f: str) -> bool: "test_retryable_reads_unified.py", "test_retryable_writes.py", "test_retryable_writes_unified.py", + "test_server_selection_in_window.py", "test_session.py", "test_transactions.py", "unified_format.py", From ee04430e94acbc683c9aa48621644ba1179befd9 Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Fri, 31 Jan 2025 20:21:59 -0800 Subject: [PATCH 02/14] remove anext --- test/asynchronous/test_server_selection_in_window.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/asynchronous/test_server_selection_in_window.py b/test/asynchronous/test_server_selection_in_window.py index 5c2210b101..a170180b47 100644 --- a/test/asynchronous/test_server_selection_in_window.py +++ b/test/asynchronous/test_server_selection_in_window.py @@ -184,7 +184,7 @@ async def test_load_balancing(self): with self.fail_point(delay_finds): nodes = async_client_context.client.nodes self.assertEqual(len(nodes), 1) - delayed_server = await anext(iter(nodes)) + delayed_server = next(iter(nodes)) freqs = self.frequencies(client, listener) self.assertLessEqual(freqs[delayed_server], 0.25) listener.reset() From 3df3d8be07a2baa27d6f296dbcf7ee1c07822de0 Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Wed, 5 Feb 2025 18:15:34 -0800 Subject: [PATCH 03/14] use ConcurrentRunner --- .../test_server_selection_in_window.py | 23 +++---------------- test/test_server_selection_in_window.py | 23 +++---------------- 2 files changed, 6 insertions(+), 40 deletions(-) diff --git a/test/asynchronous/test_server_selection_in_window.py b/test/asynchronous/test_server_selection_in_window.py index a170180b47..f8c03633f8 100644 --- a/test/asynchronous/test_server_selection_in_window.py +++ b/test/asynchronous/test_server_selection_in_window.py @@ -20,6 +20,7 @@ import threading from pathlib import Path from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest +from test.asynchronous.helpers import ConcurrentRunner from test.utils import ( CMAPListener, OvertCommandListener, @@ -97,44 +98,26 @@ def tests(self, scenario_def): CustomSpecTestCreator(create_test, TestAllScenarios, TEST_PATH).create_tests() -if _IS_SYNC: - PARENT = threading.Thread -else: - PARENT = object - -class FinderThread(PARENT): +class FinderTask(ConcurrentRunner): def __init__(self, collection, iterations): super().__init__() self.daemon = True self.collection = collection self.iterations = iterations self.passed = False - self.task = None async def run(self): for _ in range(self.iterations): await self.collection.find_one({}) self.passed = True - def start(self): - if _IS_SYNC: - super().start() - else: - self.task = asyncio.create_task(self.run()) - - async def join(self): - if _IS_SYNC: - super().join() - else: - await self.task - class TestProse(AsyncIntegrationTest): def frequencies(self, client, listener, n_finds=10): coll = client.test.test N_TASKS = 10 - tasks = [FinderThread(coll, n_finds) for _ in range(N_TASKS)] + tasks = [FinderTask(coll, n_finds) for _ in range(N_TASKS)] for task in tasks: task.start() for task in tasks: diff --git a/test/test_server_selection_in_window.py b/test/test_server_selection_in_window.py index a531dbafb6..7ccd4b529e 100644 --- a/test/test_server_selection_in_window.py +++ b/test/test_server_selection_in_window.py @@ -20,6 +20,7 @@ import threading from pathlib import Path from test import IntegrationTest, client_context, unittest +from test.helpers import ConcurrentRunner from test.utils import ( CMAPListener, OvertCommandListener, @@ -97,44 +98,26 @@ def tests(self, scenario_def): CustomSpecTestCreator(create_test, TestAllScenarios, TEST_PATH).create_tests() -if _IS_SYNC: - PARENT = threading.Thread -else: - PARENT = object - -class FinderThread(PARENT): +class FinderTask(ConcurrentRunner): def __init__(self, collection, iterations): super().__init__() self.daemon = True self.collection = collection self.iterations = iterations self.passed = False - self.task = None def run(self): for _ in range(self.iterations): self.collection.find_one({}) self.passed = True - def start(self): - if _IS_SYNC: - super().start() - else: - self.task = asyncio.create_task(self.run()) - - def join(self): - if _IS_SYNC: - super().join() - else: - self.task - class TestProse(IntegrationTest): def frequencies(self, client, listener, n_finds=10): coll = client.test.test N_TASKS = 10 - tasks = [FinderThread(coll, n_finds) for _ in range(N_TASKS)] + tasks = [FinderTask(coll, n_finds) for _ in range(N_TASKS)] for task in tasks: task.start() for task in tasks: From 7aa61e48537485f6c324c6a00dc94f3be7c3f1ce Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Thu, 6 Feb 2025 12:17:45 -0800 Subject: [PATCH 04/14] create async utils_selection_tests --- .../test_server_selection_in_window.py | 18 +- test/asynchronous/utils_selection_tests.py | 271 ++++++++++++++++++ test/asynchronous/utils_spec_runner.py | 2 +- 3 files changed, 281 insertions(+), 10 deletions(-) create mode 100644 test/asynchronous/utils_selection_tests.py diff --git a/test/asynchronous/test_server_selection_in_window.py b/test/asynchronous/test_server_selection_in_window.py index f8c03633f8..9fbe17a1cc 100644 --- a/test/asynchronous/test_server_selection_in_window.py +++ b/test/asynchronous/test_server_selection_in_window.py @@ -21,14 +21,14 @@ from pathlib import Path from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest from test.asynchronous.helpers import ConcurrentRunner +from test.asynchronous.utils_selection_tests import create_topology +from test.asynchronous.utils_spec_runner import AsyncSpecTestCreator from test.utils import ( CMAPListener, OvertCommandListener, async_get_pool, async_wait_until, ) -from test.utils_selection_tests import create_topology -from test.utils_spec_runner import SpecTestCreator from pymongo.common import clean_node from pymongo.monitoring import ConnectionReadyEvent @@ -46,8 +46,8 @@ class TestAllScenarios(unittest.IsolatedAsyncioTestCase): - def run_scenario(self, scenario_def): - topology = create_topology(scenario_def) + async def run_scenario(self, scenario_def): + topology = await create_topology(scenario_def) # Update mock operation_count state: for mock in scenario_def["mocked_topology_state"]: @@ -61,7 +61,7 @@ def run_scenario(self, scenario_def): # Number of times to repeat server selection iterations = scenario_def["iterations"] for _ in range(iterations): - server = topology.select_server(pref, _Op.TEST, server_selection_timeout=0) + server = await topology.select_server(pref, _Op.TEST, server_selection_timeout=0) counts[server.description.address] += 1 # Verify expected_frequencies @@ -80,13 +80,13 @@ def run_scenario(self, scenario_def): def create_test(scenario_def, test, name): - def run_scenario(self): - self.run_scenario(scenario_def) + async def run_scenario(self): + await self.run_scenario(scenario_def) return run_scenario -class CustomSpecTestCreator(SpecTestCreator): +class CustomSpecTestCreator(AsyncSpecTestCreator): def tests(self, scenario_def): """Extract the tests from a spec file. @@ -164,7 +164,7 @@ async def test_load_balancing(self): "appName": "loadBalancingTest", }, } - with self.fail_point(delay_finds): + async with self.fail_point(delay_finds): nodes = async_client_context.client.nodes self.assertEqual(len(nodes), 1) delayed_server = next(iter(nodes)) diff --git a/test/asynchronous/utils_selection_tests.py b/test/asynchronous/utils_selection_tests.py new file mode 100644 index 0000000000..f6dd5e03b4 --- /dev/null +++ b/test/asynchronous/utils_selection_tests.py @@ -0,0 +1,271 @@ +# Copyright 2015-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for testing Server Selection and Max Staleness.""" +from __future__ import annotations + +import datetime +import os +import sys + +sys.path[0:0] = [""] + +from test import unittest +from test.pymongo_mocks import DummyMonitor +from test.utils import AsyncMockPool, parse_read_preference + +from bson import json_util +from pymongo.asynchronous.settings import TopologySettings +from pymongo.asynchronous.topology import Topology +from pymongo.common import HEARTBEAT_FREQUENCY, MIN_SUPPORTED_WIRE_VERSION, clean_node +from pymongo.errors import AutoReconnect, ConfigurationError +from pymongo.hello import Hello, HelloCompat +from pymongo.operations import _Op +from pymongo.server_description import ServerDescription +from pymongo.server_selectors import writable_server_selector + + +def get_addresses(server_list): + seeds = [] + hosts = [] + for server in server_list: + seeds.append(clean_node(server["address"])) + hosts.append(server["address"]) + return seeds, hosts + + +def make_last_write_date(server): + epoch = datetime.datetime.fromtimestamp(0, tz=datetime.timezone.utc).replace(tzinfo=None) + millis = server.get("lastWrite", {}).get("lastWriteDate") + if millis: + diff = ((millis % 1000) + 1000) % 1000 + seconds = (millis - diff) / 1000 + micros = diff * 1000 + return epoch + datetime.timedelta(seconds=seconds, microseconds=micros) + else: + # "Unknown" server. + return epoch + + +def make_server_description(server, hosts): + """Make a ServerDescription from server info in a JSON test.""" + server_type = server["type"] + if server_type in ("Unknown", "PossiblePrimary"): + return ServerDescription(clean_node(server["address"]), Hello({})) + + hello_response = {"ok": True, "hosts": hosts} + if server_type not in ("Standalone", "Mongos", "RSGhost"): + hello_response["setName"] = "rs" + + if server_type == "RSPrimary": + hello_response[HelloCompat.LEGACY_CMD] = True + elif server_type == "RSSecondary": + hello_response["secondary"] = True + elif server_type == "Mongos": + hello_response["msg"] = "isdbgrid" + elif server_type == "RSGhost": + hello_response["isreplicaset"] = True + elif server_type == "RSArbiter": + hello_response["arbiterOnly"] = True + + hello_response["lastWrite"] = {"lastWriteDate": make_last_write_date(server)} + + for field in "maxWireVersion", "tags", "idleWritePeriodMillis": + if field in server: + hello_response[field] = server[field] + + hello_response.setdefault("maxWireVersion", MIN_SUPPORTED_WIRE_VERSION) + + # Sets _last_update_time to now. + sd = ServerDescription( + clean_node(server["address"]), + Hello(hello_response), + round_trip_time=server["avg_rtt_ms"] / 1000.0, + ) + + if "lastUpdateTime" in server: + sd._last_update_time = server["lastUpdateTime"] / 1000.0 # ms to sec. + + return sd + + +def get_topology_type_name(scenario_def): + td = scenario_def["topology_description"] + name = td["type"] + if name == "Unknown": + # PyMongo never starts a topology in type Unknown. + return "Sharded" if len(td["servers"]) > 1 else "Single" + else: + return name + + +def get_topology_settings_dict(**kwargs): + settings = { + "monitor_class": DummyMonitor, + "heartbeat_frequency": HEARTBEAT_FREQUENCY, + "pool_class": AsyncMockPool, + } + settings.update(kwargs) + return settings + + +async def create_topology(scenario_def, **kwargs): + # Initialize topologies. + if "heartbeatFrequencyMS" in scenario_def: + frequency = int(scenario_def["heartbeatFrequencyMS"]) / 1000.0 + else: + frequency = HEARTBEAT_FREQUENCY + + seeds, hosts = get_addresses(scenario_def["topology_description"]["servers"]) + + topology_type = get_topology_type_name(scenario_def) + if topology_type == "LoadBalanced": + kwargs.setdefault("load_balanced", True) + # Force topology description to ReplicaSet + elif topology_type in ["ReplicaSetNoPrimary", "ReplicaSetWithPrimary"]: + kwargs.setdefault("replica_set_name", "rs") + settings = get_topology_settings_dict(heartbeat_frequency=frequency, seeds=seeds, **kwargs) + + # "Eligible servers" is defined in the server selection spec as + # the set of servers matching both the ReadPreference's mode + # and tag sets. + topology = Topology(TopologySettings(**settings)) + await topology.open() + + # Update topologies with server descriptions. + for server in scenario_def["topology_description"]["servers"]: + server_description = make_server_description(server, hosts) + await topology.on_change(server_description) + + # Assert that descriptions match + assert ( + scenario_def["topology_description"]["type"] == topology.description.topology_type_name + ), topology.description.topology_type_name + + return topology + + +def create_test(scenario_def): + async def run_scenario(self): + _, hosts = get_addresses(scenario_def["topology_description"]["servers"]) + # "Eligible servers" is defined in the server selection spec as + # the set of servers matching both the ReadPreference's mode + # and tag sets. + top_latency = await create_topology(scenario_def) + + # "In latency window" is defined in the server selection + # spec as the subset of suitable_servers that falls within the + # allowable latency window. + top_suitable = await create_topology(scenario_def, local_threshold_ms=1000000) + + # Create server selector. + if scenario_def.get("operation") == "write": + pref = writable_server_selector + else: + # Make first letter lowercase to match read_pref's modes. + pref_def = scenario_def["read_preference"] + if scenario_def.get("error"): + with self.assertRaises((ConfigurationError, ValueError)): + # Error can be raised when making Read Pref or selecting. + pref = parse_read_preference(pref_def) + await top_latency.select_server(pref, _Op.TEST) + return + + pref = parse_read_preference(pref_def) + + # Select servers. + if not scenario_def.get("suitable_servers"): + with self.assertRaises(AutoReconnect): + await top_suitable.select_server(pref, _Op.TEST, server_selection_timeout=0) + + return + + if not scenario_def["in_latency_window"]: + with self.assertRaises(AutoReconnect): + await top_latency.select_server(pref, _Op.TEST, server_selection_timeout=0) + + return + + actual_suitable_s = await top_suitable.select_servers( + pref, _Op.TEST, server_selection_timeout=0 + ) + actual_latency_s = await top_latency.select_servers( + pref, _Op.TEST, server_selection_timeout=0 + ) + + expected_suitable_servers = {} + for server in scenario_def["suitable_servers"]: + server_description = make_server_description(server, hosts) + expected_suitable_servers[server["address"]] = server_description + + actual_suitable_servers = {} + for s in actual_suitable_s: + actual_suitable_servers[ + "%s:%d" % (s.description.address[0], s.description.address[1]) + ] = s.description + + self.assertEqual(len(actual_suitable_servers), len(expected_suitable_servers)) + for k, actual in actual_suitable_servers.items(): + expected = expected_suitable_servers[k] + self.assertEqual(expected.address, actual.address) + self.assertEqual(expected.server_type, actual.server_type) + self.assertEqual(expected.round_trip_time, actual.round_trip_time) + self.assertEqual(expected.tags, actual.tags) + self.assertEqual(expected.all_hosts, actual.all_hosts) + + expected_latency_servers = {} + for server in scenario_def["in_latency_window"]: + server_description = make_server_description(server, hosts) + expected_latency_servers[server["address"]] = server_description + + actual_latency_servers = {} + for s in actual_latency_s: + actual_latency_servers[ + "%s:%d" % (s.description.address[0], s.description.address[1]) + ] = s.description + + self.assertEqual(len(actual_latency_servers), len(expected_latency_servers)) + for k, actual in actual_latency_servers.items(): + expected = expected_latency_servers[k] + self.assertEqual(expected.address, actual.address) + self.assertEqual(expected.server_type, actual.server_type) + self.assertEqual(expected.round_trip_time, actual.round_trip_time) + self.assertEqual(expected.tags, actual.tags) + self.assertEqual(expected.all_hosts, actual.all_hosts) + + return run_scenario + + +def create_selection_tests(test_dir): + class TestAllScenarios(unittest.TestCase): + pass + + for dirpath, _, filenames in os.walk(test_dir): + dirname = os.path.split(dirpath) + dirname = os.path.split(dirname[-2])[-1] + "_" + dirname[-1] + + for filename in filenames: + if os.path.splitext(filename)[1] != ".json": + continue + with open(os.path.join(dirpath, filename)) as scenario_stream: + scenario_def = json_util.loads(scenario_stream.read()) + + # Construct test from scenario. + new_test = create_test(scenario_def) + test_name = f"test_{dirname}_{os.path.splitext(filename)[0]}" + + new_test.__name__ = test_name + setattr(TestAllScenarios, new_test.__name__, new_test) + + return TestAllScenarios diff --git a/test/asynchronous/utils_spec_runner.py b/test/asynchronous/utils_spec_runner.py index d433f1a7e6..11d88850fc 100644 --- a/test/asynchronous/utils_spec_runner.py +++ b/test/asynchronous/utils_spec_runner.py @@ -229,7 +229,7 @@ async def _create_tests(self): str(test_def["description"].replace(" ", "_").replace(".", "_")), ) - new_test = await self._create_test(scenario_def, test_def, test_name) + new_test = self._create_test(scenario_def, test_def, test_name) new_test = self._ensure_min_max_server_version(scenario_def, new_test) new_test = self.ensure_run_on(scenario_def, new_test) From 64aa04875a74b5b01b15956b3b4113b632859f45 Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Thu, 6 Feb 2025 13:08:27 -0800 Subject: [PATCH 05/14] create utils_selection_tests_shared.py --- test/asynchronous/utils_selection_tests.py | 83 ++--------------- test/utils_selection_tests.py | 83 ++--------------- test/utils_selection_tests_shared.py | 100 +++++++++++++++++++++ 3 files changed, 112 insertions(+), 154 deletions(-) create mode 100644 test/utils_selection_tests_shared.py diff --git a/test/asynchronous/utils_selection_tests.py b/test/asynchronous/utils_selection_tests.py index f6dd5e03b4..7f9711ca0a 100644 --- a/test/asynchronous/utils_selection_tests.py +++ b/test/asynchronous/utils_selection_tests.py @@ -24,92 +24,21 @@ from test import unittest from test.pymongo_mocks import DummyMonitor from test.utils import AsyncMockPool, parse_read_preference +from test.utils_selection_tests_shared import ( + get_addresses, + get_topology_type_name, + make_server_description, +) from bson import json_util from pymongo.asynchronous.settings import TopologySettings from pymongo.asynchronous.topology import Topology -from pymongo.common import HEARTBEAT_FREQUENCY, MIN_SUPPORTED_WIRE_VERSION, clean_node +from pymongo.common import HEARTBEAT_FREQUENCY from pymongo.errors import AutoReconnect, ConfigurationError -from pymongo.hello import Hello, HelloCompat from pymongo.operations import _Op -from pymongo.server_description import ServerDescription from pymongo.server_selectors import writable_server_selector -def get_addresses(server_list): - seeds = [] - hosts = [] - for server in server_list: - seeds.append(clean_node(server["address"])) - hosts.append(server["address"]) - return seeds, hosts - - -def make_last_write_date(server): - epoch = datetime.datetime.fromtimestamp(0, tz=datetime.timezone.utc).replace(tzinfo=None) - millis = server.get("lastWrite", {}).get("lastWriteDate") - if millis: - diff = ((millis % 1000) + 1000) % 1000 - seconds = (millis - diff) / 1000 - micros = diff * 1000 - return epoch + datetime.timedelta(seconds=seconds, microseconds=micros) - else: - # "Unknown" server. - return epoch - - -def make_server_description(server, hosts): - """Make a ServerDescription from server info in a JSON test.""" - server_type = server["type"] - if server_type in ("Unknown", "PossiblePrimary"): - return ServerDescription(clean_node(server["address"]), Hello({})) - - hello_response = {"ok": True, "hosts": hosts} - if server_type not in ("Standalone", "Mongos", "RSGhost"): - hello_response["setName"] = "rs" - - if server_type == "RSPrimary": - hello_response[HelloCompat.LEGACY_CMD] = True - elif server_type == "RSSecondary": - hello_response["secondary"] = True - elif server_type == "Mongos": - hello_response["msg"] = "isdbgrid" - elif server_type == "RSGhost": - hello_response["isreplicaset"] = True - elif server_type == "RSArbiter": - hello_response["arbiterOnly"] = True - - hello_response["lastWrite"] = {"lastWriteDate": make_last_write_date(server)} - - for field in "maxWireVersion", "tags", "idleWritePeriodMillis": - if field in server: - hello_response[field] = server[field] - - hello_response.setdefault("maxWireVersion", MIN_SUPPORTED_WIRE_VERSION) - - # Sets _last_update_time to now. - sd = ServerDescription( - clean_node(server["address"]), - Hello(hello_response), - round_trip_time=server["avg_rtt_ms"] / 1000.0, - ) - - if "lastUpdateTime" in server: - sd._last_update_time = server["lastUpdateTime"] / 1000.0 # ms to sec. - - return sd - - -def get_topology_type_name(scenario_def): - td = scenario_def["topology_description"] - name = td["type"] - if name == "Unknown": - # PyMongo never starts a topology in type Unknown. - return "Sharded" if len(td["servers"]) > 1 else "Single" - else: - return name - - def get_topology_settings_dict(**kwargs): settings = { "monitor_class": DummyMonitor, diff --git a/test/utils_selection_tests.py b/test/utils_selection_tests.py index 2d21888e27..3cb353490c 100644 --- a/test/utils_selection_tests.py +++ b/test/utils_selection_tests.py @@ -24,92 +24,21 @@ from test import unittest from test.pymongo_mocks import DummyMonitor from test.utils import MockPool, parse_read_preference +from test.utils_selection_tests_shared import ( + get_addresses, + get_topology_type_name, + make_server_description, +) from bson import json_util -from pymongo.common import HEARTBEAT_FREQUENCY, MIN_SUPPORTED_WIRE_VERSION, clean_node +from pymongo.common import HEARTBEAT_FREQUENCY from pymongo.errors import AutoReconnect, ConfigurationError -from pymongo.hello import Hello, HelloCompat from pymongo.operations import _Op -from pymongo.server_description import ServerDescription from pymongo.server_selectors import writable_server_selector from pymongo.synchronous.settings import TopologySettings from pymongo.synchronous.topology import Topology -def get_addresses(server_list): - seeds = [] - hosts = [] - for server in server_list: - seeds.append(clean_node(server["address"])) - hosts.append(server["address"]) - return seeds, hosts - - -def make_last_write_date(server): - epoch = datetime.datetime.fromtimestamp(0, tz=datetime.timezone.utc).replace(tzinfo=None) - millis = server.get("lastWrite", {}).get("lastWriteDate") - if millis: - diff = ((millis % 1000) + 1000) % 1000 - seconds = (millis - diff) / 1000 - micros = diff * 1000 - return epoch + datetime.timedelta(seconds=seconds, microseconds=micros) - else: - # "Unknown" server. - return epoch - - -def make_server_description(server, hosts): - """Make a ServerDescription from server info in a JSON test.""" - server_type = server["type"] - if server_type in ("Unknown", "PossiblePrimary"): - return ServerDescription(clean_node(server["address"]), Hello({})) - - hello_response = {"ok": True, "hosts": hosts} - if server_type not in ("Standalone", "Mongos", "RSGhost"): - hello_response["setName"] = "rs" - - if server_type == "RSPrimary": - hello_response[HelloCompat.LEGACY_CMD] = True - elif server_type == "RSSecondary": - hello_response["secondary"] = True - elif server_type == "Mongos": - hello_response["msg"] = "isdbgrid" - elif server_type == "RSGhost": - hello_response["isreplicaset"] = True - elif server_type == "RSArbiter": - hello_response["arbiterOnly"] = True - - hello_response["lastWrite"] = {"lastWriteDate": make_last_write_date(server)} - - for field in "maxWireVersion", "tags", "idleWritePeriodMillis": - if field in server: - hello_response[field] = server[field] - - hello_response.setdefault("maxWireVersion", MIN_SUPPORTED_WIRE_VERSION) - - # Sets _last_update_time to now. - sd = ServerDescription( - clean_node(server["address"]), - Hello(hello_response), - round_trip_time=server["avg_rtt_ms"] / 1000.0, - ) - - if "lastUpdateTime" in server: - sd._last_update_time = server["lastUpdateTime"] / 1000.0 # ms to sec. - - return sd - - -def get_topology_type_name(scenario_def): - td = scenario_def["topology_description"] - name = td["type"] - if name == "Unknown": - # PyMongo never starts a topology in type Unknown. - return "Sharded" if len(td["servers"]) > 1 else "Single" - else: - return name - - def get_topology_settings_dict(**kwargs): settings = { "monitor_class": DummyMonitor, diff --git a/test/utils_selection_tests_shared.py b/test/utils_selection_tests_shared.py new file mode 100644 index 0000000000..dbaed1034f --- /dev/null +++ b/test/utils_selection_tests_shared.py @@ -0,0 +1,100 @@ +# Copyright 2015-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for testing Server Selection and Max Staleness.""" +from __future__ import annotations + +import datetime +import os +import sys + +sys.path[0:0] = [""] + +from pymongo.common import MIN_SUPPORTED_WIRE_VERSION, clean_node +from pymongo.hello import Hello, HelloCompat +from pymongo.server_description import ServerDescription + + +def get_addresses(server_list): + seeds = [] + hosts = [] + for server in server_list: + seeds.append(clean_node(server["address"])) + hosts.append(server["address"]) + return seeds, hosts + + +def make_last_write_date(server): + epoch = datetime.datetime.fromtimestamp(0, tz=datetime.timezone.utc).replace(tzinfo=None) + millis = server.get("lastWrite", {}).get("lastWriteDate") + if millis: + diff = ((millis % 1000) + 1000) % 1000 + seconds = (millis - diff) / 1000 + micros = diff * 1000 + return epoch + datetime.timedelta(seconds=seconds, microseconds=micros) + else: + # "Unknown" server. + return epoch + + +def make_server_description(server, hosts): + """Make a ServerDescription from server info in a JSON test.""" + server_type = server["type"] + if server_type in ("Unknown", "PossiblePrimary"): + return ServerDescription(clean_node(server["address"]), Hello({})) + + hello_response = {"ok": True, "hosts": hosts} + if server_type not in ("Standalone", "Mongos", "RSGhost"): + hello_response["setName"] = "rs" + + if server_type == "RSPrimary": + hello_response[HelloCompat.LEGACY_CMD] = True + elif server_type == "RSSecondary": + hello_response["secondary"] = True + elif server_type == "Mongos": + hello_response["msg"] = "isdbgrid" + elif server_type == "RSGhost": + hello_response["isreplicaset"] = True + elif server_type == "RSArbiter": + hello_response["arbiterOnly"] = True + + hello_response["lastWrite"] = {"lastWriteDate": make_last_write_date(server)} + + for field in "maxWireVersion", "tags", "idleWritePeriodMillis": + if field in server: + hello_response[field] = server[field] + + hello_response.setdefault("maxWireVersion", MIN_SUPPORTED_WIRE_VERSION) + + # Sets _last_update_time to now. + sd = ServerDescription( + clean_node(server["address"]), + Hello(hello_response), + round_trip_time=server["avg_rtt_ms"] / 1000.0, + ) + + if "lastUpdateTime" in server: + sd._last_update_time = server["lastUpdateTime"] / 1000.0 # ms to sec. + + return sd + + +def get_topology_type_name(scenario_def): + td = scenario_def["topology_description"] + name = td["type"] + if name == "Unknown": + # PyMongo never starts a topology in type Unknown. + return "Sharded" if len(td["servers"]) > 1 else "Single" + else: + return name From b56b534925cde1d970938b415d7d8f7327e83863 Mon Sep 17 00:00:00 2001 From: Iris <58442094+sleepyStick@users.noreply.github.com> Date: Thu, 6 Feb 2025 14:01:24 -0800 Subject: [PATCH 06/14] Update test/asynchronous/utils_selection_tests.py Co-authored-by: Noah Stapp --- test/asynchronous/utils_selection_tests.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/asynchronous/utils_selection_tests.py b/test/asynchronous/utils_selection_tests.py index 7f9711ca0a..110c10e18c 100644 --- a/test/asynchronous/utils_selection_tests.py +++ b/test/asynchronous/utils_selection_tests.py @@ -177,7 +177,7 @@ async def run_scenario(self): def create_selection_tests(test_dir): - class TestAllScenarios(unittest.TestCase): + class TestAllScenarios(AsyncPyMongoTestCase): pass for dirpath, _, filenames in os.walk(test_dir): From 104ae52d98a96ac16e149cb7ad381354d6b73c40 Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Fri, 7 Feb 2025 07:29:56 -0800 Subject: [PATCH 07/14] make encryption's create_test not async --- test/asynchronous/test_encryption.py | 2 +- test/asynchronous/utils_selection_tests.py | 1 + test/utils_selection_tests.py | 3 ++- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/test/asynchronous/test_encryption.py b/test/asynchronous/test_encryption.py index 2b22bd8b76..335aa9d81c 100644 --- a/test/asynchronous/test_encryption.py +++ b/test/asynchronous/test_encryption.py @@ -739,7 +739,7 @@ def allowable_errors(self, op): return errors -async def create_test(scenario_def, test, name): +def create_test(scenario_def, test, name): @async_client_context.require_test_commands async def run_scenario(self): await self.run_scenario(scenario_def, test) diff --git a/test/asynchronous/utils_selection_tests.py b/test/asynchronous/utils_selection_tests.py index 110c10e18c..ffb6fb9167 100644 --- a/test/asynchronous/utils_selection_tests.py +++ b/test/asynchronous/utils_selection_tests.py @@ -18,6 +18,7 @@ import datetime import os import sys +from test.asynchronous import AsyncPyMongoTestCase sys.path[0:0] = [""] diff --git a/test/utils_selection_tests.py b/test/utils_selection_tests.py index 3cb353490c..01e2133b1a 100644 --- a/test/utils_selection_tests.py +++ b/test/utils_selection_tests.py @@ -18,6 +18,7 @@ import datetime import os import sys +from test.synchronous import PyMongoTestCase sys.path[0:0] = [""] @@ -173,7 +174,7 @@ def run_scenario(self): def create_selection_tests(test_dir): - class TestAllScenarios(unittest.TestCase): + class TestAllScenarios(PyMongoTestCase): pass for dirpath, _, filenames in os.walk(test_dir): From 1c45977f63d6c1871a1590cf16e828d64f160507 Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Fri, 7 Feb 2025 07:35:39 -0800 Subject: [PATCH 08/14] add _is_sync = false to utils_selection_tests and add to synchro --- test/asynchronous/utils_selection_tests.py | 2 ++ test/utils_selection_tests.py | 4 +++- tools/synchro.py | 1 + 3 files changed, 6 insertions(+), 1 deletion(-) diff --git a/test/asynchronous/utils_selection_tests.py b/test/asynchronous/utils_selection_tests.py index ffb6fb9167..71e287569a 100644 --- a/test/asynchronous/utils_selection_tests.py +++ b/test/asynchronous/utils_selection_tests.py @@ -39,6 +39,8 @@ from pymongo.operations import _Op from pymongo.server_selectors import writable_server_selector +_IS_SYNC = False + def get_topology_settings_dict(**kwargs): settings = { diff --git a/test/utils_selection_tests.py b/test/utils_selection_tests.py index 01e2133b1a..9667ea701b 100644 --- a/test/utils_selection_tests.py +++ b/test/utils_selection_tests.py @@ -18,7 +18,7 @@ import datetime import os import sys -from test.synchronous import PyMongoTestCase +from test import PyMongoTestCase sys.path[0:0] = [""] @@ -39,6 +39,8 @@ from pymongo.synchronous.settings import TopologySettings from pymongo.synchronous.topology import Topology +_IS_SYNC = True + def get_topology_settings_dict(**kwargs): settings = { diff --git a/tools/synchro.py b/tools/synchro.py index deb057105b..0ef0cd462c 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -246,6 +246,7 @@ def async_only_test(f: str) -> bool: "test_unified_format.py", "test_versioned_api_integration.py", "unified_format.py", + "utils_selection_tests.py", ] From f6f7ae898565e6f032122197eb622c5ccd37eac8 Mon Sep 17 00:00:00 2001 From: Iris <58442094+sleepyStick@users.noreply.github.com> Date: Fri, 7 Feb 2025 08:04:22 -0800 Subject: [PATCH 09/14] Update test/asynchronous/test_server_selection_in_window.py Co-authored-by: Noah Stapp --- test/asynchronous/test_server_selection_in_window.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/asynchronous/test_server_selection_in_window.py b/test/asynchronous/test_server_selection_in_window.py index 9fbe17a1cc..f9bf4e2fa3 100644 --- a/test/asynchronous/test_server_selection_in_window.py +++ b/test/asynchronous/test_server_selection_in_window.py @@ -119,7 +119,7 @@ def frequencies(self, client, listener, n_finds=10): N_TASKS = 10 tasks = [FinderTask(coll, n_finds) for _ in range(N_TASKS)] for task in tasks: - task.start() + await task.start() for task in tasks: task.join() for task in tasks: From 7f1472b5e2f770ba12ae6c096ea881b260cc0942 Mon Sep 17 00:00:00 2001 From: Iris <58442094+sleepyStick@users.noreply.github.com> Date: Fri, 7 Feb 2025 08:04:48 -0800 Subject: [PATCH 10/14] Update test/asynchronous/test_server_selection_in_window.py Co-authored-by: Noah Stapp --- test/asynchronous/test_server_selection_in_window.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/asynchronous/test_server_selection_in_window.py b/test/asynchronous/test_server_selection_in_window.py index f9bf4e2fa3..425c9496d3 100644 --- a/test/asynchronous/test_server_selection_in_window.py +++ b/test/asynchronous/test_server_selection_in_window.py @@ -121,7 +121,7 @@ def frequencies(self, client, listener, n_finds=10): for task in tasks: await task.start() for task in tasks: - task.join() + await task.join() for task in tasks: self.assertTrue(task.passed) From 30e6c6244b0eac6bf98e02b5e22773d801f78176 Mon Sep 17 00:00:00 2001 From: Iris <58442094+sleepyStick@users.noreply.github.com> Date: Fri, 7 Feb 2025 08:21:51 -0800 Subject: [PATCH 11/14] Update test/asynchronous/test_server_selection_in_window.py Co-authored-by: Noah Stapp --- test/asynchronous/test_server_selection_in_window.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/asynchronous/test_server_selection_in_window.py b/test/asynchronous/test_server_selection_in_window.py index 425c9496d3..4e165f9a7a 100644 --- a/test/asynchronous/test_server_selection_in_window.py +++ b/test/asynchronous/test_server_selection_in_window.py @@ -114,7 +114,7 @@ async def run(self): class TestProse(AsyncIntegrationTest): - def frequencies(self, client, listener, n_finds=10): + async def frequencies(self, client, listener, n_finds=10): coll = client.test.test N_TASKS = 10 tasks = [FinderTask(coll, n_finds) for _ in range(N_TASKS)] From 195d47722c8e36dd8fbb205fd0eb4ce17b5e89ff Mon Sep 17 00:00:00 2001 From: Iris <58442094+sleepyStick@users.noreply.github.com> Date: Mon, 10 Feb 2025 08:56:22 -0800 Subject: [PATCH 12/14] Update test/asynchronous/test_server_selection_in_window.py Co-authored-by: Noah Stapp --- test/asynchronous/test_server_selection_in_window.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/asynchronous/test_server_selection_in_window.py b/test/asynchronous/test_server_selection_in_window.py index 4e165f9a7a..6a3ccd3524 100644 --- a/test/asynchronous/test_server_selection_in_window.py +++ b/test/asynchronous/test_server_selection_in_window.py @@ -152,7 +152,7 @@ async def test_load_balancing(self): ) await async_wait_until(lambda: len(client.nodes) == 2, "discover both nodes") # Wait for both pools to be populated. - cmap_listener.wait_for_event(ConnectionReadyEvent, 20) + await cmap_listener.async_wait_for_event(ConnectionReadyEvent, 20) # Delay find commands on only one mongos. delay_finds = { "configureFailPoint": "failCommand", From 60c93fc5d2137ecb38ca1430d4e24667309c7b38 Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Mon, 10 Feb 2025 11:08:47 -0800 Subject: [PATCH 13/14] add missing await --- test/asynchronous/test_server_selection_in_window.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/asynchronous/test_server_selection_in_window.py b/test/asynchronous/test_server_selection_in_window.py index 6a3ccd3524..4d2b36a065 100644 --- a/test/asynchronous/test_server_selection_in_window.py +++ b/test/asynchronous/test_server_selection_in_window.py @@ -168,10 +168,10 @@ async def test_load_balancing(self): nodes = async_client_context.client.nodes self.assertEqual(len(nodes), 1) delayed_server = next(iter(nodes)) - freqs = self.frequencies(client, listener) + freqs = await self.frequencies(client, listener) self.assertLessEqual(freqs[delayed_server], 0.25) listener.reset() - freqs = self.frequencies(client, listener, n_finds=150) + freqs = await self.frequencies(client, listener, n_finds=150) self.assertAlmostEqual(freqs[delayed_server], 0.50, delta=0.15) From b31172da8a12b15a66849a8ec3d852a9764239a3 Mon Sep 17 00:00:00 2001 From: Iris <58442094+sleepyStick@users.noreply.github.com> Date: Tue, 11 Feb 2025 09:00:53 -0800 Subject: [PATCH 14/14] Update test/asynchronous/test_server_selection_in_window.py Co-authored-by: Noah Stapp --- test/asynchronous/test_server_selection_in_window.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/asynchronous/test_server_selection_in_window.py b/test/asynchronous/test_server_selection_in_window.py index 4d2b36a065..e2ae92a27c 100644 --- a/test/asynchronous/test_server_selection_in_window.py +++ b/test/asynchronous/test_server_selection_in_window.py @@ -159,7 +159,7 @@ async def test_load_balancing(self): "mode": {"times": 10000}, "data": { "failCommands": ["find"], - "blockAsyncConnection": True, + "blockConnection": True, "blockTimeMS": 500, "appName": "loadBalancingTest", },