diff --git a/test/asynchronous/test_encryption.py b/test/asynchronous/test_encryption.py index 2b22bd8b76..335aa9d81c 100644 --- a/test/asynchronous/test_encryption.py +++ b/test/asynchronous/test_encryption.py @@ -739,7 +739,7 @@ def allowable_errors(self, op): return errors -async def create_test(scenario_def, test, name): +def create_test(scenario_def, test, name): @async_client_context.require_test_commands async def run_scenario(self): await self.run_scenario(scenario_def, test) diff --git a/test/asynchronous/test_server_selection_in_window.py b/test/asynchronous/test_server_selection_in_window.py new file mode 100644 index 0000000000..e2ae92a27c --- /dev/null +++ b/test/asynchronous/test_server_selection_in_window.py @@ -0,0 +1,179 @@ +# Copyright 2020-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test the topology module's Server Selection Spec implementation.""" +from __future__ import annotations + +import asyncio +import os +import threading +from pathlib import Path +from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest +from test.asynchronous.helpers import ConcurrentRunner +from test.asynchronous.utils_selection_tests import create_topology +from test.asynchronous.utils_spec_runner import AsyncSpecTestCreator +from test.utils import ( + CMAPListener, + OvertCommandListener, + async_get_pool, + async_wait_until, +) + +from pymongo.common import clean_node +from pymongo.monitoring import ConnectionReadyEvent +from pymongo.operations import _Op +from pymongo.read_preferences import ReadPreference + +_IS_SYNC = False +# Location of JSON test specifications. +if _IS_SYNC: + TEST_PATH = os.path.join(Path(__file__).resolve().parent, "server_selection", "in_window") +else: + TEST_PATH = os.path.join( + Path(__file__).resolve().parent.parent, "server_selection", "in_window" + ) + + +class TestAllScenarios(unittest.IsolatedAsyncioTestCase): + async def run_scenario(self, scenario_def): + topology = await create_topology(scenario_def) + + # Update mock operation_count state: + for mock in scenario_def["mocked_topology_state"]: + address = clean_node(mock["address"]) + server = topology.get_server_by_address(address) + server.pool.operation_count = mock["operation_count"] + + pref = ReadPreference.NEAREST + counts = {address: 0 for address in topology._description.server_descriptions()} + + # Number of times to repeat server selection + iterations = scenario_def["iterations"] + for _ in range(iterations): + server = await topology.select_server(pref, _Op.TEST, server_selection_timeout=0) + counts[server.description.address] += 1 + + # Verify expected_frequencies + outcome = scenario_def["outcome"] + tolerance = outcome["tolerance"] + expected_frequencies = outcome["expected_frequencies"] + for host_str, freq in expected_frequencies.items(): + address = clean_node(host_str) + actual_freq = float(counts[address]) / iterations + if freq == 0: + # Should be exactly 0. + self.assertEqual(actual_freq, 0) + else: + # Should be within 'tolerance'. + self.assertAlmostEqual(actual_freq, freq, delta=tolerance) + + +def create_test(scenario_def, test, name): + async def run_scenario(self): + await self.run_scenario(scenario_def) + + return run_scenario + + +class CustomSpecTestCreator(AsyncSpecTestCreator): + def tests(self, scenario_def): + """Extract the tests from a spec file. + + Server selection in_window tests do not have a 'tests' field. + The whole file represents a single test case. + """ + return [scenario_def] + + +CustomSpecTestCreator(create_test, TestAllScenarios, TEST_PATH).create_tests() + + +class FinderTask(ConcurrentRunner): + def __init__(self, collection, iterations): + super().__init__() + self.daemon = True + self.collection = collection + self.iterations = iterations + self.passed = False + + async def run(self): + for _ in range(self.iterations): + await self.collection.find_one({}) + self.passed = True + + +class TestProse(AsyncIntegrationTest): + async def frequencies(self, client, listener, n_finds=10): + coll = client.test.test + N_TASKS = 10 + tasks = [FinderTask(coll, n_finds) for _ in range(N_TASKS)] + for task in tasks: + await task.start() + for task in tasks: + await task.join() + for task in tasks: + self.assertTrue(task.passed) + + events = listener.started_events + self.assertEqual(len(events), n_finds * N_TASKS) + nodes = client.nodes + self.assertEqual(len(nodes), 2) + freqs = {address: 0.0 for address in nodes} + for event in events: + freqs[event.connection_id] += 1 + for address in freqs: + freqs[address] = freqs[address] / float(len(events)) + return freqs + + @async_client_context.require_failCommand_appName + @async_client_context.require_multiple_mongoses + async def test_load_balancing(self): + listener = OvertCommandListener() + cmap_listener = CMAPListener() + # PYTHON-2584: Use a large localThresholdMS to avoid the impact of + # varying RTTs. + client = await self.async_rs_client( + async_client_context.mongos_seeds(), + appName="loadBalancingTest", + event_listeners=[listener, cmap_listener], + localThresholdMS=30000, + minPoolSize=10, + ) + await async_wait_until(lambda: len(client.nodes) == 2, "discover both nodes") + # Wait for both pools to be populated. + await cmap_listener.async_wait_for_event(ConnectionReadyEvent, 20) + # Delay find commands on only one mongos. + delay_finds = { + "configureFailPoint": "failCommand", + "mode": {"times": 10000}, + "data": { + "failCommands": ["find"], + "blockConnection": True, + "blockTimeMS": 500, + "appName": "loadBalancingTest", + }, + } + async with self.fail_point(delay_finds): + nodes = async_client_context.client.nodes + self.assertEqual(len(nodes), 1) + delayed_server = next(iter(nodes)) + freqs = await self.frequencies(client, listener) + self.assertLessEqual(freqs[delayed_server], 0.25) + listener.reset() + freqs = await self.frequencies(client, listener, n_finds=150) + self.assertAlmostEqual(freqs[delayed_server], 0.50, delta=0.15) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/asynchronous/utils_selection_tests.py b/test/asynchronous/utils_selection_tests.py new file mode 100644 index 0000000000..71e287569a --- /dev/null +++ b/test/asynchronous/utils_selection_tests.py @@ -0,0 +1,203 @@ +# Copyright 2015-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for testing Server Selection and Max Staleness.""" +from __future__ import annotations + +import datetime +import os +import sys +from test.asynchronous import AsyncPyMongoTestCase + +sys.path[0:0] = [""] + +from test import unittest +from test.pymongo_mocks import DummyMonitor +from test.utils import AsyncMockPool, parse_read_preference +from test.utils_selection_tests_shared import ( + get_addresses, + get_topology_type_name, + make_server_description, +) + +from bson import json_util +from pymongo.asynchronous.settings import TopologySettings +from pymongo.asynchronous.topology import Topology +from pymongo.common import HEARTBEAT_FREQUENCY +from pymongo.errors import AutoReconnect, ConfigurationError +from pymongo.operations import _Op +from pymongo.server_selectors import writable_server_selector + +_IS_SYNC = False + + +def get_topology_settings_dict(**kwargs): + settings = { + "monitor_class": DummyMonitor, + "heartbeat_frequency": HEARTBEAT_FREQUENCY, + "pool_class": AsyncMockPool, + } + settings.update(kwargs) + return settings + + +async def create_topology(scenario_def, **kwargs): + # Initialize topologies. + if "heartbeatFrequencyMS" in scenario_def: + frequency = int(scenario_def["heartbeatFrequencyMS"]) / 1000.0 + else: + frequency = HEARTBEAT_FREQUENCY + + seeds, hosts = get_addresses(scenario_def["topology_description"]["servers"]) + + topology_type = get_topology_type_name(scenario_def) + if topology_type == "LoadBalanced": + kwargs.setdefault("load_balanced", True) + # Force topology description to ReplicaSet + elif topology_type in ["ReplicaSetNoPrimary", "ReplicaSetWithPrimary"]: + kwargs.setdefault("replica_set_name", "rs") + settings = get_topology_settings_dict(heartbeat_frequency=frequency, seeds=seeds, **kwargs) + + # "Eligible servers" is defined in the server selection spec as + # the set of servers matching both the ReadPreference's mode + # and tag sets. + topology = Topology(TopologySettings(**settings)) + await topology.open() + + # Update topologies with server descriptions. + for server in scenario_def["topology_description"]["servers"]: + server_description = make_server_description(server, hosts) + await topology.on_change(server_description) + + # Assert that descriptions match + assert ( + scenario_def["topology_description"]["type"] == topology.description.topology_type_name + ), topology.description.topology_type_name + + return topology + + +def create_test(scenario_def): + async def run_scenario(self): + _, hosts = get_addresses(scenario_def["topology_description"]["servers"]) + # "Eligible servers" is defined in the server selection spec as + # the set of servers matching both the ReadPreference's mode + # and tag sets. + top_latency = await create_topology(scenario_def) + + # "In latency window" is defined in the server selection + # spec as the subset of suitable_servers that falls within the + # allowable latency window. + top_suitable = await create_topology(scenario_def, local_threshold_ms=1000000) + + # Create server selector. + if scenario_def.get("operation") == "write": + pref = writable_server_selector + else: + # Make first letter lowercase to match read_pref's modes. + pref_def = scenario_def["read_preference"] + if scenario_def.get("error"): + with self.assertRaises((ConfigurationError, ValueError)): + # Error can be raised when making Read Pref or selecting. + pref = parse_read_preference(pref_def) + await top_latency.select_server(pref, _Op.TEST) + return + + pref = parse_read_preference(pref_def) + + # Select servers. + if not scenario_def.get("suitable_servers"): + with self.assertRaises(AutoReconnect): + await top_suitable.select_server(pref, _Op.TEST, server_selection_timeout=0) + + return + + if not scenario_def["in_latency_window"]: + with self.assertRaises(AutoReconnect): + await top_latency.select_server(pref, _Op.TEST, server_selection_timeout=0) + + return + + actual_suitable_s = await top_suitable.select_servers( + pref, _Op.TEST, server_selection_timeout=0 + ) + actual_latency_s = await top_latency.select_servers( + pref, _Op.TEST, server_selection_timeout=0 + ) + + expected_suitable_servers = {} + for server in scenario_def["suitable_servers"]: + server_description = make_server_description(server, hosts) + expected_suitable_servers[server["address"]] = server_description + + actual_suitable_servers = {} + for s in actual_suitable_s: + actual_suitable_servers[ + "%s:%d" % (s.description.address[0], s.description.address[1]) + ] = s.description + + self.assertEqual(len(actual_suitable_servers), len(expected_suitable_servers)) + for k, actual in actual_suitable_servers.items(): + expected = expected_suitable_servers[k] + self.assertEqual(expected.address, actual.address) + self.assertEqual(expected.server_type, actual.server_type) + self.assertEqual(expected.round_trip_time, actual.round_trip_time) + self.assertEqual(expected.tags, actual.tags) + self.assertEqual(expected.all_hosts, actual.all_hosts) + + expected_latency_servers = {} + for server in scenario_def["in_latency_window"]: + server_description = make_server_description(server, hosts) + expected_latency_servers[server["address"]] = server_description + + actual_latency_servers = {} + for s in actual_latency_s: + actual_latency_servers[ + "%s:%d" % (s.description.address[0], s.description.address[1]) + ] = s.description + + self.assertEqual(len(actual_latency_servers), len(expected_latency_servers)) + for k, actual in actual_latency_servers.items(): + expected = expected_latency_servers[k] + self.assertEqual(expected.address, actual.address) + self.assertEqual(expected.server_type, actual.server_type) + self.assertEqual(expected.round_trip_time, actual.round_trip_time) + self.assertEqual(expected.tags, actual.tags) + self.assertEqual(expected.all_hosts, actual.all_hosts) + + return run_scenario + + +def create_selection_tests(test_dir): + class TestAllScenarios(AsyncPyMongoTestCase): + pass + + for dirpath, _, filenames in os.walk(test_dir): + dirname = os.path.split(dirpath) + dirname = os.path.split(dirname[-2])[-1] + "_" + dirname[-1] + + for filename in filenames: + if os.path.splitext(filename)[1] != ".json": + continue + with open(os.path.join(dirpath, filename)) as scenario_stream: + scenario_def = json_util.loads(scenario_stream.read()) + + # Construct test from scenario. + new_test = create_test(scenario_def) + test_name = f"test_{dirname}_{os.path.splitext(filename)[0]}" + + new_test.__name__ = test_name + setattr(TestAllScenarios, new_test.__name__, new_test) + + return TestAllScenarios diff --git a/test/asynchronous/utils_spec_runner.py b/test/asynchronous/utils_spec_runner.py index d433f1a7e6..11d88850fc 100644 --- a/test/asynchronous/utils_spec_runner.py +++ b/test/asynchronous/utils_spec_runner.py @@ -229,7 +229,7 @@ async def _create_tests(self): str(test_def["description"].replace(" ", "_").replace(".", "_")), ) - new_test = await self._create_test(scenario_def, test_def, test_name) + new_test = self._create_test(scenario_def, test_def, test_name) new_test = self._ensure_min_max_server_version(scenario_def, new_test) new_test = self.ensure_run_on(scenario_def, new_test) diff --git a/test/test_server_selection_in_window.py b/test/test_server_selection_in_window.py index 05772fa385..7ccd4b529e 100644 --- a/test/test_server_selection_in_window.py +++ b/test/test_server_selection_in_window.py @@ -15,9 +15,12 @@ """Test the topology module's Server Selection Spec implementation.""" from __future__ import annotations +import asyncio import os import threading +from pathlib import Path from test import IntegrationTest, client_context, unittest +from test.helpers import ConcurrentRunner from test.utils import ( CMAPListener, OvertCommandListener, @@ -32,10 +35,14 @@ from pymongo.operations import _Op from pymongo.read_preferences import ReadPreference +_IS_SYNC = True # Location of JSON test specifications. -TEST_PATH = os.path.join( - os.path.dirname(os.path.realpath(__file__)), os.path.join("server_selection", "in_window") -) +if _IS_SYNC: + TEST_PATH = os.path.join(Path(__file__).resolve().parent, "server_selection", "in_window") +else: + TEST_PATH = os.path.join( + Path(__file__).resolve().parent.parent, "server_selection", "in_window" + ) class TestAllScenarios(unittest.TestCase): @@ -92,7 +99,7 @@ def tests(self, scenario_def): CustomSpecTestCreator(create_test, TestAllScenarios, TEST_PATH).create_tests() -class FinderThread(threading.Thread): +class FinderTask(ConcurrentRunner): def __init__(self, collection, iterations): super().__init__() self.daemon = True @@ -109,17 +116,17 @@ def run(self): class TestProse(IntegrationTest): def frequencies(self, client, listener, n_finds=10): coll = client.test.test - N_THREADS = 10 - threads = [FinderThread(coll, n_finds) for _ in range(N_THREADS)] - for thread in threads: - thread.start() - for thread in threads: - thread.join() - for thread in threads: - self.assertTrue(thread.passed) + N_TASKS = 10 + tasks = [FinderTask(coll, n_finds) for _ in range(N_TASKS)] + for task in tasks: + task.start() + for task in tasks: + task.join() + for task in tasks: + self.assertTrue(task.passed) events = listener.started_events - self.assertEqual(len(events), n_finds * N_THREADS) + self.assertEqual(len(events), n_finds * N_TASKS) nodes = client.nodes self.assertEqual(len(nodes), 2) freqs = {address: 0.0 for address in nodes} diff --git a/test/utils_selection_tests.py b/test/utils_selection_tests.py index 2d21888e27..9667ea701b 100644 --- a/test/utils_selection_tests.py +++ b/test/utils_selection_tests.py @@ -18,96 +18,28 @@ import datetime import os import sys +from test import PyMongoTestCase sys.path[0:0] = [""] from test import unittest from test.pymongo_mocks import DummyMonitor from test.utils import MockPool, parse_read_preference +from test.utils_selection_tests_shared import ( + get_addresses, + get_topology_type_name, + make_server_description, +) from bson import json_util -from pymongo.common import HEARTBEAT_FREQUENCY, MIN_SUPPORTED_WIRE_VERSION, clean_node +from pymongo.common import HEARTBEAT_FREQUENCY from pymongo.errors import AutoReconnect, ConfigurationError -from pymongo.hello import Hello, HelloCompat from pymongo.operations import _Op -from pymongo.server_description import ServerDescription from pymongo.server_selectors import writable_server_selector from pymongo.synchronous.settings import TopologySettings from pymongo.synchronous.topology import Topology - -def get_addresses(server_list): - seeds = [] - hosts = [] - for server in server_list: - seeds.append(clean_node(server["address"])) - hosts.append(server["address"]) - return seeds, hosts - - -def make_last_write_date(server): - epoch = datetime.datetime.fromtimestamp(0, tz=datetime.timezone.utc).replace(tzinfo=None) - millis = server.get("lastWrite", {}).get("lastWriteDate") - if millis: - diff = ((millis % 1000) + 1000) % 1000 - seconds = (millis - diff) / 1000 - micros = diff * 1000 - return epoch + datetime.timedelta(seconds=seconds, microseconds=micros) - else: - # "Unknown" server. - return epoch - - -def make_server_description(server, hosts): - """Make a ServerDescription from server info in a JSON test.""" - server_type = server["type"] - if server_type in ("Unknown", "PossiblePrimary"): - return ServerDescription(clean_node(server["address"]), Hello({})) - - hello_response = {"ok": True, "hosts": hosts} - if server_type not in ("Standalone", "Mongos", "RSGhost"): - hello_response["setName"] = "rs" - - if server_type == "RSPrimary": - hello_response[HelloCompat.LEGACY_CMD] = True - elif server_type == "RSSecondary": - hello_response["secondary"] = True - elif server_type == "Mongos": - hello_response["msg"] = "isdbgrid" - elif server_type == "RSGhost": - hello_response["isreplicaset"] = True - elif server_type == "RSArbiter": - hello_response["arbiterOnly"] = True - - hello_response["lastWrite"] = {"lastWriteDate": make_last_write_date(server)} - - for field in "maxWireVersion", "tags", "idleWritePeriodMillis": - if field in server: - hello_response[field] = server[field] - - hello_response.setdefault("maxWireVersion", MIN_SUPPORTED_WIRE_VERSION) - - # Sets _last_update_time to now. - sd = ServerDescription( - clean_node(server["address"]), - Hello(hello_response), - round_trip_time=server["avg_rtt_ms"] / 1000.0, - ) - - if "lastUpdateTime" in server: - sd._last_update_time = server["lastUpdateTime"] / 1000.0 # ms to sec. - - return sd - - -def get_topology_type_name(scenario_def): - td = scenario_def["topology_description"] - name = td["type"] - if name == "Unknown": - # PyMongo never starts a topology in type Unknown. - return "Sharded" if len(td["servers"]) > 1 else "Single" - else: - return name +_IS_SYNC = True def get_topology_settings_dict(**kwargs): @@ -244,7 +176,7 @@ def run_scenario(self): def create_selection_tests(test_dir): - class TestAllScenarios(unittest.TestCase): + class TestAllScenarios(PyMongoTestCase): pass for dirpath, _, filenames in os.walk(test_dir): diff --git a/test/utils_selection_tests_shared.py b/test/utils_selection_tests_shared.py new file mode 100644 index 0000000000..dbaed1034f --- /dev/null +++ b/test/utils_selection_tests_shared.py @@ -0,0 +1,100 @@ +# Copyright 2015-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for testing Server Selection and Max Staleness.""" +from __future__ import annotations + +import datetime +import os +import sys + +sys.path[0:0] = [""] + +from pymongo.common import MIN_SUPPORTED_WIRE_VERSION, clean_node +from pymongo.hello import Hello, HelloCompat +from pymongo.server_description import ServerDescription + + +def get_addresses(server_list): + seeds = [] + hosts = [] + for server in server_list: + seeds.append(clean_node(server["address"])) + hosts.append(server["address"]) + return seeds, hosts + + +def make_last_write_date(server): + epoch = datetime.datetime.fromtimestamp(0, tz=datetime.timezone.utc).replace(tzinfo=None) + millis = server.get("lastWrite", {}).get("lastWriteDate") + if millis: + diff = ((millis % 1000) + 1000) % 1000 + seconds = (millis - diff) / 1000 + micros = diff * 1000 + return epoch + datetime.timedelta(seconds=seconds, microseconds=micros) + else: + # "Unknown" server. + return epoch + + +def make_server_description(server, hosts): + """Make a ServerDescription from server info in a JSON test.""" + server_type = server["type"] + if server_type in ("Unknown", "PossiblePrimary"): + return ServerDescription(clean_node(server["address"]), Hello({})) + + hello_response = {"ok": True, "hosts": hosts} + if server_type not in ("Standalone", "Mongos", "RSGhost"): + hello_response["setName"] = "rs" + + if server_type == "RSPrimary": + hello_response[HelloCompat.LEGACY_CMD] = True + elif server_type == "RSSecondary": + hello_response["secondary"] = True + elif server_type == "Mongos": + hello_response["msg"] = "isdbgrid" + elif server_type == "RSGhost": + hello_response["isreplicaset"] = True + elif server_type == "RSArbiter": + hello_response["arbiterOnly"] = True + + hello_response["lastWrite"] = {"lastWriteDate": make_last_write_date(server)} + + for field in "maxWireVersion", "tags", "idleWritePeriodMillis": + if field in server: + hello_response[field] = server[field] + + hello_response.setdefault("maxWireVersion", MIN_SUPPORTED_WIRE_VERSION) + + # Sets _last_update_time to now. + sd = ServerDescription( + clean_node(server["address"]), + Hello(hello_response), + round_trip_time=server["avg_rtt_ms"] / 1000.0, + ) + + if "lastUpdateTime" in server: + sd._last_update_time = server["lastUpdateTime"] / 1000.0 # ms to sec. + + return sd + + +def get_topology_type_name(scenario_def): + td = scenario_def["topology_description"] + name = td["type"] + if name == "Unknown": + # PyMongo never starts a topology in type Unknown. + return "Sharded" if len(td["servers"]) > 1 else "Single" + else: + return name diff --git a/tools/synchro.py b/tools/synchro.py index 77fdcce5ae..f8d0be80fd 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -233,6 +233,7 @@ def async_only_test(f: str) -> bool: "test_retryable_writes_unified.py", "test_run_command.py", "test_sdam_monitoring_spec.py", + "test_server_selection_in_window.py", "test_server_selection_logging.py", "test_session.py", "test_server_selection_rtt.py", @@ -245,6 +246,7 @@ def async_only_test(f: str) -> bool: "test_unified_format.py", "test_versioned_api_integration.py", "unified_format.py", + "utils_selection_tests.py", ]