|
| 1 | +# Copyright 2015-present MongoDB, Inc. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +"""Test the topology module's Server Selection Spec implementation.""" |
| 16 | +from __future__ import annotations |
| 17 | + |
| 18 | +import os |
| 19 | +import sys |
| 20 | +from pathlib import Path |
| 21 | + |
| 22 | +from pymongo import AsyncMongoClient, ReadPreference |
| 23 | +from pymongo.asynchronous.settings import TopologySettings |
| 24 | +from pymongo.asynchronous.topology import Topology |
| 25 | +from pymongo.errors import ServerSelectionTimeoutError |
| 26 | +from pymongo.hello import HelloCompat |
| 27 | +from pymongo.operations import _Op |
| 28 | +from pymongo.server_selectors import writable_server_selector |
| 29 | +from pymongo.typings import strip_optional |
| 30 | + |
| 31 | +sys.path[0:0] = [""] |
| 32 | + |
| 33 | +from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest |
| 34 | +from test.asynchronous.utils_selection_tests import ( |
| 35 | + create_selection_tests, |
| 36 | + get_addresses, |
| 37 | + get_topology_settings_dict, |
| 38 | + make_server_description, |
| 39 | +) |
| 40 | +from test.utils import ( |
| 41 | + EventListener, |
| 42 | + FunctionCallRecorder, |
| 43 | + OvertCommandListener, |
| 44 | + async_wait_until, |
| 45 | +) |
| 46 | + |
| 47 | +_IS_SYNC = False |
| 48 | + |
| 49 | +# Location of JSON test specifications. |
| 50 | +if _IS_SYNC: |
| 51 | + TEST_PATH = os.path.join( |
| 52 | + Path(__file__).resolve().parent, "server_selection", "server_selection" |
| 53 | + ) |
| 54 | +else: |
| 55 | + TEST_PATH = os.path.join( |
| 56 | + Path(__file__).resolve().parent.parent, "server_selection", "server_selection" |
| 57 | + ) |
| 58 | + |
| 59 | + |
| 60 | +class SelectionStoreSelector: |
| 61 | + """No-op selector that keeps track of what was passed to it.""" |
| 62 | + |
| 63 | + def __init__(self): |
| 64 | + self.selection = None |
| 65 | + |
| 66 | + def __call__(self, selection): |
| 67 | + self.selection = selection |
| 68 | + return selection |
| 69 | + |
| 70 | + |
| 71 | +class TestAllScenarios(create_selection_tests(TEST_PATH)): # type: ignore |
| 72 | + pass |
| 73 | + |
| 74 | + |
| 75 | +class TestCustomServerSelectorFunction(AsyncIntegrationTest): |
| 76 | + @async_client_context.require_replica_set |
| 77 | + async def test_functional_select_max_port_number_host(self): |
| 78 | + # Selector that returns server with highest port number. |
| 79 | + def custom_selector(servers): |
| 80 | + ports = [s.address[1] for s in servers] |
| 81 | + idx = ports.index(max(ports)) |
| 82 | + return [servers[idx]] |
| 83 | + |
| 84 | + # Initialize client with appropriate listeners. |
| 85 | + listener = OvertCommandListener() |
| 86 | + client = await self.async_rs_or_single_client( |
| 87 | + server_selector=custom_selector, event_listeners=[listener] |
| 88 | + ) |
| 89 | + coll = client.get_database("testdb", read_preference=ReadPreference.NEAREST).coll |
| 90 | + self.addAsyncCleanup(client.drop_database, "testdb") |
| 91 | + |
| 92 | + # Wait the node list to be fully populated. |
| 93 | + async def all_hosts_started(): |
| 94 | + return len((await client.admin.command(HelloCompat.LEGACY_CMD))["hosts"]) == len( |
| 95 | + client._topology._description.readable_servers |
| 96 | + ) |
| 97 | + |
| 98 | + await async_wait_until(all_hosts_started, "receive heartbeat from all hosts") |
| 99 | + |
| 100 | + expected_port = max( |
| 101 | + [strip_optional(n.address[1]) for n in client._topology._description.readable_servers] |
| 102 | + ) |
| 103 | + |
| 104 | + # Insert 1 record and access it 10 times. |
| 105 | + await coll.insert_one({"name": "John Doe"}) |
| 106 | + for _ in range(10): |
| 107 | + await coll.find_one({"name": "John Doe"}) |
| 108 | + |
| 109 | + # Confirm all find commands are run against appropriate host. |
| 110 | + for command in listener.started_events: |
| 111 | + if command.command_name == "find": |
| 112 | + self.assertEqual(command.connection_id[1], expected_port) |
| 113 | + |
| 114 | + async def test_invalid_server_selector(self): |
| 115 | + # Client initialization must fail if server_selector is not callable. |
| 116 | + for selector_candidate in [[], 10, "string", {}]: |
| 117 | + with self.assertRaisesRegex(ValueError, "must be a callable"): |
| 118 | + AsyncMongoClient(connect=False, server_selector=selector_candidate) |
| 119 | + |
| 120 | + # None value for server_selector is OK. |
| 121 | + AsyncMongoClient(connect=False, server_selector=None) |
| 122 | + |
| 123 | + @async_client_context.require_replica_set |
| 124 | + async def test_selector_called(self): |
| 125 | + selector = FunctionCallRecorder(lambda x: x) |
| 126 | + |
| 127 | + # Client setup. |
| 128 | + mongo_client = await self.async_rs_or_single_client(server_selector=selector) |
| 129 | + test_collection = mongo_client.testdb.test_collection |
| 130 | + self.addAsyncCleanup(mongo_client.drop_database, "testdb") |
| 131 | + |
| 132 | + # Do N operations and test selector is called at least N times. |
| 133 | + await test_collection.insert_one({"age": 20, "name": "John"}) |
| 134 | + await test_collection.insert_one({"age": 31, "name": "Jane"}) |
| 135 | + await test_collection.update_one({"name": "Jane"}, {"$set": {"age": 21}}) |
| 136 | + await test_collection.find_one({"name": "Roe"}) |
| 137 | + self.assertGreaterEqual(selector.call_count, 4) |
| 138 | + |
| 139 | + @async_client_context.require_replica_set |
| 140 | + async def test_latency_threshold_application(self): |
| 141 | + selector = SelectionStoreSelector() |
| 142 | + |
| 143 | + scenario_def: dict = { |
| 144 | + "topology_description": { |
| 145 | + "type": "ReplicaSetWithPrimary", |
| 146 | + "servers": [ |
| 147 | + {"address": "b:27017", "avg_rtt_ms": 10000, "type": "RSSecondary", "tag": {}}, |
| 148 | + {"address": "c:27017", "avg_rtt_ms": 20000, "type": "RSSecondary", "tag": {}}, |
| 149 | + {"address": "a:27017", "avg_rtt_ms": 30000, "type": "RSPrimary", "tag": {}}, |
| 150 | + ], |
| 151 | + } |
| 152 | + } |
| 153 | + |
| 154 | + # Create & populate Topology such that all but one server is too slow. |
| 155 | + rtt_times = [srv["avg_rtt_ms"] for srv in scenario_def["topology_description"]["servers"]] |
| 156 | + min_rtt_idx = rtt_times.index(min(rtt_times)) |
| 157 | + seeds, hosts = get_addresses(scenario_def["topology_description"]["servers"]) |
| 158 | + settings = get_topology_settings_dict( |
| 159 | + heartbeat_frequency=1, local_threshold_ms=1, seeds=seeds, server_selector=selector |
| 160 | + ) |
| 161 | + topology = Topology(TopologySettings(**settings)) |
| 162 | + await topology.open() |
| 163 | + for server in scenario_def["topology_description"]["servers"]: |
| 164 | + server_description = make_server_description(server, hosts) |
| 165 | + await topology.on_change(server_description) |
| 166 | + |
| 167 | + # Invoke server selection and assert no filtering based on latency |
| 168 | + # prior to custom server selection logic kicking in. |
| 169 | + server = await topology.select_server(ReadPreference.NEAREST, _Op.TEST) |
| 170 | + assert selector.selection is not None |
| 171 | + self.assertEqual(len(selector.selection), len(topology.description.server_descriptions())) |
| 172 | + |
| 173 | + # Ensure proper filtering based on latency after custom selection. |
| 174 | + self.assertEqual(server.description.address, seeds[min_rtt_idx]) |
| 175 | + |
| 176 | + @async_client_context.require_replica_set |
| 177 | + async def test_server_selector_bypassed(self): |
| 178 | + selector = FunctionCallRecorder(lambda x: x) |
| 179 | + |
| 180 | + scenario_def = { |
| 181 | + "topology_description": { |
| 182 | + "type": "ReplicaSetNoPrimary", |
| 183 | + "servers": [ |
| 184 | + {"address": "b:27017", "avg_rtt_ms": 10000, "type": "RSSecondary", "tag": {}}, |
| 185 | + {"address": "c:27017", "avg_rtt_ms": 20000, "type": "RSSecondary", "tag": {}}, |
| 186 | + {"address": "a:27017", "avg_rtt_ms": 30000, "type": "RSSecondary", "tag": {}}, |
| 187 | + ], |
| 188 | + } |
| 189 | + } |
| 190 | + |
| 191 | + # Create & populate Topology such that no server is writeable. |
| 192 | + seeds, hosts = get_addresses(scenario_def["topology_description"]["servers"]) |
| 193 | + settings = get_topology_settings_dict( |
| 194 | + heartbeat_frequency=1, local_threshold_ms=1, seeds=seeds, server_selector=selector |
| 195 | + ) |
| 196 | + topology = Topology(TopologySettings(**settings)) |
| 197 | + await topology.open() |
| 198 | + for server in scenario_def["topology_description"]["servers"]: |
| 199 | + server_description = make_server_description(server, hosts) |
| 200 | + await topology.on_change(server_description) |
| 201 | + |
| 202 | + # Invoke server selection and assert no calls to our custom selector. |
| 203 | + with self.assertRaisesRegex(ServerSelectionTimeoutError, "No primary available for writes"): |
| 204 | + await topology.select_server( |
| 205 | + writable_server_selector, _Op.TEST, server_selection_timeout=0.1 |
| 206 | + ) |
| 207 | + self.assertEqual(selector.call_count, 0) |
| 208 | + |
| 209 | + |
| 210 | +if __name__ == "__main__": |
| 211 | + unittest.main() |
0 commit comments