diff --git a/test/asynchronous/helpers.py b/test/asynchronous/helpers.py index 46f66af62d..b5fc5d8ac4 100644 --- a/test/asynchronous/helpers.py +++ b/test/asynchronous/helpers.py @@ -42,6 +42,7 @@ from bson.son import SON from pymongo import common, message +from pymongo.read_preferences import ReadPreference from pymongo.ssl_support import HAVE_SSL, _ssl # type:ignore[attr-defined] from pymongo.uri_parser import parse_uri @@ -150,6 +151,16 @@ def _create_user(authdb, user, pwd=None, roles=None, **kwargs): return authdb.command(cmd) +async def async_repl_set_step_down(client, **kwargs): + """Run replSetStepDown, first unfreezing a secondary with replSetFreeze.""" + cmd = SON([("replSetStepDown", 1)]) + cmd.update(kwargs) + + # Unfreeze a secondary to ensure a speedy election. + await client.admin.command("replSetFreeze", 0, read_preference=ReadPreference.SECONDARY) + await client.admin.command(cmd) + + class client_knobs: def __init__( self, diff --git a/test/asynchronous/test_connections_survive_primary_stepdown_spec.py b/test/asynchronous/test_connections_survive_primary_stepdown_spec.py new file mode 100644 index 0000000000..289cf49751 --- /dev/null +++ b/test/asynchronous/test_connections_survive_primary_stepdown_spec.py @@ -0,0 +1,148 @@ +# Copyright 2019-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test compliance with the connections survive primary step down spec.""" +from __future__ import annotations + +import sys + +sys.path[0:0] = [""] + +from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest +from test.asynchronous.helpers import async_repl_set_step_down +from test.utils import ( + CMAPListener, + async_ensure_all_connected, +) + +from bson import SON +from pymongo import monitoring +from pymongo.asynchronous.collection import AsyncCollection +from pymongo.errors import NotPrimaryError +from pymongo.write_concern import WriteConcern + +_IS_SYNC = False + + +class TestAsyncConnectionsSurvivePrimaryStepDown(AsyncIntegrationTest): + listener: CMAPListener + coll: AsyncCollection + + @classmethod + @async_client_context.require_replica_set + async def _setup_class(cls): + await super()._setup_class() + cls.listener = CMAPListener() + cls.client = await cls.unmanaged_async_rs_or_single_client( + event_listeners=[cls.listener], retryWrites=False, heartbeatFrequencyMS=500 + ) + + # Ensure connections to all servers in replica set. This is to test + # that the is_writable flag is properly updated for connections that + # survive a replica set election. + await async_ensure_all_connected(cls.client) + cls.listener.reset() + + cls.db = cls.client.get_database("step-down", write_concern=WriteConcern("majority")) + cls.coll = cls.db.get_collection("step-down", write_concern=WriteConcern("majority")) + + @classmethod + async def _tearDown_class(cls): + await cls.client.close() + + async def asyncSetUp(self): + # Note that all ops use same write-concern as self.db (majority). + await self.db.drop_collection("step-down") + await self.db.create_collection("step-down") + self.listener.reset() + + async def set_fail_point(self, command_args): + cmd = SON([("configureFailPoint", "failCommand")]) + cmd.update(command_args) + await self.client.admin.command(cmd) + + def verify_pool_cleared(self): + self.assertEqual(self.listener.event_count(monitoring.PoolClearedEvent), 1) + + def verify_pool_not_cleared(self): + self.assertEqual(self.listener.event_count(monitoring.PoolClearedEvent), 0) + + @async_client_context.require_version_min(4, 2, -1) + async def test_get_more_iteration(self): + # Insert 5 documents with WC majority. + await self.coll.insert_many([{"data": k} for k in range(5)]) + # Start a find operation and retrieve first batch of results. + batch_size = 2 + cursor = self.coll.find(batch_size=batch_size) + for _ in range(batch_size): + await cursor.next() + # Force step-down the primary. + await async_repl_set_step_down(self.client, replSetStepDown=5, force=True) + # Get await anext batch of results. + for _ in range(batch_size): + await cursor.next() + # Verify pool not cleared. + self.verify_pool_not_cleared() + # Attempt insertion to mark server description as stale and prevent a + # NotPrimaryError on the subsequent operation. + try: + await self.coll.insert_one({}) + except NotPrimaryError: + pass + # Next insert should succeed on the new primary without clearing pool. + await self.coll.insert_one({}) + self.verify_pool_not_cleared() + + async def run_scenario(self, error_code, retry, pool_status_checker): + # Set fail point. + await self.set_fail_point( + {"mode": {"times": 1}, "data": {"failCommands": ["insert"], "errorCode": error_code}} + ) + self.addAsyncCleanup(self.set_fail_point, {"mode": "off"}) + # Insert record and verify failure. + with self.assertRaises(NotPrimaryError) as exc: + await self.coll.insert_one({"test": 1}) + self.assertEqual(exc.exception.details["code"], error_code) # type: ignore[call-overload] + # Retry before CMAPListener assertion if retry_before=True. + if retry: + await self.coll.insert_one({"test": 1}) + # Verify pool cleared/not cleared. + pool_status_checker() + # Always retry here to ensure discovery of new primary. + await self.coll.insert_one({"test": 1}) + + @async_client_context.require_version_min(4, 2, -1) + @async_client_context.require_test_commands + async def test_not_primary_keep_connection_pool(self): + await self.run_scenario(10107, True, self.verify_pool_not_cleared) + + @async_client_context.require_version_min(4, 0, 0) + @async_client_context.require_version_max(4, 1, 0, -1) + @async_client_context.require_test_commands + async def test_not_primary_reset_connection_pool(self): + await self.run_scenario(10107, False, self.verify_pool_cleared) + + @async_client_context.require_version_min(4, 0, 0) + @async_client_context.require_test_commands + async def test_shutdown_in_progress(self): + await self.run_scenario(91, False, self.verify_pool_cleared) + + @async_client_context.require_version_min(4, 0, 0) + @async_client_context.require_test_commands + async def test_interrupted_at_shutdown(self): + await self.run_scenario(11600, False, self.verify_pool_cleared) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/helpers.py b/test/helpers.py index bf6186d1a0..11d5ab0374 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -42,6 +42,7 @@ from bson.son import SON from pymongo import common, message +from pymongo.read_preferences import ReadPreference from pymongo.ssl_support import HAVE_SSL, _ssl # type:ignore[attr-defined] from pymongo.uri_parser import parse_uri @@ -150,6 +151,16 @@ def _create_user(authdb, user, pwd=None, roles=None, **kwargs): return authdb.command(cmd) +def repl_set_step_down(client, **kwargs): + """Run replSetStepDown, first unfreezing a secondary with replSetFreeze.""" + cmd = SON([("replSetStepDown", 1)]) + cmd.update(kwargs) + + # Unfreeze a secondary to ensure a speedy election. + client.admin.command("replSetFreeze", 0, read_preference=ReadPreference.SECONDARY) + client.admin.command(cmd) + + class client_knobs: def __init__( self, diff --git a/test/test_connections_survive_primary_stepdown_spec.py b/test/test_connections_survive_primary_stepdown_spec.py index fba7675743..54cc4e0482 100644 --- a/test/test_connections_survive_primary_stepdown_spec.py +++ b/test/test_connections_survive_primary_stepdown_spec.py @@ -20,10 +20,10 @@ sys.path[0:0] = [""] from test import IntegrationTest, client_context, unittest +from test.helpers import repl_set_step_down from test.utils import ( CMAPListener, ensure_all_connected, - repl_set_step_down, ) from bson import SON @@ -32,6 +32,8 @@ from pymongo.synchronous.collection import Collection from pymongo.write_concern import WriteConcern +_IS_SYNC = True + class TestConnectionsSurvivePrimaryStepDown(IntegrationTest): listener: CMAPListener @@ -39,8 +41,8 @@ class TestConnectionsSurvivePrimaryStepDown(IntegrationTest): @classmethod @client_context.require_replica_set - def setUpClass(cls): - super().setUpClass() + def _setup_class(cls): + super()._setup_class() cls.listener = CMAPListener() cls.client = cls.unmanaged_rs_or_single_client( event_listeners=[cls.listener], retryWrites=False, heartbeatFrequencyMS=500 @@ -56,7 +58,7 @@ def setUpClass(cls): cls.coll = cls.db.get_collection("step-down", write_concern=WriteConcern("majority")) @classmethod - def tearDownClass(cls): + def _tearDown_class(cls): cls.client.close() def setUp(self): diff --git a/test/utils.py b/test/utils.py index 9615034899..9c78cff3ad 100644 --- a/test/utils.py +++ b/test/utils.py @@ -599,6 +599,44 @@ def discover(): ) +async def async_ensure_all_connected(client: AsyncMongoClient) -> None: + """Ensure that the client's connection pool has socket connections to all + members of a replica set. Raises ConfigurationError when called with a + non-replica set client. + + Depending on the use-case, the caller may need to clear any event listeners + that are configured on the client. + """ + hello: dict = await client.admin.command(HelloCompat.LEGACY_CMD) + if "setName" not in hello: + raise ConfigurationError("cluster is not a replica set") + + target_host_list = set(hello["hosts"] + hello.get("passives", [])) + connected_host_list = {hello["me"]} + + # Run hello until we have connected to each host at least once. + async def discover(): + i = 0 + while i < 100 and connected_host_list != target_host_list: + hello: dict = await client.admin.command( + HelloCompat.LEGACY_CMD, read_preference=ReadPreference.SECONDARY + ) + connected_host_list.update([hello["me"]]) + i += 1 + return connected_host_list + + try: + + async def predicate(): + return target_host_list == await discover() + + await async_wait_until(predicate, "connected to all hosts") + except AssertionError as exc: + raise AssertionError( + f"{exc}, {connected_host_list} != {target_host_list}, {client.topology_description}" + ) + + def one(s): """Get one element of a set""" return next(iter(s)) @@ -761,16 +799,6 @@ async def async_wait_until(predicate, success_description, timeout=10): await asyncio.sleep(interval) -def repl_set_step_down(client, **kwargs): - """Run replSetStepDown, first unfreezing a secondary with replSetFreeze.""" - cmd = SON([("replSetStepDown", 1)]) - cmd.update(kwargs) - - # Unfreeze a secondary to ensure a speedy election. - client.admin.command("replSetFreeze", 0, read_preference=ReadPreference.SECONDARY) - client.admin.command(cmd) - - def is_mongos(client): res = client.admin.command(HelloCompat.LEGACY_CMD) return res.get("msg", "") == "isdbgrid" diff --git a/tools/synchro.py b/tools/synchro.py index 3333b0de2e..d8ec9ae46f 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -105,6 +105,8 @@ "AsyncTestGridFile": "TestGridFile", "AsyncTestGridFileNoConnect": "TestGridFileNoConnect", "async_set_fail_point": "set_fail_point", + "async_ensure_all_connected": "ensure_all_connected", + "async_repl_set_step_down": "repl_set_step_down", } docstring_replacements: dict[tuple[str, str], str] = { @@ -186,6 +188,7 @@ def async_only_test(f: str) -> bool: "test_client_bulk_write.py", "test_client_context.py", "test_collection.py", + "test_connections_survive_primary_stepdown_spec.py", "test_cursor.py", "test_database.py", "test_encryption.py",