From e624306ff6daa8c15de9b826fbfdd3ca5ea99a8d Mon Sep 17 00:00:00 2001 From: Taku Fukada Date: Mon, 9 Jul 2018 11:59:14 +0900 Subject: [PATCH] Fix handler for websockets lib not to ignore pending tasks same as #13 --- graphql_ws/aiohttp.py | 14 ++++++-------- graphql_ws/websockets_lib.py | 29 ++++++++++++++++++++++------- 2 files changed, 28 insertions(+), 15 deletions(-) diff --git a/graphql_ws/aiohttp.py b/graphql_ws/aiohttp.py index 631150c..a7c4610 100644 --- a/graphql_ws/aiohttp.py +++ b/graphql_ws/aiohttp.py @@ -54,7 +54,7 @@ def get_graphql_params(self, *args, **kwargs): async def _handle(self, ws, request_context=None): connection_context = AiohttpConnectionContext(ws, request_context) await self.on_open(connection_context) - pending_tasks = [] + pending = set() while True: try: if connection_context.closed: @@ -63,18 +63,16 @@ async def _handle(self, ws, request_context=None): except ConnectionClosedException: break finally: - pending_tasks = [t for t in pending_tasks if not t.done()] + if pending: + (_, pending) = await wait(pending, timeout=0, loop=self.loop) task = ensure_future( self.on_message(connection_context, message), loop=self.loop) - pending_tasks.append(task) + pending.add(task) self.on_close(connection_context) - if pending_tasks: - for task in pending_tasks: - if not task.done(): - task.cancel() - await wait(pending_tasks, loop=self.loop) + for task in pending: + task.cancel() async def handle(self, ws, request_context=None): await shield(self._handle(ws, request_context), loop=self.loop) diff --git a/graphql_ws/websockets_lib.py b/graphql_ws/websockets_lib.py index f41a1bb..6812d8a 100644 --- a/graphql_ws/websockets_lib.py +++ b/graphql_ws/websockets_lib.py @@ -1,6 +1,6 @@ -from inspect import isawaitable, isasyncgen +from inspect import isawaitable -from asyncio import ensure_future +from asyncio import ensure_future, wait, shield from websockets import ConnectionClosed from graphql.execution.executors.asyncio import AsyncioExecutor @@ -38,25 +38,40 @@ async def close(self, code): class WsLibSubscriptionServer(BaseSubscriptionServer): + def __init__(self, schema, keep_alive=True, loop=None): + self.loop = loop + super().__init__(schema, keep_alive) def get_graphql_params(self, *args, **kwargs): params = super(WsLibSubscriptionServer, self).get_graphql_params(*args, **kwargs) - return dict(params, return_promise=True, executor=AsyncioExecutor()) + return dict(params, return_promise=True, executor=AsyncioExecutor(loop=self.loop)) - async def handle(self, ws, request_context=None): + async def _handle(self, ws, request_context): connection_context = WsLibConnectionContext(ws, request_context) await self.on_open(connection_context) + pending = set() while True: try: if connection_context.closed: raise ConnectionClosedException() message = await connection_context.receive() except ConnectionClosedException: - self.on_close(connection_context) - return + break + finally: + if pending: + (_, pending) = await wait(pending, timeout=0, loop=self.loop) + + task = ensure_future( + self.on_message(connection_context, message), loop=self.loop) + pending.add(task) - ensure_future(self.on_message(connection_context, message)) + self.on_close(connection_context) + for task in pending: + task.cancel() + + async def handle(self, ws, request_context=None): + await shield(self._handle(ws, request_context), loop=self.loop) async def on_open(self, connection_context): pass