From f73e0d25c5b476dd335c23542fc3cbc1d5586f63 Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Fri, 31 Jan 2025 19:28:14 -0800 Subject: [PATCH] Convert test.test_server_selection_in_window to async --- .../test_server_selection_in_window.py | 196 ++++++++++++++++++ test/test_server_selection_in_window.py | 50 +++-- tools/synchro.py | 1 + 3 files changed, 234 insertions(+), 13 deletions(-) create mode 100644 test/asynchronous/test_server_selection_in_window.py diff --git a/test/asynchronous/test_server_selection_in_window.py b/test/asynchronous/test_server_selection_in_window.py new file mode 100644 index 0000000000..5c2210b101 --- /dev/null +++ b/test/asynchronous/test_server_selection_in_window.py @@ -0,0 +1,196 @@ +# Copyright 2020-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test the topology module's Server Selection Spec implementation.""" +from __future__ import annotations + +import asyncio +import os +import threading +from pathlib import Path +from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest +from test.utils import ( + CMAPListener, + OvertCommandListener, + async_get_pool, + async_wait_until, +) +from test.utils_selection_tests import create_topology +from test.utils_spec_runner import SpecTestCreator + +from pymongo.common import clean_node +from pymongo.monitoring import ConnectionReadyEvent +from pymongo.operations import _Op +from pymongo.read_preferences import ReadPreference + +_IS_SYNC = False +# Location of JSON test specifications. +if _IS_SYNC: + TEST_PATH = os.path.join(Path(__file__).resolve().parent, "server_selection", "in_window") +else: + TEST_PATH = os.path.join( + Path(__file__).resolve().parent.parent, "server_selection", "in_window" + ) + + +class TestAllScenarios(unittest.IsolatedAsyncioTestCase): + def run_scenario(self, scenario_def): + topology = create_topology(scenario_def) + + # Update mock operation_count state: + for mock in scenario_def["mocked_topology_state"]: + address = clean_node(mock["address"]) + server = topology.get_server_by_address(address) + server.pool.operation_count = mock["operation_count"] + + pref = ReadPreference.NEAREST + counts = {address: 0 for address in topology._description.server_descriptions()} + + # Number of times to repeat server selection + iterations = scenario_def["iterations"] + for _ in range(iterations): + server = topology.select_server(pref, _Op.TEST, server_selection_timeout=0) + counts[server.description.address] += 1 + + # Verify expected_frequencies + outcome = scenario_def["outcome"] + tolerance = outcome["tolerance"] + expected_frequencies = outcome["expected_frequencies"] + for host_str, freq in expected_frequencies.items(): + address = clean_node(host_str) + actual_freq = float(counts[address]) / iterations + if freq == 0: + # Should be exactly 0. + self.assertEqual(actual_freq, 0) + else: + # Should be within 'tolerance'. + self.assertAlmostEqual(actual_freq, freq, delta=tolerance) + + +def create_test(scenario_def, test, name): + def run_scenario(self): + self.run_scenario(scenario_def) + + return run_scenario + + +class CustomSpecTestCreator(SpecTestCreator): + def tests(self, scenario_def): + """Extract the tests from a spec file. + + Server selection in_window tests do not have a 'tests' field. + The whole file represents a single test case. + """ + return [scenario_def] + + +CustomSpecTestCreator(create_test, TestAllScenarios, TEST_PATH).create_tests() + +if _IS_SYNC: + PARENT = threading.Thread +else: + PARENT = object + + +class FinderThread(PARENT): + def __init__(self, collection, iterations): + super().__init__() + self.daemon = True + self.collection = collection + self.iterations = iterations + self.passed = False + self.task = None + + async def run(self): + for _ in range(self.iterations): + await self.collection.find_one({}) + self.passed = True + + def start(self): + if _IS_SYNC: + super().start() + else: + self.task = asyncio.create_task(self.run()) + + async def join(self): + if _IS_SYNC: + super().join() + else: + await self.task + + +class TestProse(AsyncIntegrationTest): + def frequencies(self, client, listener, n_finds=10): + coll = client.test.test + N_TASKS = 10 + tasks = [FinderThread(coll, n_finds) for _ in range(N_TASKS)] + for task in tasks: + task.start() + for task in tasks: + task.join() + for task in tasks: + self.assertTrue(task.passed) + + events = listener.started_events + self.assertEqual(len(events), n_finds * N_TASKS) + nodes = client.nodes + self.assertEqual(len(nodes), 2) + freqs = {address: 0.0 for address in nodes} + for event in events: + freqs[event.connection_id] += 1 + for address in freqs: + freqs[address] = freqs[address] / float(len(events)) + return freqs + + @async_client_context.require_failCommand_appName + @async_client_context.require_multiple_mongoses + async def test_load_balancing(self): + listener = OvertCommandListener() + cmap_listener = CMAPListener() + # PYTHON-2584: Use a large localThresholdMS to avoid the impact of + # varying RTTs. + client = await self.async_rs_client( + async_client_context.mongos_seeds(), + appName="loadBalancingTest", + event_listeners=[listener, cmap_listener], + localThresholdMS=30000, + minPoolSize=10, + ) + await async_wait_until(lambda: len(client.nodes) == 2, "discover both nodes") + # Wait for both pools to be populated. + cmap_listener.wait_for_event(ConnectionReadyEvent, 20) + # Delay find commands on only one mongos. + delay_finds = { + "configureFailPoint": "failCommand", + "mode": {"times": 10000}, + "data": { + "failCommands": ["find"], + "blockAsyncConnection": True, + "blockTimeMS": 500, + "appName": "loadBalancingTest", + }, + } + with self.fail_point(delay_finds): + nodes = async_client_context.client.nodes + self.assertEqual(len(nodes), 1) + delayed_server = await anext(iter(nodes)) + freqs = self.frequencies(client, listener) + self.assertLessEqual(freqs[delayed_server], 0.25) + listener.reset() + freqs = self.frequencies(client, listener, n_finds=150) + self.assertAlmostEqual(freqs[delayed_server], 0.50, delta=0.15) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_server_selection_in_window.py b/test/test_server_selection_in_window.py index 05772fa385..a531dbafb6 100644 --- a/test/test_server_selection_in_window.py +++ b/test/test_server_selection_in_window.py @@ -15,8 +15,10 @@ """Test the topology module's Server Selection Spec implementation.""" from __future__ import annotations +import asyncio import os import threading +from pathlib import Path from test import IntegrationTest, client_context, unittest from test.utils import ( CMAPListener, @@ -32,10 +34,14 @@ from pymongo.operations import _Op from pymongo.read_preferences import ReadPreference +_IS_SYNC = True # Location of JSON test specifications. -TEST_PATH = os.path.join( - os.path.dirname(os.path.realpath(__file__)), os.path.join("server_selection", "in_window") -) +if _IS_SYNC: + TEST_PATH = os.path.join(Path(__file__).resolve().parent, "server_selection", "in_window") +else: + TEST_PATH = os.path.join( + Path(__file__).resolve().parent.parent, "server_selection", "in_window" + ) class TestAllScenarios(unittest.TestCase): @@ -91,35 +97,53 @@ def tests(self, scenario_def): CustomSpecTestCreator(create_test, TestAllScenarios, TEST_PATH).create_tests() +if _IS_SYNC: + PARENT = threading.Thread +else: + PARENT = object + -class FinderThread(threading.Thread): +class FinderThread(PARENT): def __init__(self, collection, iterations): super().__init__() self.daemon = True self.collection = collection self.iterations = iterations self.passed = False + self.task = None def run(self): for _ in range(self.iterations): self.collection.find_one({}) self.passed = True + def start(self): + if _IS_SYNC: + super().start() + else: + self.task = asyncio.create_task(self.run()) + + def join(self): + if _IS_SYNC: + super().join() + else: + self.task + class TestProse(IntegrationTest): def frequencies(self, client, listener, n_finds=10): coll = client.test.test - N_THREADS = 10 - threads = [FinderThread(coll, n_finds) for _ in range(N_THREADS)] - for thread in threads: - thread.start() - for thread in threads: - thread.join() - for thread in threads: - self.assertTrue(thread.passed) + N_TASKS = 10 + tasks = [FinderThread(coll, n_finds) for _ in range(N_TASKS)] + for task in tasks: + task.start() + for task in tasks: + task.join() + for task in tasks: + self.assertTrue(task.passed) events = listener.started_events - self.assertEqual(len(events), n_finds * N_THREADS) + self.assertEqual(len(events), n_finds * N_TASKS) nodes = client.nodes self.assertEqual(len(nodes), 2) freqs = {address: 0.0 for address in nodes} diff --git a/tools/synchro.py b/tools/synchro.py index eb44ef4ac0..16b28fffa7 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -224,6 +224,7 @@ def async_only_test(f: str) -> bool: "test_retryable_reads_unified.py", "test_retryable_writes.py", "test_retryable_writes_unified.py", + "test_server_selection_in_window.py", "test_session.py", "test_transactions.py", "unified_format.py",