Skip to content

PYTHON-5053 - AsyncMongoClient.close() should await all background tasks #2127

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Feb 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions pymongo/asynchronous/mongo_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1565,6 +1565,12 @@ async def close(self) -> None:
# TODO: PYTHON-1921 Encrypted MongoClients cannot be re-opened.
await self._encrypter.close()
self._closed = True
if not _IS_SYNC:
await asyncio.gather(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe we should be using return_exceptions=True on all these gather() calls:

If return_exceptions is True, exceptions are treated the same as successful results, and aggregated in the result list.

https://docs.python.org/3/library/asyncio-task.html#asyncio.gather

Otherwise we may accidentally propagate an exception we don't care about.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At this point should we be using asyncio.wait() instead of gather()?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. I think we can safely use asyncio.wait() instead here, since we don't need the aggregated result returned by gather.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh wait, asyncio.wait() explicitly does not support waiting for coroutines, only Task or Future objects. Sticking with gather seems a little simpler, even with return_exceptions=True.

self._topology.cleanup_monitors(), # type: ignore[func-returns-value]
self._kill_cursors_executor.join(), # type: ignore[func-returns-value]
return_exceptions=True,
)

if not _IS_SYNC:
# Add support for contextlib.aclosing.
Expand Down
9 changes: 7 additions & 2 deletions pymongo/asynchronous/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,9 @@ async def close(self) -> None:
"""
self.gc_safe_close()

async def join(self, timeout: Optional[int] = None) -> None:
async def join(self) -> None:
"""Wait for the monitor to stop."""
await self._executor.join(timeout)
await self._executor.join()

def request_check(self) -> None:
"""If the monitor is sleeping, wake it soon."""
Expand Down Expand Up @@ -189,6 +189,11 @@ def gc_safe_close(self) -> None:
self._rtt_monitor.gc_safe_close()
self.cancel_check()

async def join(self) -> None:
await asyncio.gather(
self._executor.join(), self._rtt_monitor.join(), return_exceptions=True
) # type: ignore[func-returns-value]

async def close(self) -> None:
self.gc_safe_close()
await self._rtt_monitor.close()
Expand Down
29 changes: 28 additions & 1 deletion pymongo/asynchronous/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from __future__ import annotations

import asyncio
import logging
import os
import queue
Expand All @@ -29,7 +30,7 @@

from pymongo import _csot, common, helpers_shared, periodic_executor
from pymongo.asynchronous.client_session import _ServerSession, _ServerSessionPool
from pymongo.asynchronous.monitor import SrvMonitor
from pymongo.asynchronous.monitor import MonitorBase, SrvMonitor
from pymongo.asynchronous.pool import Pool
from pymongo.asynchronous.server import Server
from pymongo.errors import (
Expand Down Expand Up @@ -207,6 +208,9 @@ async def target() -> bool:
if self._settings.fqdn is not None and not self._settings.load_balanced:
self._srv_monitor = SrvMonitor(self, self._settings)

# Stores all monitor tasks that need to be joined on close or server selection
self._monitor_tasks: list[MonitorBase] = []

async def open(self) -> None:
"""Start monitoring, or restart after a fork.

