Skip to content

PYTHON-5101 - Convert test.test_server_selection_in_window to async #2118

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 196 additions & 0 deletions test/asynchronous/test_server_selection_in_window.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
# Copyright 2020-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Test the topology module's Server Selection Spec implementation."""
from __future__ import annotations

import asyncio
import os
import threading
from pathlib import Path
from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest
from test.utils import (
CMAPListener,
OvertCommandListener,
async_get_pool,
async_wait_until,
)
from test.utils_selection_tests import create_topology
from test.utils_spec_runner import SpecTestCreator

from pymongo.common import clean_node
from pymongo.monitoring import ConnectionReadyEvent
from pymongo.operations import _Op
from pymongo.read_preferences import ReadPreference

_IS_SYNC = False
# Location of JSON test specifications.
if _IS_SYNC:
TEST_PATH = os.path.join(Path(__file__).resolve().parent, "server_selection", "in_window")
else:
TEST_PATH = os.path.join(
Path(__file__).resolve().parent.parent, "server_selection", "in_window"
)


class TestAllScenarios(unittest.IsolatedAsyncioTestCase):
def run_scenario(self, scenario_def):
topology = create_topology(scenario_def)

# Update mock operation_count state:
for mock in scenario_def["mocked_topology_state"]:
address = clean_node(mock["address"])
server = topology.get_server_by_address(address)
server.pool.operation_count = mock["operation_count"]

pref = ReadPreference.NEAREST
counts = {address: 0 for address in topology._description.server_descriptions()}

# Number of times to repeat server selection
iterations = scenario_def["iterations"]
for _ in range(iterations):
server = topology.select_server(pref, _Op.TEST, server_selection_timeout=0)
counts[server.description.address] += 1

# Verify expected_frequencies
outcome = scenario_def["outcome"]
tolerance = outcome["tolerance"]
expected_frequencies = outcome["expected_frequencies"]
for host_str, freq in expected_frequencies.items():
address = clean_node(host_str)
actual_freq = float(counts[address]) / iterations
if freq == 0:
# Should be exactly 0.
self.assertEqual(actual_freq, 0)
else:
# Should be within 'tolerance'.
self.assertAlmostEqual(actual_freq, freq, delta=tolerance)


def create_test(scenario_def, test, name):
def run_scenario(self):
self.run_scenario(scenario_def)

return run_scenario


class CustomSpecTestCreator(SpecTestCreator):
def tests(self, scenario_def):
"""Extract the tests from a spec file.

Server selection in_window tests do not have a 'tests' field.
The whole file represents a single test case.
"""
return [scenario_def]


CustomSpecTestCreator(create_test, TestAllScenarios, TEST_PATH).create_tests()

if _IS_SYNC:
PARENT = threading.Thread
else:
PARENT = object


class FinderThread(PARENT):
def __init__(self, collection, iterations):
super().__init__()
self.daemon = True
self.collection = collection
self.iterations = iterations
self.passed = False
self.task = None

async def run(self):
for _ in range(self.iterations):
await self.collection.find_one({})
self.passed = True

def start(self):
if _IS_SYNC:
super().start()
else:
self.task = asyncio.create_task(self.run())

async def join(self):
if _IS_SYNC:
super().join()
else:
await self.task


class TestProse(AsyncIntegrationTest):
def frequencies(self, client, listener, n_finds=10):
coll = client.test.test
N_TASKS = 10
tasks = [FinderThread(coll, n_finds) for _ in range(N_TASKS)]
for task in tasks:
task.start()
for task in tasks:
task.join()
for task in tasks:
self.assertTrue(task.passed)

events = listener.started_events
self.assertEqual(len(events), n_finds * N_TASKS)
nodes = client.nodes
self.assertEqual(len(nodes), 2)
freqs = {address: 0.0 for address in nodes}
for event in events:
freqs[event.connection_id] += 1
for address in freqs:
freqs[address] = freqs[address] / float(len(events))
return freqs

