Skip to content

Commit 3e7b361

Browse files
committed
Filter subsequent payloads when parent field is null
Replicates graphql/graphql-js@b9a2695
1 parent 62749e5 commit 3e7b361

File tree

2 files changed

+447
-63
lines changed

2 files changed

+447
-63
lines changed

src/graphql/execution/execute.py

Lines changed: 89 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from asyncio import Event, as_completed, ensure_future, gather, shield, sleep, wait_for
44
from collections.abc import Mapping
5+
from contextlib import suppress
56
from inspect import isawaitable
67
from typing import (
78
Any,
@@ -17,6 +18,7 @@
1718
NamedTuple,
1819
Optional,
1920
Sequence,
21+
Set,
2022
Tuple,
2123
Type,
2224
Union,
@@ -673,6 +675,7 @@ def __init__(
673675
self.middleware_manager = middleware_manager
674676
if is_awaitable:
675677
self.is_awaitable = is_awaitable
678+
self._canceled_iterators: Set[AsyncIterator] = set()
676679
self._subfields_cache: Dict[Tuple, FieldsAndPatches] = {}
677680

678681
@classmethod
@@ -1006,6 +1009,7 @@ async def await_completed() -> Any:
10061009
except Exception as raw_error:
10071010
error = located_error(raw_error, field_nodes, path.as_list())
10081011
handle_field_error(error, return_type, errors)
1012+
self.filter_subsequent_payloads(path)
10091013
return None
10101014

10111015
return await_completed()
@@ -1014,6 +1018,7 @@ async def await_completed() -> Any:
10141018
except Exception as raw_error:
10151019
error = located_error(raw_error, field_nodes, path.as_list())
10161020
handle_field_error(error, return_type, errors)
1021+
self.filter_subsequent_payloads(path)
10171022
return None
10181023

10191024
def build_resolve_info(
@@ -1305,6 +1310,7 @@ def complete_list_value(
13051310
and index >= stream.initial_count
13061311
):
13071312
previous_async_payload_record = self.execute_stream_field(
1313+
path,
13081314
item_path,
13091315
item,
13101316
field_nodes,
@@ -1334,6 +1340,7 @@ async def await_completed(item: Any, item_path: Path) -> Any:
13341340
raw_error, field_nodes, item_path.as_list()
13351341
)
13361342
handle_field_error(error, item_type, errors)
1343+
self.filter_subsequent_payloads(item_path)
13371344
return None
13381345

13391346
completed_item = await_completed(item, item_path)
@@ -1357,12 +1364,14 @@ async def await_completed(item: Any, item_path: Path) -> Any:
13571364
raw_error, field_nodes, item_path.as_list()
13581365
)
13591366
handle_field_error(error, item_type, errors)
1367+
self.filter_subsequent_payloads(item_path)
13601368
return None
13611369

13621370
completed_item = await_completed(completed_item, item_path)
13631371
except Exception as raw_error:
13641372
error = located_error(raw_error, field_nodes, item_path.as_list())
13651373
handle_field_error(error, item_type, errors)
1374+
self.filter_subsequent_payloads(item_path)
13661375
completed_item = None
13671376

13681377
if is_awaitable(completed_item):
@@ -1694,14 +1703,17 @@ async def await_data(
16941703
def execute_stream_field(
16951704
self,
16961705
path: Path,
1706+
item_path: Path,
16971707
item: AwaitableOrValue[Any],
16981708
field_nodes: List[FieldNode],
16991709
info: GraphQLResolveInfo,
17001710
item_type: GraphQLOutputType,
17011711
label: Optional[str] = None,
17021712
parent_context: Optional[AsyncPayloadRecord] = None,
17031713
) -> AsyncPayloadRecord:
1704-
async_payload_record = StreamRecord(label, path, None, parent_context, self)
1714+
async_payload_record = StreamRecord(
1715+
label, item_path, None, parent_context, self
1716+
)
17051717
completed_item: Any
17061718
completed_items: Any
17071719
try:
@@ -1713,7 +1725,7 @@ async def await_completed_item() -> Any:
17131725
item_type,
17141726
field_nodes,
17151727
info,
1716-
path,
1728+
item_path,
17171729
await item,
17181730
async_payload_record,
17191731
)
@@ -1727,7 +1739,12 @@ async def await_completed_item() -> Any:
17271739

17281740
else:
17291741
completed_item = self.complete_value(
1730-
item_type, field_nodes, info, path, item, async_payload_record
1742+
item_type,
1743+
field_nodes,
1744+
info,
1745+
item_path,
1746+
item,
1747+
async_payload_record,
17311748
)
17321749

17331750
if self.is_awaitable(completed_item):
@@ -1739,24 +1756,31 @@ async def await_completed_item() -> Any:
17391756
except Exception as raw_error:
17401757
# noinspection PyShadowingNames
17411758
error = located_error(
1742-
raw_error, field_nodes, path.as_list()
1759+
raw_error, field_nodes, item_path.as_list()
17431760
)
17441761
handle_field_error(
17451762
error, item_type, async_payload_record.errors
17461763
)
1764+
self.filter_subsequent_payloads(
1765+
item_path, async_payload_record
1766+
)
17471767
return None
17481768

17491769
complete_item = await_completed_item()
17501770

17511771
else:
17521772
complete_item = completed_item
17531773
except Exception as raw_error:
1754-
error = located_error(raw_error, field_nodes, path.as_list())
1774+
error = located_error(raw_error, field_nodes, item_path.as_list())
17551775
handle_field_error(error, item_type, async_payload_record.errors)
1776+
self.filter_subsequent_payloads( # pragma: no cover
1777+
item_path, async_payload_record
1778+
)
17561779
complete_item = None # pragma: no cover
17571780

17581781
except GraphQLError as error:
17591782
async_payload_record.errors.append(error)
1783+
self.filter_subsequent_payloads(item_path, async_payload_record)
17601784
async_payload_record.add_items(None)
17611785
return async_payload_record
17621786

@@ -1768,6 +1792,7 @@ async def await_completed_items() -> Optional[List[Any]]:
17681792
return [await complete_item] # type: ignore
17691793
except GraphQLError as error:
17701794
async_payload_record.errors.append(error)
1795+
self.filter_subsequent_payloads(path, async_payload_record)
17711796
return None
17721797

17731798
completed_items = await_completed_items()
@@ -1786,6 +1811,8 @@ async def execute_stream_iterator_item(
17861811
async_payload_record: StreamRecord,
17871812
field_path: Path,
17881813
) -> Any:
1814+
if iterator in self._canceled_iterators:
1815+
raise StopAsyncIteration
17891816
try:
17901817
item = await anext(iterator)
17911818
completed_item = self.complete_value(
@@ -1799,12 +1826,13 @@ async def execute_stream_iterator_item(
17991826
)
18001827

18011828
except StopAsyncIteration as raw_error:
1802-
async_payload_record.set_ist_completed_iterator()
1829+
async_payload_record.set_is_completed_iterator()
18031830
raise StopAsyncIteration from raw_error
18041831

18051832
except Exception as raw_error:
18061833
error = located_error(raw_error, field_nodes, field_path.as_list())
18071834
handle_field_error(error, item_type, async_payload_record.errors)
1835+
self.filter_subsequent_payloads(field_path, async_payload_record)
18081836

18091837
async def execute_stream_iterator(
18101838
self,
@@ -1830,30 +1858,50 @@ async def execute_stream_iterator(
18301858
iterator, field_modes, info, item_type, async_payload_record, field_path
18311859
)
18321860

1833-
# noinspection PyShadowingNames
1834-
async def items(
1835-
data: Awaitable[Any], async_payload_record: StreamRecord
1836-
) -> AwaitableOrValue[Optional[List[Any]]]:
1837-
try:
1838-
return [await data]
1839-
except GraphQLError as error:
1840-
async_payload_record.errors.append(error)
1841-
return None
1842-
18431861
try:
1844-
async_payload_record.add_items(
1845-
await items(awaitable_data, async_payload_record)
1846-
)
1862+
data = await awaitable_data
18471863
except StopAsyncIteration:
18481864
if async_payload_record.errors:
1849-
async_payload_record.add_items([None]) # pragma: no cover
1865+
async_payload_record.add_items(None) # pragma: no cover
18501866
else:
18511867
del self.subsequent_payloads[async_payload_record]
18521868
break
1869+
except GraphQLError as error:
1870+
# entire stream has errored and bubbled upwards
1871+
self.filter_subsequent_payloads(path, async_payload_record)
1872+
if iterator: # pragma: no cover else
1873+
with suppress(Exception):
1874+
await iterator.aclose() # type: ignore
1875+
# running generators cannot be closed since Python 3.8,
1876+
# so we need to remember that this iterator is already canceled
1877+
self._canceled_iterators.add(iterator)
1878+
async_payload_record.add_items(None)
1879+
async_payload_record.errors.append(error)
1880+
break
1881+
1882+
async_payload_record.add_items([data])
18531883

18541884
previous_async_payload_record = async_payload_record
18551885
index += 1
18561886

1887+
def filter_subsequent_payloads(
1888+
self,
1889+
null_path: Optional[Path] = None,
1890+
current_async_record: Optional[AsyncPayloadRecord] = None,
1891+
) -> None:
1892+
null_path_list = null_path.as_list() if null_path else []
1893+
for async_record in list(self.subsequent_payloads):
1894+
if async_record is current_async_record:
1895+
# don't remove payload from where error originates
1896+
continue
1897+
if async_record.path[: len(null_path_list)] != null_path_list:
1898+
# async_record points to a path unaffected by this payload
1899+
continue
1900+
# async_record path points to nulled error field
1901+
if isinstance(async_record, StreamRecord) and async_record.iterator:
1902+
self._canceled_iterators.add(async_record.iterator)
1903+
del self.subsequent_payloads[async_record]
1904+
18571905
def get_completed_incremental_results(self) -> List[IncrementalResult]:
18581906
incremental_results: List[IncrementalResult] = []
18591907
append_result = incremental_results.append
@@ -2661,12 +2709,16 @@ async def wait(self) -> Optional[Dict[str, Any]]:
26612709
if self.parent_context:
26622710
await self.parent_context.completed.wait()
26632711
_data = self._data
2664-
data = (
2665-
await _data if self._context.is_awaitable(_data) else _data # type: ignore
2666-
)
2667-
self.data = data
2668-
await sleep(ASYNC_DELAY) # always defer completion a little bit
2669-
self.completed.set()
2712+
try:
2713+
data = (
2714+
await _data # type: ignore
2715+
if self._context.is_awaitable(_data)
2716+
else _data
2717+
)
2718+
finally:
2719+
await sleep(ASYNC_DELAY) # always defer completion a little bit
2720+
self.data = data
2721+
self.completed.set()
26702722
return data
26712723

26722724
def add_data(self, data: AwaitableOrValue[Optional[Dict[str, Any]]]) -> None:
@@ -2728,21 +2780,23 @@ async def wait(self) -> Optional[List[str]]:
27282780
if self.parent_context:
27292781
await self.parent_context.completed.wait()
27302782
_items = self._items
2731-
items = (
2732-
await _items # type: ignore
2733-
if self._context.is_awaitable(_items)
2734-
else _items
2735-
)
2736-
self.items = items
2737-
await sleep(ASYNC_DELAY) # always defer completion a little bit
2738-
self.completed.set()
2783+
try:
2784+
items = (
2785+
await _items # type: ignore
2786+
if self._context.is_awaitable(_items)
2787+
else _items
2788+
)
2789+
finally:
2790+
await sleep(ASYNC_DELAY) # always defer completion a little bit
2791+
self.items = items
2792+
self.completed.set()
27392793
return items
27402794

27412795
def add_items(self, items: AwaitableOrValue[Optional[List[Any]]]) -> None:
27422796
self._items = items
27432797
self._items_added.set()
27442798

2745-
def set_ist_completed_iterator(self) -> None:
2799+
def set_is_completed_iterator(self) -> None:
27462800
self.is_completed_iterator = True
27472801
self._items_added.set()
27482802

0 commit comments

Comments
 (0)