|
1 |
| -from asgiref.sync import async_to_sync |
| 1 | +from inspect import isawaitable |
2 | 2 | from graphene_django.settings import graphene_settings
|
3 | 3 | from graphql.execution.executors.asyncio import AsyncioExecutor
|
4 |
| -from rx import Observer, Observable |
| 4 | +from rx import Observer |
5 | 5 | from ..base import BaseConnectionContext, BaseSubscriptionServer
|
6 |
| -from ..constants import GQL_CONNECTION_ACK, GQL_CONNECTION_ERROR |
| 6 | +from ..constants import GQL_CONNECTION_ACK, GQL_CONNECTION_ERROR, GQL_COMPLETE |
| 7 | +from ..observable_aiter import setup_observable_extension |
| 8 | + |
| 9 | +setup_observable_extension() |
7 | 10 |
|
8 | 11 |
|
9 | 12 | class SubscriptionObserver(Observer):
|
@@ -76,24 +79,25 @@ async def on_connection_init(self, connection_context, op_id, payload):
|
76 | 79 | await connection_context.close(1011)
|
77 | 80 |
|
78 | 81 | async def on_start(self, connection_context, op_id, params):
|
79 |
| - try: |
80 |
| - execution_result = await self.execute( |
81 |
| - connection_context.request_context, params |
| 82 | + execution_result = self.execute(connection_context.request_context, params) |
| 83 | + |
| 84 | + if isawaitable(execution_result): |
| 85 | + execution_result = await execution_result |
| 86 | + |
| 87 | + if not hasattr(execution_result, "__aiter__"): |
| 88 | + await self.send_execution_result( |
| 89 | + connection_context, op_id, execution_result |
82 | 90 | )
|
83 |
| - assert isinstance( |
84 |
| - execution_result, Observable |
85 |
| - ), "A subscription must return an observable" |
86 |
| - execution_result.subscribe( |
87 |
| - SubscriptionObserver( |
88 |
| - connection_context, |
89 |
| - op_id, |
90 |
| - async_to_sync(self.send_execution_result), |
91 |
| - async_to_sync(self.send_error), |
92 |
| - async_to_sync(self.on_close), |
| 91 | + else: |
| 92 | + iterator = await execution_result.__aiter__() |
| 93 | + connection_context.register_operation(op_id, iterator) |
| 94 | + async for single_result in iterator: |
| 95 | + if not connection_context.has_operation(op_id): |
| 96 | + break |
| 97 | + await self.send_execution_result( |
| 98 | + connection_context, op_id, single_result |
93 | 99 | )
|
94 |
| - ) |
95 |
| - except Exception as e: |
96 |
| - self.send_error(connection_context, op_id, str(e)) |
| 100 | + await self.send_message(connection_context, op_id, GQL_COMPLETE) |
97 | 101 |
|
98 | 102 | async def on_close(self, connection_context):
|
99 | 103 | remove_operations = list(connection_context.operations.keys())
|
|
0 commit comments