From ff0d5dc3b86084571b1f27004605ee547612231f Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Thu, 30 Jan 2025 17:21:03 -0800 Subject: [PATCH 1/8] Convert test.test_mongos_load_balancing to async --- .../test_mongos_load_balancing.py | 223 ++++++++++++++++++ test/test_mongos_load_balancing.py | 58 +++-- tools/synchro.py | 1 + 3 files changed, 268 insertions(+), 14 deletions(-) create mode 100644 test/asynchronous/test_mongos_load_balancing.py diff --git a/test/asynchronous/test_mongos_load_balancing.py b/test/asynchronous/test_mongos_load_balancing.py new file mode 100644 index 0000000000..dc2e96ddc8 --- /dev/null +++ b/test/asynchronous/test_mongos_load_balancing.py @@ -0,0 +1,223 @@ +# Copyright 2015-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test AsyncMongoClient's mongos load balancing using a mock.""" +from __future__ import annotations + +import asyncio +import sys +import threading + +from pymongo.operations import _Op + +sys.path[0:0] = [""] + +from test.asynchronous import AsyncMockClientTest, async_client_context, connected, unittest +from test.asynchronous.pymongo_mocks import AsyncMockClient +from test.utils import async_wait_until + +from pymongo.errors import AutoReconnect, InvalidOperation +from pymongo.server_selectors import writable_server_selector +from pymongo.topology_description import TOPOLOGY_TYPE + +_IS_SYNC = False + + +@async_client_context.require_connection +@async_client_context.require_no_load_balancer +def asyncSetUpModule(): + pass + + +if _IS_SYNC: + + class SimpleOp(threading.Thread): + def __init__(self, client): + super().__init__() + self.client = client + self.passed = False + + def run(self): + self.client.db.command("ping") + self.passed = True # No exception raised. +else: + + class SimpleOp: + def __init__(self, client): + self.task = asyncio.create_task(self.run()) + self.client = client + self.passed = False + + async def run(self): + await self.client.db.command("ping") + self.passed = True # No exception raised. + + def start(self): + pass + + async def join(self): + await self.task + + +async def do_simple_op(client, nthreads): + threads = [SimpleOp(client) for _ in range(nthreads)] + for t in threads: + t.start() + + for t in threads: + await t.join() + + for t in threads: + assert t.passed + + +async def writable_addresses(topology): + return { + server.description.address + for server in await topology.select_servers(writable_server_selector, _Op.TEST) + } + + +class TestMongosLoadBalancing(AsyncMockClientTest): + def mock_client(self, **kwargs): + mock_client = AsyncMockClient( + standalones=[], + members=[], + mongoses=["a:1", "b:2", "c:3"], + host="a:1,b:2,c:3", + connect=False, + **kwargs, + ) + self.addAsyncCleanup(mock_client.aclose) + + # Latencies in seconds. + mock_client.mock_rtts["a:1"] = 0.020 + mock_client.mock_rtts["b:2"] = 0.025 + mock_client.mock_rtts["c:3"] = 0.045 + return mock_client + + async def test_lazy_connect(self): + # While connected() ensures we can trigger connection from the main + # thread and wait for the monitors, this test triggers connection from + # several threads at once to check for data races. + nthreads = 10 + client = self.mock_client() + self.assertEqual(0, len(client.nodes)) + + # Trigger initial connection. + await do_simple_op(client, nthreads) + await async_wait_until(lambda: len(client.nodes) == 3, "connect to all mongoses") + + async def test_failover(self): + nthreads = 10 + client = await connected(self.mock_client(localThresholdMS=0.001)) + await async_wait_until(lambda: len(client.nodes) == 3, "connect to all mongoses") + + # Our chosen mongos goes down. + client.kill_host("a:1") + + # Trigger failover to higher-latency nodes. AutoReconnect should be + # raised at most once in each thread. + passed = [] + + async def f(): + try: + await client.db.command("ping") + except AutoReconnect: + # Second attempt succeeds. + await client.db.command("ping") + + passed.append(True) + + if _IS_SYNC: + threads = [threading.Thread(target=f) for _ in range(nthreads)] + for t in threads: + t.start() + + for t in threads: + t.join() + else: + tasks = [asyncio.create_task(f()) for _ in range(nthreads)] + for t in tasks: + await t + + self.assertEqual(nthreads, len(passed)) + + # Down host removed from list. + self.assertEqual(2, len(client.nodes)) + + async def test_local_threshold(self): + client = await connected(self.mock_client(localThresholdMS=30)) + self.assertEqual(30, client.options.local_threshold_ms) + await async_wait_until(lambda: len(client.nodes) == 3, "connect to all mongoses") + topology = client._topology + + # All are within a 30-ms latency window, see self.mock_client(). + self.assertEqual({("a", 1), ("b", 2), ("c", 3)}, await writable_addresses(topology)) + + # No error + await client.admin.command("ping") + + client = await connected(self.mock_client(localThresholdMS=0)) + self.assertEqual(0, client.options.local_threshold_ms) + # No error + await client.db.command("ping") + # Our chosen mongos goes down. + client.kill_host("{}:{}".format(*next(iter(client.nodes)))) + try: + await client.db.command("ping") + except: + pass + + # We eventually connect to a new mongos. + async def connect_to_new_mongos(): + try: + return await client.db.command("ping") + except AutoReconnect: + pass + + await async_wait_until(connect_to_new_mongos, "connect to a new mongos") + + async def test_load_balancing(self): + # Although the server selection JSON tests already prove that + # select_servers works for sharded topologies, here we do an end-to-end + # test of discovering servers' round trip times and configuring + # localThresholdMS. + client = await connected(self.mock_client()) + await async_wait_until(lambda: len(client.nodes) == 3, "connect to all mongoses") + + # Prohibited for topology type Sharded. + with self.assertRaises(InvalidOperation): + await client.address + + topology = client._topology + self.assertEqual(TOPOLOGY_TYPE.Sharded, topology.description.topology_type) + + # a and b are within the 15-ms latency window, see self.mock_client(). + self.assertEqual({("a", 1), ("b", 2)}, await writable_addresses(topology)) + + client.mock_rtts["a:1"] = 0.045 + + # Discover only b is within latency window. + async def predicate(): + return {("b", 2)} == await writable_addresses(topology) + + await async_wait_until( + predicate, + 'discover server "a" is too far', + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_mongos_load_balancing.py b/test/test_mongos_load_balancing.py index 7bc8225465..584ace73f9 100644 --- a/test/test_mongos_load_balancing.py +++ b/test/test_mongos_load_balancing.py @@ -15,6 +15,7 @@ """Test MongoClient's mongos load balancing using a mock.""" from __future__ import annotations +import asyncio import sys import threading @@ -30,6 +31,8 @@ from pymongo.server_selectors import writable_server_selector from pymongo.topology_description import TOPOLOGY_TYPE +_IS_SYNC = True + @client_context.require_connection @client_context.require_no_load_balancer @@ -37,15 +40,34 @@ def setUpModule(): pass -class SimpleOp(threading.Thread): - def __init__(self, client): - super().__init__() - self.client = client - self.passed = False +if _IS_SYNC: + + class SimpleOp(threading.Thread): + def __init__(self, client): + super().__init__() + self.client = client + self.passed = False + + def run(self): + self.client.db.command("ping") + self.passed = True # No exception raised. +else: + + class SimpleOp: + def __init__(self, client): + self.task = asyncio.create_task(self.run()) + self.client = client + self.passed = False - def run(self): - self.client.db.command("ping") - self.passed = True # No exception raised. + def run(self): + self.client.db.command("ping") + self.passed = True # No exception raised. + + def start(self): + pass + + def join(self): + self.task def do_simple_op(client, nthreads): @@ -118,12 +140,17 @@ def f(): passed.append(True) - threads = [threading.Thread(target=f) for _ in range(nthreads)] - for t in threads: - t.start() + if _IS_SYNC: + threads = [threading.Thread(target=f) for _ in range(nthreads)] + for t in threads: + t.start() - for t in threads: - t.join() + for t in threads: + t.join() + else: + tasks = [asyncio.create_task(f()) for _ in range(nthreads)] + for t in tasks: + t self.assertEqual(nthreads, len(passed)) @@ -183,8 +210,11 @@ def test_load_balancing(self): client.mock_rtts["a:1"] = 0.045 # Discover only b is within latency window. + def predicate(): + return {("b", 2)} == writable_addresses(topology) + wait_until( - lambda: {("b", 2)} == writable_addresses(topology), + predicate, 'discover server "a" is too far', ) diff --git a/tools/synchro.py b/tools/synchro.py index dc272929ad..f876547d93 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -215,6 +215,7 @@ def async_only_test(f: str) -> bool: "test_gridfs_spec.py", "test_logger.py", "test_monitoring.py", + "test_mongos_load_balancing.py", "test_raw_bson.py", "test_retryable_reads.py", "test_retryable_writes.py", From 690dcf79d1dcb3ca3c944925210edcda822eeb35 Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Thu, 30 Jan 2025 17:35:18 -0800 Subject: [PATCH 2/8] modify where task is created / ran --- test/asynchronous/test_mongos_load_balancing.py | 4 ++-- test/test_mongos_load_balancing.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/test/asynchronous/test_mongos_load_balancing.py b/test/asynchronous/test_mongos_load_balancing.py index dc2e96ddc8..8e8ae457b5 100644 --- a/test/asynchronous/test_mongos_load_balancing.py +++ b/test/asynchronous/test_mongos_load_balancing.py @@ -55,7 +55,7 @@ def run(self): class SimpleOp: def __init__(self, client): - self.task = asyncio.create_task(self.run()) + self.task: asyncio.Task self.client = client self.passed = False @@ -64,7 +64,7 @@ async def run(self): self.passed = True # No exception raised. def start(self): - pass + self.task = asyncio.create_task(self.run()) async def join(self): await self.task diff --git a/test/test_mongos_load_balancing.py b/test/test_mongos_load_balancing.py index 584ace73f9..709d28fc5c 100644 --- a/test/test_mongos_load_balancing.py +++ b/test/test_mongos_load_balancing.py @@ -55,7 +55,7 @@ def run(self): class SimpleOp: def __init__(self, client): - self.task = asyncio.create_task(self.run()) + self.task: asyncio.Task self.client = client self.passed = False @@ -64,7 +64,7 @@ def run(self): self.passed = True # No exception raised. def start(self): - pass + self.task = asyncio.create_task(self.run()) def join(self): self.task From 988663f4f3f23e5d8d5f1f72d31c2a83df5c93fd Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Thu, 30 Jan 2025 18:01:13 -0800 Subject: [PATCH 3/8] apparently order matters for typing tests? --- .../test_mongos_load_balancing.py | 24 +++++++++---------- test/test_mongos_load_balancing.py | 22 ++++++++--------- 2 files changed, 23 insertions(+), 23 deletions(-) diff --git a/test/asynchronous/test_mongos_load_balancing.py b/test/asynchronous/test_mongos_load_balancing.py index 8e8ae457b5..a4ea58dc1d 100644 --- a/test/asynchronous/test_mongos_load_balancing.py +++ b/test/asynchronous/test_mongos_load_balancing.py @@ -40,18 +40,7 @@ def asyncSetUpModule(): pass -if _IS_SYNC: - - class SimpleOp(threading.Thread): - def __init__(self, client): - super().__init__() - self.client = client - self.passed = False - - def run(self): - self.client.db.command("ping") - self.passed = True # No exception raised. -else: +if not _IS_SYNC: class SimpleOp: def __init__(self, client): @@ -68,6 +57,17 @@ def start(self): async def join(self): await self.task +else: + + class SimpleOp(threading.Thread): + def __init__(self, client): + super().__init__() + self.client = client + self.passed = False + + def run(self): + self.client.db.command("ping") + self.passed = True # No exception raised. async def do_simple_op(client, nthreads): diff --git a/test/test_mongos_load_balancing.py b/test/test_mongos_load_balancing.py index 709d28fc5c..b187a04e48 100644 --- a/test/test_mongos_load_balancing.py +++ b/test/test_mongos_load_balancing.py @@ -40,22 +40,28 @@ def setUpModule(): pass -if _IS_SYNC: +if not _IS_SYNC: - class SimpleOp(threading.Thread): + class SimpleOp: def __init__(self, client): - super().__init__() + self.task: asyncio.Task self.client = client self.passed = False def run(self): self.client.db.command("ping") self.passed = True # No exception raised. + + def start(self): + self.task = asyncio.create_task(self.run()) + + def join(self): + self.task else: - class SimpleOp: + class SimpleOp(threading.Thread): def __init__(self, client): - self.task: asyncio.Task + super().__init__() self.client = client self.passed = False @@ -63,12 +69,6 @@ def run(self): self.client.db.command("ping") self.passed = True # No exception raised. - def start(self): - self.task = asyncio.create_task(self.run()) - - def join(self): - self.task - def do_simple_op(client, nthreads): threads = [SimpleOp(client) for _ in range(nthreads)] From 5cd8be8a7262fe7609edc5491303330edaebd4b7 Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Fri, 31 Jan 2025 10:56:32 -0800 Subject: [PATCH 4/8] response to review --- .../test_mongos_load_balancing.py | 25 +++++++++---------- test/test_mongos_load_balancing.py | 25 +++++++++---------- 2 files changed, 24 insertions(+), 26 deletions(-) diff --git a/test/asynchronous/test_mongos_load_balancing.py b/test/asynchronous/test_mongos_load_balancing.py index a4ea58dc1d..954f80f06d 100644 --- a/test/asynchronous/test_mongos_load_balancing.py +++ b/test/asynchronous/test_mongos_load_balancing.py @@ -34,17 +34,11 @@ _IS_SYNC = False -@async_client_context.require_connection -@async_client_context.require_no_load_balancer -def asyncSetUpModule(): - pass - - if not _IS_SYNC: class SimpleOp: def __init__(self, client): - self.task: asyncio.Task + self.task = asyncio.create_task(self.run()) self.client = client self.passed = False @@ -53,7 +47,7 @@ async def run(self): self.passed = True # No exception raised. def start(self): - self.task = asyncio.create_task(self.run()) + pass async def join(self): await self.task @@ -70,15 +64,15 @@ def run(self): self.passed = True # No exception raised. -async def do_simple_op(client, nthreads): - threads = [SimpleOp(client) for _ in range(nthreads)] - for t in threads: +async def do_simple_op(client, ntasks): + tasks = [SimpleOp(client) for _ in range(ntasks)] + for t in tasks: t.start() - for t in threads: + for t in tasks: await t.join() - for t in threads: + for t in tasks: assert t.passed @@ -90,6 +84,11 @@ async def writable_addresses(topology): class TestMongosLoadBalancing(AsyncMockClientTest): + @async_client_context.require_connection + @async_client_context.require_no_load_balancer + async def asyncSetUp(self): + await super().asyncSetUp() + def mock_client(self, **kwargs): mock_client = AsyncMockClient( standalones=[], diff --git a/test/test_mongos_load_balancing.py b/test/test_mongos_load_balancing.py index b187a04e48..636bd0b2c8 100644 --- a/test/test_mongos_load_balancing.py +++ b/test/test_mongos_load_balancing.py @@ -34,17 +34,11 @@ _IS_SYNC = True -@client_context.require_connection -@client_context.require_no_load_balancer -def setUpModule(): - pass - - if not _IS_SYNC: class SimpleOp: def __init__(self, client): - self.task: asyncio.Task + self.task = asyncio.create_task(self.run()) self.client = client self.passed = False @@ -53,7 +47,7 @@ def run(self): self.passed = True # No exception raised. def start(self): - self.task = asyncio.create_task(self.run()) + pass def join(self): self.task @@ -70,15 +64,15 @@ def run(self): self.passed = True # No exception raised. -def do_simple_op(client, nthreads): - threads = [SimpleOp(client) for _ in range(nthreads)] - for t in threads: +def do_simple_op(client, ntasks): + tasks = [SimpleOp(client) for _ in range(ntasks)] + for t in tasks: t.start() - for t in threads: + for t in tasks: t.join() - for t in threads: + for t in tasks: assert t.passed @@ -90,6 +84,11 @@ def writable_addresses(topology): class TestMongosLoadBalancing(MockClientTest): + @client_context.require_connection + @client_context.require_no_load_balancer + def setUp(self): + super().setUp() + def mock_client(self, **kwargs): mock_client = MockClient( standalones=[], From 499fe4232a99ef3c32a4578f52cea2eb03b0a388 Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Fri, 31 Jan 2025 12:10:44 -0800 Subject: [PATCH 5/8] create task in start --- test/asynchronous/test_mongos_load_balancing.py | 4 ++-- test/test_mongos_load_balancing.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/test/asynchronous/test_mongos_load_balancing.py b/test/asynchronous/test_mongos_load_balancing.py index 954f80f06d..4f5257a4d2 100644 --- a/test/asynchronous/test_mongos_load_balancing.py +++ b/test/asynchronous/test_mongos_load_balancing.py @@ -38,7 +38,7 @@ class SimpleOp: def __init__(self, client): - self.task = asyncio.create_task(self.run()) + self.task = None self.client = client self.passed = False @@ -47,7 +47,7 @@ async def run(self): self.passed = True # No exception raised. def start(self): - pass + self.task = asyncio.create_task(self.run()) async def join(self): await self.task diff --git a/test/test_mongos_load_balancing.py b/test/test_mongos_load_balancing.py index 636bd0b2c8..50dbcde2ef 100644 --- a/test/test_mongos_load_balancing.py +++ b/test/test_mongos_load_balancing.py @@ -38,7 +38,7 @@ class SimpleOp: def __init__(self, client): - self.task = asyncio.create_task(self.run()) + self.task = None self.client = client self.passed = False @@ -47,7 +47,7 @@ def run(self): self.passed = True # No exception raised. def start(self): - pass + self.task = asyncio.create_task(self.run()) def join(self): self.task From 41d2aab17188f65e37dbe92bd211333ca385808e Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Fri, 31 Jan 2025 17:09:35 -0800 Subject: [PATCH 6/8] refactor SimpleOp --- .../test_mongos_load_balancing.py | 43 +++++++++---------- test/test_mongos_load_balancing.py | 43 +++++++++---------- 2 files changed, 42 insertions(+), 44 deletions(-) diff --git a/test/asynchronous/test_mongos_load_balancing.py b/test/asynchronous/test_mongos_load_balancing.py index 4f5257a4d2..f3166f45f6 100644 --- a/test/asynchronous/test_mongos_load_balancing.py +++ b/test/asynchronous/test_mongos_load_balancing.py @@ -33,35 +33,34 @@ _IS_SYNC = False +if _IS_SYNC: + PARENT = threading.Thread +else: + PARENT = object -if not _IS_SYNC: - class SimpleOp: - def __init__(self, client): - self.task = None - self.client = client - self.passed = False +class SimpleOp(PARENT): + def __init__(self, client): + super().__init__() + self.client = client + self.passed = False + self.task = None - async def run(self): - await self.client.db.command("ping") - self.passed = True # No exception raised. + async def run(self): + await self.client.db.command("ping") + self.passed = True # No exception raised. - def start(self): + def start(self): + if _IS_SYNC: + super().start() + else: self.task = asyncio.create_task(self.run()) - async def join(self): + async def join(self): + if _IS_SYNC: + super().join() + else: await self.task -else: - - class SimpleOp(threading.Thread): - def __init__(self, client): - super().__init__() - self.client = client - self.passed = False - - def run(self): - self.client.db.command("ping") - self.passed = True # No exception raised. async def do_simple_op(client, ntasks): diff --git a/test/test_mongos_load_balancing.py b/test/test_mongos_load_balancing.py index 50dbcde2ef..1c1de7735d 100644 --- a/test/test_mongos_load_balancing.py +++ b/test/test_mongos_load_balancing.py @@ -33,35 +33,34 @@ _IS_SYNC = True +if _IS_SYNC: + PARENT = threading.Thread +else: + PARENT = object -if not _IS_SYNC: - class SimpleOp: - def __init__(self, client): - self.task = None - self.client = client - self.passed = False +class SimpleOp(PARENT): + def __init__(self, client): + super().__init__() + self.client = client + self.passed = False + self.task = None - def run(self): - self.client.db.command("ping") - self.passed = True # No exception raised. + def run(self): + self.client.db.command("ping") + self.passed = True # No exception raised. - def start(self): + def start(self): + if _IS_SYNC: + super().start() + else: self.task = asyncio.create_task(self.run()) - def join(self): + def join(self): + if _IS_SYNC: + super().join() + else: self.task -else: - - class SimpleOp(threading.Thread): - def __init__(self, client): - super().__init__() - self.client = client - self.passed = False - - def run(self): - self.client.db.command("ping") - self.passed = True # No exception raised. def do_simple_op(client, ntasks): From 5cd7d02acf060ae6184b7b8a9ad9a2a77d065651 Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Wed, 5 Feb 2025 18:23:45 -0800 Subject: [PATCH 7/8] use ConcurrentRunner --- test/asynchronous/helpers.py | 2 +- .../test_mongos_load_balancing.py | 23 +++---------------- test/helpers.py | 2 +- test/test_mongos_load_balancing.py | 21 ++--------------- 4 files changed, 7 insertions(+), 41 deletions(-) diff --git a/test/asynchronous/helpers.py b/test/asynchronous/helpers.py index a35c71b107..32fa254106 100644 --- a/test/asynchronous/helpers.py +++ b/test/asynchronous/helpers.py @@ -395,7 +395,7 @@ def __init__(self, **kwargs): async def start(self): self.task = create_task(self.run(), name=self.name) - async def join(self, timeout: float | None = 0): # type: ignore[override] + async def join(self, timeout: float | None = None): # type: ignore[override] if self.task is not None: await asyncio.wait([self.task], timeout=timeout) diff --git a/test/asynchronous/test_mongos_load_balancing.py b/test/asynchronous/test_mongos_load_balancing.py index f3166f45f6..c10f5b74e3 100644 --- a/test/asynchronous/test_mongos_load_balancing.py +++ b/test/asynchronous/test_mongos_load_balancing.py @@ -18,6 +18,7 @@ import asyncio import sys import threading +from test.asynchronous.helpers import ConcurrentRunner from pymongo.operations import _Op @@ -33,40 +34,22 @@ _IS_SYNC = False -if _IS_SYNC: - PARENT = threading.Thread -else: - PARENT = object - -class SimpleOp(PARENT): +class SimpleOp(ConcurrentRunner): def __init__(self, client): super().__init__() self.client = client self.passed = False - self.task = None async def run(self): await self.client.db.command("ping") self.passed = True # No exception raised. - def start(self): - if _IS_SYNC: - super().start() - else: - self.task = asyncio.create_task(self.run()) - - async def join(self): - if _IS_SYNC: - super().join() - else: - await self.task - async def do_simple_op(client, ntasks): tasks = [SimpleOp(client) for _ in range(ntasks)] for t in tasks: - t.start() + await t.start() for t in tasks: await t.join() diff --git a/test/helpers.py b/test/helpers.py index 705843efcd..6fdfd2ba42 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -395,7 +395,7 @@ def __init__(self, **kwargs): def start(self): self.task = create_task(self.run(), name=self.name) - def join(self, timeout: float | None = 0): # type: ignore[override] + def join(self, timeout: float | None = None): # type: ignore[override] if self.task is not None: asyncio.wait([self.task], timeout=timeout) diff --git a/test/test_mongos_load_balancing.py b/test/test_mongos_load_balancing.py index 1c1de7735d..4ad9325444 100644 --- a/test/test_mongos_load_balancing.py +++ b/test/test_mongos_load_balancing.py @@ -18,6 +18,7 @@ import asyncio import sys import threading +from test.helpers import ConcurrentRunner from pymongo.operations import _Op @@ -33,35 +34,17 @@ _IS_SYNC = True -if _IS_SYNC: - PARENT = threading.Thread -else: - PARENT = object - -class SimpleOp(PARENT): +class SimpleOp(ConcurrentRunner): def __init__(self, client): super().__init__() self.client = client self.passed = False - self.task = None def run(self): self.client.db.command("ping") self.passed = True # No exception raised. - def start(self): - if _IS_SYNC: - super().start() - else: - self.task = asyncio.create_task(self.run()) - - def join(self): - if _IS_SYNC: - super().join() - else: - self.task - def do_simple_op(client, ntasks): tasks = [SimpleOp(client) for _ in range(ntasks)] From b0c878c00654754b53ed8bc35755e19179a6decb Mon Sep 17 00:00:00 2001 From: Iris Ho Date: Thu, 6 Feb 2025 10:04:08 -0800 Subject: [PATCH 8/8] use ConcurrentRunner --- .../test_mongos_load_balancing.py | 19 +++++++------------ test/test_mongos_load_balancing.py | 19 +++++++------------ 2 files changed, 14 insertions(+), 24 deletions(-) diff --git a/test/asynchronous/test_mongos_load_balancing.py b/test/asynchronous/test_mongos_load_balancing.py index c10f5b74e3..0bc6a405f4 100644 --- a/test/asynchronous/test_mongos_load_balancing.py +++ b/test/asynchronous/test_mongos_load_balancing.py @@ -101,7 +101,7 @@ async def test_lazy_connect(self): await async_wait_until(lambda: len(client.nodes) == 3, "connect to all mongoses") async def test_failover(self): - nthreads = 10 + ntasks = 10 client = await connected(self.mock_client(localThresholdMS=0.001)) await async_wait_until(lambda: len(client.nodes) == 3, "connect to all mongoses") @@ -121,19 +121,14 @@ async def f(): passed.append(True) - if _IS_SYNC: - threads = [threading.Thread(target=f) for _ in range(nthreads)] - for t in threads: - t.start() + tasks = [ConcurrentRunner(target=f) for _ in range(ntasks)] + for t in tasks: + await t.start() - for t in threads: - t.join() - else: - tasks = [asyncio.create_task(f()) for _ in range(nthreads)] - for t in tasks: - await t + for t in tasks: + await t.join() - self.assertEqual(nthreads, len(passed)) + self.assertEqual(ntasks, len(passed)) # Down host removed from list. self.assertEqual(2, len(client.nodes)) diff --git a/test/test_mongos_load_balancing.py b/test/test_mongos_load_balancing.py index 4ad9325444..ca2f3cfd1e 100644 --- a/test/test_mongos_load_balancing.py +++ b/test/test_mongos_load_balancing.py @@ -101,7 +101,7 @@ def test_lazy_connect(self): wait_until(lambda: len(client.nodes) == 3, "connect to all mongoses") def test_failover(self): - nthreads = 10 + ntasks = 10 client = connected(self.mock_client(localThresholdMS=0.001)) wait_until(lambda: len(client.nodes) == 3, "connect to all mongoses") @@ -121,19 +121,14 @@ def f(): passed.append(True) - if _IS_SYNC: - threads = [threading.Thread(target=f) for _ in range(nthreads)] - for t in threads: - t.start() + tasks = [ConcurrentRunner(target=f) for _ in range(ntasks)] + for t in tasks: + t.start() - for t in threads: - t.join() - else: - tasks = [asyncio.create_task(f()) for _ in range(nthreads)] - for t in tasks: - t + for t in tasks: + t.join() - self.assertEqual(nthreads, len(passed)) + self.assertEqual(ntasks, len(passed)) # Down host removed from list. self.assertEqual(2, len(client.nodes))