diff --git a/src/graphql/subscription/map_async_iterator.py b/src/graphql/subscription/map_async_iterator.py index 35fd42a8..39d6d321 100644 --- a/src/graphql/subscription/map_async_iterator.py +++ b/src/graphql/subscription/map_async_iterator.py @@ -1,4 +1,4 @@ -from asyncio import Event, ensure_future, Future, wait +from asyncio import Event, ensure_future, Future, wait, CancelledError from concurrent.futures import FIRST_COMPLETED from inspect import isasyncgen, isawaitable from typing import cast, Any, AsyncIterable, Callable, Optional, Set, Type, Union @@ -43,9 +43,20 @@ async def __anext__(self) -> Any: aclose = ensure_future(self._close_event.wait()) anext = ensure_future(self.iterator.__anext__()) - pending: Set[Future] = ( - await wait([aclose, anext], return_when=FIRST_COMPLETED) - )[1] + # Suppress the StopAsyncIteration exception warning when the + # iterator is cancelled. + anext.add_done_callback(lambda *args: anext.exception()) + try: + pending: Set[Future] = ( + await wait([aclose, anext], return_when=FIRST_COMPLETED) + )[1] + except CancelledError as e: + # The iterator is cancelled + aclose.cancel() + anext.cancel() + self.is_closed = True + raise StopAsyncIteration from e + for task in pending: task.cancel() diff --git a/tests/subscription/test_map_async_iterator.py b/tests/subscription/test_map_async_iterator.py index 4df371af..c9e3fca0 100644 --- a/tests/subscription/test_map_async_iterator.py +++ b/tests/subscription/test_map_async_iterator.py @@ -1,5 +1,5 @@ import sys -from asyncio import Event, ensure_future, sleep +from asyncio import Event, ensure_future, CancelledError, sleep, Queue from pytest import mark, raises # type: ignore @@ -457,3 +457,41 @@ async def aclose(self): await anext(doubles) assert not doubles.is_closed assert not iterator.is_closed + + @mark.asyncio + async def cancel_async_iterator_while_waiting(): + class Iterator: + def __init__(self): + self.queue: Queue[int] = Queue() + self.queue.put_nowait(1) # suppress coverage warning + self.cancelled = False + + def __aiter__(self): + return self + + async def __anext__(self): + try: + return await self.queue.get() + except BaseException: + self.cancelled = True + + iterator = Iterator() + doubles = MapAsyncIterator(iterator, lambda x: x + x) + + async def iterator_task(): + try: + async for double in doubles: + pass + # If cancellation is handled using StopAsyncIteration, it will reach + # here. + except CancelledError: # pragma: no cover + # Otherwise it should reach here + pass + + task = ensure_future(iterator_task()) + await sleep(0.1) + await doubles.aclose() + task.cancel() + await sleep(0.1) + assert iterator.cancelled + assert doubles.is_closed