diff --git a/pymongo/asynchronous/monitor.py b/pymongo/asynchronous/monitor.py index 15289af4dc..d7f87b718a 100644 --- a/pymongo/asynchronous/monitor.py +++ b/pymongo/asynchronous/monitor.py @@ -21,11 +21,11 @@ import logging import time import weakref -from typing import TYPE_CHECKING, Any, Mapping, Optional, cast +from typing import TYPE_CHECKING, Any, Optional from pymongo import common, periodic_executor from pymongo._csot import MovingMinimum -from pymongo.errors import NetworkTimeout, NotPrimaryError, OperationFailure, _OperationCancelled +from pymongo.errors import NetworkTimeout, _OperationCancelled from pymongo.hello import Hello from pymongo.lock import _async_create_lock from pymongo.logger import _SDAM_LOGGER, _debug_log, _SDAMStatusMessage @@ -255,13 +255,7 @@ async def _check_server(self) -> ServerDescription: self._conn_id = None start = time.monotonic() try: - try: - return await self._check_once() - except (OperationFailure, NotPrimaryError) as exc: - # Update max cluster time even when hello fails. - details = cast(Mapping[str, Any], exc.details) - await self._topology.receive_cluster_time(details.get("$clusterTime")) - raise + return await self._check_once() except ReferenceError: raise except Exception as error: @@ -358,7 +352,6 @@ async def _check_with_socket(self, conn: AsyncConnection) -> tuple[Hello, float] Can raise ConnectionFailure or OperationFailure. """ - cluster_time = self._topology.max_cluster_time() start = time.monotonic() if conn.more_to_come: # Read the next streaming hello (MongoDB 4.4+). @@ -368,13 +361,12 @@ async def _check_with_socket(self, conn: AsyncConnection) -> tuple[Hello, float] ): # Initiate streaming hello (MongoDB 4.4+). response = await conn._hello( - cluster_time, self._server_description.topology_version, self._settings.heartbeat_frequency, ) else: # New connection handshake or polling hello (MongoDB <4.4). - response = await conn._hello(cluster_time, None, None) + response = await conn._hello(None, None) duration = _monotonic_duration(start) return response, duration diff --git a/pymongo/asynchronous/network.py b/pymongo/asynchronous/network.py index d17aead120..c7a5580eca 100644 --- a/pymongo/asynchronous/network.py +++ b/pymongo/asynchronous/network.py @@ -207,6 +207,10 @@ async def command( ) response_doc = unpacked_docs[0] + if not conn.ready: + cluster_time = response_doc.get("$clusterTime") + if cluster_time: + conn._cluster_time = cluster_time if client: await client._process_response(response_doc, session) if check: diff --git a/pymongo/asynchronous/pool.py b/pymongo/asynchronous/pool.py index 1da695c5c8..698558aa5d 100644 --- a/pymongo/asynchronous/pool.py +++ b/pymongo/asynchronous/pool.py @@ -102,7 +102,7 @@ from pymongo.pyopenssl_context import _sslConn from pymongo.read_concern import ReadConcern from pymongo.read_preferences import _ServerMode - from pymongo.typings import ClusterTime, _Address, _CollationIn + from pymongo.typings import _Address, _CollationIn from pymongo.write_concern import WriteConcern try: @@ -310,6 +310,8 @@ def __init__( self.connect_rtt = 0.0 self._client_id = pool._client_id self.creation_time = time.monotonic() + # For gossiping $clusterTime from the connection handshake to the client. + self._cluster_time = None def set_conn_timeout(self, timeout: Optional[float]) -> None: """Cache last timeout to avoid duplicate calls to conn.settimeout.""" @@ -374,11 +376,10 @@ def hello_cmd(self) -> dict[str, Any]: return {HelloCompat.LEGACY_CMD: 1, "helloOk": True} async def hello(self) -> Hello: - return await self._hello(None, None, None) + return await self._hello(None, None) async def _hello( self, - cluster_time: Optional[ClusterTime], topology_version: Optional[Any], heartbeat_frequency: Optional[int], ) -> Hello[dict[str, Any]]: @@ -401,9 +402,6 @@ async def _hello( if self.opts.connect_timeout: self.set_conn_timeout(self.opts.connect_timeout + heartbeat_frequency) - if not performing_handshake and cluster_time is not None: - cmd["$clusterTime"] = cluster_time - creds = self.opts._credentials if creds: if creds.mechanism == "DEFAULT" and creds.username: @@ -1316,6 +1314,9 @@ async def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> A conn.close_conn(ConnectionClosedReason.ERROR) raise + if handler: + await handler.client._topology.receive_cluster_time(conn._cluster_time) + return conn @contextlib.asynccontextmanager diff --git a/pymongo/asynchronous/topology.py b/pymongo/asynchronous/topology.py index 19fc76b0d3..bb003bbfde 100644 --- a/pymongo/asynchronous/topology.py +++ b/pymongo/asynchronous/topology.py @@ -501,7 +501,6 @@ async def _process_change( self._description = new_td await self._update_servers() - self._receive_cluster_time_no_lock(server_description.cluster_time) if self._publish_tp and not suppress_event: assert self._events is not None diff --git a/pymongo/synchronous/monitor.py b/pymongo/synchronous/monitor.py index 802ba4742f..c39a57c392 100644 --- a/pymongo/synchronous/monitor.py +++ b/pymongo/synchronous/monitor.py @@ -21,11 +21,11 @@ import logging import time import weakref -from typing import TYPE_CHECKING, Any, Mapping, Optional, cast +from typing import TYPE_CHECKING, Any, Optional from pymongo import common, periodic_executor from pymongo._csot import MovingMinimum -from pymongo.errors import NetworkTimeout, NotPrimaryError, OperationFailure, _OperationCancelled +from pymongo.errors import NetworkTimeout, _OperationCancelled from pymongo.hello import Hello from pymongo.lock import _create_lock from pymongo.logger import _SDAM_LOGGER, _debug_log, _SDAMStatusMessage @@ -253,13 +253,7 @@ def _check_server(self) -> ServerDescription: self._conn_id = None start = time.monotonic() try: - try: - return self._check_once() - except (OperationFailure, NotPrimaryError) as exc: - # Update max cluster time even when hello fails. - details = cast(Mapping[str, Any], exc.details) - self._topology.receive_cluster_time(details.get("$clusterTime")) - raise + return self._check_once() except ReferenceError: raise except Exception as error: @@ -356,7 +350,6 @@ def _check_with_socket(self, conn: Connection) -> tuple[Hello, float]: Can raise ConnectionFailure or OperationFailure. """ - cluster_time = self._topology.max_cluster_time() start = time.monotonic() if conn.more_to_come: # Read the next streaming hello (MongoDB 4.4+). @@ -366,13 +359,12 @@ def _check_with_socket(self, conn: Connection) -> tuple[Hello, float]: ): # Initiate streaming hello (MongoDB 4.4+). response = conn._hello( - cluster_time, self._server_description.topology_version, self._settings.heartbeat_frequency, ) else: # New connection handshake or polling hello (MongoDB <4.4). - response = conn._hello(cluster_time, None, None) + response = conn._hello(None, None) duration = _monotonic_duration(start) return response, duration diff --git a/pymongo/synchronous/network.py b/pymongo/synchronous/network.py index 7206dca735..543b069bfc 100644 --- a/pymongo/synchronous/network.py +++ b/pymongo/synchronous/network.py @@ -207,6 +207,10 @@ def command( ) response_doc = unpacked_docs[0] + if not conn.ready: + cluster_time = response_doc.get("$clusterTime") + if cluster_time: + conn._cluster_time = cluster_time if client: client._process_response(response_doc, session) if check: diff --git a/pymongo/synchronous/pool.py b/pymongo/synchronous/pool.py index 978f0ae391..e575710ff5 100644 --- a/pymongo/synchronous/pool.py +++ b/pymongo/synchronous/pool.py @@ -102,7 +102,7 @@ from pymongo.synchronous.auth import _AuthContext from pymongo.synchronous.client_session import ClientSession from pymongo.synchronous.mongo_client import MongoClient, _MongoClientErrorHandler - from pymongo.typings import ClusterTime, _Address, _CollationIn + from pymongo.typings import _Address, _CollationIn from pymongo.write_concern import WriteConcern try: @@ -310,6 +310,8 @@ def __init__( self.connect_rtt = 0.0 self._client_id = pool._client_id self.creation_time = time.monotonic() + # For gossiping $clusterTime from the connection handshake to the client. + self._cluster_time = None def set_conn_timeout(self, timeout: Optional[float]) -> None: """Cache last timeout to avoid duplicate calls to conn.settimeout.""" @@ -374,11 +376,10 @@ def hello_cmd(self) -> dict[str, Any]: return {HelloCompat.LEGACY_CMD: 1, "helloOk": True} def hello(self) -> Hello: - return self._hello(None, None, None) + return self._hello(None, None) def _hello( self, - cluster_time: Optional[ClusterTime], topology_version: Optional[Any], heartbeat_frequency: Optional[int], ) -> Hello[dict[str, Any]]: @@ -401,9 +402,6 @@ def _hello( if self.opts.connect_timeout: self.set_conn_timeout(self.opts.connect_timeout + heartbeat_frequency) - if not performing_handshake and cluster_time is not None: - cmd["$clusterTime"] = cluster_time - creds = self.opts._credentials if creds: if creds.mechanism == "DEFAULT" and creds.username: @@ -1310,6 +1308,9 @@ def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> Connect conn.close_conn(ConnectionClosedReason.ERROR) raise + if handler: + handler.client._topology.receive_cluster_time(conn._cluster_time) + return conn @contextlib.contextmanager diff --git a/pymongo/synchronous/topology.py b/pymongo/synchronous/topology.py index 6a8503c6c0..2bc8934540 100644 --- a/pymongo/synchronous/topology.py +++ b/pymongo/synchronous/topology.py @@ -501,7 +501,6 @@ def _process_change( self._description = new_td self._update_servers() - self._receive_cluster_time_no_lock(server_description.cluster_time) if self._publish_tp and not suppress_event: assert self._events is not None diff --git a/test/asynchronous/test_session.py b/test/asynchronous/test_session.py index 03d1032b5b..568d392cd5 100644 --- a/test/asynchronous/test_session.py +++ b/test/asynchronous/test_session.py @@ -36,8 +36,10 @@ async_client_context, unittest, ) +from test.asynchronous.helpers import client_knobs from test.utils import ( EventListener, + HeartbeatEventListener, OvertCommandListener, async_wait_until, ) @@ -1135,12 +1137,10 @@ async def asyncSetUp(self): if "$clusterTime" not in (await async_client_context.hello): raise SkipTest("$clusterTime not supported") + # Sessions prose test: 3) $clusterTime in commands async def test_cluster_time(self): listener = SessionTestListener() - # Prevent heartbeats from updating $clusterTime between operations. - client = await self.async_rs_or_single_client( - event_listeners=[listener], heartbeatFrequencyMS=999999 - ) + client = await self.async_rs_or_single_client(event_listeners=[listener]) collection = client.pymongo_test.collection # Prepare for tests of find() and aggregate(). await collection.insert_many([{} for _ in range(10)]) @@ -1219,6 +1219,40 @@ async def aggregate(): f"{f.__name__} sent wrong $clusterTime with {event.command_name}", ) + # Sessions prose test: 20) Drivers do not gossip `$clusterTime` on SDAM commands + async def test_cluster_time_not_used_by_sdam(self): + heartbeat_listener = HeartbeatEventListener() + cmd_listener = OvertCommandListener() + with client_knobs(min_heartbeat_interval=0.01): + c1 = await self.async_single_client( + event_listeners=[heartbeat_listener, cmd_listener], heartbeatFrequencyMS=10 + ) + cluster_time = (await c1.admin.command({"ping": 1}))["$clusterTime"] + self.assertEqual(c1._topology.max_cluster_time(), cluster_time) + + # Advance the server's $clusterTime by performing an insert via another client. + await self.db.test.insert_one({"advance": "$clusterTime"}) + # Wait until the client C1 processes the next pair of SDAM heartbeat started + succeeded events. + heartbeat_listener.reset() + + async def next_heartbeat(): + events = heartbeat_listener.events + for i in range(len(events) - 1): + if isinstance(events[i], monitoring.ServerHeartbeatStartedEvent): + if isinstance(events[i + 1], monitoring.ServerHeartbeatSucceededEvent): + return True + return False + + await async_wait_until( + next_heartbeat, "never found pair of heartbeat started + succeeded events" + ) + # Assert that C1's max $clusterTime is still the same and has not been updated by SDAM. + cmd_listener.reset() + await c1.admin.command({"ping": 1}) + started = cmd_listener.started_events[0] + self.assertEqual(started.command_name, "ping") + self.assertEqual(started.command["$clusterTime"], cluster_time) + if __name__ == "__main__": unittest.main() diff --git a/test/test_discovery_and_monitoring.py b/test/test_discovery_and_monitoring.py index ce7a52f1a0..70dcfc5b48 100644 --- a/test/test_discovery_and_monitoring.py +++ b/test/test_discovery_and_monitoring.py @@ -244,7 +244,7 @@ class TestClusterTimeComparison(unittest.TestCase): def test_cluster_time_comparison(self): t = create_mock_topology("mongodb://host") - def send_cluster_time(time, inc, should_update): + def send_cluster_time(time, inc): old = t.max_cluster_time() new = {"clusterTime": Timestamp(time, inc)} got_hello( @@ -259,16 +259,14 @@ def send_cluster_time(time, inc, should_update): ) actual = t.max_cluster_time() - if should_update: - self.assertEqual(actual, new) - else: - self.assertEqual(actual, old) - - send_cluster_time(0, 1, True) - send_cluster_time(2, 2, True) - send_cluster_time(2, 1, False) - send_cluster_time(1, 3, False) - send_cluster_time(2, 3, True) + # We never update $clusterTime from monitoring connections. + self.assertEqual(actual, old) + + send_cluster_time(0, 1) + send_cluster_time(2, 2) + send_cluster_time(2, 1) + send_cluster_time(1, 3) + send_cluster_time(2, 3) class TestIgnoreStaleErrors(IntegrationTest): diff --git a/test/test_session.py b/test/test_session.py index 175a282495..e80ab41896 100644 --- a/test/test_session.py +++ b/test/test_session.py @@ -36,8 +36,10 @@ client_context, unittest, ) +from test.helpers import client_knobs from test.utils import ( EventListener, + HeartbeatEventListener, OvertCommandListener, wait_until, ) @@ -1121,10 +1123,10 @@ def setUp(self): if "$clusterTime" not in (client_context.hello): raise SkipTest("$clusterTime not supported") + # Sessions prose test: 3) $clusterTime in commands def test_cluster_time(self): listener = SessionTestListener() - # Prevent heartbeats from updating $clusterTime between operations. - client = self.rs_or_single_client(event_listeners=[listener], heartbeatFrequencyMS=999999) + client = self.rs_or_single_client(event_listeners=[listener]) collection = client.pymongo_test.collection # Prepare for tests of find() and aggregate(). collection.insert_many([{} for _ in range(10)]) @@ -1203,6 +1205,38 @@ def aggregate(): f"{f.__name__} sent wrong $clusterTime with {event.command_name}", ) + # Sessions prose test: 20) Drivers do not gossip `$clusterTime` on SDAM commands + def test_cluster_time_not_used_by_sdam(self): + heartbeat_listener = HeartbeatEventListener() + cmd_listener = OvertCommandListener() + with client_knobs(min_heartbeat_interval=0.01): + c1 = self.single_client( + event_listeners=[heartbeat_listener, cmd_listener], heartbeatFrequencyMS=10 + ) + cluster_time = (c1.admin.command({"ping": 1}))["$clusterTime"] + self.assertEqual(c1._topology.max_cluster_time(), cluster_time) + + # Advance the server's $clusterTime by performing an insert via another client. + self.db.test.insert_one({"advance": "$clusterTime"}) + # Wait until the client C1 processes the next pair of SDAM heartbeat started + succeeded events. + heartbeat_listener.reset() + + def next_heartbeat(): + events = heartbeat_listener.events + for i in range(len(events) - 1): + if isinstance(events[i], monitoring.ServerHeartbeatStartedEvent): + if isinstance(events[i + 1], monitoring.ServerHeartbeatSucceededEvent): + return True + return False + + wait_until(next_heartbeat, "never found pair of heartbeat started + succeeded events") + # Assert that C1's max $clusterTime is still the same and has not been updated by SDAM. + cmd_listener.reset() + c1.admin.command({"ping": 1}) + started = cmd_listener.started_events[0] + self.assertEqual(started.command_name, "ping") + self.assertEqual(started.command["$clusterTime"], cluster_time) + if __name__ == "__main__": unittest.main()