Skip to content

Commit f54d8bd

Browse files
committed
test_map_async_iterator: improve coverage
Replicates graphql/graphql-js@b418ad6
1 parent 167e269 commit f54d8bd

File tree

1 file changed

+131
-13
lines changed

1 file changed

+131
-13
lines changed

tests/subscription/test_map_async_iterator.py

Lines changed: 131 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,40 @@ async def anext(iterable):
1313

1414
def describe_map_async_iterator():
1515
@mark.asyncio
16-
async def maps_over_async_values():
16+
async def maps_over_async_generator():
17+
async def source():
18+
yield 1
19+
yield 2
20+
yield 3
21+
22+
doubles = MapAsyncIterator(source(), lambda x: x + x)
23+
24+
assert await anext(doubles) == 2
25+
assert await anext(doubles) == 4
26+
assert await anext(doubles) == 6
27+
with raises(StopAsyncIteration):
28+
assert await anext(doubles)
29+
30+
@mark.asyncio
31+
async def maps_over_async_iterator():
32+
items = [1, 2, 3]
33+
34+
class Iterator:
35+
def __aiter__(self):
36+
return self
37+
38+
async def __anext__(self):
39+
try:
40+
return items.pop(0)
41+
except IndexError:
42+
raise StopAsyncIteration
43+
44+
doubles = MapAsyncIterator(Iterator(), lambda x: x + x)
45+
46+
assert [value async for value in doubles] == [2, 4, 6]
47+
48+
@mark.asyncio
49+
async def compatible_with_async_for():
1750
async def source():
1851
yield 1
1952
yield 2
@@ -38,11 +71,11 @@ async def double(x):
3871
assert [value async for value in doubles] == [2, 4, 6]
3972

4073
@mark.asyncio
41-
async def allows_returning_early_from_async_values():
74+
async def allows_returning_early_from_mapped_async_generator():
4275
async def source():
4376
yield 1
4477
yield 2
45-
yield 3
78+
yield 3 # pragma: no cover
4679

4780
doubles = MapAsyncIterator(source(), lambda x: x + x)
4881

@@ -58,13 +91,41 @@ async def source():
5891
with raises(StopAsyncIteration):
5992
await anext(doubles)
6093

94+
@mark.asyncio
95+
async def allows_returning_early_from_mapped_async_iterator():
96+
items = [1, 2, 3]
97+
98+
class Iterator:
99+
def __aiter__(self):
100+
return self
101+
102+
async def __anext__(self):
103+
try:
104+
return items.pop(0)
105+
except IndexError: # pragma: no cover
106+
raise StopAsyncIteration
107+
108+
doubles = MapAsyncIterator(Iterator(), lambda x: x + x)
109+
110+
assert await anext(doubles) == 2
111+
assert await anext(doubles) == 4
112+
113+
# Early return
114+
await doubles.aclose()
115+
116+
# Subsequent next calls
117+
with raises(StopAsyncIteration):
118+
await anext(doubles)
119+
with raises(StopAsyncIteration):
120+
await anext(doubles)
121+
61122
@mark.asyncio
62123
async def passes_through_early_return_from_async_values():
63124
async def source():
64125
try:
65126
yield 1
66127
yield 2
67-
yield 3
128+
yield 3 # pragma: no cover
68129
finally:
69130
yield "Done"
70131
yield "Last"
@@ -83,13 +144,20 @@ async def source():
83144
assert await anext(doubles)
84145

85146
@mark.asyncio
86-
async def allows_throwing_errors_through_async_generators():
87-
async def source():
88-
yield 1
89-
yield 2
90-
yield 3
147+
async def allows_throwing_errors_through_async_iterators():
148+
items = [1, 2, 3]
91149

92-
doubles = MapAsyncIterator(source(), lambda x: x + x)
150+
class Iterator:
151+
def __aiter__(self):
152+
return self
153+
154+
async def __anext__(self):
155+
try:
156+
return items.pop(0)
157+
except IndexError: # pragma: no cover
158+
raise StopAsyncIteration
159+
160+
doubles = MapAsyncIterator(Iterator(), lambda x: x + x)
93161

94162
assert await anext(doubles) == 2
95163
assert await anext(doubles) == 4
@@ -111,7 +179,7 @@ async def source():
111179
try:
112180
yield 1
113181
yield 2
114-
yield 3
182+
yield 3 # pragma: no cover
115183
except Exception as e:
116184
yield e
117185

@@ -249,8 +317,8 @@ async def stops_async_iteration_on_close():
249317
async def source():
250318
yield 1
251319
await Event().wait() # Block forever
252-
yield 2
253-
yield 3
320+
yield 2 # pragma: no cover
321+
yield 3 # pragma: no cover
254322

255323
singles = source()
256324
doubles = MapAsyncIterator(singles, lambda x: x * 2)
@@ -271,3 +339,53 @@ async def source():
271339

272340
with raises(StopAsyncIteration):
273341
await anext(singles)
342+
343+
@mark.asyncio
344+
async def can_unset_closed_state_of_async_iterator():
345+
items = [1, 2, 3]
346+
347+
class Iterator:
348+
def __init__(self):
349+
self.is_closed = False
350+
351+
def __aiter__(self):
352+
return self
353+
354+
async def __anext__(self):
355+
if self.is_closed:
356+
raise StopAsyncIteration
357+
try:
358+
return items.pop(0)
359+
except IndexError:
360+
raise StopAsyncIteration
361+
362+
async def aclose(self):
363+
self.is_closed = True
364+
365+
iterator = Iterator()
366+
doubles = MapAsyncIterator(iterator, lambda x: x + x)
367+
368+
assert await anext(doubles) == 2
369+
assert await anext(doubles) == 4
370+
assert not iterator.is_closed
371+
await doubles.aclose()
372+
assert iterator.is_closed
373+
with raises(StopAsyncIteration):
374+
await anext(iterator)
375+
with raises(StopAsyncIteration):
376+
await anext(doubles)
377+
assert doubles.is_closed
378+
379+
iterator.is_closed = False
380+
doubles.is_closed = False
381+
assert not doubles.is_closed
382+
383+
assert await anext(doubles) == 6
384+
assert not doubles.is_closed
385+
assert not iterator.is_closed
386+
with raises(StopAsyncIteration):
387+
await anext(iterator)
388+
with raises(StopAsyncIteration):
389+
await anext(doubles)
390+
assert not doubles.is_closed
391+
assert not iterator.is_closed

0 commit comments

Comments
 (0)