diff --git a/graphql_ws/aiohttp.py b/graphql_ws/aiohttp.py index 4af5720..631150c 100644 --- a/graphql_ws/aiohttp.py +++ b/graphql_ws/aiohttp.py @@ -1,6 +1,6 @@ from inspect import isawaitable, isasyncgen +from asyncio import ensure_future, wait, shield -from asyncio import ensure_future from aiohttp import WSMsgType from graphql.execution.executors.asyncio import AsyncioExecutor @@ -23,6 +23,10 @@ async def receive(self): return msg.data elif msg.type == WSMsgType.ERROR: raise ConnectionClosedException() + elif msg.type == WSMsgType.CLOSING: + raise ConnectionClosedException() + elif msg.type == WSMsgType.CLOSED: + raise ConnectionClosedException() async def send(self, data): if self.closed: @@ -38,25 +42,42 @@ async def close(self, code): class AiohttpSubscriptionServer(BaseSubscriptionServer): + def __init__(self, schema, keep_alive=True, loop=None): + self.loop = loop + super().__init__(schema, keep_alive) def get_graphql_params(self, *args, **kwargs): params = super(AiohttpSubscriptionServer, self).get_graphql_params(*args, **kwargs) - return dict(params, return_promise=True, executor=AsyncioExecutor()) + return dict(params, return_promise=True, executor=AsyncioExecutor(loop=self.loop)) - async def handle(self, ws, request_context=None): + async def _handle(self, ws, request_context=None): connection_context = AiohttpConnectionContext(ws, request_context) await self.on_open(connection_context) + pending_tasks = [] while True: try: if connection_context.closed: raise ConnectionClosedException() message = await connection_context.receive() except ConnectionClosedException: - self.on_close(connection_context) - return + break + finally: + pending_tasks = [t for t in pending_tasks if not t.done()] - ensure_future(self.on_message(connection_context, message)) + task = ensure_future( + self.on_message(connection_context, message), loop=self.loop) + pending_tasks.append(task) + + self.on_close(connection_context) + if pending_tasks: + for task in pending_tasks: + if not task.done(): + task.cancel() + await wait(pending_tasks, loop=self.loop) + + async def handle(self, ws, request_context=None): + await shield(self._handle(ws, request_context), loop=self.loop) async def on_open(self, connection_context): pass