Skip to content

Commit a3197d0

Browse files
committed
Recursively resolve Promises, fix async tests
1 parent 5ed4f1d commit a3197d0

File tree

2 files changed

+124
-33
lines changed

2 files changed

+124
-33
lines changed

graphql_ws/base_async.py

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,62 @@
11
import asyncio
2+
import inspect
23
from abc import ABC, abstractmethod
3-
from inspect import isawaitable
4+
from types import CoroutineType, GeneratorType
5+
from typing import Any, Union, List, Dict
46
from weakref import WeakSet
57

68
from graphql.execution.executors.asyncio import AsyncioExecutor
9+
from promise import Promise
710

811
from graphql_ws import base
912

1013
from .constants import GQL_COMPLETE, GQL_CONNECTION_ACK, GQL_CONNECTION_ERROR
1114
from .observable_aiter import setup_observable_extension
1215

1316
setup_observable_extension()
17+
CO_ITERABLE_COROUTINE = inspect.CO_ITERABLE_COROUTINE
18+
19+
20+
# Copied from graphql-core v3.1.0 (graphql/pyutils/is_awaitable.py)
21+
def is_awaitable(value: Any) -> bool:
22+
"""Return true if object can be passed to an ``await`` expression.
23+
Instead of testing if the object is an instance of abc.Awaitable, it checks
24+
the existence of an `__await__` attribute. This is much faster.
25+
"""
26+
return (
27+
# check for coroutine objects
28+
isinstance(value, CoroutineType)
29+
# check for old-style generator based coroutine objects
30+
or isinstance(value, GeneratorType)
31+
and bool(value.gi_code.co_flags & CO_ITERABLE_COROUTINE)
32+
# check for other awaitables (e.g. futures)
33+
or hasattr(value, "__await__")
34+
)
35+
36+
37+
async def resolve(
38+
data: Any, _container: Union[List, Dict] = None, _key: Union[str, int] = None
39+
) -> None:
40+
"""
41+
Recursively wait on any awaitable children of a data element and resolve any
42+
Promises.
43+
"""
44+
if is_awaitable(data):
45+
data = await data
46+
if isinstance(data, Promise):
47+
data = data.value # type: Any
48+
if _container is not None:
49+
_container[_key] = data
50+
if isinstance(data, dict):
51+
items = data.items()
52+
elif isinstance(data, list):
53+
items = enumerate(data)
54+
else:
55+
items = None
56+
if items is not None:
57+
children = [resolve(child, _container=data, _key=key) for key, child in items]
58+
if children:
59+
await asyncio.wait(children)
1460

1561

1662
class BaseAsyncConnectionContext(base.BaseConnectionContext, ABC):
@@ -81,7 +127,7 @@ async def on_connection_init(self, connection_context, op_id, payload):
81127
async def on_start(self, connection_context, op_id, params):
82128
execution_result = self.execute(params)
83129

84-
if isawaitable(execution_result):
130+
if is_awaitable(execution_result):
85131
execution_result = await execution_result
86132

87133
if hasattr(execution_result, "__aiter__"):
@@ -120,3 +166,8 @@ async def on_stop(self, connection_context, op_id):
120166

121167
async def on_operation_complete(self, connection_context, op_id):
122168
pass
169+
170+
async def send_execution_result(self, connection_context, op_id, execution_result):
171+
# Resolve any pending promises
172+
await resolve(execution_result.data)
173+
await super().send_execution_result(connection_context, op_id, execution_result)

tests/test_base_async.py

Lines changed: 71 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,59 +1,99 @@
11
from unittest import mock
22

33
import json
4+
import promise
45

56
import pytest
67

7-
from graphql_ws import base
8+
from graphql_ws import base, base_async
89

10+
pytestmark = pytest.mark.asyncio
911

10-
def test_not_implemented():
11-
server = base.BaseSubscriptionServer(schema=None)
12-
with pytest.raises(NotImplementedError):
13-
server.on_connection_init(connection_context=None, op_id=1, payload={})
14-
with pytest.raises(NotImplementedError):
15-
server.on_open(connection_context=None)
16-
with pytest.raises(NotImplementedError):
17-
server.on_stop(connection_context=None, op_id=1)
1812

