1
1
from inspect import isawaitable , isasyncgen
2
+ from asyncio import ensure_future , wait , shield
2
3
3
- from asyncio import ensure_future
4
4
from aiohttp import WSMsgType
5
5
from graphql .execution .executors .asyncio import AsyncioExecutor
6
6
@@ -23,6 +23,10 @@ async def receive(self):
23
23
return msg .data
24
24
elif msg .type == WSMsgType .ERROR :
25
25
raise ConnectionClosedException ()
26
+ elif msg .type == WSMsgType .CLOSING :
27
+ raise ConnectionClosedException ()
28
+ elif msg .type == WSMsgType .CLOSED :
29
+ raise ConnectionClosedException ()
26
30
27
31
async def send (self , data ):
28
32
if self .closed :
@@ -38,25 +42,42 @@ async def close(self, code):
38
42
39
43
40
44
class AiohttpSubscriptionServer (BaseSubscriptionServer ):
45
+ def __init__ (self , schema , keep_alive = True , loop = None ):
46
+ self .loop = loop
47
+ super ().__init__ (schema , keep_alive )
41
48
42
49
def get_graphql_params (self , * args , ** kwargs ):
43
50
params = super (AiohttpSubscriptionServer ,
44
51
self ).get_graphql_params (* args , ** kwargs )
45
- return dict (params , return_promise = True , executor = AsyncioExecutor ())
52
+ return dict (params , return_promise = True , executor = AsyncioExecutor (loop = self . loop ))
46
53
47
- async def handle (self , ws , request_context = None ):
54
+ async def _handle (self , ws , request_context = None ):
48
55
connection_context = AiohttpConnectionContext (ws , request_context )
49
56
await self .on_open (connection_context )
57
+ pending_tasks = []
50
58
while True :
51
59
try :
52
60
if connection_context .closed :
53
61
raise ConnectionClosedException ()
54
62
message = await connection_context .receive ()
55
63
except ConnectionClosedException :
56
- self .on_close (connection_context )
57
- return
64
+ break
65
+ finally :
66
+ pending_tasks = [t for t in pending_tasks if not t .done ()]
58
67
59
- ensure_future (self .on_message (connection_context , message ))
68
+ task = ensure_future (
69
+ self .on_message (connection_context , message ), loop = self .loop )
70
+ pending_tasks .append (task )
71
+
72
+ self .on_close (connection_context )
73
+ if pending_tasks :
74
+ for task in pending_tasks :
75
+ if not task .done ():
76
+ task .cancel ()
77
+ await wait (pending_tasks , loop = self .loop )
78
+
79
+ async def handle (self , ws , request_context = None ):
80
+ await shield (self ._handle (ws , request_context ), loop = self .loop )
60
81
61
82
async def on_open (self , connection_context ):
62
83
pass
0 commit comments