|
1 |
| -import asyncio |
2 | 1 | from inspect import isawaitable
|
3 | 2 | from graphene_django.settings import graphene_settings
|
4 | 3 | from graphql.execution.executors.asyncio import AsyncioExecutor
|
@@ -64,49 +63,34 @@ async def on_start(self, connection_context, op_id, params):
|
64 | 63 | if isawaitable(execution_result):
|
65 | 64 | execution_result = await execution_result
|
66 | 65 |
|
67 |
| - if not hasattr(execution_result, "__aiter__"): |
| 66 | + if hasattr(execution_result, "__aiter__"): |
| 67 | + iterator = await execution_result.__aiter__() |
| 68 | + connection_context.register_operation(op_id, iterator) |
| 69 | + async for single_result in iterator: |
| 70 | + if not connection_context.has_operation(op_id): |
| 71 | + break |
| 72 | + await self.send_execution_result( |
| 73 | + connection_context, op_id, single_result |
| 74 | + ) |
| 75 | + else: |
68 | 76 | await self.send_execution_result(
|
69 | 77 | connection_context, op_id, execution_result
|
70 | 78 | )
|
71 |
| - await self.on_operation_complete(connection_context, op_id) |
72 |
| - return |
73 |
| - |
74 |
| - task = asyncio.ensure_future( |
75 |
| - self.run_op(connection_context, op_id, execution_result) |
76 |
| - ) |
77 |
| - connection_context.register_operation(op_id, task) |
78 |
| - |
79 |
| - async def run_op(self, connection_context, op_id, aiterable): |
80 |
| - async for single_result in aiterable: |
81 |
| - if not connection_context.has_operation(op_id): |
82 |
| - break |
83 |
| - await self.send_execution_result(connection_context, op_id, single_result) |
84 | 79 | await self.on_operation_complete(connection_context, op_id)
|
85 | 80 |
|
86 | 81 | async def on_close(self, connection_context):
|
87 |
| - # Unsubscribe from all the connection's current operations in parallel. |
88 |
| - unsubscribes = [ |
| 82 | + for op_id in connection_context.operations: |
89 | 83 | self.unsubscribe(connection_context, op_id)
|
90 |
| - for op_id in connection_context.operations |
91 |
| - ] |
92 |
| - cancelled_tasks = [task for task in await asyncio.gather(*unsubscribes) if task] |
93 |
| - # Wait around for all the tasks to actually cancel. |
94 |
| - if cancelled_tasks: |
95 |
| - await asyncio.wait(cancelled_tasks) |
96 | 84 |
|
97 | 85 | async def on_stop(self, connection_context, op_id):
|
98 |
| - task = await self.unsubscribe(connection_context, op_id) |
99 |
| - if task: |
100 |
| - await asyncio.wait([task]) |
| 86 | + await self.unsubscribe(connection_context, op_id) |
101 | 87 |
|
102 | 88 | async def unsubscribe(self, connection_context, op_id):
|
103 |
| - op = None |
104 | 89 | if connection_context.has_operation(op_id):
|
105 | 90 | op = connection_context.get_operation(op_id)
|
106 |
| - op.cancel() |
| 91 | + op.dispose() |
107 | 92 | connection_context.remove_operation(op_id)
|
108 | 93 | await self.on_operation_complete(connection_context, op_id)
|
109 |
| - return op |
110 | 94 |
|
111 | 95 | async def on_operation_complete(self, connection_context, op_id):
|
112 | 96 | await self.send_message(connection_context, op_id, GQL_COMPLETE)
|
|
0 commit comments