diff --git a/pymongo/asynchronous/change_stream.py b/pymongo/asynchronous/change_stream.py index 719020c409..f405e91161 100644 --- a/pymongo/asynchronous/change_stream.py +++ b/pymongo/asynchronous/change_stream.py @@ -391,7 +391,8 @@ async def try_next(self) -> Optional[_DocumentType]: if not _resumable(exc) and not exc.timeout: await self.close() raise - except Exception: + # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. + except BaseException: await self.close() raise diff --git a/pymongo/asynchronous/client_session.py b/pymongo/asynchronous/client_session.py index 4c5171a350..b635cf4648 100644 --- a/pymongo/asynchronous/client_session.py +++ b/pymongo/asynchronous/client_session.py @@ -697,7 +697,8 @@ async def callback(session, custom_arg, custom_kwarg=None): ) try: ret = await callback(self) - except Exception as exc: + # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. + except BaseException as exc: if self.in_transaction: await self.abort_transaction() if ( diff --git a/pymongo/asynchronous/cursor.py b/pymongo/asynchronous/cursor.py index 9101197ce2..1b25bf4ee8 100644 --- a/pymongo/asynchronous/cursor.py +++ b/pymongo/asynchronous/cursor.py @@ -1126,7 +1126,8 @@ async def _send_message(self, operation: Union[_Query, _GetMore]) -> None: self._killed = True await self.close() raise - except Exception: + # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. + except BaseException: await self.close() raise diff --git a/pymongo/asynchronous/pool.py b/pymongo/asynchronous/pool.py index bf2f2b4946..39b3bfc042 100644 --- a/pymongo/asynchronous/pool.py +++ b/pymongo/asynchronous/pool.py @@ -559,7 +559,7 @@ async def command( ) except (OperationFailure, NotPrimaryError): raise - # Catch socket.error, KeyboardInterrupt, etc. and close ourselves. + # Catch socket.error, KeyboardInterrupt, CancelledError, etc. and close ourselves. except BaseException as error: self._raise_connection_failure(error) @@ -576,6 +576,7 @@ async def send_message(self, message: bytes, max_doc_size: int) -> None: try: await async_sendall(self.conn, message) + # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. except BaseException as error: self._raise_connection_failure(error) @@ -586,6 +587,7 @@ async def receive_message(self, request_id: Optional[int]) -> Union[_OpReply, _O """ try: return await receive_message(self, request_id, self.max_message_size) + # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. except BaseException as error: self._raise_connection_failure(error) @@ -1269,6 +1271,7 @@ async def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> A try: sock = await _configured_socket(self.address, self.opts) + # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. except BaseException as error: async with self.lock: self.active_contexts.discard(tmp_context) @@ -1308,6 +1311,7 @@ async def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> A handler.contribute_socket(conn, completed_handshake=False) await conn.authenticate() + # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. except BaseException: async with self.lock: self.active_contexts.discard(conn.cancel_context) @@ -1369,6 +1373,7 @@ async def checkout( async with self.lock: self.active_contexts.add(conn.cancel_context) yield conn + # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. except BaseException: # Exception in caller. Ensure the connection gets returned. # Note that when pinned is True, the session owns the @@ -1515,6 +1520,7 @@ async def _get_conn( async with self._max_connecting_cond: self._pending -= 1 self._max_connecting_cond.notify() + # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. except BaseException: if conn: # We checked out a socket but authentication failed. diff --git a/pymongo/periodic_executor.py b/pymongo/periodic_executor.py index 9b10f6e7e3..492f764d73 100644 --- a/pymongo/periodic_executor.py +++ b/pymongo/periodic_executor.py @@ -98,6 +98,7 @@ async def _run(self) -> None: if not await self._target(): self._stopped = True break + # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. except BaseException: self._stopped = True raise @@ -230,6 +231,7 @@ def _run(self) -> None: if not self._target(): self._stopped = True break + # Catch KeyboardInterrupt, etc. and cleanup. except BaseException: with self._lock: self._stopped = True diff --git a/pymongo/synchronous/change_stream.py b/pymongo/synchronous/change_stream.py index a971ad08c0..43aab39ee1 100644 --- a/pymongo/synchronous/change_stream.py +++ b/pymongo/synchronous/change_stream.py @@ -389,7 +389,8 @@ def try_next(self) -> Optional[_DocumentType]: if not _resumable(exc) and not exc.timeout: self.close() raise - except Exception: + # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. + except BaseException: self.close() raise diff --git a/pymongo/synchronous/client_session.py b/pymongo/synchronous/client_session.py index 298dd7b357..af7ff59b3d 100644 --- a/pymongo/synchronous/client_session.py +++ b/pymongo/synchronous/client_session.py @@ -694,7 +694,8 @@ def callback(session, custom_arg, custom_kwarg=None): self.start_transaction(read_concern, write_concern, read_preference, max_commit_time_ms) try: ret = callback(self) - except Exception as exc: + # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. + except BaseException as exc: if self.in_transaction: self.abort_transaction() if ( diff --git a/pymongo/synchronous/cursor.py b/pymongo/synchronous/cursor.py index cda093ee19..31c4604f89 100644 --- a/pymongo/synchronous/cursor.py +++ b/pymongo/synchronous/cursor.py @@ -1124,7 +1124,8 @@ def _send_message(self, operation: Union[_Query, _GetMore]) -> None: self._killed = True self.close() raise - except Exception: + # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. + except BaseException: self.close() raise diff --git a/pymongo/synchronous/pool.py b/pymongo/synchronous/pool.py index 05f930d480..7c55e04b22 100644 --- a/pymongo/synchronous/pool.py +++ b/pymongo/synchronous/pool.py @@ -559,7 +559,7 @@ def command( ) except (OperationFailure, NotPrimaryError): raise - # Catch socket.error, KeyboardInterrupt, etc. and close ourselves. + # Catch socket.error, KeyboardInterrupt, CancelledError, etc. and close ourselves. except BaseException as error: self._raise_connection_failure(error) @@ -576,6 +576,7 @@ def send_message(self, message: bytes, max_doc_size: int) -> None: try: sendall(self.conn, message) + # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. except BaseException as error: self._raise_connection_failure(error) @@ -586,6 +587,7 @@ def receive_message(self, request_id: Optional[int]) -> Union[_OpReply, _OpMsg]: """ try: return receive_message(self, request_id, self.max_message_size) + # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. except BaseException as error: self._raise_connection_failure(error) @@ -1263,6 +1265,7 @@ def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> Connect try: sock = _configured_socket(self.address, self.opts) + # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. except BaseException as error: with self.lock: self.active_contexts.discard(tmp_context) @@ -1302,6 +1305,7 @@ def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> Connect handler.contribute_socket(conn, completed_handshake=False) conn.authenticate() + # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. except BaseException: with self.lock: self.active_contexts.discard(conn.cancel_context) @@ -1363,6 +1367,7 @@ def checkout( with self.lock: self.active_contexts.add(conn.cancel_context) yield conn + # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. except BaseException: # Exception in caller. Ensure the connection gets returned. # Note that when pinned is True, the session owns the @@ -1509,6 +1514,7 @@ def _get_conn( with self._max_connecting_cond: self._pending -= 1 self._max_connecting_cond.notify() + # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. except BaseException: if conn: # We checked out a socket but authentication failed. diff --git a/test/asynchronous/test_async_cancellation.py b/test/asynchronous/test_async_cancellation.py new file mode 100644 index 0000000000..b73c7a8084 --- /dev/null +++ b/test/asynchronous/test_async_cancellation.py @@ -0,0 +1,126 @@ +# Copyright 2025-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test that async cancellation performed by users clean up resources correctly.""" +from __future__ import annotations + +import asyncio +import sys +from test.utils import async_get_pool, delay, one + +sys.path[0:0] = [""] + +from test.asynchronous import AsyncIntegrationTest, async_client_context, connected + + +class TestAsyncCancellation(AsyncIntegrationTest): + async def test_async_cancellation_closes_connection(self): + pool = await async_get_pool(self.client) + await self.client.db.test.insert_one({"x": 1}) + self.addAsyncCleanup(self.client.db.test.delete_many, {}) + + conn = one(pool.conns) + + async def task(): + await self.client.db.test.find_one({"$where": delay(0.2)}) + + task = asyncio.create_task(task()) + + await asyncio.sleep(0.1) + + task.cancel() + with self.assertRaises(asyncio.CancelledError): + await task + + self.assertTrue(conn.closed) + + @async_client_context.require_transactions + async def test_async_cancellation_aborts_transaction(self): + await self.client.db.test.insert_one({"x": 1}) + self.addAsyncCleanup(self.client.db.test.delete_many, {}) + + session = self.client.start_session() + + async def callback(session): + await self.client.db.test.find_one({"$where": delay(0.2)}, session=session) + + async def task(): + await session.with_transaction(callback) + + task = asyncio.create_task(task()) + + await asyncio.sleep(0.1) + + task.cancel() + with self.assertRaises(asyncio.CancelledError): + await task + + self.assertFalse(session.in_transaction) + + @async_client_context.require_failCommand_blockConnection + async def test_async_cancellation_closes_cursor(self): + await self.client.db.test.insert_many([{"x": 1}, {"x": 2}]) + self.addAsyncCleanup(self.client.db.test.delete_many, {}) + + cursor = self.client.db.test.find({}, batch_size=1) + await cursor.next() + + # Make sure getMore commands block + fail_command = { + "configureFailPoint": "failCommand", + "mode": "alwaysOn", + "data": {"failCommands": ["getMore"], "blockConnection": True, "blockTimeMS": 200}, + } + + async def task(): + async with self.fail_point(fail_command): + await cursor.next() + + task = asyncio.create_task(task()) + + await asyncio.sleep(0.1) + + task.cancel() + with self.assertRaises(asyncio.CancelledError): + await task + + self.assertTrue(cursor._killed) + + @async_client_context.require_change_streams + @async_client_context.require_failCommand_blockConnection + async def test_async_cancellation_closes_change_stream(self): + self.addAsyncCleanup(self.client.db.test.delete_many, {}) + change_stream = await self.client.db.test.watch(batch_size=2) + + # Make sure getMore commands block + fail_command = { + "configureFailPoint": "failCommand", + "mode": "alwaysOn", + "data": {"failCommands": ["getMore"], "blockConnection": True, "blockTimeMS": 200}, + } + + async def task(): + async with self.fail_point(fail_command): + await self.client.db.test.insert_many([{"x": 1}, {"x": 2}]) + await change_stream.next() + + task = asyncio.create_task(task()) + + await asyncio.sleep(0.1) + + task.cancel() + with self.assertRaises(asyncio.CancelledError): + await task + + self.assertTrue(change_stream._closed) diff --git a/tools/synchro.py b/tools/synchro.py index 06dc708e08..eaced04f41 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -168,7 +168,7 @@ def async_only_test(f: str) -> bool: """Return True for async tests that should not be converted to sync.""" - return f in ["test_locks.py", "test_concurrency.py"] + return f in ["test_locks.py", "test_concurrency.py", "test_async_cancellation.py"] test_files = [