@async_client_context.require_failCommand_appName
@async_client_context.require_multiple_mongoses
async def test_load_balancing(self):
listener = OvertCommandListener()
cmap_listener = CMAPListener()
# PYTHON-2584: Use a large localThresholdMS to avoid the impact of
# varying RTTs.
client = await self.async_rs_client(
async_client_context.mongos_seeds(),
appName="loadBalancingTest",
event_listeners=[listener, cmap_listener],
localThresholdMS=30000,
minPoolSize=10,
)
await async_wait_until(lambda: len(client.nodes) == 2, "discover both nodes")
# Wait for both pools to be populated.
cmap_listener.wait_for_event(ConnectionReadyEvent, 20)
# Delay find commands on only one mongos.
delay_finds = {
"configureFailPoint": "failCommand",
"mode": {"times": 10000},
"data": {
"failCommands": ["find"],
"blockAsyncConnection": True,
"blockTimeMS": 500,
"appName": "loadBalancingTest",
},
}
with self.fail_point(delay_finds):
nodes = async_client_context.client.nodes
self.assertEqual(len(nodes), 1)
delayed_server = await anext(iter(nodes))
freqs = self.frequencies(client, listener)
self.assertLessEqual(freqs[delayed_server], 0.25)
listener.reset()
freqs = self.frequencies(client, listener, n_finds=150)
self.assertAlmostEqual(freqs[delayed_server], 0.50, delta=0.15)


if __name__ == "__main__":
unittest.main()
50 changes: 37 additions & 13 deletions test/test_server_selection_in_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@
"""Test the topology module's Server Selection Spec implementation."""
from __future__ import annotations

import asyncio
import os
import threading
from pathlib import Path
from test import IntegrationTest, client_context, unittest
from test.utils import (
CMAPListener,
Expand All @@ -32,10 +34,14 @@
from pymongo.operations import _Op
from pymongo.read_preferences import ReadPreference

_IS_SYNC = True
# Location of JSON test specifications.
TEST_PATH = os.path.join(
os.path.dirname(os.path.realpath(__file__)), os.path.join("server_selection", "in_window")
)
if _IS_SYNC:
TEST_PATH = os.path.join(Path(__file__).resolve().parent, "server_selection", "in_window")
else:
TEST_PATH = os.path.join(
Path(__file__).resolve().parent.parent, "server_selection", "in_window"
)


class TestAllScenarios(unittest.TestCase):
Expand Down Expand Up @@ -91,35 +97,53 @@ def tests(self, scenario_def):

CustomSpecTestCreator(create_test, TestAllScenarios, TEST_PATH).create_tests()

if _IS_SYNC:
PARENT = threading.Thread
else:
PARENT = object


class FinderThread(threading.Thread):
class FinderThread(PARENT):
def __init__(self, collection, iterations):
super().__init__()
self.daemon = True
self.collection = collection
self.iterations = iterations
self.passed = False
self.task = None

def run(self):
for _ in range(self.iterations):
self.collection.find_one({})
self.passed = True

def start(self):
if _IS_SYNC:
super().start()
else:
self.task = asyncio.create_task(self.run())

def join(self):
if _IS_SYNC:
super().join()
else:
self.task


class TestProse(IntegrationTest):
def frequencies(self, client, listener, n_finds=10):
coll = client.test.test
N_THREADS = 10
threads = [FinderThread(coll, n_finds) for _ in range(N_THREADS)]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
for thread in threads:
self.assertTrue(thread.passed)
N_TASKS = 10
tasks = [FinderThread(coll, n_finds) for _ in range(N_TASKS)]
for task in tasks:
task.start()
for task in tasks:
task.join()
for task in tasks:
self.assertTrue(task.passed)

events = listener.started_events
self.assertEqual(len(events), n_finds * N_THREADS)
self.assertEqual(len(events), n_finds * N_TASKS)
nodes = client.nodes
self.assertEqual(len(nodes), 2)
freqs = {address: 0.0 for address in nodes}
Expand Down
1 change: 1 addition & 0 deletions tools/synchro.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ def async_only_test(f: str) -> bool:
"test_retryable_reads_unified.py",
"test_retryable_writes.py",
"test_retryable_writes_unified.py",
"test_server_selection_in_window.py",
"test_session.py",
"test_transactions.py",
"unified_format.py",
Expand Down
Loading