Skip to content

Commit 48ed790

Browse files
committed
Cancel remaining iterator items on exceptions
1 parent b2f8fe3 commit 48ed790

File tree

2 files changed

+63
-7
lines changed

2 files changed

+63
-7
lines changed

src/graphql/execution/execute.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -911,13 +911,22 @@ async def complete_async_iterator_value(
911911
index = awaitable_indices[0]
912912
completed_results[index] = await completed_results[index]
913913
else:
914-
for index, result in zip(
915-
awaitable_indices,
916-
await gather(
917-
*(completed_results[index] for index in awaitable_indices)
918-
),
919-
):
920-
completed_results[index] = result
914+
tasks = [
915+
create_task(completed_results[index]) for index in awaitable_indices
916+
]
917+
918+
try:
919+
awaited_results = await gather(*tasks)
920+
except Exception:
921+
# Cancel unfinished tasks before raising the exception
922+
for task in tasks:
923+
if not task.done():
924+
task.cancel()
925+
await gather(*tasks, return_exceptions=True)
926+
raise
927+
928+
for index, sub_result in zip(awaitable_indices, awaited_results):
929+
completed_results[index] = sub_result
921930
return completed_results
922931

923932
def complete_list_value(

tests/execution/test_parallel.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,3 +281,50 @@ async def resolve_list(*args):
281281
await barrier.wait()
282282
await asyncio.sleep(0)
283283
assert not completed
284+
285+
@pytest.mark.asyncio
286+
async def cancel_async_iterator_on_exception():
287+
barrier = Barrier(2)
288+
completed = False
289+
290+
async def succeed(*_args):
291+
nonlocal completed
292+
await barrier.wait()
293+
completed = True # pragma: no cover
294+
295+
async def fail(*_args):
296+
raise RuntimeError("Oops")
297+
298+
async def resolve_iterator(*args):
299+
yield fail(*args)
300+
yield succeed(*args)
301+
302+
schema = GraphQLSchema(
303+
GraphQLObjectType(
304+
"Query",
305+
{
306+
"foo": GraphQLField(
307+
GraphQLList(GraphQLNonNull(GraphQLBoolean)),
308+
resolve=resolve_iterator,
309+
)
310+
},
311+
)
312+
)
313+
314+
ast = parse("{foo}")
315+
316+
awaitable_result = execute(schema, ast)
317+
assert isinstance(awaitable_result, Awaitable)
318+
result = await asyncio.wait_for(awaitable_result, 1)
319+
320+
assert result == (
321+
{"foo": None},
322+
[{"message": "Oops", "locations": [(1, 2)], "path": ["foo", 0]}],
323+
)
324+
325+
assert not completed
326+
327+
# Unblock succeed() and check that it does not complete
328+
await barrier.wait()
329+
await asyncio.sleep(0)
330+
assert not completed

0 commit comments

Comments
 (0)