Skip to content

Commit 7879f32

Browse files
committed
Simplify the django async futures
Promise observers are already futures, the only thing that needs to be a future is the on_message call from the receive_json Django consumer
1 parent 99bc3a8 commit 7879f32

File tree

2 files changed

+37
-31
lines changed

2 files changed

+37
-31
lines changed

graphql_ws/django/consumers.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,30 @@
1+
import asyncio
2+
import json
3+
14
from channels.generic.websocket import AsyncJsonWebsocketConsumer
5+
from promise import Promise
6+
27
from ..constants import WS_PROTOCOL
38
from .subscriptions import subscription_server
49

510

11+
class JSONPromiseEncoder(json.JSONEncoder):
12+
def default(self, o):
13+
if isinstance(o, Promise):
14+
return o.value
15+
return super(JSONPromiseEncoder, self).default(o)
16+
17+
18+
json_promise_encoder = JSONPromiseEncoder()
19+
20+
621
class GraphQLSubscriptionConsumer(AsyncJsonWebsocketConsumer):
722
async def connect(self):
823
self.connection_context = None
924
if WS_PROTOCOL in self.scope["subprotocols"]:
1025
self.connection_context = await subscription_server.handle(
11-
ws=self, request_context=self.scope)
26+
ws=self, request_context=self.scope
27+
)
1228
await self.accept(subprotocol=WS_PROTOCOL)
1329
else:
1430
await self.close()
@@ -18,4 +34,10 @@ async def disconnect(self, code):
1834
await subscription_server.on_close(self.connection_context)
1935

2036
async def receive_json(self, content):
21-
await subscription_server.on_message(self.connection_context, content)
37+
asyncio.ensure_future(
38+
subscription_server.on_message(self.connection_context, content)
39+
)
40+
41+
@classmethod
42+
async def encode_json(cls, content):
43+
return json_promise_encoder.encode(content)

graphql_ws/django/subscriptions.py

Lines changed: 13 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import asyncio
21
from inspect import isawaitable
32
from graphene_django.settings import graphene_settings
43
from graphql.execution.executors.asyncio import AsyncioExecutor
@@ -64,49 +63,34 @@ async def on_start(self, connection_context, op_id, params):
6463
if isawaitable(execution_result):
6564
execution_result = await execution_result
6665

67-
if not hasattr(execution_result, "__aiter__"):
66+
if hasattr(execution_result, "__aiter__"):
67+
iterator = await execution_result.__aiter__()
68+
connection_context.register_operation(op_id, iterator)
69+
async for single_result in iterator:
70+
if not connection_context.has_operation(op_id):
71+
break
72+
await self.send_execution_result(
73+
connection_context, op_id, single_result
74+
)
75+
else:
6876
await self.send_execution_result(
6977
connection_context, op_id, execution_result
7078
)
71-
await self.on_operation_complete(connection_context, op_id)
72-
return
73-
74-
task = asyncio.ensure_future(
75-
self.run_op(connection_context, op_id, execution_result)
76-
)
77-
connection_context.register_operation(op_id, task)
78-
79-
async def run_op(self, connection_context, op_id, aiterable):
80-
async for single_result in aiterable:
81-
if not connection_context.has_operation(op_id):
82-
break
83-
await self.send_execution_result(connection_context, op_id, single_result)
8479
await self.on_operation_complete(connection_context, op_id)
8580

8681
async def on_close(self, connection_context):
87-
# Unsubscribe from all the connection's current operations in parallel.
88-
unsubscribes = [
82+
for op_id in connection_context.operations:
8983
self.unsubscribe(connection_context, op_id)
90-
for op_id in connection_context.operations
91-
]
92-
cancelled_tasks = [task for task in await asyncio.gather(*unsubscribes) if task]
93-
# Wait around for all the tasks to actually cancel.
94-
if cancelled_tasks:
95-
await asyncio.wait(cancelled_tasks)
9684

9785
async def on_stop(self, connection_context, op_id):
98-
task = await self.unsubscribe(connection_context, op_id)
99-
if task:
100-
await asyncio.wait([task])
86+
await self.unsubscribe(connection_context, op_id)
10187

10288
async def unsubscribe(self, connection_context, op_id):
103-
op = None
10489
if connection_context.has_operation(op_id):
10590
op = connection_context.get_operation(op_id)
106-
op.cancel()
91+
op.dispose()
10792
connection_context.remove_operation(op_id)
10893
await self.on_operation_complete(connection_context, op_id)
109-
return op
11094

11195
async def on_operation_complete(self, connection_context, op_id):
11296
await self.send_message(connection_context, op_id, GQL_COMPLETE)

0 commit comments

Comments
 (0)