Expand Down Expand Up @@ -241,6 +245,8 @@ async def open(self) -> None:
# Close servers and clear the pools.
for server in self._servers.values():
await server.close()
if not _IS_SYNC:
self._monitor_tasks.append(server._monitor)
# Reset the session pool to avoid duplicate sessions in
# the child process.
self._session_pool.reset()
Expand Down Expand Up @@ -283,6 +289,10 @@ async def select_servers(
else:
server_timeout = server_selection_timeout

# Cleanup any completed monitor tasks safely
if not _IS_SYNC and self._monitor_tasks:
await self.cleanup_monitors()

async with self._lock:
server_descriptions = await self._select_servers_loop(
selector, server_timeout, operation, operation_id, address
Expand Down Expand Up @@ -520,6 +530,8 @@ async def _process_change(
and self._description.topology_type not in SRV_POLLING_TOPOLOGIES
):
await self._srv_monitor.close()
if not _IS_SYNC:
self._monitor_tasks.append(self._srv_monitor)

# Clear the pool from a failed heartbeat.
if reset_pool:
Expand Down Expand Up @@ -695,6 +707,8 @@ async def close(self) -> None:
old_td = self._description
for server in self._servers.values():
await server.close()
if not _IS_SYNC:
self._monitor_tasks.append(server._monitor)

# Mark all servers Unknown.
self._description = self._description.reset()
Expand All @@ -705,6 +719,8 @@ async def close(self) -> None:
# Stop SRV polling thread.
if self._srv_monitor:
await self._srv_monitor.close()
if not _IS_SYNC:
self._monitor_tasks.append(self._srv_monitor)

self._opened = False
self._closed = True
Expand Down Expand Up @@ -944,6 +960,8 @@ async def _update_servers(self) -> None:
for address, server in list(self._servers.items()):
if not self._description.has_server(address):
await server.close()
if not _IS_SYNC:
self._monitor_tasks.append(server._monitor)
self._servers.pop(address)

def _create_pool_for_server(self, address: _Address) -> Pool:
Expand Down Expand Up @@ -1031,6 +1049,15 @@ def _error_message(self, selector: Callable[[Selection], Selection]) -> str:
else:
return ",".join(str(server.error) for server in servers if server.error)

async def cleanup_monitors(self) -> None:
tasks = []
try:
while self._monitor_tasks:
tasks.append(self._monitor_tasks.pop())
except IndexError:
pass
await asyncio.gather(*[t.join() for t in tasks], return_exceptions=True) # type: ignore[func-returns-value]

def __repr__(self) -> str:
msg = ""
if not self._opened:
Expand Down
2 changes: 2 additions & 0 deletions pymongo/periodic_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ def close(self, dummy: Any = None) -> None:
callback; see monitor.py.
"""
self._stopped = True
if self._task is not None:
self._task.cancel()

async def join(self, timeout: Optional[int] = None) -> None:
if self._task is not None:
Expand Down
6 changes: 6 additions & 0 deletions pymongo/synchronous/mongo_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1559,6 +1559,12 @@ def close(self) -> None:
# TODO: PYTHON-1921 Encrypted MongoClients cannot be re-opened.
self._encrypter.close()
self._closed = True
if not _IS_SYNC:
asyncio.gather(
self._topology.cleanup_monitors(), # type: ignore[func-returns-value]
self._kill_cursors_executor.join(), # type: ignore[func-returns-value]
return_exceptions=True,
)

if not _IS_SYNC:
# Add support for contextlib.closing.
Expand Down
7 changes: 5 additions & 2 deletions pymongo/synchronous/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,9 @@ def close(self) -> None:
"""
self.gc_safe_close()

def join(self, timeout: Optional[int] = None) -> None:
def join(self) -> None:
"""Wait for the monitor to stop."""
self._executor.join(timeout)
self._executor.join()

def request_check(self) -> None:
"""If the monitor is sleeping, wake it soon."""
Expand Down Expand Up @@ -189,6 +189,9 @@ def gc_safe_close(self) -> None:
self._rtt_monitor.gc_safe_close()
self.cancel_check()

def join(self) -> None:
asyncio.gather(self._executor.join(), self._rtt_monitor.join(), return_exceptions=True) # type: ignore[func-returns-value]

def close(self) -> None:
self.gc_safe_close()
self._rtt_monitor.close()
Expand Down
29 changes: 28 additions & 1 deletion pymongo/synchronous/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from __future__ import annotations

import asyncio
import logging
import os
import queue
Expand Down Expand Up @@ -61,7 +62,7 @@
writable_server_selector,
)
from pymongo.synchronous.client_session import _ServerSession, _ServerSessionPool
from pymongo.synchronous.monitor import SrvMonitor
from pymongo.synchronous.monitor import MonitorBase, SrvMonitor
from pymongo.synchronous.pool import Pool
from pymongo.synchronous.server import Server
from pymongo.topology_description import (
Expand Down Expand Up @@ -207,6 +208,9 @@ def target() -> bool:
if self._settings.fqdn is not None and not self._settings.load_balanced:
self._srv_monitor = SrvMonitor(self, self._settings)

# Stores all monitor tasks that need to be joined on close or server selection
self._monitor_tasks: list[MonitorBase] = []

def open(self) -> None:
"""Start monitoring, or restart after a fork.

Expand Down Expand Up @@ -241,6 +245,8 @@ def open(self) -> None:
# Close servers and clear the pools.
for server in self._servers.values():
server.close()
if not _IS_SYNC:
self._monitor_tasks.append(server._monitor)
# Reset the session pool to avoid duplicate sessions in
# the child process.
self._session_pool.reset()
Expand Down Expand Up @@ -283,6 +289,10 @@ def select_servers(
else:
server_timeout = server_selection_timeout

# Cleanup any completed monitor tasks safely
if not _IS_SYNC and self._monitor_tasks:
self.cleanup_monitors()

with self._lock:
server_descriptions = self._select_servers_loop(
selector, server_timeout, operation, operation_id, address
Expand Down Expand Up @@ -520,6 +530,8 @@ def _process_change(
and self._description.topology_type not in SRV_POLLING_TOPOLOGIES
):
self._srv_monitor.close()
if not _IS_SYNC:
self._monitor_tasks.append(self._srv_monitor)

# Clear the pool from a failed heartbeat.
if reset_pool:
Expand Down Expand Up @@ -693,6 +705,8 @@ def close(self) -> None:
old_td = self._description
for server in self._servers.values():
server.close()
if not _IS_SYNC:
self._monitor_tasks.append(server._monitor)

# Mark all servers Unknown.
self._description = self._description.reset()
Expand All @@ -703,6 +717,8 @@ def close(self) -> None:
# Stop SRV polling thread.
if self._srv_monitor:
self._srv_monitor.close()
if not _IS_SYNC:
self._monitor_tasks.append(self._srv_monitor)

self._opened = False
self._closed = True
Expand Down Expand Up @@ -942,6 +958,8 @@ def _update_servers(self) -> None:
for address, server in list(self._servers.items()):
if not self._description.has_server(address):
server.close()
if not _IS_SYNC:
self._monitor_tasks.append(server._monitor)
self._servers.pop(address)

def _create_pool_for_server(self, address: _Address) -> Pool:
Expand Down Expand Up @@ -1029,6 +1047,15 @@ def _error_message(self, selector: Callable[[Selection], Selection]) -> str:
else:
return ",".join(str(server.error) for server in servers if server.error)

def cleanup_monitors(self) -> None:
tasks = []
try:
while self._monitor_tasks:
tasks.append(self._monitor_tasks.pop())
except IndexError:
pass
asyncio.gather(*[t.join() for t in tasks], return_exceptions=True) # type: ignore[func-returns-value]

def __repr__(self) -> str:
msg = ""
if not self._opened:
Expand Down
Loading