diff --git a/test/asynchronous/test_mongos_load_balancing.py b/test/asynchronous/test_mongos_load_balancing.py new file mode 100644 index 0000000000..0bc6a405f4 --- /dev/null +++ b/test/asynchronous/test_mongos_load_balancing.py @@ -0,0 +1,199 @@ +# Copyright 2015-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test AsyncMongoClient's mongos load balancing using a mock.""" +from __future__ import annotations + +import asyncio +import sys +import threading +from test.asynchronous.helpers import ConcurrentRunner + +from pymongo.operations import _Op + +sys.path[0:0] = [""] + +from test.asynchronous import AsyncMockClientTest, async_client_context, connected, unittest +from test.asynchronous.pymongo_mocks import AsyncMockClient +from test.utils import async_wait_until + +from pymongo.errors import AutoReconnect, InvalidOperation +from pymongo.server_selectors import writable_server_selector +from pymongo.topology_description import TOPOLOGY_TYPE + +_IS_SYNC = False + + +class SimpleOp(ConcurrentRunner): + def __init__(self, client): + super().__init__() + self.client = client + self.passed = False + + async def run(self): + await self.client.db.command("ping") + self.passed = True # No exception raised. + + +async def do_simple_op(client, ntasks): + tasks = [SimpleOp(client) for _ in range(ntasks)] + for t in tasks: + await t.start() + + for t in tasks: + await t.join() + + for t in tasks: + assert t.passed + + +async def writable_addresses(topology): + return { + server.description.address + for server in await topology.select_servers(writable_server_selector, _Op.TEST) + } + + +class TestMongosLoadBalancing(AsyncMockClientTest): + @async_client_context.require_connection + @async_client_context.require_no_load_balancer + async def asyncSetUp(self): + await super().asyncSetUp() + + def mock_client(self, **kwargs): + mock_client = AsyncMockClient( + standalones=[], + members=[], + mongoses=["a:1", "b:2", "c:3"], + host="a:1,b:2,c:3", + connect=False, + **kwargs, + ) + self.addAsyncCleanup(mock_client.aclose) + + # Latencies in seconds. + mock_client.mock_rtts["a:1"] = 0.020 + mock_client.mock_rtts["b:2"] = 0.025 + mock_client.mock_rtts["c:3"] = 0.045 + return mock_client + + async def test_lazy_connect(self): + # While connected() ensures we can trigger connection from the main + # thread and wait for the monitors, this test triggers connection from + # several threads at once to check for data races. + nthreads = 10 + client = self.mock_client() + self.assertEqual(0, len(client.nodes)) + + # Trigger initial connection. + await do_simple_op(client, nthreads) + await async_wait_until(lambda: len(client.nodes) == 3, "connect to all mongoses") + + async def test_failover(self): + ntasks = 10 + client = await connected(self.mock_client(localThresholdMS=0.001)) + await async_wait_until(lambda: len(client.nodes) == 3, "connect to all mongoses") + + # Our chosen mongos goes down. + client.kill_host("a:1") + + # Trigger failover to higher-latency nodes. AutoReconnect should be + # raised at most once in each thread. + passed = [] + + async def f(): + try: + await client.db.command("ping") + except AutoReconnect: + # Second attempt succeeds. + await client.db.command("ping") + + passed.append(True) + + tasks = [ConcurrentRunner(target=f) for _ in range(ntasks)] + for t in tasks: + await t.start() + + for t in tasks: + await t.join() + + self.assertEqual(ntasks, len(passed)) + + # Down host removed from list. + self.assertEqual(2, len(client.nodes)) + + async def test_local_threshold(self): + client = await connected(self.mock_client(localThresholdMS=30)) + self.assertEqual(30, client.options.local_threshold_ms) + await async_wait_until(lambda: len(client.nodes) == 3, "connect to all mongoses") + topology = client._topology + + # All are within a 30-ms latency window, see self.mock_client(). + self.assertEqual({("a", 1), ("b", 2), ("c", 3)}, await writable_addresses(topology)) + + # No error + await client.admin.command("ping") + + client = await connected(self.mock_client(localThresholdMS=0)) + self.assertEqual(0, client.options.local_threshold_ms) + # No error + await client.db.command("ping") + # Our chosen mongos goes down. + client.kill_host("{}:{}".format(*next(iter(client.nodes)))) + try: + await client.db.command("ping") + except: + pass + + # We eventually connect to a new mongos. + async def connect_to_new_mongos(): + try: + return await client.db.command("ping") + except AutoReconnect: + pass + + await async_wait_until(connect_to_new_mongos, "connect to a new mongos") + + async def test_load_balancing(self): + # Although the server selection JSON tests already prove that + # select_servers works for sharded topologies, here we do an end-to-end + # test of discovering servers' round trip times and configuring + # localThresholdMS. + client = await connected(self.mock_client()) + await async_wait_until(lambda: len(client.nodes) == 3, "connect to all mongoses") + + # Prohibited for topology type Sharded. + with self.assertRaises(InvalidOperation): + await client.address + + topology = client._topology + self.assertEqual(TOPOLOGY_TYPE.Sharded, topology.description.topology_type) + + # a and b are within the 15-ms latency window, see self.mock_client(). + self.assertEqual({("a", 1), ("b", 2)}, await writable_addresses(topology)) + + client.mock_rtts["a:1"] = 0.045 + + # Discover only b is within latency window. + async def predicate(): + return {("b", 2)} == await writable_addresses(topology) + + await async_wait_until( + predicate, + 'discover server "a" is too far', + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_mongos_load_balancing.py b/test/test_mongos_load_balancing.py index 7bc8225465..ca2f3cfd1e 100644 --- a/test/test_mongos_load_balancing.py +++ b/test/test_mongos_load_balancing.py @@ -15,8 +15,10 @@ """Test MongoClient's mongos load balancing using a mock.""" from __future__ import annotations +import asyncio import sys import threading +from test.helpers import ConcurrentRunner from pymongo.operations import _Op @@ -30,14 +32,10 @@ from pymongo.server_selectors import writable_server_selector from pymongo.topology_description import TOPOLOGY_TYPE +_IS_SYNC = True -@client_context.require_connection -@client_context.require_no_load_balancer -def setUpModule(): - pass - -class SimpleOp(threading.Thread): +class SimpleOp(ConcurrentRunner): def __init__(self, client): super().__init__() self.client = client @@ -48,15 +46,15 @@ def run(self): self.passed = True # No exception raised. -def do_simple_op(client, nthreads): - threads = [SimpleOp(client) for _ in range(nthreads)] - for t in threads: +def do_simple_op(client, ntasks): + tasks = [SimpleOp(client) for _ in range(ntasks)] + for t in tasks: t.start() - for t in threads: + for t in tasks: t.join() - for t in threads: + for t in tasks: assert t.passed @@ -68,6 +66,11 @@ def writable_addresses(topology): class TestMongosLoadBalancing(MockClientTest): + @client_context.require_connection + @client_context.require_no_load_balancer + def setUp(self): + super().setUp() + def mock_client(self, **kwargs): mock_client = MockClient( standalones=[], @@ -98,7 +101,7 @@ def test_lazy_connect(self): wait_until(lambda: len(client.nodes) == 3, "connect to all mongoses") def test_failover(self): - nthreads = 10 + ntasks = 10 client = connected(self.mock_client(localThresholdMS=0.001)) wait_until(lambda: len(client.nodes) == 3, "connect to all mongoses") @@ -118,14 +121,14 @@ def f(): passed.append(True) - threads = [threading.Thread(target=f) for _ in range(nthreads)] - for t in threads: + tasks = [ConcurrentRunner(target=f) for _ in range(ntasks)] + for t in tasks: t.start() - for t in threads: + for t in tasks: t.join() - self.assertEqual(nthreads, len(passed)) + self.assertEqual(ntasks, len(passed)) # Down host removed from list. self.assertEqual(2, len(client.nodes)) @@ -183,8 +186,11 @@ def test_load_balancing(self): client.mock_rtts["a:1"] = 0.045 # Discover only b is within latency window. + def predicate(): + return {("b", 2)} == writable_addresses(topology) + wait_until( - lambda: {("b", 2)} == writable_addresses(topology), + predicate, 'discover server "a" is too far', ) diff --git a/tools/synchro.py b/tools/synchro.py index fe38b4dcfe..3721a3bab3 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -221,6 +221,7 @@ def async_only_test(f: str) -> bool: "test_logger.py", "test_max_staleness.py", "test_monitoring.py", + "test_mongos_load_balancing.py", "test_on_demand_csfle.py", "test_raw_bson.py", "test_read_concern.py",