Skip to content

Commit 62ddc6c

Browse files
authored
Handle case when MapAsyncIterator is cancelled (#131)
1 parent 99404d4 commit 62ddc6c

File tree

2 files changed

+54
-5
lines changed

2 files changed

+54
-5
lines changed

src/graphql/subscription/map_async_iterator.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from asyncio import Event, ensure_future, Future, wait
1+
from asyncio import Event, ensure_future, Future, wait, CancelledError
22
from concurrent.futures import FIRST_COMPLETED
33
from inspect import isasyncgen, isawaitable
44
from typing import cast, Any, AsyncIterable, Callable, Optional, Set, Type, Union
@@ -43,9 +43,20 @@ async def __anext__(self) -> Any:
4343
aclose = ensure_future(self._close_event.wait())
4444
anext = ensure_future(self.iterator.__anext__())
4545

46-
pending: Set[Future] = (
47-
await wait([aclose, anext], return_when=FIRST_COMPLETED)
48-
)[1]
46+
# Suppress the StopAsyncIteration exception warning when the
47+
# iterator is cancelled.
48+
anext.add_done_callback(lambda *args: anext.exception())
49+
try:
50+
pending: Set[Future] = (
51+
await wait([aclose, anext], return_when=FIRST_COMPLETED)
52+
)[1]
53+
except CancelledError as e:
54+
# The iterator is cancelled
55+
aclose.cancel()
56+
anext.cancel()
57+
self.is_closed = True
58+
raise StopAsyncIteration from e
59+
4960
for task in pending:
5061
task.cancel()
5162

tests/subscription/test_map_async_iterator.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import sys
2-
from asyncio import Event, ensure_future, sleep
2+
from asyncio import Event, ensure_future, CancelledError, sleep, Queue
33

44
from pytest import mark, raises
55

@@ -457,3 +457,41 @@ async def aclose(self):
457457
await anext(doubles)
458458
assert not doubles.is_closed
459459
assert not iterator.is_closed
460+
461+
@mark.asyncio
462+
async def cancel_async_iterator_while_waiting():
463+
class Iterator:
464+
def __init__(self):
465+
self.queue: Queue[int] = Queue()
466+
self.queue.put_nowait(1) # suppress coverage warning
467+
self.cancelled = False
468+
469+
def __aiter__(self):
470+
return self
471+
472+
async def __anext__(self):
473+
try:
474+
return await self.queue.get()
475+
except BaseException:
476+
self.cancelled = True
477+
478+
iterator = Iterator()
479+
doubles = MapAsyncIterator(iterator, lambda x: x + x)
480+
481+
async def iterator_task():
482+
try:
483+
async for double in doubles:
484+
pass
485+
# If cancellation is handled using StopAsyncIteration, it will reach
486+
# here.
487+
except CancelledError: # pragma: no cover
488+
# Otherwise it should reach here
489+
pass
490+
491+
task = ensure_future(iterator_task())
492+
await sleep(0.1)
493+
await doubles.aclose()
494+
task.cancel()
495+
await sleep(0.1)
496+
assert iterator.cancelled
497+
assert doubles.is_closed

0 commit comments

Comments
 (0)