Skip to content

Commit fe91dbb

Browse files
committed
Use lower level asyncio.wait, abstract the on_complete command
1 parent f0a2727 commit fe91dbb

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

graphql_ws/django/subscriptions.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ async def on_start(self, connection_context, op_id, params):
6868
await self.send_execution_result(
6969
connection_context, op_id, execution_result
7070
)
71-
await self.send_message(connection_context, op_id, GQL_COMPLETE)
71+
await self.on_operation_complete(connection_context, op_id)
7272
return
7373

7474
iterator = await execution_result.__aiter__()
@@ -80,7 +80,7 @@ async def run_op(self, connection_context, op_id, iterator):
8080
if not connection_context.has_operation(op_id):
8181
break
8282
await self.send_execution_result(connection_context, op_id, single_result)
83-
await self.send_message(connection_context, op_id, GQL_COMPLETE)
83+
await self.on_operation_complete(connection_context, op_id)
8484

8585
async def on_close(self, connection_context):
8686
remove_operations = list(connection_context.operations.keys())
@@ -90,20 +90,23 @@ async def on_close(self, connection_context):
9090
if task:
9191
cancelled_tasks.append(task)
9292
# Wait around for all the tasks to actually cancel.
93-
await asyncio.gather(*cancelled_tasks, return_exceptions=True)
93+
await asyncio.wait(cancelled_tasks)
9494

9595
async def on_stop(self, connection_context, op_id):
9696
task = await self.unsubscribe(connection_context, op_id)
97-
await asyncio.gather(task, return_exceptions=True)
97+
await asyncio.wait([task])
9898

9999
async def unsubscribe(self, connection_context, op_id):
100100
op = None
101101
if connection_context.has_operation(op_id):
102102
op = connection_context.get_operation(op_id)
103103
op.cancel()
104104
connection_context.remove_operation(op_id)
105-
self.on_operation_complete(connection_context, op_id)
105+
await self.on_operation_complete(connection_context, op_id)
106106
return op
107107

108+
async def on_operation_complete(self, connection_context, op_id):
109+
await self.send_message(connection_context, op_id, GQL_COMPLETE)
110+
108111

109112
subscription_server = ChannelsSubscriptionServer(schema=graphene_settings.SCHEMA)

0 commit comments

Comments
 (0)