2
2
3
3
from asyncio import Event , as_completed , ensure_future , gather , shield , sleep , wait_for
4
4
from collections .abc import Mapping
5
+ from contextlib import suppress
5
6
from inspect import isawaitable
6
7
from typing import (
7
8
Any ,
17
18
NamedTuple ,
18
19
Optional ,
19
20
Sequence ,
21
+ Set ,
20
22
Tuple ,
21
23
Type ,
22
24
Union ,
@@ -673,6 +675,7 @@ def __init__(
673
675
self .middleware_manager = middleware_manager
674
676
if is_awaitable :
675
677
self .is_awaitable = is_awaitable
678
+ self ._canceled_iterators : Set [AsyncIterator ] = set ()
676
679
self ._subfields_cache : Dict [Tuple , FieldsAndPatches ] = {}
677
680
678
681
@classmethod
@@ -1006,6 +1009,7 @@ async def await_completed() -> Any:
1006
1009
except Exception as raw_error :
1007
1010
error = located_error (raw_error , field_nodes , path .as_list ())
1008
1011
handle_field_error (error , return_type , errors )
1012
+ self .filter_subsequent_payloads (path )
1009
1013
return None
1010
1014
1011
1015
return await_completed ()
@@ -1014,6 +1018,7 @@ async def await_completed() -> Any:
1014
1018
except Exception as raw_error :
1015
1019
error = located_error (raw_error , field_nodes , path .as_list ())
1016
1020
handle_field_error (error , return_type , errors )
1021
+ self .filter_subsequent_payloads (path )
1017
1022
return None
1018
1023
1019
1024
def build_resolve_info (
@@ -1305,6 +1310,7 @@ def complete_list_value(
1305
1310
and index >= stream .initial_count
1306
1311
):
1307
1312
previous_async_payload_record = self .execute_stream_field (
1313
+ path ,
1308
1314
item_path ,
1309
1315
item ,
1310
1316
field_nodes ,
@@ -1334,6 +1340,7 @@ async def await_completed(item: Any, item_path: Path) -> Any:
1334
1340
raw_error , field_nodes , item_path .as_list ()
1335
1341
)
1336
1342
handle_field_error (error , item_type , errors )
1343
+ self .filter_subsequent_payloads (item_path )
1337
1344
return None
1338
1345
1339
1346
completed_item = await_completed (item , item_path )
@@ -1357,12 +1364,14 @@ async def await_completed(item: Any, item_path: Path) -> Any:
1357
1364
raw_error , field_nodes , item_path .as_list ()
1358
1365
)
1359
1366
handle_field_error (error , item_type , errors )
1367
+ self .filter_subsequent_payloads (item_path )
1360
1368
return None
1361
1369
1362
1370
completed_item = await_completed (completed_item , item_path )
1363
1371
except Exception as raw_error :
1364
1372
error = located_error (raw_error , field_nodes , item_path .as_list ())
1365
1373
handle_field_error (error , item_type , errors )
1374
+ self .filter_subsequent_payloads (item_path )
1366
1375
completed_item = None
1367
1376
1368
1377
if is_awaitable (completed_item ):
@@ -1694,14 +1703,17 @@ async def await_data(
1694
1703
def execute_stream_field (
1695
1704
self ,
1696
1705
path : Path ,
1706
+ item_path : Path ,
1697
1707
item : AwaitableOrValue [Any ],
1698
1708
field_nodes : List [FieldNode ],
1699
1709
info : GraphQLResolveInfo ,
1700
1710
item_type : GraphQLOutputType ,
1701
1711
label : Optional [str ] = None ,
1702
1712
parent_context : Optional [AsyncPayloadRecord ] = None ,
1703
1713
) -> AsyncPayloadRecord :
1704
- async_payload_record = StreamRecord (label , path , None , parent_context , self )
1714
+ async_payload_record = StreamRecord (
1715
+ label , item_path , None , parent_context , self
1716
+ )
1705
1717
completed_item : Any
1706
1718
completed_items : Any
1707
1719
try :
@@ -1713,7 +1725,7 @@ async def await_completed_item() -> Any:
1713
1725
item_type ,
1714
1726
field_nodes ,
1715
1727
info ,
1716
- path ,
1728
+ item_path ,
1717
1729
await item ,
1718
1730
async_payload_record ,
1719
1731
)
@@ -1727,7 +1739,12 @@ async def await_completed_item() -> Any:
1727
1739
1728
1740
else :
1729
1741
completed_item = self .complete_value (
1730
- item_type , field_nodes , info , path , item , async_payload_record
1742
+ item_type ,
1743
+ field_nodes ,
1744
+ info ,
1745
+ item_path ,
1746
+ item ,
1747
+ async_payload_record ,
1731
1748
)
1732
1749
1733
1750
if self .is_awaitable (completed_item ):
@@ -1739,24 +1756,31 @@ async def await_completed_item() -> Any:
1739
1756
except Exception as raw_error :
1740
1757
# noinspection PyShadowingNames
1741
1758
error = located_error (
1742
- raw_error , field_nodes , path .as_list ()
1759
+ raw_error , field_nodes , item_path .as_list ()
1743
1760
)
1744
1761
handle_field_error (
1745
1762
error , item_type , async_payload_record .errors
1746
1763
)
1764
+ self .filter_subsequent_payloads (
1765
+ item_path , async_payload_record
1766
+ )
1747
1767
return None
1748
1768
1749
1769
complete_item = await_completed_item ()
1750
1770
1751
1771
else :
1752
1772
complete_item = completed_item
1753
1773
except Exception as raw_error :
1754
- error = located_error (raw_error , field_nodes , path .as_list ())
1774
+ error = located_error (raw_error , field_nodes , item_path .as_list ())
1755
1775
handle_field_error (error , item_type , async_payload_record .errors )
1776
+ self .filter_subsequent_payloads ( # pragma: no cover
1777
+ item_path , async_payload_record
1778
+ )
1756
1779
complete_item = None # pragma: no cover
1757
1780
1758
1781
except GraphQLError as error :
1759
1782
async_payload_record .errors .append (error )
1783
+ self .filter_subsequent_payloads (item_path , async_payload_record )
1760
1784
async_payload_record .add_items (None )
1761
1785
return async_payload_record
1762
1786
@@ -1768,6 +1792,7 @@ async def await_completed_items() -> Optional[List[Any]]:
1768
1792
return [await complete_item ] # type: ignore
1769
1793
except GraphQLError as error :
1770
1794
async_payload_record .errors .append (error )
1795
+ self .filter_subsequent_payloads (path , async_payload_record )
1771
1796
return None
1772
1797
1773
1798
completed_items = await_completed_items ()
@@ -1786,6 +1811,8 @@ async def execute_stream_iterator_item(
1786
1811
async_payload_record : StreamRecord ,
1787
1812
field_path : Path ,
1788
1813
) -> Any :
1814
+ if iterator in self ._canceled_iterators :
1815
+ raise StopAsyncIteration
1789
1816
try :
1790
1817
item = await anext (iterator )
1791
1818
completed_item = self .complete_value (
@@ -1799,12 +1826,13 @@ async def execute_stream_iterator_item(
1799
1826
)
1800
1827
1801
1828
except StopAsyncIteration as raw_error :
1802
- async_payload_record .set_ist_completed_iterator ()
1829
+ async_payload_record .set_is_completed_iterator ()
1803
1830
raise StopAsyncIteration from raw_error
1804
1831
1805
1832
except Exception as raw_error :
1806
1833
error = located_error (raw_error , field_nodes , field_path .as_list ())
1807
1834
handle_field_error (error , item_type , async_payload_record .errors )
1835
+ self .filter_subsequent_payloads (field_path , async_payload_record )
1808
1836
1809
1837
async def execute_stream_iterator (
1810
1838
self ,
@@ -1830,30 +1858,50 @@ async def execute_stream_iterator(
1830
1858
iterator , field_modes , info , item_type , async_payload_record , field_path
1831
1859
)
1832
1860
1833
- # noinspection PyShadowingNames
1834
- async def items (
1835
- data : Awaitable [Any ], async_payload_record : StreamRecord
1836
- ) -> AwaitableOrValue [Optional [List [Any ]]]:
1837
- try :
1838
- return [await data ]
1839
- except GraphQLError as error :
1840
- async_payload_record .errors .append (error )
1841
- return None
1842
-
1843
1861
try :
1844
- async_payload_record .add_items (
1845
- await items (awaitable_data , async_payload_record )
1846
- )
1862
+ data = await awaitable_data
1847
1863
except StopAsyncIteration :
1848
1864
if async_payload_record .errors :
1849
- async_payload_record .add_items ([ None ] ) # pragma: no cover
1865
+ async_payload_record .add_items (None ) # pragma: no cover
1850
1866
else :
1851
1867
del self .subsequent_payloads [async_payload_record ]
1852
1868
break
1869
+ except GraphQLError as error :
1870
+ # entire stream has errored and bubbled upwards
1871
+ self .filter_subsequent_payloads (path , async_payload_record )
1872
+ if iterator : # pragma: no cover else
1873
+ with suppress (Exception ):
1874
+ await iterator .aclose () # type: ignore
1875
+ # running generators cannot be closed since Python 3.8,
1876
+ # so we need to remember that this iterator is already canceled
1877
+ self ._canceled_iterators .add (iterator )
1878
+ async_payload_record .add_items (None )
1879
+ async_payload_record .errors .append (error )
1880
+ break
1881
+
1882
+ async_payload_record .add_items ([data ])
1853
1883
1854
1884
previous_async_payload_record = async_payload_record
1855
1885
index += 1
1856
1886
1887
+ def filter_subsequent_payloads (
1888
+ self ,
1889
+ null_path : Optional [Path ] = None ,
1890
+ current_async_record : Optional [AsyncPayloadRecord ] = None ,
1891
+ ) -> None :
1892
+ null_path_list = null_path .as_list () if null_path else []
1893
+ for async_record in list (self .subsequent_payloads ):
1894
+ if async_record is current_async_record :
1895
+ # don't remove payload from where error originates
1896
+ continue
1897
+ if async_record .path [: len (null_path_list )] != null_path_list :
1898
+ # async_record points to a path unaffected by this payload
1899
+ continue
1900
+ # async_record path points to nulled error field
1901
+ if isinstance (async_record , StreamRecord ) and async_record .iterator :
1902
+ self ._canceled_iterators .add (async_record .iterator )
1903
+ del self .subsequent_payloads [async_record ]
1904
+
1857
1905
def get_completed_incremental_results (self ) -> List [IncrementalResult ]:
1858
1906
incremental_results : List [IncrementalResult ] = []
1859
1907
append_result = incremental_results .append
@@ -2661,12 +2709,16 @@ async def wait(self) -> Optional[Dict[str, Any]]:
2661
2709
if self .parent_context :
2662
2710
await self .parent_context .completed .wait ()
2663
2711
_data = self ._data
2664
- data = (
2665
- await _data if self ._context .is_awaitable (_data ) else _data # type: ignore
2666
- )
2667
- self .data = data
2668
- await sleep (ASYNC_DELAY ) # always defer completion a little bit
2669
- self .completed .set ()
2712
+ try :
2713
+ data = (
2714
+ await _data # type: ignore
2715
+ if self ._context .is_awaitable (_data )
2716
+ else _data
2717
+ )
2718
+ finally :
2719
+ await sleep (ASYNC_DELAY ) # always defer completion a little bit
2720
+ self .data = data
2721
+ self .completed .set ()
2670
2722
return data
2671
2723
2672
2724
def add_data (self , data : AwaitableOrValue [Optional [Dict [str , Any ]]]) -> None :
@@ -2728,21 +2780,23 @@ async def wait(self) -> Optional[List[str]]:
2728
2780
if self .parent_context :
2729
2781
await self .parent_context .completed .wait ()
2730
2782
_items = self ._items
2731
- items = (
2732
- await _items # type: ignore
2733
- if self ._context .is_awaitable (_items )
2734
- else _items
2735
- )
2736
- self .items = items
2737
- await sleep (ASYNC_DELAY ) # always defer completion a little bit
2738
- self .completed .set ()
2783
+ try :
2784
+ items = (
2785
+ await _items # type: ignore
2786
+ if self ._context .is_awaitable (_items )
2787
+ else _items
2788
+ )
2789
+ finally :
2790
+ await sleep (ASYNC_DELAY ) # always defer completion a little bit
2791
+ self .items = items
2792
+ self .completed .set ()
2739
2793
return items
2740
2794
2741
2795
def add_items (self , items : AwaitableOrValue [Optional [List [Any ]]]) -> None :
2742
2796
self ._items = items
2743
2797
self ._items_added .set ()
2744
2798
2745
- def set_ist_completed_iterator (self ) -> None :
2799
+ def set_is_completed_iterator (self ) -> None :
2746
2800
self .is_completed_iterator = True
2747
2801
self ._items_added .set ()
2748
2802
0 commit comments