Skip to content

PYTHON-4745 - Test behavior of async task cancellation #2136

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Feb 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pymongo/asynchronous/change_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,8 @@ async def try_next(self) -> Optional[_DocumentType]:
if not _resumable(exc) and not exc.timeout:
await self.close()
raise
except Exception:
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
except BaseException:
Copy link
Member

@ShaneHarvey ShaneHarvey Feb 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Everywhere we catch BaseException, can you add a one line comment to explain it's intentional? Something like:

# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.

await self.close()
raise

Expand Down
3 changes: 2 additions & 1 deletion pymongo/asynchronous/client_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,7 +697,8 @@ async def callback(session, custom_arg, custom_kwarg=None):
)
try:
ret = await callback(self)
except Exception as exc:
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
except BaseException as exc:
if self.in_transaction:
await self.abort_transaction()
if (
Expand Down
3 changes: 2 additions & 1 deletion pymongo/asynchronous/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1126,7 +1126,8 @@ async def _send_message(self, operation: Union[_Query, _GetMore]) -> None:
self._killed = True
await self.close()
raise
except Exception:
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
except BaseException:
await self.close()
raise

Expand Down
8 changes: 7 additions & 1 deletion pymongo/asynchronous/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,7 +559,7 @@ async def command(
)
except (OperationFailure, NotPrimaryError):
raise
# Catch socket.error, KeyboardInterrupt, etc. and close ourselves.
# Catch socket.error, KeyboardInterrupt, CancelledError, etc. and close ourselves.
except BaseException as error:
self._raise_connection_failure(error)

Expand All @@ -576,6 +576,7 @@ async def send_message(self, message: bytes, max_doc_size: int) -> None:

try:
await async_sendall(self.conn, message)
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
except BaseException as error:
self._raise_connection_failure(error)

Expand All @@ -586,6 +587,7 @@ async def receive_message(self, request_id: Optional[int]) -> Union[_OpReply, _O
"""
try:
return await receive_message(self, request_id, self.max_message_size)
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
except BaseException as error:
self._raise_connection_failure(error)

Expand Down Expand Up @@ -1269,6 +1271,7 @@ async def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> A

try:
sock = await _configured_socket(self.address, self.opts)
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
except BaseException as error:
async with self.lock:
self.active_contexts.discard(tmp_context)
Expand Down Expand Up @@ -1308,6 +1311,7 @@ async def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> A
handler.contribute_socket(conn, completed_handshake=False)

await conn.authenticate()
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
except BaseException:
async with self.lock:
self.active_contexts.discard(conn.cancel_context)
Expand Down Expand Up @@ -1369,6 +1373,7 @@ async def checkout(
async with self.lock:
self.active_contexts.add(conn.cancel_context)
yield conn
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
except BaseException:
# Exception in caller. Ensure the connection gets returned.
# Note that when pinned is True, the session owns the
Expand Down Expand Up @@ -1515,6 +1520,7 @@ async def _get_conn(
async with self._max_connecting_cond:
self._pending -= 1
self._max_connecting_cond.notify()
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
except BaseException:
if conn:
# We checked out a socket but authentication failed.
Expand Down
2 changes: 2 additions & 0 deletions pymongo/periodic_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ async def _run(self) -> None:
if not await self._target():
self._stopped = True
break
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
except BaseException:
self._stopped = True
raise
Expand Down Expand Up @@ -230,6 +231,7 @@ def _run(self) -> None:
if not self._target():
self._stopped = True
break
# Catch KeyboardInterrupt, etc. and cleanup.
except BaseException:
with self._lock:
self._stopped = True
Expand Down
3 changes: 2 additions & 1 deletion pymongo/synchronous/change_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,8 @@ def try_next(self) -> Optional[_DocumentType]:
if not _resumable(exc) and not exc.timeout:
self.close()
raise
except Exception:
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
except BaseException:
self.close()
raise

Expand Down
3 changes: 2 additions & 1 deletion pymongo/synchronous/client_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,7 +694,8 @@ def callback(session, custom_arg, custom_kwarg=None):
self.start_transaction(read_concern, write_concern, read_preference, max_commit_time_ms)
try:
ret = callback(self)
except Exception as exc:
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
except BaseException as exc:
if self.in_transaction:
self.abort_transaction()
if (
Expand Down
3 changes: 2 additions & 1 deletion pymongo/synchronous/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1124,7 +1124,8 @@ def _send_message(self, operation: Union[_Query, _GetMore]) -> None:
self._killed = True
self.close()
raise
except Exception:
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
except BaseException:
self.close()
raise

Expand Down
8 changes: 7 additions & 1 deletion pymongo/synchronous/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,7 +559,7 @@ def command(
)
except (OperationFailure, NotPrimaryError):
raise
# Catch socket.error, KeyboardInterrupt, etc. and close ourselves.
# Catch socket.error, KeyboardInterrupt, CancelledError, etc. and close ourselves.
except BaseException as error:
self._raise_connection_failure(error)

Expand All @@ -576,6 +576,7 @@ def send_message(self, message: bytes, max_doc_size: int) -> None:

try:
sendall(self.conn, message)
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
except BaseException as error:
self._raise_connection_failure(error)

Expand All @@ -586,6 +587,7 @@ def receive_message(self, request_id: Optional[int]) -> Union[_OpReply, _OpMsg]:
"""
try:
return receive_message(self, request_id, self.max_message_size)
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
except BaseException as error:
self._raise_connection_failure(error)

Expand Down Expand Up @@ -1263,6 +1265,7 @@ def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> Connect

try:
sock = _configured_socket(self.address, self.opts)
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
except BaseException as error:
with self.lock:
self.active_contexts.discard(tmp_context)
Expand Down Expand Up @@ -1302,6 +1305,7 @@ def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> Connect
handler.contribute_socket(conn, completed_handshake=False)

conn.authenticate()
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
except BaseException:
with self.lock:
self.active_contexts.discard(conn.cancel_context)
Expand Down Expand Up @@ -1363,6 +1367,7 @@ def checkout(
with self.lock:
self.active_contexts.add(conn.cancel_context)
yield conn
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
except BaseException:
# Exception in caller. Ensure the connection gets returned.
# Note that when pinned is True, the session owns the
Expand Down Expand Up @@ -1509,6 +1514,7 @@ def _get_conn(
with self._max_connecting_cond:
self._pending -= 1
self._max_connecting_cond.notify()
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
except BaseException:
if conn:
# We checked out a socket but authentication failed.
Expand Down
126 changes: 126 additions & 0 deletions test/asynchronous/test_async_cancellation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# Copyright 2025-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Test that async cancellation performed by users clean up resources correctly."""
from __future__ import annotations

import asyncio
import sys
from test.utils import async_get_pool, delay, one

sys.path[0:0] = [""]

from test.asynchronous import AsyncIntegrationTest, async_client_context, connected


class TestAsyncCancellation(AsyncIntegrationTest):
async def test_async_cancellation_closes_connection(self):
pool = await async_get_pool(self.client)
await self.client.db.test.insert_one({"x": 1})
self.addAsyncCleanup(self.client.db.test.delete_many, {})

conn = one(pool.conns)

async def task():
await self.client.db.test.find_one({"$where": delay(0.2)})

task = asyncio.create_task(task())

await asyncio.sleep(0.1)

task.cancel()
with self.assertRaises(asyncio.CancelledError):
await task

self.assertTrue(conn.closed)

@async_client_context.require_transactions
async def test_async_cancellation_aborts_transaction(self):
await self.client.db.test.insert_one({"x": 1})
self.addAsyncCleanup(self.client.db.test.delete_many, {})

session = self.client.start_session()

async def callback(session):
await self.client.db.test.find_one({"$where": delay(0.2)}, session=session)

async def task():
await session.with_transaction(callback)

task = asyncio.create_task(task())

await asyncio.sleep(0.1)

task.cancel()
with self.assertRaises(asyncio.CancelledError):
await task

self.assertFalse(session.in_transaction)

@async_client_context.require_failCommand_blockConnection
async def test_async_cancellation_closes_cursor(self):
await self.client.db.test.insert_many([{"x": 1}, {"x": 2}])
self.addAsyncCleanup(self.client.db.test.delete_many, {})

cursor = self.client.db.test.find({}, batch_size=1)
await cursor.next()

# Make sure getMore commands block
fail_command = {
"configureFailPoint": "failCommand",
"mode": "alwaysOn",
"data": {"failCommands": ["getMore"], "blockConnection": True, "blockTimeMS": 200},
}

async def task():
async with self.fail_point(fail_command):
await cursor.next()

task = asyncio.create_task(task())

await asyncio.sleep(0.1)

task.cancel()
with self.assertRaises(asyncio.CancelledError):
await task

self.assertTrue(cursor._killed)

@async_client_context.require_change_streams
@async_client_context.require_failCommand_blockConnection
async def test_async_cancellation_closes_change_stream(self):
self.addAsyncCleanup(self.client.db.test.delete_many, {})
change_stream = await self.client.db.test.watch(batch_size=2)

# Make sure getMore commands block
fail_command = {
"configureFailPoint": "failCommand",
"mode": "alwaysOn",
"data": {"failCommands": ["getMore"], "blockConnection": True, "blockTimeMS": 200},
}

async def task():
async with self.fail_point(fail_command):
await self.client.db.test.insert_many([{"x": 1}, {"x": 2}])
await change_stream.next()

task = asyncio.create_task(task())

await asyncio.sleep(0.1)

task.cancel()
with self.assertRaises(asyncio.CancelledError):
await task

self.assertTrue(change_stream._closed)
2 changes: 1 addition & 1 deletion tools/synchro.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@

def async_only_test(f: str) -> bool:
"""Return True for async tests that should not be converted to sync."""
return f in ["test_locks.py", "test_concurrency.py"]
return f in ["test_locks.py", "test_concurrency.py", "test_async_cancellation.py"]


test_files = [
Expand Down
Loading