diff --git a/pymongo/network_layer.py b/pymongo/network_layer.py index c1db31f89c..11c66bf16e 100644 --- a/pymongo/network_layer.py +++ b/pymongo/network_layer.py @@ -267,18 +267,25 @@ async def async_receive_data( else: read_task = create_task(_async_receive(sock, length, loop)) # type: ignore[arg-type] tasks = [read_task, cancellation_task] - done, pending = await asyncio.wait( - tasks, timeout=timeout, return_when=asyncio.FIRST_COMPLETED - ) - for task in pending: - task.cancel() - if pending: - await asyncio.wait(pending) - if len(done) == 0: - raise socket.timeout("timed out") - if read_task in done: - return read_task.result() - raise _OperationCancelled("operation cancelled") + try: + done, pending = await asyncio.wait( + tasks, timeout=timeout, return_when=asyncio.FIRST_COMPLETED + ) + for task in pending: + task.cancel() + if pending: + await asyncio.wait(pending) + if len(done) == 0: + raise socket.timeout("timed out") + if read_task in done: + return read_task.result() + raise _OperationCancelled("operation cancelled") + except asyncio.CancelledError: + for task in tasks: + task.cancel() + await asyncio.wait(tasks) + raise + finally: sock.settimeout(sock_timeout) diff --git a/pymongo/periodic_executor.py b/pymongo/periodic_executor.py index 2f89b91deb..9b10f6e7e3 100644 --- a/pymongo/periodic_executor.py +++ b/pymongo/periodic_executor.py @@ -78,14 +78,7 @@ def close(self, dummy: Any = None) -> None: async def join(self, timeout: Optional[int] = None) -> None: if self._task is not None: - try: - await asyncio.wait_for(self._task, timeout=timeout) # type-ignore: [arg-type] - except asyncio.TimeoutError: - # Task timed out - pass - except asyncio.exceptions.CancelledError: - # Task was already finished, or not yet started. - raise + await asyncio.wait([self._task], timeout=timeout) # type-ignore: [arg-type] def wake(self) -> None: """Execute the target function soon.""" diff --git a/test/asynchronous/test_connections_survive_primary_stepdown_spec.py b/test/asynchronous/test_connections_survive_primary_stepdown_spec.py index 4795d3937a..7c11742a90 100644 --- a/test/asynchronous/test_connections_survive_primary_stepdown_spec.py +++ b/test/asynchronous/test_connections_survive_primary_stepdown_spec.py @@ -22,7 +22,6 @@ from test.asynchronous import ( AsyncIntegrationTest, async_client_context, - reset_client_context, unittest, ) from test.asynchronous.helpers import async_repl_set_step_down diff --git a/test/test_connections_survive_primary_stepdown_spec.py b/test/test_connections_survive_primary_stepdown_spec.py index 1fb08cbed5..9cac633301 100644 --- a/test/test_connections_survive_primary_stepdown_spec.py +++ b/test/test_connections_survive_primary_stepdown_spec.py @@ -22,7 +22,6 @@ from test import ( IntegrationTest, client_context, - reset_client_context, unittest, ) from test.helpers import repl_set_step_down