diff --git a/test/__init__.py b/test/__init__.py index b49eee99ac..6eda00bdec 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -593,7 +593,7 @@ def supports_secondary_read_pref(self): if self.has_secondaries: return True if self.is_mongos: - shard = self.client.config.shards.find_one()["host"] # type:ignore[index] + shard = (self.client.config.shards.find_one())["host"] # type:ignore[index] num_members = shard.count(",") + 1 return num_members > 1 return False diff --git a/test/asynchronous/__init__.py b/test/asynchronous/__init__.py index a6ba29baaa..b3b0ca93e1 100644 --- a/test/asynchronous/__init__.py +++ b/test/asynchronous/__init__.py @@ -592,10 +592,10 @@ async def check(): @property async def supports_secondary_read_pref(self): - if self.has_secondaries: + if await self.has_secondaries: return True if self.is_mongos: - shard = await self.client.config.shards.find_one()["host"] # type:ignore[index] + shard = (await self.client.config.shards.find_one())["host"] # type:ignore[index] num_members = shard.count(",") + 1 return num_members > 1 return False diff --git a/test/asynchronous/test_read_preferences.py b/test/asynchronous/test_read_preferences.py new file mode 100644 index 0000000000..077bc21eaf --- /dev/null +++ b/test/asynchronous/test_read_preferences.py @@ -0,0 +1,730 @@ +# Copyright 2011-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test the replica_set_connection module.""" +from __future__ import annotations + +import contextlib +import copy +import pickle +import random +import sys +from typing import Any + +from pymongo.operations import _Op + +sys.path[0:0] = [""] + +from test.asynchronous import ( + AsyncIntegrationTest, + SkipTest, + async_client_context, + connected, + unittest, +) +from test.utils import ( + OvertCommandListener, + async_wait_until, + one, +) +from test.version import Version + +from bson.son import SON +from pymongo.asynchronous.helpers import anext +from pymongo.asynchronous.mongo_client import AsyncMongoClient +from pymongo.errors import ConfigurationError, OperationFailure +from pymongo.message import _maybe_add_read_preference +from pymongo.read_preferences import ( + MovingAverage, + Nearest, + Primary, + PrimaryPreferred, + ReadPreference, + Secondary, + SecondaryPreferred, +) +from pymongo.server_description import ServerDescription +from pymongo.server_selectors import Selection, readable_server_selector +from pymongo.server_type import SERVER_TYPE +from pymongo.write_concern import WriteConcern + +_IS_SYNC = False + + +class TestSelections(AsyncIntegrationTest): + @async_client_context.require_connection + async def test_bool(self): + client = await self.async_single_client() + + async def predicate(): + return await client.address + + await async_wait_until(predicate, "discover primary") + selection = Selection.from_topology_description(client._topology.description) + + self.assertTrue(selection) + self.assertFalse(selection.with_server_descriptions([])) + + +class TestReadPreferenceObjects(unittest.TestCase): + prefs = [ + Primary(), + PrimaryPreferred(), + Secondary(), + Nearest(tag_sets=[{"a": 1}, {"b": 2}]), + SecondaryPreferred(max_staleness=30), + ] + + def test_pickle(self): + for pref in self.prefs: + self.assertEqual(pref, pickle.loads(pickle.dumps(pref))) + + def test_copy(self): + for pref in self.prefs: + self.assertEqual(pref, copy.copy(pref)) + + def test_deepcopy(self): + for pref in self.prefs: + self.assertEqual(pref, copy.deepcopy(pref)) + + +class TestReadPreferencesBase(AsyncIntegrationTest): + @async_client_context.require_secondaries_count(1) + async def asyncSetUp(self): + await super().asyncSetUp() + # Insert some data so we can use cursors in read_from_which_host + await self.client.pymongo_test.test.drop() + await self.client.get_database( + "pymongo_test", write_concern=WriteConcern(w=async_client_context.w) + ).test.insert_many([{"_id": i} for i in range(10)]) + + self.addAsyncCleanup(self.client.pymongo_test.test.drop) + + async def read_from_which_host(self, client): + """Do a find() on the client and return which host was used""" + cursor = client.pymongo_test.test.find() + await anext(cursor) + return cursor.address + + async def read_from_which_kind(self, client): + """Do a find() on the client and return 'primary' or 'secondary' + depending on which the client used. + """ + address = await self.read_from_which_host(client) + if address == await client.primary: + return "primary" + elif address in await client.secondaries: + return "secondary" + else: + self.fail( + f"Cursor used address {address}, expected either primary " + f"{client.primary} or secondaries {client.secondaries}" + ) + + async def assertReadsFrom(self, expected, **kwargs): + c = await self.async_rs_client(**kwargs) + + async def predicate(): + return len(c.nodes - await c.arbiters) == async_client_context.w + + await async_wait_until(predicate, "discovered all nodes") + + used = await self.read_from_which_kind(c) + self.assertEqual(expected, used, f"Cursor used {used}, expected {expected}") + + +class TestSingleSecondaryOk(TestReadPreferencesBase): + async def test_reads_from_secondary(self): + host, port = next(iter(await self.client.secondaries)) + # Direct connection to a secondary. + client = await self.async_single_client(host, port) + self.assertFalse(await client.is_primary) + + # Regardless of read preference, we should be able to do + # "reads" with a direct connection to a secondary. + # See server-selection.rst#topology-type-single. + self.assertEqual(client.read_preference, ReadPreference.PRIMARY) + + db = client.pymongo_test + coll = db.test + + # Test find and find_one. + self.assertIsNotNone(await coll.find_one()) + self.assertEqual(10, len(await coll.find().to_list())) + + # Test some database helpers. + self.assertIsNotNone(await db.list_collection_names()) + self.assertIsNotNone(await db.validate_collection("test")) + self.assertIsNotNone(await db.command("ping")) + + # Test some collection helpers. + self.assertEqual(10, await coll.count_documents({})) + self.assertEqual(10, len(await coll.distinct("_id"))) + self.assertIsNotNone(await coll.aggregate([])) + self.assertIsNotNone(await coll.index_information()) + + +class TestReadPreferences(TestReadPreferencesBase): + async def test_mode_validation(self): + for mode in ( + ReadPreference.PRIMARY, + ReadPreference.PRIMARY_PREFERRED, + ReadPreference.SECONDARY, + ReadPreference.SECONDARY_PREFERRED, + ReadPreference.NEAREST, + ): + self.assertEqual( + mode, (await self.async_rs_client(read_preference=mode)).read_preference + ) + + with self.assertRaises(TypeError): + await self.async_rs_client(read_preference="foo") + + async def test_tag_sets_validation(self): + S = Secondary(tag_sets=[{}]) + self.assertEqual( + [{}], (await self.async_rs_client(read_preference=S)).read_preference.tag_sets + ) + + S = Secondary(tag_sets=[{"k": "v"}]) + self.assertEqual( + [{"k": "v"}], (await self.async_rs_client(read_preference=S)).read_preference.tag_sets + ) + + S = Secondary(tag_sets=[{"k": "v"}, {}]) + self.assertEqual( + [{"k": "v"}, {}], + (await self.async_rs_client(read_preference=S)).read_preference.tag_sets, + ) + + self.assertRaises(ValueError, Secondary, tag_sets=[]) + + # One dict not ok, must be a list of dicts + self.assertRaises(TypeError, Secondary, tag_sets={"k": "v"}) + + self.assertRaises(TypeError, Secondary, tag_sets="foo") + + self.assertRaises(TypeError, Secondary, tag_sets=["foo"]) + + async def test_threshold_validation(self): + self.assertEqual( + 17, + ( + await self.async_rs_client(localThresholdMS=17, connect=False) + ).options.local_threshold_ms, + ) + + self.assertEqual( + 42, + ( + await self.async_rs_client(localThresholdMS=42, connect=False) + ).options.local_threshold_ms, + ) + + self.assertEqual( + 666, + ( + await self.async_rs_client(localThresholdMS=666, connect=False) + ).options.local_threshold_ms, + ) + + self.assertEqual( + 0, + ( + await self.async_rs_client(localThresholdMS=0, connect=False) + ).options.local_threshold_ms, + ) + + with self.assertRaises(ValueError): + await self.async_rs_client(localthresholdms=-1) + + async def test_zero_latency(self): + ping_times: set = set() + # Generate unique ping times. + while len(ping_times) < len(self.client.nodes): + ping_times.add(random.random()) + for ping_time, host in zip(ping_times, self.client.nodes): + ServerDescription._host_to_round_trip_time[host] = ping_time + try: + client = await connected( + await self.async_rs_client(readPreference="nearest", localThresholdMS=0) + ) + await async_wait_until( + lambda: client.nodes == self.client.nodes, "discovered all nodes" + ) + host = await self.read_from_which_host(client) + for _ in range(5): + self.assertEqual(host, await self.read_from_which_host(client)) + finally: + ServerDescription._host_to_round_trip_time.clear() + + async def test_primary(self): + await self.assertReadsFrom("primary", read_preference=ReadPreference.PRIMARY) + + async def test_primary_with_tags(self): + # Tags not allowed with PRIMARY + with self.assertRaises(ConfigurationError): + await self.async_rs_client(tag_sets=[{"dc": "ny"}]) + + async def test_primary_preferred(self): + await self.assertReadsFrom("primary", read_preference=ReadPreference.PRIMARY_PREFERRED) + + async def test_secondary(self): + await self.assertReadsFrom("secondary", read_preference=ReadPreference.SECONDARY) + + async def test_secondary_preferred(self): + await self.assertReadsFrom("secondary", read_preference=ReadPreference.SECONDARY_PREFERRED) + + async def test_nearest(self): + # With high localThresholdMS, expect to read from any + # member + c = await self.async_rs_client( + read_preference=ReadPreference.NEAREST, localThresholdMS=10000 + ) # 10 seconds + + data_members = {await self.client.primary} | await self.client.secondaries + + # This is a probabilistic test; track which members we've read from so + # far, and keep reading until we've used all the members or give up. + # Chance of using only 2 of 3 members 10k times if there's no bug = + # 3 * (2/3)**10000, very low. + used: set = set() + i = 0 + while data_members.difference(used) and i < 10000: + address = await self.read_from_which_host(c) + used.add(address) + i += 1 + + not_used = data_members.difference(used) + latencies = ", ".join( + "%s: %sms" % (server.description.address, server.description.round_trip_time) + for server in await (await c._get_topology()).select_servers( + readable_server_selector, _Op.TEST + ) + ) + + self.assertFalse( + not_used, + "Expected to use primary and all secondaries for mode NEAREST," + f" but didn't use {not_used}\nlatencies: {latencies}", + ) + + +class ReadPrefTester(AsyncMongoClient): + def __init__(self, *args, **kwargs): + self.has_read_from = set() + client_options = async_client_context.client_options + client_options.update(kwargs) + super().__init__(*args, **client_options) + + async def _conn_for_reads(self, read_preference, session, operation): + context = await super()._conn_for_reads(read_preference, session, operation) + return context + + @contextlib.asynccontextmanager + async def _conn_from_server(self, read_preference, server, session): + context = super()._conn_from_server(read_preference, server, session) + async with context as (conn, read_preference): + await self.record_a_read(conn.address) + yield conn, read_preference + + async def record_a_read(self, address): + server = await (await self._get_topology()).select_server_by_address(address, _Op.TEST, 0) + self.has_read_from.add(server) + + +_PREF_MAP = [ + (Primary, SERVER_TYPE.RSPrimary), + (PrimaryPreferred, SERVER_TYPE.RSPrimary), + (Secondary, SERVER_TYPE.RSSecondary), + (SecondaryPreferred, SERVER_TYPE.RSSecondary), + (Nearest, "any"), +] + + +class TestCommandAndReadPreference(AsyncIntegrationTest): + c: ReadPrefTester + client_version: Version + + @async_client_context.require_secondaries_count(1) + async def asyncSetUp(self): + await super().asyncSetUp() + self.c = ReadPrefTester( + # Ignore round trip times, to test ReadPreference modes only. + localThresholdMS=1000 * 1000, + ) + self.client_version = await Version.async_from_client(self.c) + # mapReduce fails if the collection does not exist. + coll = self.c.pymongo_test.get_collection( + "test", write_concern=WriteConcern(w=async_client_context.w) + ) + await coll.insert_one({}) + + async def asyncTearDown(self): + await self.c.drop_database("pymongo_test") + await self.c.close() + + async def executed_on_which_server(self, client, fn, *args, **kwargs): + """Execute fn(*args, **kwargs) and return the Server instance used.""" + client.has_read_from.clear() + await fn(*args, **kwargs) + self.assertEqual(1, len(client.has_read_from)) + return one(client.has_read_from) + + async def assertExecutedOn(self, server_type, client, fn, *args, **kwargs): + server = await self.executed_on_which_server(client, fn, *args, **kwargs) + self.assertEqual( + SERVER_TYPE._fields[server_type], SERVER_TYPE._fields[server.description.server_type] + ) + + async def _test_fn(self, server_type, fn): + for _ in range(10): + if server_type == "any": + used = set() + for _ in range(1000): + server = await self.executed_on_which_server(self.c, fn) + used.add(server.description.address) + if len(used) == len(await self.c.secondaries) + 1: + # Success + break + + assert await self.c.primary is not None + unused = (await self.c.secondaries).union({await self.c.primary}).difference(used) + if unused: + self.fail("Some members not used for NEAREST: %s" % (unused)) + else: + await self.assertExecutedOn(server_type, self.c, fn) + + async def _test_primary_helper(self, func): + # Helpers that ignore read preference. + await self._test_fn(SERVER_TYPE.RSPrimary, func) + + async def _test_coll_helper(self, secondary_ok, coll, meth, *args, **kwargs): + for mode, server_type in _PREF_MAP: + new_coll = coll.with_options(read_preference=mode()) + + async def func(): + return await getattr(new_coll, meth)(*args, **kwargs) + + if secondary_ok: + await self._test_fn(server_type, func) + else: + await self._test_fn(SERVER_TYPE.RSPrimary, func) + + async def test_command(self): + # Test that the generic command helper obeys the read preference + # passed to it. + for mode, server_type in _PREF_MAP: + + async def func(): + return await self.c.pymongo_test.command("dbStats", read_preference=mode()) + + await self._test_fn(server_type, func) + + async def test_create_collection(self): + # create_collection runs listCollections on the primary to check if + # the collection already exists. + async def func(): + return await self.c.pymongo_test.create_collection( + "some_collection%s" % random.randint(0, sys.maxsize) + ) + + await self._test_primary_helper(func) + + async def test_count_documents(self): + await self._test_coll_helper(True, self.c.pymongo_test.test, "count_documents", {}) + + async def test_estimated_document_count(self): + await self._test_coll_helper(True, self.c.pymongo_test.test, "estimated_document_count") + + async def test_distinct(self): + await self._test_coll_helper(True, self.c.pymongo_test.test, "distinct", "a") + + async def test_aggregate(self): + await self._test_coll_helper( + True, self.c.pymongo_test.test, "aggregate", [{"$project": {"_id": 1}}] + ) + + async def test_aggregate_write(self): + # 5.0 servers support $out on secondaries. + secondary_ok = async_client_context.version.at_least(5, 0) + await self._test_coll_helper( + secondary_ok, + self.c.pymongo_test.test, + "aggregate", + [{"$project": {"_id": 1}}, {"$out": "agg_write_test"}], + ) + + +class TestMovingAverage(unittest.TestCase): + def test_moving_average(self): + avg = MovingAverage() + self.assertIsNone(avg.get()) + avg.add_sample(10) + self.assertAlmostEqual(10, avg.get()) # type: ignore + avg.add_sample(20) + self.assertAlmostEqual(12, avg.get()) # type: ignore + avg.add_sample(30) + self.assertAlmostEqual(15.6, avg.get()) # type: ignore + + +class TestMongosAndReadPreference(AsyncIntegrationTest): + def test_read_preference_document(self): + pref = Primary() + self.assertEqual(pref.document, {"mode": "primary"}) + + pref = PrimaryPreferred() + self.assertEqual(pref.document, {"mode": "primaryPreferred"}) + pref = PrimaryPreferred(tag_sets=[{"dc": "sf"}]) + self.assertEqual(pref.document, {"mode": "primaryPreferred", "tags": [{"dc": "sf"}]}) + pref = PrimaryPreferred(tag_sets=[{"dc": "sf"}], max_staleness=30) + self.assertEqual( + pref.document, + {"mode": "primaryPreferred", "tags": [{"dc": "sf"}], "maxStalenessSeconds": 30}, + ) + + pref = Secondary() + self.assertEqual(pref.document, {"mode": "secondary"}) + pref = Secondary(tag_sets=[{"dc": "sf"}]) + self.assertEqual(pref.document, {"mode": "secondary", "tags": [{"dc": "sf"}]}) + pref = Secondary(tag_sets=[{"dc": "sf"}], max_staleness=30) + self.assertEqual( + pref.document, {"mode": "secondary", "tags": [{"dc": "sf"}], "maxStalenessSeconds": 30} + ) + + pref = SecondaryPreferred() + self.assertEqual(pref.document, {"mode": "secondaryPreferred"}) + pref = SecondaryPreferred(tag_sets=[{"dc": "sf"}]) + self.assertEqual(pref.document, {"mode": "secondaryPreferred", "tags": [{"dc": "sf"}]}) + pref = SecondaryPreferred(tag_sets=[{"dc": "sf"}], max_staleness=30) + self.assertEqual( + pref.document, + {"mode": "secondaryPreferred", "tags": [{"dc": "sf"}], "maxStalenessSeconds": 30}, + ) + + pref = Nearest() + self.assertEqual(pref.document, {"mode": "nearest"}) + pref = Nearest(tag_sets=[{"dc": "sf"}]) + self.assertEqual(pref.document, {"mode": "nearest", "tags": [{"dc": "sf"}]}) + pref = Nearest(tag_sets=[{"dc": "sf"}], max_staleness=30) + self.assertEqual( + pref.document, {"mode": "nearest", "tags": [{"dc": "sf"}], "maxStalenessSeconds": 30} + ) + + with self.assertRaises(TypeError): + # Float is prohibited. + Nearest(max_staleness=1.5) # type: ignore + + with self.assertRaises(ValueError): + Nearest(max_staleness=0) + + with self.assertRaises(ValueError): + Nearest(max_staleness=-2) + + def test_read_preference_document_hedge(self): + cases = { + "primaryPreferred": PrimaryPreferred, + "secondary": Secondary, + "secondaryPreferred": SecondaryPreferred, + "nearest": Nearest, + } + for mode, cls in cases.items(): + with self.assertRaises(TypeError): + cls(hedge=[]) # type: ignore + + pref = cls(hedge={}) + self.assertEqual(pref.document, {"mode": mode}) + out = _maybe_add_read_preference({}, pref) + if cls == SecondaryPreferred: + # SecondaryPreferred without hedge doesn't add $readPreference. + self.assertEqual(out, {}) + else: + self.assertEqual(out, SON([("$query", {}), ("$readPreference", pref.document)])) + + hedge: dict[str, Any] = {"enabled": True} + pref = cls(hedge=hedge) + self.assertEqual(pref.document, {"mode": mode, "hedge": hedge}) + out = _maybe_add_read_preference({}, pref) + self.assertEqual(out, SON([("$query", {}), ("$readPreference", pref.document)])) + + hedge = {"enabled": False} + pref = cls(hedge=hedge) + self.assertEqual(pref.document, {"mode": mode, "hedge": hedge}) + out = _maybe_add_read_preference({}, pref) + self.assertEqual(out, SON([("$query", {}), ("$readPreference", pref.document)])) + + hedge = {"enabled": False, "extra": "option"} + pref = cls(hedge=hedge) + self.assertEqual(pref.document, {"mode": mode, "hedge": hedge}) + out = _maybe_add_read_preference({}, pref) + self.assertEqual(out, SON([("$query", {}), ("$readPreference", pref.document)])) + + async def test_send_hedge(self): + cases = { + "primaryPreferred": PrimaryPreferred, + "secondaryPreferred": SecondaryPreferred, + "nearest": Nearest, + } + if await async_client_context.supports_secondary_read_pref: + cases["secondary"] = Secondary + listener = OvertCommandListener() + client = await self.async_rs_client(event_listeners=[listener]) + await client.admin.command("ping") + for _mode, cls in cases.items(): + pref = cls(hedge={"enabled": True}) + coll = client.test.get_collection("test", read_preference=pref) + listener.reset() + await coll.find_one() + started = listener.started_events + self.assertEqual(len(started), 1, started) + cmd = started[0].command + if async_client_context.is_rs or async_client_context.is_mongos: + self.assertIn("$readPreference", cmd) + self.assertEqual(cmd["$readPreference"], pref.document) + else: + self.assertNotIn("$readPreference", cmd) + + def test_maybe_add_read_preference(self): + # Primary doesn't add $readPreference + out = _maybe_add_read_preference({}, Primary()) + self.assertEqual(out, {}) + + pref = PrimaryPreferred() + out = _maybe_add_read_preference({}, pref) + self.assertEqual(out, SON([("$query", {}), ("$readPreference", pref.document)])) + pref = PrimaryPreferred(tag_sets=[{"dc": "nyc"}]) + out = _maybe_add_read_preference({}, pref) + self.assertEqual(out, SON([("$query", {}), ("$readPreference", pref.document)])) + + pref = Secondary() + out = _maybe_add_read_preference({}, pref) + self.assertEqual(out, SON([("$query", {}), ("$readPreference", pref.document)])) + pref = Secondary(tag_sets=[{"dc": "nyc"}]) + out = _maybe_add_read_preference({}, pref) + self.assertEqual(out, SON([("$query", {}), ("$readPreference", pref.document)])) + + # SecondaryPreferred without tag_sets or max_staleness doesn't add + # $readPreference + pref = SecondaryPreferred() + out = _maybe_add_read_preference({}, pref) + self.assertEqual(out, {}) + pref = SecondaryPreferred(tag_sets=[{"dc": "nyc"}]) + out = _maybe_add_read_preference({}, pref) + self.assertEqual(out, SON([("$query", {}), ("$readPreference", pref.document)])) + pref = SecondaryPreferred(max_staleness=120) + out = _maybe_add_read_preference({}, pref) + self.assertEqual(out, SON([("$query", {}), ("$readPreference", pref.document)])) + + pref = Nearest() + out = _maybe_add_read_preference({}, pref) + self.assertEqual(out, SON([("$query", {}), ("$readPreference", pref.document)])) + pref = Nearest(tag_sets=[{"dc": "nyc"}]) + out = _maybe_add_read_preference({}, pref) + self.assertEqual(out, SON([("$query", {}), ("$readPreference", pref.document)])) + + criteria = SON([("$query", {}), ("$orderby", SON([("_id", 1)]))]) + pref = Nearest() + out = _maybe_add_read_preference(criteria, pref) + self.assertEqual( + out, + SON( + [ + ("$query", {}), + ("$orderby", SON([("_id", 1)])), + ("$readPreference", pref.document), + ] + ), + ) + pref = Nearest(tag_sets=[{"dc": "nyc"}]) + out = _maybe_add_read_preference(criteria, pref) + self.assertEqual( + out, + SON( + [ + ("$query", {}), + ("$orderby", SON([("_id", 1)])), + ("$readPreference", pref.document), + ] + ), + ) + + @async_client_context.require_mongos + async def test_mongos(self): + res = await async_client_context.client.config.shards.find_one() + assert res is not None + shard = res["host"] + num_members = shard.count(",") + 1 + if num_members == 1: + raise SkipTest("Need a replica set shard to test.") + coll = async_client_context.client.pymongo_test.get_collection( + "test", write_concern=WriteConcern(w=num_members) + ) + await coll.drop() + res = await coll.insert_many([{} for _ in range(5)]) + first_id = res.inserted_ids[0] + last_id = res.inserted_ids[-1] + + # Note - this isn't a perfect test since there's no way to + # tell what shard member a query ran on. + for pref in (Primary(), PrimaryPreferred(), Secondary(), SecondaryPreferred(), Nearest()): + qcoll = coll.with_options(read_preference=pref) + results = await qcoll.find().sort([("_id", 1)]).to_list() + self.assertEqual(first_id, results[0]["_id"]) + self.assertEqual(last_id, results[-1]["_id"]) + results = await qcoll.find().sort([("_id", -1)]).to_list() + self.assertEqual(first_id, results[-1]["_id"]) + self.assertEqual(last_id, results[0]["_id"]) + + @async_client_context.require_mongos + async def test_mongos_max_staleness(self): + # Sanity check that we're sending maxStalenessSeconds + coll = async_client_context.client.pymongo_test.get_collection( + "test", read_preference=SecondaryPreferred(max_staleness=120) + ) + # No error + await coll.find_one() + + coll = async_client_context.client.pymongo_test.get_collection( + "test", read_preference=SecondaryPreferred(max_staleness=10) + ) + try: + await coll.find_one() + except OperationFailure as exc: + self.assertEqual(160, exc.code) + else: + self.fail("mongos accepted invalid staleness") + + coll = ( + await self.async_single_client( + readPreference="secondaryPreferred", maxStalenessSeconds=120 + ) + ).pymongo_test.test + # No error + await coll.find_one() + + coll = ( + await self.async_single_client( + readPreference="secondaryPreferred", maxStalenessSeconds=10 + ) + ).pymongo_test.test + try: + await coll.find_one() + except OperationFailure as exc: + self.assertEqual(160, exc.code) + else: + self.fail("mongos accepted invalid staleness") + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_read_preferences.py b/test/test_read_preferences.py index 32883399e1..0d38f3f00d 100644 --- a/test/test_read_preferences.py +++ b/test/test_read_preferences.py @@ -26,7 +26,13 @@ sys.path[0:0] = [""] -from test import IntegrationTest, SkipTest, client_context, connected, unittest +from test import ( + IntegrationTest, + SkipTest, + client_context, + connected, + unittest, +) from test.utils import ( OvertCommandListener, one, @@ -49,16 +55,22 @@ from pymongo.server_description import ServerDescription from pymongo.server_selectors import Selection, readable_server_selector from pymongo.server_type import SERVER_TYPE +from pymongo.synchronous.helpers import next from pymongo.synchronous.mongo_client import MongoClient from pymongo.write_concern import WriteConcern +_IS_SYNC = True + class TestSelections(IntegrationTest): @client_context.require_connection def test_bool(self): client = self.single_client() - wait_until(lambda: client.address, "discover primary") + def predicate(): + return client.address + + wait_until(predicate, "discover primary") selection = Selection.from_topology_description(client._topology.description) self.assertTrue(selection) @@ -88,11 +100,7 @@ def test_deepcopy(self): class TestReadPreferencesBase(IntegrationTest): - @classmethod @client_context.require_secondaries_count(1) - def setUpClass(cls): - super().setUpClass() - def setUp(self): super().setUp() # Insert some data so we can use cursors in read_from_which_host @@ -123,11 +131,14 @@ def read_from_which_kind(self, client): f"Cursor used address {address}, expected either primary " f"{client.primary} or secondaries {client.secondaries}" ) - return None def assertReadsFrom(self, expected, **kwargs): c = self.rs_client(**kwargs) - wait_until(lambda: len(c.nodes - c.arbiters) == client_context.w, "discovered all nodes") + + def predicate(): + return len(c.nodes - c.arbiters) == client_context.w + + wait_until(predicate, "discovered all nodes") used = self.read_from_which_kind(c) self.assertEqual(expected, used, f"Cursor used {used}, expected {expected}") @@ -150,7 +161,7 @@ def test_reads_from_secondary(self): # Test find and find_one. self.assertIsNotNone(coll.find_one()) - self.assertEqual(10, len(list(coll.find()))) + self.assertEqual(10, len(coll.find().to_list())) # Test some database helpers. self.assertIsNotNone(db.list_collection_names()) @@ -173,20 +184,22 @@ def test_mode_validation(self): ReadPreference.SECONDARY_PREFERRED, ReadPreference.NEAREST, ): - self.assertEqual(mode, self.rs_client(read_preference=mode).read_preference) + self.assertEqual(mode, (self.rs_client(read_preference=mode)).read_preference) - self.assertRaises(TypeError, self.rs_client, read_preference="foo") + with self.assertRaises(TypeError): + self.rs_client(read_preference="foo") def test_tag_sets_validation(self): S = Secondary(tag_sets=[{}]) - self.assertEqual([{}], self.rs_client(read_preference=S).read_preference.tag_sets) + self.assertEqual([{}], (self.rs_client(read_preference=S)).read_preference.tag_sets) S = Secondary(tag_sets=[{"k": "v"}]) - self.assertEqual([{"k": "v"}], self.rs_client(read_preference=S).read_preference.tag_sets) + self.assertEqual([{"k": "v"}], (self.rs_client(read_preference=S)).read_preference.tag_sets) S = Secondary(tag_sets=[{"k": "v"}, {}]) self.assertEqual( - [{"k": "v"}, {}], self.rs_client(read_preference=S).read_preference.tag_sets + [{"k": "v"}, {}], + (self.rs_client(read_preference=S)).read_preference.tag_sets, ) self.assertRaises(ValueError, Secondary, tag_sets=[]) @@ -200,22 +213,27 @@ def test_tag_sets_validation(self): def test_threshold_validation(self): self.assertEqual( - 17, self.rs_client(localThresholdMS=17, connect=False).options.local_threshold_ms + 17, + (self.rs_client(localThresholdMS=17, connect=False)).options.local_threshold_ms, ) self.assertEqual( - 42, self.rs_client(localThresholdMS=42, connect=False).options.local_threshold_ms + 42, + (self.rs_client(localThresholdMS=42, connect=False)).options.local_threshold_ms, ) self.assertEqual( - 666, self.rs_client(localThresholdMS=666, connect=False).options.local_threshold_ms + 666, + (self.rs_client(localThresholdMS=666, connect=False)).options.local_threshold_ms, ) self.assertEqual( - 0, self.rs_client(localThresholdMS=0, connect=False).options.local_threshold_ms + 0, + (self.rs_client(localThresholdMS=0, connect=False)).options.local_threshold_ms, ) - self.assertRaises(ValueError, self.rs_client, localthresholdms=-1) + with self.assertRaises(ValueError): + self.rs_client(localthresholdms=-1) def test_zero_latency(self): ping_times: set = set() @@ -238,7 +256,8 @@ def test_primary(self): def test_primary_with_tags(self): # Tags not allowed with PRIMARY - self.assertRaises(ConfigurationError, self.rs_client, tag_sets=[{"dc": "ny"}]) + with self.assertRaises(ConfigurationError): + self.rs_client(tag_sets=[{"dc": "ny"}]) def test_primary_preferred(self): self.assertReadsFrom("primary", read_preference=ReadPreference.PRIMARY_PREFERRED) @@ -272,7 +291,7 @@ def test_nearest(self): not_used = data_members.difference(used) latencies = ", ".join( "%s: %sms" % (server.description.address, server.description.round_trip_time) - for server in c._get_topology().select_servers(readable_server_selector, _Op.TEST) + for server in (c._get_topology()).select_servers(readable_server_selector, _Op.TEST) ) self.assertFalse( @@ -289,12 +308,9 @@ def __init__(self, *args, **kwargs): client_options.update(kwargs) super().__init__(*args, **client_options) - @contextlib.contextmanager def _conn_for_reads(self, read_preference, session, operation): context = super()._conn_for_reads(read_preference, session, operation) - with context as (conn, read_preference): - self.record_a_read(conn.address) - yield conn, read_preference + return context @contextlib.contextmanager def _conn_from_server(self, read_preference, server, session): @@ -304,7 +320,7 @@ def _conn_from_server(self, read_preference, server, session): yield conn, read_preference def record_a_read(self, address): - server = self._get_topology().select_server_by_address(address, _Op.TEST, 0) + server = (self._get_topology()).select_server_by_address(address, _Op.TEST, 0) self.has_read_from.add(server) @@ -321,25 +337,23 @@ class TestCommandAndReadPreference(IntegrationTest): c: ReadPrefTester client_version: Version - @classmethod @client_context.require_secondaries_count(1) - def setUpClass(cls): - super().setUpClass() - cls.c = ReadPrefTester( + def setUp(self): + super().setUp() + self.c = ReadPrefTester( # Ignore round trip times, to test ReadPreference modes only. localThresholdMS=1000 * 1000, ) - cls.client_version = Version.from_client(cls.c) + self.client_version = Version.from_client(self.c) # mapReduce fails if the collection does not exist. - coll = cls.c.pymongo_test.get_collection( + coll = self.c.pymongo_test.get_collection( "test", write_concern=WriteConcern(w=client_context.w) ) coll.insert_one({}) - @classmethod - def tearDownClass(cls): - cls.c.drop_database("pymongo_test") - cls.c.close() + def tearDown(self): + self.c.drop_database("pymongo_test") + self.c.close() def executed_on_which_server(self, client, fn, *args, **kwargs): """Execute fn(*args, **kwargs) and return the Server instance used.""" @@ -366,7 +380,7 @@ def _test_fn(self, server_type, fn): break assert self.c.primary is not None - unused = self.c.secondaries.union({self.c.primary}).difference(used) + unused = (self.c.secondaries).union({self.c.primary}).difference(used) if unused: self.fail("Some members not used for NEAREST: %s" % (unused)) else: @@ -401,11 +415,12 @@ def func(): def test_create_collection(self): # create_collection runs listCollections on the primary to check if # the collection already exists. - self._test_primary_helper( - lambda: self.c.pymongo_test.create_collection( + def func(): + return self.c.pymongo_test.create_collection( "some_collection%s" % random.randint(0, sys.maxsize) ) - ) + + self._test_primary_helper(func) def test_count_documents(self): self._test_coll_helper(True, self.c.pymongo_test.test, "count_documents", {}) @@ -545,7 +560,6 @@ def test_send_hedge(self): cases["secondary"] = Secondary listener = OvertCommandListener() client = self.rs_client(event_listeners=[listener]) - self.addCleanup(client.close) client.admin.command("ping") for _mode, cls in cases.items(): pref = cls(hedge={"enabled": True}) @@ -645,10 +659,10 @@ def test_mongos(self): # tell what shard member a query ran on. for pref in (Primary(), PrimaryPreferred(), Secondary(), SecondaryPreferred(), Nearest()): qcoll = coll.with_options(read_preference=pref) - results = list(qcoll.find().sort([("_id", 1)])) + results = qcoll.find().sort([("_id", 1)]).to_list() self.assertEqual(first_id, results[0]["_id"]) self.assertEqual(last_id, results[-1]["_id"]) - results = list(qcoll.find().sort([("_id", -1)])) + results = qcoll.find().sort([("_id", -1)]).to_list() self.assertEqual(first_id, results[-1]["_id"]) self.assertEqual(last_id, results[0]["_id"]) @@ -671,14 +685,14 @@ def test_mongos_max_staleness(self): else: self.fail("mongos accepted invalid staleness") - coll = self.single_client( - readPreference="secondaryPreferred", maxStalenessSeconds=120 + coll = ( + self.single_client(readPreference="secondaryPreferred", maxStalenessSeconds=120) ).pymongo_test.test # No error coll.find_one() - coll = self.single_client( - readPreference="secondaryPreferred", maxStalenessSeconds=10 + coll = ( + self.single_client(readPreference="secondaryPreferred", maxStalenessSeconds=10) ).pymongo_test.test try: coll.find_one() diff --git a/tools/synchro.py b/tools/synchro.py index 6317cb84fb..7c26fab523 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -220,6 +220,7 @@ def async_only_test(f: str) -> bool: "test_on_demand_csfle.py", "test_raw_bson.py", "test_read_concern.py", + "test_read_preferences.py", "test_read_write_concern_spec.py", "test_retryable_reads.py", "test_retryable_reads_unified.py",