Skip to content

Commit 2ce92ac

Browse files
committed
Support async generator responses in django channels
1 parent c9cda69 commit 2ce92ac

File tree

3 files changed

+56
-21
lines changed

3 files changed

+56
-21
lines changed
Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
import graphene
22
from rx import Observable
3+
from channels.layers import get_channel_layer
4+
from asgiref.sync import async_to_sync
5+
6+
channel_layer = get_channel_layer()
37

48

59
class Query(graphene.ObjectType):
@@ -9,15 +13,41 @@ def resolve_hello(self, info, **kwargs):
913
return "world"
1014

1115

16+
class TestMessageMutation(graphene.Mutation):
17+
class Arguments:
18+
input_text = graphene.String()
19+
20+
output_text = graphene.String()
21+
22+
def mutate(self, info, input_text):
23+
async_to_sync(channel_layer.group_send)("new_message", {"data": input_text})
24+
return TestMessageMutation(output_text=input_text)
25+
26+
27+
class Mutations(graphene.ObjectType):
28+
test_message = TestMessageMutation.Field()
29+
30+
1231
class Subscription(graphene.ObjectType):
1332
count_seconds = graphene.Int(up_to=graphene.Int())
33+
new_message = graphene.String()
1434

15-
def resolve_count_seconds(root, info, up_to=5):
35+
def resolve_count_seconds(self, info, up_to=5):
1636
return (
1737
Observable.interval(1000)
1838
.map(lambda i: "{0}".format(i))
1939
.take_while(lambda i: int(i) <= up_to)
2040
)
2141

42+
async def resolve_new_message(self, info):
43+
channel_name = await channel_layer.new_channel()
44+
await channel_layer.group_add("new_message", channel_name)
45+
try:
46+
while True:
47+
message = await channel_layer.receive(channel_name)
48+
yield message["data"]
49+
finally:
50+
await channel_layer.group_discard("new_message", channel_name)
51+
2252

23-
schema = graphene.Schema(query=Query, subscription=Subscription)
53+
schema = graphene.Schema(query=Query, mutation=Mutations, subscription=Subscription)

examples/django_channels2/django_channels2/settings.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,5 @@
2626
ASGI_APPLICATION = "graphql_ws.django.routing.application"
2727

2828

29+
CHANNEL_LAYERS = {"default": {"BACKEND": "channels.layers.InMemoryChannelLayer"}}
2930
GRAPHENE = {"MIDDLEWARE": [], "SCHEMA": "django_channels2.schema.schema"}

graphql_ws/django/subscriptions.py

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
1-
from asgiref.sync import async_to_sync
1+
from inspect import isawaitable
22
from graphene_django.settings import graphene_settings
33
from graphql.execution.executors.asyncio import AsyncioExecutor
4-
from rx import Observer, Observable
4+
from rx import Observer
55
from ..base import BaseConnectionContext, BaseSubscriptionServer
6-
from ..constants import GQL_CONNECTION_ACK, GQL_CONNECTION_ERROR
6+
from ..constants import GQL_CONNECTION_ACK, GQL_CONNECTION_ERROR, GQL_COMPLETE
7+
from ..observable_aiter import setup_observable_extension
8+
9+
setup_observable_extension()
710

811

912
class SubscriptionObserver(Observer):
@@ -76,24 +79,25 @@ async def on_connection_init(self, connection_context, op_id, payload):
7679
await connection_context.close(1011)
7780

7881
async def on_start(self, connection_context, op_id, params):
79-
try:
80-
execution_result = await self.execute(
81-
connection_context.request_context, params
82+
execution_result = self.execute(connection_context.request_context, params)
83+
84+
if isawaitable(execution_result):
85+
execution_result = await execution_result
86+
87+
if not hasattr(execution_result, "__aiter__"):
88+
await self.send_execution_result(
89+
connection_context, op_id, execution_result
8290
)
83-
assert isinstance(
84-
execution_result, Observable
85-
), "A subscription must return an observable"
86-
execution_result.subscribe(
87-
SubscriptionObserver(
88-
connection_context,
89-
op_id,
90-
async_to_sync(self.send_execution_result),
91-
async_to_sync(self.send_error),
92-
async_to_sync(self.on_close),
91+
else:
92+
iterator = await execution_result.__aiter__()
93+
connection_context.register_operation(op_id, iterator)
94+
async for single_result in iterator:
95+
if not connection_context.has_operation(op_id):
96+
break
97+
await self.send_execution_result(
98+
connection_context, op_id, single_result
9399
)
94-
)
95-
except Exception as e:
96-
self.send_error(connection_context, op_id, str(e))
100+
await self.send_message(connection_context, op_id, GQL_COMPLETE)
97101

98102
async def on_close(self, connection_context):
99103
remove_operations = list(connection_context.operations.keys())

0 commit comments

Comments
 (0)