13+
class AsyncMock(mock.MagicMock):
14+
async def __call__(self, *args, **kwargs):
15+
return super().__call__(*args, **kwargs)
1916

20-
def test_terminate():
21-
server = base.BaseSubscriptionServer(schema=None)
2217

23-
context = mock.Mock()
24-
server.on_connection_terminate(connection_context=context, op_id=1)
18+
class TestServer(base_async.BaseAsyncSubscriptionServer):
19+
def handle(self, *args, **kwargs):
20+
pass
21+
22+
23+
@pytest.fixture
24+
def server():
25+
return TestServer(schema=None)
26+
27+
28+
async def test_terminate(server: TestServer):
29+
context = AsyncMock()
30+
await server.on_connection_terminate(connection_context=context, op_id=1)
2531
context.close.assert_called_with(1011)
2632

2733

28-
def test_send_error():
29-
server = base.BaseSubscriptionServer(schema=None)
30-
context = mock.Mock()
31-
server.send_error(connection_context=context, op_id=1, error="test error")
34+
async def test_send_error(server: TestServer):
35+
context = AsyncMock()
36+
await server.send_error(connection_context=context, op_id=1, error="test error")
3237
context.send.assert_called_with(
3338
{"id": 1, "type": "error", "payload": {"message": "test error"}}
3439
)
3540

3641

37-
def test_message():
38-
server = base.BaseSubscriptionServer(schema=None)
39-
server.process_message = mock.Mock()
40-
context = mock.Mock()
42+
async def test_message(server):
43+
server.process_message = AsyncMock()
44+
context = AsyncMock()
4145
msg = {"id": 1, "type": base.GQL_CONNECTION_INIT, "payload": ""}
42-
server.on_message(context, msg)
46+
await server.on_message(context, msg)
4347
server.process_message.assert_called_with(context, msg)
4448

4549

46-
def test_message_str():
47-
server = base.BaseSubscriptionServer(schema=None)
48-
server.process_message = mock.Mock()
49-
context = mock.Mock()
50+
async def test_message_str(server):
51+
server.process_message = AsyncMock()
52+
context = AsyncMock()
5053
msg = {"id": 1, "type": base.GQL_CONNECTION_INIT, "payload": ""}
51-
server.on_message(context, json.dumps(msg))
54+
await server.on_message(context, json.dumps(msg))
5255
server.process_message.assert_called_with(context, msg)
5356

5457

55-
def test_message_invalid():
56-
server = base.BaseSubscriptionServer(schema=None)
57-
server.send_error = mock.Mock()
58-
server.on_message(connection_context=None, message="'not-json")
58+
async def test_message_invalid(server):
59+
server.send_error = AsyncMock()
60+
await server.on_message(connection_context=None, message="'not-json")
5961
assert server.send_error.called
62+
63+
64+
async def test_resolver(server):
65+
server.send_message = AsyncMock()
66+
result = mock.Mock()
67+
result.data = {"test": [1, 2]}
68+
result.errors = None
69+
await server.send_execution_result(
70+
connection_context=None, op_id=1, execution_result=result
71+
)
72+
assert server.send_message.called
73+
74+
75+
@pytest.mark.asyncio
76+
async def test_resolver_with_promise(server):
77+
server.send_message = AsyncMock()
78+
result = mock.Mock()
79+
result.data = {"test": [1, promise.Promise(lambda resolve, reject: resolve(2))]}
80+
result.errors = None
81+
await server.send_execution_result(
82+
connection_context=None, op_id=1, execution_result=result
83+
)
84+
assert server.send_message.called
85+
assert result.data == {'test': [1, 2]}
86+
87+
88+
async def test_resolver_with_nested_promise(server):
89+
server.send_message = AsyncMock()
90+
result = mock.Mock()
91+
inner = promise.Promise(lambda resolve, reject: resolve(2))
92+
outer = promise.Promise(lambda resolve, reject: resolve({'in': inner}))
93+
result.data = {"test": [1, outer]}
94+
result.errors = None
95+
await server.send_execution_result(
96+
connection_context=None, op_id=1, execution_result=result
97+
)
98+
assert server.send_message.called
99+
assert result.data == {'test': [1, {'in': 2}]}

0 commit comments

Comments
 (0)