diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index cf7de19c2f..365fc62100 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -1565,6 +1565,12 @@ async def close(self) -> None: # TODO: PYTHON-1921 Encrypted MongoClients cannot be re-opened. await self._encrypter.close() self._closed = True + if not _IS_SYNC: + await asyncio.gather( + self._topology.cleanup_monitors(), # type: ignore[func-returns-value] + self._kill_cursors_executor.join(), # type: ignore[func-returns-value] + return_exceptions=True, + ) if not _IS_SYNC: # Add support for contextlib.aclosing. diff --git a/pymongo/asynchronous/monitor.py b/pymongo/asynchronous/monitor.py index ad1bc70aba..abde7a9055 100644 --- a/pymongo/asynchronous/monitor.py +++ b/pymongo/asynchronous/monitor.py @@ -112,9 +112,9 @@ async def close(self) -> None: """ self.gc_safe_close() - async def join(self, timeout: Optional[int] = None) -> None: + async def join(self) -> None: """Wait for the monitor to stop.""" - await self._executor.join(timeout) + await self._executor.join() def request_check(self) -> None: """If the monitor is sleeping, wake it soon.""" @@ -189,6 +189,11 @@ def gc_safe_close(self) -> None: self._rtt_monitor.gc_safe_close() self.cancel_check() + async def join(self) -> None: + await asyncio.gather( + self._executor.join(), self._rtt_monitor.join(), return_exceptions=True + ) # type: ignore[func-returns-value] + async def close(self) -> None: self.gc_safe_close() await self._rtt_monitor.close() diff --git a/pymongo/asynchronous/topology.py b/pymongo/asynchronous/topology.py index 6d67710a7e..3033377de5 100644 --- a/pymongo/asynchronous/topology.py +++ b/pymongo/asynchronous/topology.py @@ -16,6 +16,7 @@ from __future__ import annotations +import asyncio import logging import os import queue @@ -29,7 +30,7 @@ from pymongo import _csot, common, helpers_shared, periodic_executor from pymongo.asynchronous.client_session import _ServerSession, _ServerSessionPool -from pymongo.asynchronous.monitor import SrvMonitor +from pymongo.asynchronous.monitor import MonitorBase, SrvMonitor from pymongo.asynchronous.pool import Pool from pymongo.asynchronous.server import Server from pymongo.errors import ( @@ -207,6 +208,9 @@ async def target() -> bool: if self._settings.fqdn is not None and not self._settings.load_balanced: self._srv_monitor = SrvMonitor(self, self._settings) + # Stores all monitor tasks that need to be joined on close or server selection + self._monitor_tasks: list[MonitorBase] = [] + async def open(self) -> None: """Start monitoring, or restart after a fork. @@ -241,6 +245,8 @@ async def open(self) -> None: # Close servers and clear the pools. for server in self._servers.values(): await server.close() + if not _IS_SYNC: + self._monitor_tasks.append(server._monitor) # Reset the session pool to avoid duplicate sessions in # the child process. self._session_pool.reset() @@ -283,6 +289,10 @@ async def select_servers( else: server_timeout = server_selection_timeout + # Cleanup any completed monitor tasks safely + if not _IS_SYNC and self._monitor_tasks: + await self.cleanup_monitors() + async with self._lock: server_descriptions = await self._select_servers_loop( selector, server_timeout, operation, operation_id, address @@ -520,6 +530,8 @@ async def _process_change( and self._description.topology_type not in SRV_POLLING_TOPOLOGIES ): await self._srv_monitor.close() + if not _IS_SYNC: + self._monitor_tasks.append(self._srv_monitor) # Clear the pool from a failed heartbeat. if reset_pool: @@ -695,6 +707,8 @@ async def close(self) -> None: old_td = self._description for server in self._servers.values(): await server.close() + if not _IS_SYNC: + self._monitor_tasks.append(server._monitor) # Mark all servers Unknown. self._description = self._description.reset() @@ -705,6 +719,8 @@ async def close(self) -> None: # Stop SRV polling thread. if self._srv_monitor: await self._srv_monitor.close() + if not _IS_SYNC: + self._monitor_tasks.append(self._srv_monitor) self._opened = False self._closed = True @@ -944,6 +960,8 @@ async def _update_servers(self) -> None: for address, server in list(self._servers.items()): if not self._description.has_server(address): await server.close() + if not _IS_SYNC: + self._monitor_tasks.append(server._monitor) self._servers.pop(address) def _create_pool_for_server(self, address: _Address) -> Pool: @@ -1031,6 +1049,15 @@ def _error_message(self, selector: Callable[[Selection], Selection]) -> str: else: return ",".join(str(server.error) for server in servers if server.error) + async def cleanup_monitors(self) -> None: + tasks = [] + try: + while self._monitor_tasks: + tasks.append(self._monitor_tasks.pop()) + except IndexError: + pass + await asyncio.gather(*[t.join() for t in tasks], return_exceptions=True) # type: ignore[func-returns-value] + def __repr__(self) -> str: msg = "" if not self._opened: diff --git a/pymongo/periodic_executor.py b/pymongo/periodic_executor.py index 9b10f6e7e3..f51a988728 100644 --- a/pymongo/periodic_executor.py +++ b/pymongo/periodic_executor.py @@ -75,6 +75,8 @@ def close(self, dummy: Any = None) -> None: callback; see monitor.py. """ self._stopped = True + if self._task is not None: + self._task.cancel() async def join(self, timeout: Optional[int] = None) -> None: if self._task is not None: diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index 706623c214..8cd08ab725 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -1559,6 +1559,12 @@ def close(self) -> None: # TODO: PYTHON-1921 Encrypted MongoClients cannot be re-opened. self._encrypter.close() self._closed = True + if not _IS_SYNC: + asyncio.gather( + self._topology.cleanup_monitors(), # type: ignore[func-returns-value] + self._kill_cursors_executor.join(), # type: ignore[func-returns-value] + return_exceptions=True, + ) if not _IS_SYNC: # Add support for contextlib.closing. diff --git a/pymongo/synchronous/monitor.py b/pymongo/synchronous/monitor.py index df4130d4ab..211635d8b8 100644 --- a/pymongo/synchronous/monitor.py +++ b/pymongo/synchronous/monitor.py @@ -112,9 +112,9 @@ def close(self) -> None: """ self.gc_safe_close() - def join(self, timeout: Optional[int] = None) -> None: + def join(self) -> None: """Wait for the monitor to stop.""" - self._executor.join(timeout) + self._executor.join() def request_check(self) -> None: """If the monitor is sleeping, wake it soon.""" @@ -189,6 +189,9 @@ def gc_safe_close(self) -> None: self._rtt_monitor.gc_safe_close() self.cancel_check() + def join(self) -> None: + asyncio.gather(self._executor.join(), self._rtt_monitor.join(), return_exceptions=True) # type: ignore[func-returns-value] + def close(self) -> None: self.gc_safe_close() self._rtt_monitor.close() diff --git a/pymongo/synchronous/topology.py b/pymongo/synchronous/topology.py index b03269ae43..09b61f6d05 100644 --- a/pymongo/synchronous/topology.py +++ b/pymongo/synchronous/topology.py @@ -16,6 +16,7 @@ from __future__ import annotations +import asyncio import logging import os import queue @@ -61,7 +62,7 @@ writable_server_selector, ) from pymongo.synchronous.client_session import _ServerSession, _ServerSessionPool -from pymongo.synchronous.monitor import SrvMonitor +from pymongo.synchronous.monitor import MonitorBase, SrvMonitor from pymongo.synchronous.pool import Pool from pymongo.synchronous.server import Server from pymongo.topology_description import ( @@ -207,6 +208,9 @@ def target() -> bool: if self._settings.fqdn is not None and not self._settings.load_balanced: self._srv_monitor = SrvMonitor(self, self._settings) + # Stores all monitor tasks that need to be joined on close or server selection + self._monitor_tasks: list[MonitorBase] = [] + def open(self) -> None: """Start monitoring, or restart after a fork. @@ -241,6 +245,8 @@ def open(self) -> None: # Close servers and clear the pools. for server in self._servers.values(): server.close() + if not _IS_SYNC: + self._monitor_tasks.append(server._monitor) # Reset the session pool to avoid duplicate sessions in # the child process. self._session_pool.reset() @@ -283,6 +289,10 @@ def select_servers( else: server_timeout = server_selection_timeout + # Cleanup any completed monitor tasks safely + if not _IS_SYNC and self._monitor_tasks: + self.cleanup_monitors() + with self._lock: server_descriptions = self._select_servers_loop( selector, server_timeout, operation, operation_id, address @@ -520,6 +530,8 @@ def _process_change( and self._description.topology_type not in SRV_POLLING_TOPOLOGIES ): self._srv_monitor.close() + if not _IS_SYNC: + self._monitor_tasks.append(self._srv_monitor) # Clear the pool from a failed heartbeat. if reset_pool: @@ -693,6 +705,8 @@ def close(self) -> None: old_td = self._description for server in self._servers.values(): server.close() + if not _IS_SYNC: + self._monitor_tasks.append(server._monitor) # Mark all servers Unknown. self._description = self._description.reset() @@ -703,6 +717,8 @@ def close(self) -> None: # Stop SRV polling thread. if self._srv_monitor: self._srv_monitor.close() + if not _IS_SYNC: + self._monitor_tasks.append(self._srv_monitor) self._opened = False self._closed = True @@ -942,6 +958,8 @@ def _update_servers(self) -> None: for address, server in list(self._servers.items()): if not self._description.has_server(address): server.close() + if not _IS_SYNC: + self._monitor_tasks.append(server._monitor) self._servers.pop(address) def _create_pool_for_server(self, address: _Address) -> Pool: @@ -1029,6 +1047,15 @@ def _error_message(self, selector: Callable[[Selection], Selection]) -> str: else: return ",".join(str(server.error) for server in servers if server.error) + def cleanup_monitors(self) -> None: + tasks = [] + try: + while self._monitor_tasks: + tasks.append(self._monitor_tasks.pop()) + except IndexError: + pass + asyncio.gather(*[t.join() for t in tasks], return_exceptions=True) # type: ignore[func-returns-value] + def __repr__(self) -> str: msg = "" if not self._opened: