Skip to content

Commit f7da106

Browse files
authored
Merge pull request #13 from ciscorn/fix-aiohttp
Handle asyncio's Task cancellation and aiohttp's websock close properly
2 parents 06e7c65 + 010531c commit f7da106

File tree

1 file changed

+27
-6
lines changed

1 file changed

+27
-6
lines changed

graphql_ws/aiohttp.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from inspect import isawaitable, isasyncgen
2+
from asyncio import ensure_future, wait, shield
23

3-
from asyncio import ensure_future
44
from aiohttp import WSMsgType
55
from graphql.execution.executors.asyncio import AsyncioExecutor
66

@@ -23,6 +23,10 @@ async def receive(self):
2323
return msg.data
2424
elif msg.type == WSMsgType.ERROR:
2525
raise ConnectionClosedException()
26+
elif msg.type == WSMsgType.CLOSING:
27+
raise ConnectionClosedException()
28+
elif msg.type == WSMsgType.CLOSED:
29+
raise ConnectionClosedException()
2630

2731
async def send(self, data):
2832
if self.closed:
@@ -38,25 +42,42 @@ async def close(self, code):
3842

3943

4044
class AiohttpSubscriptionServer(BaseSubscriptionServer):
45+
def __init__(self, schema, keep_alive=True, loop=None):
46+
self.loop = loop
47+
super().__init__(schema, keep_alive)
4148

4249
def get_graphql_params(self, *args, **kwargs):
4350
params = super(AiohttpSubscriptionServer,
4451
self).get_graphql_params(*args, **kwargs)
45-
return dict(params, return_promise=True, executor=AsyncioExecutor())
52+
return dict(params, return_promise=True, executor=AsyncioExecutor(loop=self.loop))
4653

47-
async def handle(self, ws, request_context=None):
54+
async def _handle(self, ws, request_context=None):
4855
connection_context = AiohttpConnectionContext(ws, request_context)
4956
await self.on_open(connection_context)
57+
pending_tasks = []
5058
while True:
5159
try:
5260
if connection_context.closed:
5361
raise ConnectionClosedException()
5462
message = await connection_context.receive()
5563
except ConnectionClosedException:
56-
self.on_close(connection_context)
57-
return
64+
break
65+
finally:
66+
pending_tasks = [t for t in pending_tasks if not t.done()]
5867

59-
ensure_future(self.on_message(connection_context, message))
68+
task = ensure_future(
69+
self.on_message(connection_context, message), loop=self.loop)
70+
pending_tasks.append(task)
71+
72+
self.on_close(connection_context)
73+
if pending_tasks:
74+
for task in pending_tasks:
75+
if not task.done():
76+
task.cancel()
77+
await wait(pending_tasks, loop=self.loop)
78+
79+
async def handle(self, ws, request_context=None):
80+
await shield(self._handle(ws, request_context), loop=self.loop)
6081

6182
async def on_open(self, connection_context):
6283
pass

0 commit comments

Comments
 (0)