Skip to content

Commit b2f8fe3

Browse files
committed
Cancel remaining list items on exceptions
1 parent 57c083d commit b2f8fe3

File tree

2 files changed

+71
-14
lines changed

2 files changed

+71
-14
lines changed

src/graphql/execution/execute.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -466,10 +466,10 @@ async def get_results() -> dict[str, Any]:
466466
field = awaitable_fields[0]
467467
results[field] = await results[field]
468468
else:
469-
tasks = {
470-
create_task(results[field]): field # type: ignore[arg-type]
469+
tasks = [
470+
create_task(results[field]) # type: ignore[arg-type]
471471
for field in awaitable_fields
472-
}
472+
]
473473

474474
try:
475475
awaited_results = await gather(*tasks)
@@ -1014,12 +1014,21 @@ async def get_completed_results() -> list[Any]:
10141014
index = awaitable_indices[0]
10151015
completed_results[index] = await completed_results[index]
10161016
else:
1017-
for index, sub_result in zip(
1018-
awaitable_indices,
1019-
await gather(
1020-
*(completed_results[index] for index in awaitable_indices)
1021-
),
1022-
):
1017+
tasks = [
1018+
create_task(completed_results[index]) for index in awaitable_indices
1019+
]
1020+
1021+
try:
1022+
awaited_results = await gather(*tasks)
1023+
except Exception:
1024+
# Cancel unfinished tasks before raising the exception
1025+
for task in tasks:
1026+
if not task.done():
1027+
task.cancel()
1028+
await gather(*tasks, return_exceptions=True)
1029+
raise
1030+
1031+
for index, sub_result in zip(awaitable_indices, awaited_results):
10231032
completed_results[index] = sub_result
10241033
return completed_results
10251034

tests/execution/test_parallel.py

Lines changed: 53 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ async def resolve(*_args):
7676
# raises TimeoutError if not parallel
7777
awaitable_result = execute(schema, ast)
7878
assert isinstance(awaitable_result, Awaitable)
79-
result = await asyncio.wait_for(awaitable_result, 1.0)
79+
result = await asyncio.wait_for(awaitable_result, 1)
8080

8181
assert result == ({"foo": True, "bar": True}, None)
8282

@@ -125,7 +125,7 @@ async def resolve_list(*args):
125125
# raises TimeoutError if not parallel
126126
awaitable_result = execute(schema, ast)
127127
assert isinstance(awaitable_result, Awaitable)
128-
result = await asyncio.wait_for(awaitable_result, 1.0)
128+
result = await asyncio.wait_for(awaitable_result, 1)
129129

130130
assert result == ({"foo": [True, True]}, None)
131131

@@ -188,15 +188,15 @@ async def is_type_of_baz(obj, *_args):
188188
# raises TimeoutError if not parallel
189189
awaitable_result = execute(schema, ast)
190190
assert isinstance(awaitable_result, Awaitable)
191-
result = await asyncio.wait_for(awaitable_result, 1.0)
191+
result = await asyncio.wait_for(awaitable_result, 1)
192192

193193
assert result == (
194194
{"foo": [{"foo": "bar", "foobar": 1}, {"foo": "baz", "foobaz": 2}]},
195195
None,
196196
)
197197

198198
@pytest.mark.asyncio
199-
async def cancel_on_exception():
199+
async def cancel_selection_sets_on_exception():
200200
barrier = Barrier(2)
201201
completed = False
202202

@@ -222,13 +222,61 @@ async def fail(*_args):
222222

223223
awaitable_result = execute(schema, ast)
224224
assert isinstance(awaitable_result, Awaitable)
225-
result = await asyncio.wait_for(awaitable_result, 1.0)
225+
result = await asyncio.wait_for(awaitable_result, 1)
226226

227227
assert result == (
228228
None,
229229
[{"message": "Oops", "locations": [(1, 2)], "path": ["foo"]}],
230230
)
231231

232+
assert not completed
233+
234+
# Unblock succeed() and check that it does not complete
235+
await barrier.wait()
236+
await asyncio.sleep(0)
237+
assert not completed
238+
239+
@pytest.mark.asyncio
240+
async def cancel_lists_on_exception():
241+
barrier = Barrier(2)
242+
completed = False
243+
244+
async def succeed(*_args):
245+
nonlocal completed
246+
await barrier.wait()
247+
completed = True # pragma: no cover
248+
249+
async def fail(*_args):
250+
raise RuntimeError("Oops")
251+
252+
async def resolve_list(*args):
253+
return [fail(*args), succeed(*args)]
254+
255+
schema = GraphQLSchema(
256+
GraphQLObjectType(
257+
"Query",
258+
{
259+
"foo": GraphQLField(
260+
GraphQLList(GraphQLNonNull(GraphQLBoolean)),
261+
resolve=resolve_list,
262+
)
263+
},
264+
)
265+
)
266+
267+
ast = parse("{foo}")
268+
269+
awaitable_result = execute(schema, ast)
270+
assert isinstance(awaitable_result, Awaitable)
271+
result = await asyncio.wait_for(awaitable_result, 1)
272+
273+
assert result == (
274+
{"foo": None},
275+
[{"message": "Oops", "locations": [(1, 2)], "path": ["foo", 0]}],
276+
)
277+
278+
assert not completed
279+
232280
# Unblock succeed() and check that it does not complete
233281
await barrier.wait()
234282
await asyncio.sleep(0)

0 commit comments

Comments
 (0)