Skip to content

Commit f0a2727

Browse files
committed
Behave correctly by cancelling async tasks
1 parent dc478be commit f0a2727

File tree

1 file changed

+21
-9
lines changed

1 file changed

+21
-9
lines changed

graphql_ws/django/subscriptions.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from asyncio import ensure_future
1+
import asyncio
22
from inspect import isawaitable
33
from graphene_django.settings import graphene_settings
44
from graphql.execution.executors.asyncio import AsyncioExecutor
@@ -10,7 +10,6 @@
1010

1111

1212
class ChannelsConnectionContext(BaseConnectionContext):
13-
1413
async def send(self, data):
1514
await self.ws.send_json(data)
1615

@@ -73,25 +72,38 @@ async def on_start(self, connection_context, op_id, params):
7372
return
7473

7574
iterator = await execution_result.__aiter__()
76-
ensure_future(self.run_op(connection_context, op_id, iterator))
75+
task = asyncio.ensure_future(self.run_op(connection_context, op_id, iterator))
76+
connection_context.register_operation(op_id, task)
7777

7878
async def run_op(self, connection_context, op_id, iterator):
79-
connection_context.register_operation(op_id, iterator)
8079
async for single_result in iterator:
8180
if not connection_context.has_operation(op_id):
8281
break
83-
await self.send_execution_result(
84-
connection_context, op_id, single_result
85-
)
82+
await self.send_execution_result(connection_context, op_id, single_result)
8683
await self.send_message(connection_context, op_id, GQL_COMPLETE)
8784

8885
async def on_close(self, connection_context):
8986
remove_operations = list(connection_context.operations.keys())
87+
cancelled_tasks = []
9088
for op_id in remove_operations:
91-
self.unsubscribe(connection_context, op_id)
89+
task = await self.unsubscribe(connection_context, op_id)
90+
if task:
91+
cancelled_tasks.append(task)
92+
# Wait around for all the tasks to actually cancel.
93+
await asyncio.gather(*cancelled_tasks, return_exceptions=True)
9294

9395
async def on_stop(self, connection_context, op_id):
94-
self.unsubscribe(connection_context, op_id)
96+
task = await self.unsubscribe(connection_context, op_id)
97+
await asyncio.gather(task, return_exceptions=True)
98+
99+
async def unsubscribe(self, connection_context, op_id):
100+
op = None
101+
if connection_context.has_operation(op_id):
102+
op = connection_context.get_operation(op_id)
103+
op.cancel()
104+
connection_context.remove_operation(op_id)
105+
self.on_operation_complete(connection_context, op_id)
106+
return op
95107

96108

97109
subscription_server = ChannelsSubscriptionServer(schema=graphene_settings.SCHEMA)

0 commit comments

Comments
 (0)