Skip to content

Commit 55cf6ec

Browse files
committed
Extract task cancellation as utility function
1 parent e48d160 commit 55cf6ec

File tree

4 files changed

+182
-60
lines changed

4 files changed

+182
-60
lines changed

src/graphql/execution/execute.py

Lines changed: 11 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,7 @@
44

55
from asyncio import (
66
CancelledError,
7-
create_task,
87
ensure_future,
9-
gather,
108
shield,
119
wait_for,
1210
)
@@ -52,6 +50,7 @@
5250
RefMap,
5351
Undefined,
5452
async_reduce,
53+
gather_with_cancel,
5554
inspect,
5655
is_iterable,
5756
)
@@ -466,21 +465,9 @@ async def get_results() -> dict[str, Any]:
466465
field = awaitable_fields[0]
467466
results[field] = await results[field]
468467
else:
469-
tasks = [
470-
create_task(results[field]) # type: ignore[arg-type]
471-
for field in awaitable_fields
472-
]
473-
474-
try:
475-
awaited_results = await gather(*tasks)
476-
except Exception:
477-
# Cancel unfinished tasks before raising the exception
478-
for task in tasks:
479-
if not task.done():
480-
task.cancel()
481-
await gather(*tasks, return_exceptions=True)
482-
raise
483-
468+
awaited_results = await gather_with_cancel(
469+
*(results[field] for field in awaitable_fields)
470+
)
484471
results.update(zip(awaitable_fields, awaited_results))
485472

