Skip to content

Commit c6f73a8

Browse files
committed
Don't swallow CancelledError in MapAsyncIterator
As discussed in #131, better than conversion to StopAsyncIteration.
1 parent 62ddc6c commit c6f73a8

File tree

2 files changed

+33
-28
lines changed

2 files changed

+33
-28
lines changed

src/graphql/subscription/map_async_iterator.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from asyncio import Event, ensure_future, Future, wait, CancelledError
1+
from asyncio import CancelledError, Event, Future, ensure_future, wait
22
from concurrent.futures import FIRST_COMPLETED
33
from inspect import isasyncgen, isawaitable
44
from typing import cast, Any, AsyncIterable, Callable, Optional, Set, Type, Union
@@ -43,19 +43,16 @@ async def __anext__(self) -> Any:
4343
aclose = ensure_future(self._close_event.wait())
4444
anext = ensure_future(self.iterator.__anext__())
4545

46-
# Suppress the StopAsyncIteration exception warning when the
47-
# iterator is cancelled.
48-
anext.add_done_callback(lambda *args: anext.exception())
4946
try:
5047
pending: Set[Future] = (
5148
await wait([aclose, anext], return_when=FIRST_COMPLETED)
5249
)[1]
53-
except CancelledError as e:
54-
# The iterator is cancelled
50+
except CancelledError:
51+
# cancel underlying tasks and close
5552
aclose.cancel()
5653
anext.cancel()
57-
self.is_closed = True
58-
raise StopAsyncIteration from e
54+
await self.aclose()
55+
raise # re-raise the cancellation
5956

6057
for task in pending:
6158
task.cancel()

tests/subscription/test_map_async_iterator.py

Lines changed: 28 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import sys
2-
from asyncio import Event, ensure_future, CancelledError, sleep, Queue
2+
from asyncio import CancelledError, Event, ensure_future, sleep
33

44
from pytest import mark, raises
55

@@ -459,39 +459,47 @@ async def aclose(self):
459459
assert not iterator.is_closed
460460

461461
@mark.asyncio
462-
async def cancel_async_iterator_while_waiting():
462+
async def can_cancel_async_iterator_while_waiting():
463463
class Iterator:
464464
def __init__(self):
465-
self.queue: Queue[int] = Queue()
466-
self.queue.put_nowait(1) # suppress coverage warning
467-
self.cancelled = False
465+
self.is_closed = False
466+
self.value = 1
468467

469468
def __aiter__(self):
470469
return self
471470

472471
async def __anext__(self):
473472
try:
474-
return await self.queue.get()
475-
except BaseException:
476-
self.cancelled = True
473+
await sleep(0.5)
474+
return self.value # pragma: no cover
475+
except CancelledError:
476+
self.value = -1
477+
raise
478+
479+
async def aclose(self):
480+
self.is_closed = True
477481

478482
iterator = Iterator()
479-
doubles = MapAsyncIterator(iterator, lambda x: x + x)
483+
doubles = MapAsyncIterator(iterator, lambda x: x + x) # pragma: no cover exit
484+
cancelled = False
480485

481486
async def iterator_task():
487+
nonlocal cancelled
482488
try:
483-
async for double in doubles:
484-
pass
485-
# If cancellation is handled using StopAsyncIteration, it will reach
486-
# here.
487-
except CancelledError: # pragma: no cover
488-
# Otherwise it should reach here
489-
pass
489+
async for _ in doubles:
490+
assert False # pragma: no cover
491+
except CancelledError:
492+
cancelled = True
490493

491494
task = ensure_future(iterator_task())
492-
await sleep(0.1)
493-
await doubles.aclose()
495+
await sleep(0.05)
496+
assert not cancelled
497+
assert not doubles.is_closed
498+
assert iterator.value == 1
499+
assert not iterator.is_closed
494500
task.cancel()
495-
await sleep(0.1)
496-
assert iterator.cancelled
501+
await sleep(0.05)
502+
assert cancelled
503+
assert iterator.value == -1
497504
assert doubles.is_closed
505+
assert iterator.is_closed

0 commit comments

Comments
 (0)