|
1 |
| -from asyncio import ensure_future |
| 1 | +import asyncio |
2 | 2 | from inspect import isawaitable
|
3 | 3 | from graphene_django.settings import graphene_settings
|
4 | 4 | from graphql.execution.executors.asyncio import AsyncioExecutor
|
|
10 | 10 |
|
11 | 11 |
|
12 | 12 | class ChannelsConnectionContext(BaseConnectionContext):
|
13 |
| - |
14 | 13 | async def send(self, data):
|
15 | 14 | await self.ws.send_json(data)
|
16 | 15 |
|
@@ -73,25 +72,38 @@ async def on_start(self, connection_context, op_id, params):
|
73 | 72 | return
|
74 | 73 |
|
75 | 74 | iterator = await execution_result.__aiter__()
|
76 |
| - ensure_future(self.run_op(connection_context, op_id, iterator)) |
| 75 | + task = asyncio.ensure_future(self.run_op(connection_context, op_id, iterator)) |
| 76 | + connection_context.register_operation(op_id, task) |
77 | 77 |
|
78 | 78 | async def run_op(self, connection_context, op_id, iterator):
|
79 |
| - connection_context.register_operation(op_id, iterator) |
80 | 79 | async for single_result in iterator:
|
81 | 80 | if not connection_context.has_operation(op_id):
|
82 | 81 | break
|
83 |
| - await self.send_execution_result( |
84 |
| - connection_context, op_id, single_result |
85 |
| - ) |
| 82 | + await self.send_execution_result(connection_context, op_id, single_result) |
86 | 83 | await self.send_message(connection_context, op_id, GQL_COMPLETE)
|
87 | 84 |
|
88 | 85 | async def on_close(self, connection_context):
|
89 | 86 | remove_operations = list(connection_context.operations.keys())
|
| 87 | + cancelled_tasks = [] |
90 | 88 | for op_id in remove_operations:
|
91 |
| - self.unsubscribe(connection_context, op_id) |
| 89 | + task = await self.unsubscribe(connection_context, op_id) |
| 90 | + if task: |
| 91 | + cancelled_tasks.append(task) |
| 92 | + # Wait around for all the tasks to actually cancel. |
| 93 | + await asyncio.gather(*cancelled_tasks, return_exceptions=True) |
92 | 94 |
|
93 | 95 | async def on_stop(self, connection_context, op_id):
|
94 |
| - self.unsubscribe(connection_context, op_id) |
| 96 | + task = await self.unsubscribe(connection_context, op_id) |
| 97 | + await asyncio.gather(task, return_exceptions=True) |
| 98 | + |
| 99 | + async def unsubscribe(self, connection_context, op_id): |
| 100 | + op = None |
| 101 | + if connection_context.has_operation(op_id): |
| 102 | + op = connection_context.get_operation(op_id) |
| 103 | + op.cancel() |
| 104 | + connection_context.remove_operation(op_id) |
| 105 | + self.on_operation_complete(connection_context, op_id) |
| 106 | + return op |
95 | 107 |
|
96 | 108 |
|
97 | 109 | subscription_server = ChannelsSubscriptionServer(schema=graphene_settings.SCHEMA)
|
0 commit comments