Skip to content

Commit 13fa361

Browse files
PYTHON-5101 Convert test.test_server_selection_in_window to async (#2119)
Co-authored-by: Noah Stapp <noah@noahstapp.com>
1 parent 1a7239c commit 13fa361

8 files changed

+515
-92
lines changed

test/asynchronous/test_encryption.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -739,7 +739,7 @@ def allowable_errors(self, op):
739739
return errors
740740

741741

742-
async def create_test(scenario_def, test, name):
742+
def create_test(scenario_def, test, name):
743743
@async_client_context.require_test_commands
744744
async def run_scenario(self):
745745
await self.run_scenario(scenario_def, test)
Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
# Copyright 2020-present MongoDB, Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Test the topology module's Server Selection Spec implementation."""
16+
from __future__ import annotations
17+
18+
import asyncio
19+
import os
20+
import threading
21+
from pathlib import Path
22+
from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest
23+
from test.asynchronous.helpers import ConcurrentRunner
24+
from test.asynchronous.utils_selection_tests import create_topology
25+
from test.asynchronous.utils_spec_runner import AsyncSpecTestCreator
26+
from test.utils import (
27+
CMAPListener,
28+
OvertCommandListener,
29+
async_get_pool,
30+
async_wait_until,
31+
)
32+
33+
from pymongo.common import clean_node
34+
from pymongo.monitoring import ConnectionReadyEvent
35+
from pymongo.operations import _Op
36+
from pymongo.read_preferences import ReadPreference
37+
38+
_IS_SYNC = False
39+
# Location of JSON test specifications.
40+
if _IS_SYNC:
41+
TEST_PATH = os.path.join(Path(__file__).resolve().parent, "server_selection", "in_window")
42+
else:
43+
TEST_PATH = os.path.join(
44+
Path(__file__).resolve().parent.parent, "server_selection", "in_window"
45+
)
46+
47+
48+
class TestAllScenarios(unittest.IsolatedAsyncioTestCase):
49+
async def run_scenario(self, scenario_def):
50+
topology = await create_topology(scenario_def)
51+
52+
# Update mock operation_count state:
53+
for mock in scenario_def["mocked_topology_state"]:
54+
address = clean_node(mock["address"])
55+
server = topology.get_server_by_address(address)
56+
server.pool.operation_count = mock["operation_count"]
57+
58+
pref = ReadPreference.NEAREST
59+
counts = {address: 0 for address in topology._description.server_descriptions()}
60+
61+
# Number of times to repeat server selection
62+
iterations = scenario_def["iterations"]
63+
for _ in range(iterations):
64+
server = await topology.select_server(pref, _Op.TEST, server_selection_timeout=0)
65+
counts[server.description.address] += 1
66+
67+
# Verify expected_frequencies
68+
outcome = scenario_def["outcome"]
69+
tolerance = outcome["tolerance"]
70+
expected_frequencies = outcome["expected_frequencies"]
71+
for host_str, freq in expected_frequencies.items():
72+
address = clean_node(host_str)
73+
actual_freq = float(counts[address]) / iterations
74+
if freq == 0:
75+
# Should be exactly 0.
76+
self.assertEqual(actual_freq, 0)
77+
else:
78+
# Should be within 'tolerance'.
79+
self.assertAlmostEqual(actual_freq, freq, delta=tolerance)
80+
81+
82+
def create_test(scenario_def, test, name):
83+
async def run_scenario(self):
84+
await self.run_scenario(scenario_def)
85+
86+
return run_scenario
87+
88+
89+
class CustomSpecTestCreator(AsyncSpecTestCreator):
90+
def tests(self, scenario_def):
91+
"""Extract the tests from a spec file.
92+
93+
Server selection in_window tests do not have a 'tests' field.
94+
The whole file represents a single test case.
95+
"""
96+
return [scenario_def]
97+
98+
99+
CustomSpecTestCreator(create_test, TestAllScenarios, TEST_PATH).create_tests()
100+
101+
102+
class FinderTask(ConcurrentRunner):
103+
def __init__(self, collection, iterations):
104+
super().__init__()
105+
self.daemon = True
106+
self.collection = collection
107+
self.iterations = iterations
108+
self.passed = False
109+
110+
async def run(self):
111+
for _ in range(self.iterations):
112+
await self.collection.find_one({})
113+
self.passed = True
114+
115+
116+
class TestProse(AsyncIntegrationTest):
117+
async def frequencies(self, client, listener, n_finds=10):
118+
coll = client.test.test
119+
N_TASKS = 10
120+
tasks = [FinderTask(coll, n_finds) for _ in range(N_TASKS)]
121+
for task in tasks:
122+
await task.start()
123+
for task in tasks:
124+
await task.join()
125+
for task in tasks:
126+
self.assertTrue(task.passed)
127+
128+
events = listener.started_events
129+
self.assertEqual(len(events), n_finds * N_TASKS)
130+
nodes = client.nodes
131+
self.assertEqual(len(nodes), 2)
132+
freqs = {address: 0.0 for address in nodes}
133+
for event in events:
134+
freqs[event.connection_id] += 1
135+
for address in freqs:
136+
freqs[address] = freqs[address] / float(len(events))
137+
return freqs
138+
139+
@async_client_context.require_failCommand_appName
140+
@async_client_context.require_multiple_mongoses
141+
async def test_load_balancing(self):
142+
listener = OvertCommandListener()
143+
cmap_listener = CMAPListener()
144+
# PYTHON-2584: Use a large localThresholdMS to avoid the impact of
145+
# varying RTTs.
146+
client = await self.async_rs_client(
147+
async_client_context.mongos_seeds(),
148+
appName="loadBalancingTest",
149+
event_listeners=[listener, cmap_listener],
150+
localThresholdMS=30000,
151+
minPoolSize=10,
152+
)
153+
await async_wait_until(lambda: len(client.nodes) == 2, "discover both nodes")
154+
# Wait for both pools to be populated.
155+
await cmap_listener.async_wait_for_event(ConnectionReadyEvent, 20)
156+
# Delay find commands on only one mongos.
157+
delay_finds = {
158+
"configureFailPoint": "failCommand",
159+
"mode": {"times": 10000},
160+
"data": {
161+
"failCommands": ["find"],
162+
"blockConnection": True,
163+
"blockTimeMS": 500,
164+
"appName": "loadBalancingTest",
165+
},
166+
}
167+
async with self.fail_point(delay_finds):
168+
nodes = async_client_context.client.nodes
169+
self.assertEqual(len(nodes), 1)
170+
delayed_server = next(iter(nodes))
171+
freqs = await self.frequencies(client, listener)
172+
self.assertLessEqual(freqs[delayed_server], 0.25)
173+
listener.reset()
174+
freqs = await self.frequencies(client, listener, n_finds=150)
175+
self.assertAlmostEqual(freqs[delayed_server], 0.50, delta=0.15)
176+
177+
178+
if __name__ == "__main__":
179+
unittest.main()
Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
# Copyright 2015-present MongoDB, Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Utilities for testing Server Selection and Max Staleness."""
16+
from __future__ import annotations
17+
18+
import datetime
19+
import os
20+
import sys
21+
from test.asynchronous import AsyncPyMongoTestCase
22+
23+
sys.path[0:0] = [""]
24+
25+
from test import unittest
26+
from test.pymongo_mocks import DummyMonitor
27+
from test.utils import AsyncMockPool, parse_read_preference
28+
from test.utils_selection_tests_shared import (
29+
get_addresses,
30+
get_topology_type_name,
31+
make_server_description,
32+
)
33+
34+
from bson import json_util
35+
from pymongo.asynchronous.settings import TopologySettings
36+
from pymongo.asynchronous.topology import Topology
37+
from pymongo.common import HEARTBEAT_FREQUENCY
38+
from pymongo.errors import AutoReconnect, ConfigurationError
39+
from pymongo.operations import _Op
40+
from pymongo.server_selectors import writable_server_selector
41+
42+
_IS_SYNC = False
43+
44+
45+
def get_topology_settings_dict(**kwargs):
46+
settings = {
47+
"monitor_class": DummyMonitor,
48+
"heartbeat_frequency": HEARTBEAT_FREQUENCY,
49+
"pool_class": AsyncMockPool,
50+
}
51+
settings.update(kwargs)
52+
return settings
53+
54+
55+
async def create_topology(scenario_def, **kwargs):
56+
# Initialize topologies.
57+
if "heartbeatFrequencyMS" in scenario_def:
58+
frequency = int(scenario_def["heartbeatFrequencyMS"]) / 1000.0
59+
else:
60+
frequency = HEARTBEAT_FREQUENCY
61+
62+
seeds, hosts = get_addresses(scenario_def["topology_description"]["servers"])
63+
64+
topology_type = get_topology_type_name(scenario_def)
65+
if topology_type == "LoadBalanced":
66+
kwargs.setdefault("load_balanced", True)
67+
# Force topology description to ReplicaSet
68+
elif topology_type in ["ReplicaSetNoPrimary", "ReplicaSetWithPrimary"]:
69+
kwargs.setdefault("replica_set_name", "rs")
70+
settings = get_topology_settings_dict(heartbeat_frequency=frequency, seeds=seeds, **kwargs)
71+
72+
# "Eligible servers" is defined in the server selection spec as
73+
# the set of servers matching both the ReadPreference's mode
74+
# and tag sets.
75+
topology = Topology(TopologySettings(**settings))
76+
await topology.open()
77+
78+
# Update topologies with server descriptions.
79+
for server in scenario_def["topology_description"]["servers"]:
80+
server_description = make_server_description(server, hosts)
81+
await topology.on_change(server_description)
82+
83+
# Assert that descriptions match
84+
assert (
85+
scenario_def["topology_description"]["type"] == topology.description.topology_type_name
86+
), topology.description.topology_type_name
87+
88+
return topology
89+
90+
91+
def create_test(scenario_def):
92+
async def run_scenario(self):
93+
_, hosts = get_addresses(scenario_def["topology_description"]["servers"])
94+
# "Eligible servers" is defined in the server selection spec as
95+
# the set of servers matching both the ReadPreference's mode
96+
# and tag sets.
97+
top_latency = await create_topology(scenario_def)
98+
99+
# "In latency window" is defined in the server selection
100+
# spec as the subset of suitable_servers that falls within the
101+
# allowable latency window.
102+
top_suitable = await create_topology(scenario_def, local_threshold_ms=1000000)
103+
104+
# Create server selector.
105+
if scenario_def.get("operation") == "write":
106+
pref = writable_server_selector
107+
else:
108+
# Make first letter lowercase to match read_pref's modes.
109+
pref_def = scenario_def["read_preference"]
110+
if scenario_def.get("error"):
111+
with self.assertRaises((ConfigurationError, ValueError)):
112+
# Error can be raised when making Read Pref or selecting.
113+
pref = parse_read_preference(pref_def)
114+
await top_latency.select_server(pref, _Op.TEST)
115+
return
116+
117+
pref = parse_read_preference(pref_def)
118+
119+
# Select servers.
120+
if not scenario_def.get("suitable_servers"):
121+
with self.assertRaises(AutoReconnect):
122+
await top_suitable.select_server(pref, _Op.TEST, server_selection_timeout=0)
123+
124+
return
125+
126+
if not scenario_def["in_latency_window"]:
127+
with self.assertRaises(AutoReconnect):
128+
await top_latency.select_server(pref, _Op.TEST, server_selection_timeout=0)
129+
130+
return
131+
132+
actual_suitable_s = await top_suitable.select_servers(
133+
pref, _Op.TEST, server_selection_timeout=0
134+
)
135+
actual_latency_s = await top_latency.select_servers(
136+
pref, _Op.TEST, server_selection_timeout=0
137+
)
138+
139+
expected_suitable_servers = {}
140+
for server in scenario_def["suitable_servers"]:
141+
server_description = make_server_description(server, hosts)
142+
expected_suitable_servers[server["address"]] = server_description
143+
144+
actual_suitable_servers = {}
145+
for s in actual_suitable_s:
146+
actual_suitable_servers[
147+
"%s:%d" % (s.description.address[0], s.description.address[1])
148+
] = s.description
149+
150+
self.assertEqual(len(actual_suitable_servers), len(expected_suitable_servers))
151+
for k, actual in actual_suitable_servers.items():
152+
expected = expected_suitable_servers[k]
153+
self.assertEqual(expected.address, actual.address)
154+
self.assertEqual(expected.server_type, actual.server_type)
155+
self.assertEqual(expected.round_trip_time, actual.round_trip_time)
156+
self.assertEqual(expected.tags, actual.tags)
157+
self.assertEqual(expected.all_hosts, actual.all_hosts)
158+
159+
expected_latency_servers = {}
160+
for server in scenario_def["in_latency_window"]:
161+
server_description = make_server_description(server, hosts)
162+
expected_latency_servers[server["address"]] = server_description
163+
164+
actual_latency_servers = {}
165+
for s in actual_latency_s:
166+
actual_latency_servers[
167+
"%s:%d" % (s.description.address[0], s.description.address[1])
168+
] = s.description
169+
170+
self.assertEqual(len(actual_latency_servers), len(expected_latency_servers))
171+
for k, actual in actual_latency_servers.items():
172+
expected = expected_latency_servers[k]
173+
self.assertEqual(expected.address, actual.address)
174+
self.assertEqual(expected.server_type, actual.server_type)
175+
self.assertEqual(expected.round_trip_time, actual.round_trip_time)
176+
self.assertEqual(expected.tags, actual.tags)
177+
self.assertEqual(expected.all_hosts, actual.all_hosts)
178+
179+
return run_scenario
180+
181+
182+
def create_selection_tests(test_dir):
183+
class TestAllScenarios(AsyncPyMongoTestCase):
184+
pass
185+
186+
for dirpath, _, filenames in os.walk(test_dir):
187+
dirname = os.path.split(dirpath)
188+
dirname = os.path.split(dirname[-2])[-1] + "_" + dirname[-1]
189+
190+
for filename in filenames:
191+
if os.path.splitext(filename)[1] != ".json":
192+
continue
193+
with open(os.path.join(dirpath, filename)) as scenario_stream:
194+
scenario_def = json_util.loads(scenario_stream.read())
195+
196+
# Construct test from scenario.
197+
new_test = create_test(scenario_def)
198+
test_name = f"test_{dirname}_{os.path.splitext(filename)[0]}"
199+
200+
new_test.__name__ = test_name
201+
setattr(TestAllScenarios, new_test.__name__, new_test)
202+
203+
return TestAllScenarios

0 commit comments

Comments
 (0)