486473
return results
@@ -911,20 +898,9 @@ async def complete_async_iterator_value(
911898
index = awaitable_indices[0]
912899
completed_results[index] = await completed_results[index]
913900
else:
914-
tasks = [
915-
create_task(completed_results[index]) for index in awaitable_indices
916-
]
917-
918-
try:
919-
awaited_results = await gather(*tasks)
920-
except Exception:
921-
# Cancel unfinished tasks before raising the exception
922-
for task in tasks:
923-
if not task.done():
924-
task.cancel()
925-
await gather(*tasks, return_exceptions=True)
926-
raise
927-
901+
awaited_results = await gather_with_cancel(
902+
*(completed_results[index] for index in awaitable_indices)
903+
)
928904
for index, sub_result in zip(awaitable_indices, awaited_results):
929905
completed_results[index] = sub_result
930906
return completed_results
@@ -1023,20 +999,9 @@ async def get_completed_results() -> list[Any]:
1023999
index = awaitable_indices[0]
10241000
completed_results[index] = await completed_results[index]
10251001
else:
1026-
tasks = [
1027-
create_task(completed_results[index]) for index in awaitable_indices
1028-
]
1029-
1030-
try:
1031-
awaited_results = await gather(*tasks)
1032-
except Exception:
1033-
# Cancel unfinished tasks before raising the exception
1034-
for task in tasks:
1035-
if not task.done():
1036-
task.cancel()
1037-
await gather(*tasks, return_exceptions=True)
1038-
raise
1039-
1002+
awaited_results = await gather_with_cancel(
1003+
*(completed_results[index] for index in awaitable_indices)
1004+
)
10401005
for index, sub_result in zip(awaitable_indices, awaited_results):
10411006
completed_results[index] = sub_result
10421007
return completed_results
@@ -2123,21 +2088,7 @@ def default_type_resolver(
21232088
if awaitable_is_type_of_results:
21242089
# noinspection PyShadowingNames
21252090
async def get_type() -> str | None:
2126-
tasks = [
2127-
create_task(result) # type: ignore[arg-type]
2128-
for result in awaitable_is_type_of_results
2129-
]
2130-
2131-
try:
2132-
is_type_of_results = await gather(*tasks)
2133-
except Exception:
2134-
# Cancel unfinished tasks before raising the exception
2135-
for task in tasks:
2136-
if not task.done():
2137-
task.cancel()
2138-
await gather(*tasks, return_exceptions=True)
2139-
raise
2140-
2091+
is_type_of_results = await gather_with_cancel(*awaitable_is_type_of_results)
21412092
for is_type_of_result, type_ in zip(is_type_of_results, awaitable_types):
21422093
if is_type_of_result:
21432094
return type_.name

src/graphql/pyutils/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
"""
1010

1111
from .async_reduce import async_reduce
12+
from .gather_with_cancel import gather_with_cancel
1213
from .convert_case import camel_to_snake, snake_to_camel
1314
from .cached_property import cached_property
1415
from .description import (
@@ -52,6 +53,7 @@
5253
"cached_property",
5354
"camel_to_snake",
5455
"did_you_mean",
56+
"gather_with_cancel",
5557
"group_by",
5658
"identity_func",
5759
"inspect",
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
"""Run awaitables concurrently with cancellation support."""
2+
3+
from __future__ import annotations
4+
5+
from asyncio import Task, create_task, gather
6+
from typing import Any, Awaitable
7+
8+
__all__ = ["gather_with_cancel"]
9+
10+
11+
async def gather_with_cancel(*awaitables: Awaitable[Any]) -> list[Any]:
12+
"""Run awaitable objects in the sequence concurrently.
13+
14+
The first raised exception is immediately propagated to the task that awaits
15+
on this function and all pending awaitables in the sequence will be cancelled.
16+
17+
This is different from the default behavior or `asyncio.gather` which waits
18+
for all tasks to complete even if one of them raises an exception. It is also
19+
different from `asyncio.gather` with `return_exceptions` set, which does not
20+
cancel the other tasks when one of them raises an exception.
21+
"""
22+
try:
23+
tasks: list[Task[Any]] = [
24+
aw if isinstance(aw, Task) else create_task(aw) # type: ignore[arg-type]
25+
for aw in awaitables
26+
]
27+
except TypeError:
28+
return await gather(*awaitables) # type: ignore[arg-type]
29+
try:
30+
return await gather(*tasks)
31+
except Exception:
32+
for task in tasks:
33+
if not task.done():
34+
task.cancel()
35+
await gather(*tasks, return_exceptions=True)
36+
raise
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
from asyncio import Event, create_task, gather, sleep, wait_for
2+
from typing import Callable
3+
4+
import pytest
5+
6+
from graphql.pyutils import gather_with_cancel, is_awaitable
7+
8+
9+
class Controller:
10+
def reset(self, wait=False):
11+
self.event = Event()
12+
if not wait:
13+
self.event.set()
14+
self.returned = []
15+
16+
17+
controller = Controller()
18+
19+
20+
async def coroutine(value: int) -> int:
21+
"""Simple coroutine that returns a value."""
22+
if value > 2:
23+
raise RuntimeError("Oops")
24+
await controller.event.wait()
25+
controller.returned.append(value)
26+
return value
27+
28+
29+
class CustomAwaitable:
30+
"""Custom awaitable that return a value."""
31+
32+
def __init__(self, value: int):
33+
self.value = value
34+
self.coroutine = coroutine(value)
35+
36+
def __await__(self):
37+
return self.coroutine.__await__()
38+
39+
40+
awaitable_factories: dict[str, Callable] = {
41+
"coroutine": coroutine,
42+
"task": lambda value: create_task(coroutine(value)),
43+
"custom": lambda value: CustomAwaitable(value),
44+
}
45+
46+
with_all_types_of_awaitables = pytest.mark.parametrize(
47+
"type_of_awaitable", awaitable_factories
48+
)
49+
50+
51+
def describe_gather_with_cancel():
52+
@with_all_types_of_awaitables
53+
@pytest.mark.asyncio
54+
async def gathers_all_values(type_of_awaitable: str):
55+
return # !!!s
56+
factory = awaitable_factories[type_of_awaitable]
57+
values = list(range(3))
58+
59+
controller.reset()
60+
aws = [factory(i) for i in values]
61+
62+
assert await gather(*aws) == values
63+
assert controller.returned == values
64+
65+
controller.reset()
66+
aws = [factory(i) for i in values]
67+
68+
result = gather_with_cancel(*aws)
69+
assert is_awaitable(result)
70+
71+
awaited = await wait_for(result, 1)
72+
assert awaited == values
73+
74+
@with_all_types_of_awaitables
75+
@pytest.mark.asyncio
76+
async def raises_on_exception(type_of_awaitable: str):
77+
return # !!!
78+
factory = awaitable_factories[type_of_awaitable]
79+
values = list(range(4))
80+
81+
controller.reset()
82+
aws = [factory(i) for i in values]
83+
84+
with pytest.raises(RuntimeError, match="Oops"):
85+
await gather(*aws)
86+
assert controller.returned == values[:-1]
87+
88+
controller.reset()
89+
aws = [factory(i) for i in values]
90+
91+
result = gather_with_cancel(*aws)
92+
assert is_awaitable(result)
93+
94+
with pytest.raises(RuntimeError, match="Oops"):
95+
await wait_for(result, 1)
96+
assert controller.returned == values[:-1]
97+
98+
@with_all_types_of_awaitables
99+
@pytest.mark.asyncio
100+
async def cancels_on_exception(type_of_awaitable: str):
101+
factory = awaitable_factories[type_of_awaitable]
102+
values = list(range(4))
103+
104+
controller.reset(wait=True)
105+
aws = [factory(i) for i in values]
106+
107+
with pytest.raises(RuntimeError, match="Oops"):
108+
await gather(*aws)
109+
assert not controller.returned
110+
111+
# check that the standard gather continues to produce results
112+
controller.event.set()
113+
await sleep(0)
114+
assert controller.returned == values[:-1]
115+
116+
controller.reset(wait=True)
117+
aws = [factory(i) for i in values]
118+
119+
result = gather_with_cancel(*aws)
120+
assert is_awaitable(result)
121+
122+
with pytest.raises(RuntimeError, match="Oops"):
123+
await wait_for(result, 1)
124+
assert not controller.returned
125+
126+
# check that gather_with_cancel stops producing results
127+
controller.event.set()
128+
await sleep(0)
129+
if type_of_awaitable == "custom":
130+
# Cancellation of custom awaitables is not supported
131+
assert controller.returned == values[:-1]
132+
else:
133+
assert not controller.returned

0 commit comments

Comments
 (0)