Skip to content

Commit 665eb9a

Browse files
authored
PYTHON-5105 - Convert test.test_srv_polling to async (#2124)
1 parent 702c86c commit 665eb9a

File tree

3 files changed

+378
-7
lines changed

3 files changed

+378
-7
lines changed

test/asynchronous/test_srv_polling.py

Lines changed: 361 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,361 @@
1+
# Copyright 2019-present MongoDB, Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Run the SRV support tests."""
16+
from __future__ import annotations
17+
18+
import asyncio
19+
import sys
20+
import time
21+
from typing import Any
22+
23+
sys.path[0:0] = [""]
24+
25+
from test.asynchronous import AsyncPyMongoTestCase, client_knobs, unittest
26+
from test.utils import FunctionCallRecorder, async_wait_until
27+
28+
import pymongo
29+
from pymongo import common
30+
from pymongo.errors import ConfigurationError
31+
from pymongo.srv_resolver import _have_dnspython
32+
33+
_IS_SYNC = False
34+
35+
WAIT_TIME = 0.1
36+
37+
38+
class SrvPollingKnobs:
39+
def __init__(
40+
self,
41+
ttl_time=None,
42+
min_srv_rescan_interval=None,
43+
nodelist_callback=None,
44+
count_resolver_calls=False,
45+
):
46+
self.ttl_time = ttl_time
47+
self.min_srv_rescan_interval = min_srv_rescan_interval
48+
self.nodelist_callback = nodelist_callback
49+
self.count_resolver_calls = count_resolver_calls
50+
51+
self.old_min_srv_rescan_interval = None
52+
self.old_dns_resolver_response = None
53+
54+
def enable(self):
55+
self.old_min_srv_rescan_interval = common.MIN_SRV_RESCAN_INTERVAL
56+
self.old_dns_resolver_response = pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl
57+
58+
if self.min_srv_rescan_interval is not None:
59+
common.MIN_SRV_RESCAN_INTERVAL = self.min_srv_rescan_interval
60+
61+
def mock_get_hosts_and_min_ttl(resolver, *args):
62+
assert self.old_dns_resolver_response is not None
63+
nodes, ttl = self.old_dns_resolver_response(resolver)
64+
if self.nodelist_callback is not None:
65+
nodes = self.nodelist_callback()
66+
if self.ttl_time is not None:
67+
ttl = self.ttl_time
68+
return nodes, ttl
69+
70+
patch_func: Any
71+
if self.count_resolver_calls:
72+
patch_func = FunctionCallRecorder(mock_get_hosts_and_min_ttl)
73+
else:
74+
patch_func = mock_get_hosts_and_min_ttl
75+
76+
pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl = patch_func # type: ignore
77+
78+
def __enter__(self):
79+
self.enable()
80+
81+
def disable(self):
82+
common.MIN_SRV_RESCAN_INTERVAL = self.old_min_srv_rescan_interval # type: ignore
83+
pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl = ( # type: ignore
84+
self.old_dns_resolver_response
85+
)
86+
87+
def __exit__(self, exc_type, exc_val, exc_tb):
88+
self.disable()
89+
90+
91+
class TestSrvPolling(AsyncPyMongoTestCase):
92+
BASE_SRV_RESPONSE = [
93+
("localhost.test.build.10gen.cc", 27017),
94+
("localhost.test.build.10gen.cc", 27018),
95+
]
96+
97+
CONNECTION_STRING = "mongodb+srv://test1.test.build.10gen.cc"
98+
99+
async def asyncSetUp(self):
100+
# Patch timeouts to ensure short rescan SRV interval.
101+
self.client_knobs = client_knobs(
102+
heartbeat_frequency=WAIT_TIME,
103+
min_heartbeat_interval=WAIT_TIME,
104+
events_queue_frequency=WAIT_TIME,
105+
)
106+
self.client_knobs.enable()
107+
108+
async def asyncTearDown(self):
109+
self.client_knobs.disable()
110+
111+
def get_nodelist(self, client):
112+
return client._topology.description.server_descriptions().keys()
113+
114+
async def assert_nodelist_change(self, expected_nodelist, client, timeout=(100 * WAIT_TIME)):
115+
"""Check if the client._topology eventually sees all nodes in the
116+
expected_nodelist.
117+
"""
118+
119+
def predicate():
120+
nodelist = self.get_nodelist(client)
121+
if set(expected_nodelist) == set(nodelist):
122+
return True
123+
return False
124+
125+
await async_wait_until(predicate, "see expected nodelist", timeout=timeout)
126+
127+
async def assert_nodelist_nochange(self, expected_nodelist, client, timeout=(100 * WAIT_TIME)):
128+
"""Check if the client._topology ever deviates from seeing all nodes
129+
in the expected_nodelist. Consistency is checked after sleeping for
130+
(WAIT_TIME * 10) seconds. Also check that the resolver is called at
131+
least once.
132+
"""
133+
134+
def predicate():
135+
if set(expected_nodelist) == set(self.get_nodelist(client)):
136+
return pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl.call_count >= 1
137+
return False
138+
139+
await async_wait_until(predicate, "Node list equals expected nodelist", timeout=timeout)
140+
141+
nodelist = self.get_nodelist(client)
142+
if set(expected_nodelist) != set(nodelist):
143+
msg = "Client nodelist %s changed unexpectedly (expected %s)"
144+
raise self.fail(msg % (nodelist, expected_nodelist))
145+
self.assertGreaterEqual(
146+
pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl.call_count, # type: ignore
147+
1,
148+
"resolver was never called",
149+
)
150+
return True
151+
152+
async def run_scenario(self, dns_response, expect_change):
153+
self.assertEqual(_have_dnspython(), True)
154+
if callable(dns_response):
155+
dns_resolver_response = dns_response
156+
else:
157+
158+
def dns_resolver_response():
159+
return dns_response
160+
161+
if expect_change:
162+
assertion_method = self.assert_nodelist_change
163+
count_resolver_calls = False
164+
expected_response = dns_response
165+
else:
166+
assertion_method = self.assert_nodelist_nochange
167+
count_resolver_calls = True
168+
expected_response = self.BASE_SRV_RESPONSE
169+
170+
# Patch timeouts to ensure short test running times.
171+
with SrvPollingKnobs(ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME):
172+
client = self.simple_client(self.CONNECTION_STRING)
173+
await client.aconnect()
174+
await self.assert_nodelist_change(self.BASE_SRV_RESPONSE, client)
175+
# Patch list of hosts returned by DNS query.
176+
with SrvPollingKnobs(
177+
nodelist_callback=dns_resolver_response, count_resolver_calls=count_resolver_calls
178+
):
179+
await assertion_method(expected_response, client)
180+
181+
async def test_addition(self):
182+
response = self.BASE_SRV_RESPONSE[:]
183+
response.append(("localhost.test.build.10gen.cc", 27019))
184+
await self.run_scenario(response, True)
185+
186+
async def test_removal(self):
187+
response = self.BASE_SRV_RESPONSE[:]
188+
response.remove(("localhost.test.build.10gen.cc", 27018))
189+
await self.run_scenario(response, True)
190+
191+
async def test_replace_one(self):
192+
response = self.BASE_SRV_RESPONSE[:]
193+
response.remove(("localhost.test.build.10gen.cc", 27018))
194+
response.append(("localhost.test.build.10gen.cc", 27019))
195+
await self.run_scenario(response, True)
196+
197+
async def test_replace_both_with_one(self):
198+
response = [("localhost.test.build.10gen.cc", 27019)]
199+
await self.run_scenario(response, True)
200+
201+
async def test_replace_both_with_two(self):
202+
response = [
203+
("localhost.test.build.10gen.cc", 27019),
204+
("localhost.test.build.10gen.cc", 27020),
205+
]
206+
await self.run_scenario(response, True)
207+
208+
async def test_dns_failures(self):
209+
from dns import exception
210+
211+
for exc in (exception.FormError, exception.TooBig, exception.Timeout):
212+
213+
def response_callback(*args):
214+
raise exc("DNS Failure!")
215+
216+
await self.run_scenario(response_callback, False)
217+
218+
async def test_dns_record_lookup_empty(self):
219+
response: list = []
220+
await self.run_scenario(response, False)
221+
222+
async def _test_recover_from_initial(self, initial_callback):
223+
# Construct a valid final response callback distinct from base.
224+
response_final = self.BASE_SRV_RESPONSE[:]
225+
response_final.pop()
226+
227+
def final_callback():
228+
return response_final
229+
230+
with SrvPollingKnobs(
231+
ttl_time=WAIT_TIME,
232+
min_srv_rescan_interval=WAIT_TIME,
233+
nodelist_callback=initial_callback,
234+
count_resolver_calls=True,
235+
):
236+
# Client uses unpatched method to get initial nodelist
237+
client = self.simple_client(self.CONNECTION_STRING)
238+
await client.aconnect()
239+
# Invalid DNS resolver response should not change nodelist.
240+
await self.assert_nodelist_nochange(self.BASE_SRV_RESPONSE, client)
241+
242+
with SrvPollingKnobs(
243+
ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME, nodelist_callback=final_callback
244+
):
245+
# Nodelist should reflect new valid DNS resolver response.
246+
await self.assert_nodelist_change(response_final, client)
247+
248+
async def test_recover_from_initially_empty_seedlist(self):
249+
def empty_seedlist():
250+
return []
251+
252+
await self._test_recover_from_initial(empty_seedlist)
253+
254+
async def test_recover_from_initially_erroring_seedlist(self):
255+
def erroring_seedlist():
256+
raise ConfigurationError
257+
258+
await self._test_recover_from_initial(erroring_seedlist)
259+
260+
async def test_10_all_dns_selected(self):
261+
response = [
262+
("localhost.test.build.10gen.cc", 27017),
263+
("localhost.test.build.10gen.cc", 27019),
264+
("localhost.test.build.10gen.cc", 27020),
265+
]
266+
267+
def nodelist_callback():
268+
return response
269+
270+
with SrvPollingKnobs(ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME):
271+
client = self.simple_client(self.CONNECTION_STRING, srvMaxHosts=0)
272+
await client.aconnect()
273+
with SrvPollingKnobs(nodelist_callback=nodelist_callback):
274+
await self.assert_nodelist_change(response, client)
275+
276+
async def test_11_all_dns_selected(self):
277+
response = [
278+
("localhost.test.build.10gen.cc", 27019),
279+
("localhost.test.build.10gen.cc", 27020),
280+
]
281+
282+
def nodelist_callback():
283+
return response
284+
285+
with SrvPollingKnobs(ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME):
286+
client = self.simple_client(self.CONNECTION_STRING, srvMaxHosts=2)
287+
await client.aconnect()
288+
with SrvPollingKnobs(nodelist_callback=nodelist_callback):
289+
await self.assert_nodelist_change(response, client)
290+
291+
async def test_12_new_dns_randomly_selected(self):
292+
response = [
293+
("localhost.test.build.10gen.cc", 27020),
294+
("localhost.test.build.10gen.cc", 27019),
295+
("localhost.test.build.10gen.cc", 27017),
296+
]
297+
298+
def nodelist_callback():
299+
return response
300+
301+
with SrvPollingKnobs(ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME):
302+
client = self.simple_client(self.CONNECTION_STRING, srvMaxHosts=2)
303+
await client.aconnect()
304+
with SrvPollingKnobs(nodelist_callback=nodelist_callback):
305+
await asyncio.sleep(2 * common.MIN_SRV_RESCAN_INTERVAL)
306+
final_topology = set(client.topology_description.server_descriptions())
307+
self.assertIn(("localhost.test.build.10gen.cc", 27017), final_topology)
308+
self.assertEqual(len(final_topology), 2)
309+
310+
async def test_does_not_flipflop(self):
311+
with SrvPollingKnobs(ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME):
312+
client = self.simple_client(self.CONNECTION_STRING, srvMaxHosts=1)
313+
await client.aconnect()
314+
old = set(client.topology_description.server_descriptions())
315+
await asyncio.sleep(4 * WAIT_TIME)
316+
new = set(client.topology_description.server_descriptions())
317+
self.assertSetEqual(old, new)
318+
319+
async def test_srv_service_name(self):
320+
# Construct a valid final response callback distinct from base.
321+
response = [
322+
("localhost.test.build.10gen.cc.", 27019),
323+
("localhost.test.build.10gen.cc.", 27020),
324+
]
325+
326+
def nodelist_callback():
327+
return response
328+
329+
with SrvPollingKnobs(ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME):
330+
client = self.simple_client(
331+
"mongodb+srv://test22.test.build.10gen.cc/?srvServiceName=customname"
332+
)
333+
await client.aconnect()
334+
with SrvPollingKnobs(nodelist_callback=nodelist_callback):
335+
await self.assert_nodelist_change(response, client)
336+
337+
async def test_srv_waits_to_poll(self):
338+
modified = [("localhost.test.build.10gen.cc", 27019)]
339+
340+
def resolver_response():
341+
return modified
342+
343+
with SrvPollingKnobs(
344+
ttl_time=WAIT_TIME,
345+
min_srv_rescan_interval=WAIT_TIME,
346+
nodelist_callback=resolver_response,
347+
):
348+
client = self.simple_client(self.CONNECTION_STRING)
349+
await client.aconnect()
350+
with self.assertRaises(AssertionError):
351+
await self.assert_nodelist_change(modified, client, timeout=WAIT_TIME / 2)
352+
353+
def test_import_dns_resolver(self):
354+
# Regression test for PYTHON-4407
355+
import dns.resolver
356+
357+
self.assertTrue(hasattr(dns.resolver, "resolve"))
358+
359+
360+
if __name__ == "__main__":
361+
unittest.main()

0 commit comments

Comments
 (0)