From 260105483633843e3daf44ab6b7c28300b88bb2e Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Tue, 4 Feb 2025 15:56:20 -0500 Subject: [PATCH 1/9] PYTHON-4745 - Document and Test Behavior when User Cancels Async Operation --- test/asynchronous/test_async_cancellation.py | 40 ++++++++++++++++++++ tools/synchro.py | 2 +- 2 files changed, 41 insertions(+), 1 deletion(-) create mode 100644 test/asynchronous/test_async_cancellation.py diff --git a/test/asynchronous/test_async_cancellation.py b/test/asynchronous/test_async_cancellation.py new file mode 100644 index 0000000000..3c82f28462 --- /dev/null +++ b/test/asynchronous/test_async_cancellation.py @@ -0,0 +1,40 @@ +# Copyright 2025-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test that async cancellation performed by users raises the expected error.""" +from __future__ import annotations + +import asyncio +import sys + +sys.path[0:0] = [""] + +from test.asynchronous import AsyncIntegrationTest + + +class TestAsyncCancellation(AsyncIntegrationTest): + async def test_async_cancellation(self): + async def task(): + while True: + await self.client.db.test.insert_one({"x": 1}) + await asyncio.sleep(0.005) + + task = asyncio.create_task(task()) + + # Make sure the task successfully runs a few operations to simulate a long-running user task + await asyncio.sleep(0.01) + + task.cancel() + with self.assertRaises(asyncio.CancelledError): + await task diff --git a/tools/synchro.py b/tools/synchro.py index 06dc708e08..eaced04f41 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -168,7 +168,7 @@ def async_only_test(f: str) -> bool: """Return True for async tests that should not be converted to sync.""" - return f in ["test_locks.py", "test_concurrency.py"] + return f in ["test_locks.py", "test_concurrency.py", "test_async_cancellation.py"] test_files = [ From e27062cecc4db959a2bc9018828ec428d5b79bdf Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Wed, 5 Feb 2025 11:29:57 -0500 Subject: [PATCH 2/9] WIP --- test/asynchronous/test_async_cancellation.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/test/asynchronous/test_async_cancellation.py b/test/asynchronous/test_async_cancellation.py index 3c82f28462..8ab471a0b4 100644 --- a/test/asynchronous/test_async_cancellation.py +++ b/test/asynchronous/test_async_cancellation.py @@ -17,17 +17,23 @@ import asyncio import sys +from test.utils import async_get_pool, get_pool, one sys.path[0:0] = [""] -from test.asynchronous import AsyncIntegrationTest +from test.asynchronous import AsyncIntegrationTest, connected class TestAsyncCancellation(AsyncIntegrationTest): - async def test_async_cancellation(self): + async def test_async_cancellation_does_not_close_connection(self): + client = await self.async_rs_or_single_client(maxPoolSize=1, retryReads=False) + pool = await async_get_pool(client) + await connected(client) + conn = one(pool.conns) + async def task(): while True: - await self.client.db.test.insert_one({"x": 1}) + await client.db.test.insert_one({"x": 1}) await asyncio.sleep(0.005) task = asyncio.create_task(task()) @@ -38,3 +44,5 @@ async def task(): task.cancel() with self.assertRaises(asyncio.CancelledError): await task + + self.assertFalse(conn.closed) From 91eb68ff5e0e9533e643f1e2313b97ad71c880e7 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Wed, 5 Feb 2025 15:12:42 -0500 Subject: [PATCH 3/9] WIP --- test/asynchronous/test_async_cancellation.py | 43 +++++++++++++++----- 1 file changed, 33 insertions(+), 10 deletions(-) diff --git a/test/asynchronous/test_async_cancellation.py b/test/asynchronous/test_async_cancellation.py index 8ab471a0b4..e5c3ebbb7c 100644 --- a/test/asynchronous/test_async_cancellation.py +++ b/test/asynchronous/test_async_cancellation.py @@ -17,32 +17,55 @@ import asyncio import sys -from test.utils import async_get_pool, get_pool, one +import traceback + +from test.utils import async_get_pool, get_pool, one, delay sys.path[0:0] = [""] -from test.asynchronous import AsyncIntegrationTest, connected +from test.asynchronous import AsyncIntegrationTest, connected, async_client_context class TestAsyncCancellation(AsyncIntegrationTest): - async def test_async_cancellation_does_not_close_connection(self): - client = await self.async_rs_or_single_client(maxPoolSize=1, retryReads=False) + async def test_async_cancellation_closes_connection(self): + client = await self.async_rs_or_single_client() pool = await async_get_pool(client) await connected(client) conn = one(pool.conns) async def task(): - while True: - await client.db.test.insert_one({"x": 1}) - await asyncio.sleep(0.005) + await client.db.test.find_one({"$where": delay(1.0)}) + + task = asyncio.create_task(task()) + + await asyncio.sleep(0.1) + + task.cancel() + with self.assertRaises(asyncio.CancelledError): + await task + + self.assertTrue(conn.closed) + + @async_client_context.require_transactions + async def test_async_cancellation_aborts_transaction(self): + client = await self.async_rs_or_single_client() + await connected(client) + + session = client.start_session() + + async def callback(session): + await client.db.test.find_one({"$where": delay(1.0)}) + + async def task(): + await session.with_transaction(callback) task = asyncio.create_task(task()) - # Make sure the task successfully runs a few operations to simulate a long-running user task - await asyncio.sleep(0.01) + await asyncio.sleep(0.1) task.cancel() with self.assertRaises(asyncio.CancelledError): await task - self.assertFalse(conn.closed) + self.assertFalse(session.in_transaction) + From fd18b7a59d04bd4a95289c7f444558952af31222 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Thu, 6 Feb 2025 12:14:23 -0500 Subject: [PATCH 4/9] Verify that connections, transactions, and cursors are terminated when cancelled --- pymongo/asynchronous/client_session.py | 2 +- pymongo/asynchronous/cursor.py | 2 +- pymongo/synchronous/client_session.py | 2 +- pymongo/synchronous/cursor.py | 2 +- test/asynchronous/test_async_cancellation.py | 46 +++++++++++++++++--- 5 files changed, 43 insertions(+), 11 deletions(-) diff --git a/pymongo/asynchronous/client_session.py b/pymongo/asynchronous/client_session.py index 4c5171a350..97b9b4ffd9 100644 --- a/pymongo/asynchronous/client_session.py +++ b/pymongo/asynchronous/client_session.py @@ -697,7 +697,7 @@ async def callback(session, custom_arg, custom_kwarg=None): ) try: ret = await callback(self) - except Exception as exc: + except BaseException as exc: if self.in_transaction: await self.abort_transaction() if ( diff --git a/pymongo/asynchronous/cursor.py b/pymongo/asynchronous/cursor.py index 9101197ce2..58e3ee6e9b 100644 --- a/pymongo/asynchronous/cursor.py +++ b/pymongo/asynchronous/cursor.py @@ -1126,7 +1126,7 @@ async def _send_message(self, operation: Union[_Query, _GetMore]) -> None: self._killed = True await self.close() raise - except Exception: + except BaseException: await self.close() raise diff --git a/pymongo/synchronous/client_session.py b/pymongo/synchronous/client_session.py index 298dd7b357..d922bafa53 100644 --- a/pymongo/synchronous/client_session.py +++ b/pymongo/synchronous/client_session.py @@ -694,7 +694,7 @@ def callback(session, custom_arg, custom_kwarg=None): self.start_transaction(read_concern, write_concern, read_preference, max_commit_time_ms) try: ret = callback(self) - except Exception as exc: + except BaseException as exc: if self.in_transaction: self.abort_transaction() if ( diff --git a/pymongo/synchronous/cursor.py b/pymongo/synchronous/cursor.py index cda093ee19..712f0ebb28 100644 --- a/pymongo/synchronous/cursor.py +++ b/pymongo/synchronous/cursor.py @@ -1124,7 +1124,7 @@ def _send_message(self, operation: Union[_Query, _GetMore]) -> None: self._killed = True self.close() raise - except Exception: + except BaseException: self.close() raise diff --git a/test/asynchronous/test_async_cancellation.py b/test/asynchronous/test_async_cancellation.py index e5c3ebbb7c..9be79f5220 100644 --- a/test/asynchronous/test_async_cancellation.py +++ b/test/asynchronous/test_async_cancellation.py @@ -17,24 +17,24 @@ import asyncio import sys -import traceback - -from test.utils import async_get_pool, get_pool, one, delay +from test.utils import async_get_pool, delay, one sys.path[0:0] = [""] -from test.asynchronous import AsyncIntegrationTest, connected, async_client_context +from test.asynchronous import AsyncIntegrationTest, async_client_context, connected class TestAsyncCancellation(AsyncIntegrationTest): async def test_async_cancellation_closes_connection(self): - client = await self.async_rs_or_single_client() + client = await self.async_rs_or_single_client(maxPoolSize=1) pool = await async_get_pool(client) await connected(client) conn = one(pool.conns) + await client.db.test.insert_one({"x": 1}) + self.addAsyncCleanup(client.db.test.drop) async def task(): - await client.db.test.find_one({"$where": delay(1.0)}) + await client.db.test.find_one({"$where": delay(0.2)}) task = asyncio.create_task(task()) @@ -50,11 +50,13 @@ async def task(): async def test_async_cancellation_aborts_transaction(self): client = await self.async_rs_or_single_client() await connected(client) + await client.db.test.insert_one({"x": 1}) + self.addAsyncCleanup(client.db.test.drop) session = client.start_session() async def callback(session): - await client.db.test.find_one({"$where": delay(1.0)}) + await client.db.test.find_one({"$where": delay(0.2)}, session=session) async def task(): await session.with_transaction(callback) @@ -69,3 +71,33 @@ async def task(): self.assertFalse(session.in_transaction) + async def test_async_cancellation_kills_cursor(self): + client = await self.async_rs_or_single_client() + await connected(client) + for _ in range(2): + await client.db.test.insert_one({"x": 1}) + self.addAsyncCleanup(client.db.test.drop) + + cursor = client.db.test.find({}, batch_size=1) + await cursor.next() + + # Make sure getMore commands block + fail_command = { + "configureFailPoint": "failCommand", + "mode": "alwaysOn", + "data": {"failCommands": ["getMore"], "blockConnection": True, "blockTimeMS": 200}, + } + + async def task(): + async with self.fail_point(fail_command): + await cursor.next() + + task = asyncio.create_task(task()) + + await asyncio.sleep(0.1) + + task.cancel() + with self.assertRaises(asyncio.CancelledError): + await task + + self.assertTrue(cursor._killed) From 490689738dcb5d57f694714f0ed38de9dc3ee482 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Thu, 6 Feb 2025 12:36:59 -0500 Subject: [PATCH 5/9] Test change_stream cancellation --- pymongo/asynchronous/change_stream.py | 2 +- pymongo/synchronous/change_stream.py | 2 +- test/asynchronous/test_async_cancellation.py | 34 ++++++++++++++++++-- 3 files changed, 34 insertions(+), 4 deletions(-) diff --git a/pymongo/asynchronous/change_stream.py b/pymongo/asynchronous/change_stream.py index 719020c409..e36eea542c 100644 --- a/pymongo/asynchronous/change_stream.py +++ b/pymongo/asynchronous/change_stream.py @@ -391,7 +391,7 @@ async def try_next(self) -> Optional[_DocumentType]: if not _resumable(exc) and not exc.timeout: await self.close() raise - except Exception: + except BaseException: await self.close() raise diff --git a/pymongo/synchronous/change_stream.py b/pymongo/synchronous/change_stream.py index a971ad08c0..e8f42eb830 100644 --- a/pymongo/synchronous/change_stream.py +++ b/pymongo/synchronous/change_stream.py @@ -389,7 +389,7 @@ def try_next(self) -> Optional[_DocumentType]: if not _resumable(exc) and not exc.timeout: self.close() raise - except Exception: + except BaseException: self.close() raise diff --git a/test/asynchronous/test_async_cancellation.py b/test/asynchronous/test_async_cancellation.py index 9be79f5220..fccf88df15 100644 --- a/test/asynchronous/test_async_cancellation.py +++ b/test/asynchronous/test_async_cancellation.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Test that async cancellation performed by users raises the expected error.""" +"""Test that async cancellation performed by users clean up resources correctly.""" from __future__ import annotations import asyncio @@ -71,7 +71,7 @@ async def task(): self.assertFalse(session.in_transaction) - async def test_async_cancellation_kills_cursor(self): + async def test_async_cancellation_closes_cursor(self): client = await self.async_rs_or_single_client() await connected(client) for _ in range(2): @@ -101,3 +101,33 @@ async def task(): await task self.assertTrue(cursor._killed) + + async def test_async_cancellation_closes_change_stream(self): + client = await self.async_rs_or_single_client() + await connected(client) + self.addAsyncCleanup(client.db.test.drop) + + change_stream = await client.db.test.watch(batch_size=2) + + # Make sure getMore commands block + fail_command = { + "configureFailPoint": "failCommand", + "mode": "alwaysOn", + "data": {"failCommands": ["getMore"], "blockConnection": True, "blockTimeMS": 200}, + } + + async def task(): + async with self.fail_point(fail_command): + for _ in range(2): + await client.db.test.insert_one({"x": 1}) + await change_stream.next() + + task = asyncio.create_task(task()) + + await asyncio.sleep(0.1) + + task.cancel() + with self.assertRaises(asyncio.CancelledError): + await task + + self.assertTrue(change_stream._closed) From 09b545d6e6f80e0e06377d5c12b033d4b3557e50 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Thu, 6 Feb 2025 12:45:50 -0500 Subject: [PATCH 6/9] Fix platform requirements --- test/asynchronous/test_async_cancellation.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/asynchronous/test_async_cancellation.py b/test/asynchronous/test_async_cancellation.py index fccf88df15..845686bc6b 100644 --- a/test/asynchronous/test_async_cancellation.py +++ b/test/asynchronous/test_async_cancellation.py @@ -71,6 +71,7 @@ async def task(): self.assertFalse(session.in_transaction) + @async_client_context.require_failCommand_blockConnection async def test_async_cancellation_closes_cursor(self): client = await self.async_rs_or_single_client() await connected(client) @@ -102,6 +103,7 @@ async def task(): self.assertTrue(cursor._killed) + @async_client_context.require_change_streams async def test_async_cancellation_closes_change_stream(self): client = await self.async_rs_or_single_client() await connected(client) From cbc1160a39d2447cb4a7c4389b3ffa508279b1ea Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Thu, 6 Feb 2025 14:42:22 -0500 Subject: [PATCH 7/9] More require updates --- test/asynchronous/test_async_cancellation.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/asynchronous/test_async_cancellation.py b/test/asynchronous/test_async_cancellation.py index 845686bc6b..d27d908ca3 100644 --- a/test/asynchronous/test_async_cancellation.py +++ b/test/asynchronous/test_async_cancellation.py @@ -104,6 +104,7 @@ async def task(): self.assertTrue(cursor._killed) @async_client_context.require_change_streams + @async_client_context.require_failCommand_blockConnection async def test_async_cancellation_closes_change_stream(self): client = await self.async_rs_or_single_client() await connected(client) From fcaf8f96f5f5e6af6e8cb8785b744ca8861016ee Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Mon, 10 Feb 2025 08:35:53 -0500 Subject: [PATCH 8/9] Add comments to catching BaseException, speed up tests --- pymongo/asynchronous/change_stream.py | 1 + pymongo/asynchronous/client_session.py | 1 + pymongo/asynchronous/cursor.py | 1 + pymongo/asynchronous/pool.py | 8 +++- pymongo/periodic_executor.py | 2 + pymongo/synchronous/change_stream.py | 1 + pymongo/synchronous/client_session.py | 1 + pymongo/synchronous/cursor.py | 1 + pymongo/synchronous/pool.py | 8 +++- test/asynchronous/test_async_cancellation.py | 40 +++++++++----------- 10 files changed, 40 insertions(+), 24 deletions(-) diff --git a/pymongo/asynchronous/change_stream.py b/pymongo/asynchronous/change_stream.py index e36eea542c..f405e91161 100644 --- a/pymongo/asynchronous/change_stream.py +++ b/pymongo/asynchronous/change_stream.py @@ -391,6 +391,7 @@ async def try_next(self) -> Optional[_DocumentType]: if not _resumable(exc) and not exc.timeout: await self.close() raise + # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. except BaseException: await self.close() raise diff --git a/pymongo/asynchronous/client_session.py b/pymongo/asynchronous/client_session.py index 97b9b4ffd9..b635cf4648 100644 --- a/pymongo/asynchronous/client_session.py +++ b/pymongo/asynchronous/client_session.py @@ -697,6 +697,7 @@ async def callback(session, custom_arg, custom_kwarg=None): ) try: ret = await callback(self) + # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. except BaseException as exc: if self.in_transaction: await self.abort_transaction() diff --git a/pymongo/asynchronous/cursor.py b/pymongo/asynchronous/cursor.py index 58e3ee6e9b..1b25bf4ee8 100644 --- a/pymongo/asynchronous/cursor.py +++ b/pymongo/asynchronous/cursor.py @@ -1126,6 +1126,7 @@ async def _send_message(self, operation: Union[_Query, _GetMore]) -> None: self._killed = True await self.close() raise + # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. except BaseException: await self.close() raise diff --git a/pymongo/asynchronous/pool.py b/pymongo/asynchronous/pool.py index bf2f2b4946..39b3bfc042 100644 --- a/pymongo/asynchronous/pool.py +++ b/pymongo/asynchronous/pool.py @@ -559,7 +559,7 @@ async def command( ) except (OperationFailure, NotPrimaryError): raise - # Catch socket.error, KeyboardInterrupt, etc. and close ourselves. + # Catch socket.error, KeyboardInterrupt, CancelledError, etc. and close ourselves. except BaseException as error: self._raise_connection_failure(error) @@ -576,6 +576,7 @@ async def send_message(self, message: bytes, max_doc_size: int) -> None: try: await async_sendall(self.conn, message) + # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. except BaseException as error: self._raise_connection_failure(error) @@ -586,6 +587,7 @@ async def receive_message(self, request_id: Optional[int]) -> Union[_OpReply, _O """ try: return await receive_message(self, request_id, self.max_message_size) + # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. except BaseException as error: self._raise_connection_failure(error) @@ -1269,6 +1271,7 @@ async def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> A try: sock = await _configured_socket(self.address, self.opts) + # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. except BaseException as error: async with self.lock: self.active_contexts.discard(tmp_context) @@ -1308,6 +1311,7 @@ async def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> A handler.contribute_socket(conn, completed_handshake=False) await conn.authenticate() + # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. except BaseException: async with self.lock: self.active_contexts.discard(conn.cancel_context) @@ -1369,6 +1373,7 @@ async def checkout( async with self.lock: self.active_contexts.add(conn.cancel_context) yield conn + # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. except BaseException: # Exception in caller. Ensure the connection gets returned. # Note that when pinned is True, the session owns the @@ -1515,6 +1520,7 @@ async def _get_conn( async with self._max_connecting_cond: self._pending -= 1 self._max_connecting_cond.notify() + # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. except BaseException: if conn: # We checked out a socket but authentication failed. diff --git a/pymongo/periodic_executor.py b/pymongo/periodic_executor.py index 9b10f6e7e3..492f764d73 100644 --- a/pymongo/periodic_executor.py +++ b/pymongo/periodic_executor.py @@ -98,6 +98,7 @@ async def _run(self) -> None: if not await self._target(): self._stopped = True break + # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. except BaseException: self._stopped = True raise @@ -230,6 +231,7 @@ def _run(self) -> None: if not self._target(): self._stopped = True break + # Catch KeyboardInterrupt, etc. and cleanup. except BaseException: with self._lock: self._stopped = True diff --git a/pymongo/synchronous/change_stream.py b/pymongo/synchronous/change_stream.py index e8f42eb830..43aab39ee1 100644 --- a/pymongo/synchronous/change_stream.py +++ b/pymongo/synchronous/change_stream.py @@ -389,6 +389,7 @@ def try_next(self) -> Optional[_DocumentType]: if not _resumable(exc) and not exc.timeout: self.close() raise + # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. except BaseException: self.close() raise diff --git a/pymongo/synchronous/client_session.py b/pymongo/synchronous/client_session.py index d922bafa53..af7ff59b3d 100644 --- a/pymongo/synchronous/client_session.py +++ b/pymongo/synchronous/client_session.py @@ -694,6 +694,7 @@ def callback(session, custom_arg, custom_kwarg=None): self.start_transaction(read_concern, write_concern, read_preference, max_commit_time_ms) try: ret = callback(self) + # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. except BaseException as exc: if self.in_transaction: self.abort_transaction() diff --git a/pymongo/synchronous/cursor.py b/pymongo/synchronous/cursor.py index 712f0ebb28..31c4604f89 100644 --- a/pymongo/synchronous/cursor.py +++ b/pymongo/synchronous/cursor.py @@ -1124,6 +1124,7 @@ def _send_message(self, operation: Union[_Query, _GetMore]) -> None: self._killed = True self.close() raise + # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. except BaseException: self.close() raise diff --git a/pymongo/synchronous/pool.py b/pymongo/synchronous/pool.py index 05f930d480..7c55e04b22 100644 --- a/pymongo/synchronous/pool.py +++ b/pymongo/synchronous/pool.py @@ -559,7 +559,7 @@ def command( ) except (OperationFailure, NotPrimaryError): raise - # Catch socket.error, KeyboardInterrupt, etc. and close ourselves. + # Catch socket.error, KeyboardInterrupt, CancelledError, etc. and close ourselves. except BaseException as error: self._raise_connection_failure(error) @@ -576,6 +576,7 @@ def send_message(self, message: bytes, max_doc_size: int) -> None: try: sendall(self.conn, message) + # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. except BaseException as error: self._raise_connection_failure(error) @@ -586,6 +587,7 @@ def receive_message(self, request_id: Optional[int]) -> Union[_OpReply, _OpMsg]: """ try: return receive_message(self, request_id, self.max_message_size) + # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. except BaseException as error: self._raise_connection_failure(error) @@ -1263,6 +1265,7 @@ def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> Connect try: sock = _configured_socket(self.address, self.opts) + # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. except BaseException as error: with self.lock: self.active_contexts.discard(tmp_context) @@ -1302,6 +1305,7 @@ def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> Connect handler.contribute_socket(conn, completed_handshake=False) conn.authenticate() + # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. except BaseException: with self.lock: self.active_contexts.discard(conn.cancel_context) @@ -1363,6 +1367,7 @@ def checkout( with self.lock: self.active_contexts.add(conn.cancel_context) yield conn + # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. except BaseException: # Exception in caller. Ensure the connection gets returned. # Note that when pinned is True, the session owns the @@ -1509,6 +1514,7 @@ def _get_conn( with self._max_connecting_cond: self._pending -= 1 self._max_connecting_cond.notify() + # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. except BaseException: if conn: # We checked out a socket but authentication failed. diff --git a/test/asynchronous/test_async_cancellation.py b/test/asynchronous/test_async_cancellation.py index d27d908ca3..a3f31e8e94 100644 --- a/test/asynchronous/test_async_cancellation.py +++ b/test/asynchronous/test_async_cancellation.py @@ -26,15 +26,14 @@ class TestAsyncCancellation(AsyncIntegrationTest): async def test_async_cancellation_closes_connection(self): - client = await self.async_rs_or_single_client(maxPoolSize=1) - pool = await async_get_pool(client) - await connected(client) + pool = await async_get_pool(self.client) + await connected(self.client) conn = one(pool.conns) - await client.db.test.insert_one({"x": 1}) - self.addAsyncCleanup(client.db.test.drop) + await self.client.db.test.insert_one({"x": 1}) + self.addAsyncCleanup(self.client.db.test.drop) async def task(): - await client.db.test.find_one({"$where": delay(0.2)}) + await self.client.db.test.find_one({"$where": delay(0.2)}) task = asyncio.create_task(task()) @@ -48,15 +47,14 @@ async def task(): @async_client_context.require_transactions async def test_async_cancellation_aborts_transaction(self): - client = await self.async_rs_or_single_client() - await connected(client) - await client.db.test.insert_one({"x": 1}) - self.addAsyncCleanup(client.db.test.drop) + await connected(self.client) + await self.client.db.test.insert_one({"x": 1}) + self.addAsyncCleanup(self.client.db.test.drop) - session = client.start_session() + session = self.client.start_session() async def callback(session): - await client.db.test.find_one({"$where": delay(0.2)}, session=session) + await self.client.db.test.find_one({"$where": delay(0.2)}, session=session) async def task(): await session.with_transaction(callback) @@ -73,13 +71,12 @@ async def task(): @async_client_context.require_failCommand_blockConnection async def test_async_cancellation_closes_cursor(self): - client = await self.async_rs_or_single_client() - await connected(client) + await connected(self.client) for _ in range(2): - await client.db.test.insert_one({"x": 1}) - self.addAsyncCleanup(client.db.test.drop) + await self.client.db.test.insert_one({"x": 1}) + self.addAsyncCleanup(self.client.db.test.drop) - cursor = client.db.test.find({}, batch_size=1) + cursor = self.client.db.test.find({}, batch_size=1) await cursor.next() # Make sure getMore commands block @@ -106,11 +103,10 @@ async def task(): @async_client_context.require_change_streams @async_client_context.require_failCommand_blockConnection async def test_async_cancellation_closes_change_stream(self): - client = await self.async_rs_or_single_client() - await connected(client) - self.addAsyncCleanup(client.db.test.drop) + await connected(self.client) + self.addAsyncCleanup(self.client.db.test.drop) - change_stream = await client.db.test.watch(batch_size=2) + change_stream = await self.client.db.test.watch(batch_size=2) # Make sure getMore commands block fail_command = { @@ -122,7 +118,7 @@ async def test_async_cancellation_closes_change_stream(self): async def task(): async with self.fail_point(fail_command): for _ in range(2): - await client.db.test.insert_one({"x": 1}) + await self.client.db.test.insert_one({"x": 1}) await change_stream.next() task = asyncio.create_task(task()) From 4fd6b13dcfcafdbfc64e035272afdb13b1cdcb9f Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Mon, 10 Feb 2025 15:33:38 -0500 Subject: [PATCH 9/9] Performance improvements --- test/asynchronous/test_async_cancellation.py | 22 +++++++------------- 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/test/asynchronous/test_async_cancellation.py b/test/asynchronous/test_async_cancellation.py index a3f31e8e94..b73c7a8084 100644 --- a/test/asynchronous/test_async_cancellation.py +++ b/test/asynchronous/test_async_cancellation.py @@ -27,10 +27,10 @@ class TestAsyncCancellation(AsyncIntegrationTest): async def test_async_cancellation_closes_connection(self): pool = await async_get_pool(self.client) - await connected(self.client) - conn = one(pool.conns) await self.client.db.test.insert_one({"x": 1}) - self.addAsyncCleanup(self.client.db.test.drop) + self.addAsyncCleanup(self.client.db.test.delete_many, {}) + + conn = one(pool.conns) async def task(): await self.client.db.test.find_one({"$where": delay(0.2)}) @@ -47,9 +47,8 @@ async def task(): @async_client_context.require_transactions async def test_async_cancellation_aborts_transaction(self): - await connected(self.client) await self.client.db.test.insert_one({"x": 1}) - self.addAsyncCleanup(self.client.db.test.drop) + self.addAsyncCleanup(self.client.db.test.delete_many, {}) session = self.client.start_session() @@ -71,10 +70,8 @@ async def task(): @async_client_context.require_failCommand_blockConnection async def test_async_cancellation_closes_cursor(self): - await connected(self.client) - for _ in range(2): - await self.client.db.test.insert_one({"x": 1}) - self.addAsyncCleanup(self.client.db.test.drop) + await self.client.db.test.insert_many([{"x": 1}, {"x": 2}]) + self.addAsyncCleanup(self.client.db.test.delete_many, {}) cursor = self.client.db.test.find({}, batch_size=1) await cursor.next() @@ -103,9 +100,7 @@ async def task(): @async_client_context.require_change_streams @async_client_context.require_failCommand_blockConnection async def test_async_cancellation_closes_change_stream(self): - await connected(self.client) - self.addAsyncCleanup(self.client.db.test.drop) - + self.addAsyncCleanup(self.client.db.test.delete_many, {}) change_stream = await self.client.db.test.watch(batch_size=2) # Make sure getMore commands block @@ -117,8 +112,7 @@ async def test_async_cancellation_closes_change_stream(self): async def task(): async with self.fail_point(fail_command): - for _ in range(2): - await self.client.db.test.insert_one({"x": 1}) + await self.client.db.test.insert_many([{"x": 1}, {"x": 2}]) await change_stream.next() task = asyncio.create_task(task())