Skip to content

Commit 9620f3a

Browse files
committed
Fix async tests
1 parent 50f4357 commit 9620f3a

File tree

4 files changed

+26
-21
lines changed

4 files changed

+26
-21
lines changed

graphql_ws/base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ class ConnectionClosedException(Exception):
2323

2424

2525
class BaseConnectionContext(object):
26+
transport_ws_protocol = False
27+
2628
def __init__(self, ws, request_context=None):
2729
self.ws = ws
2830
self.operations = {}

setup.cfg

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,8 @@ test =
5050
pytest-asyncio; python_version>="3.4"
5151
graphene>=2.0,<3
5252
gevent
53-
graphene>=2.0
5453
graphene_django
55-
mock; python_version<"3"
54+
mock; python_version<"3.8"
5655
django==1.11.*; python_version<"3"
5756
channels==1.*; python_version<"3"
5857
django==2.*; python_version>="3"

tests/test_base_async.py

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,10 @@
1010
pytestmark = pytest.mark.asyncio
1111

1212

13-
class AsyncMock(mock.MagicMock):
14-
async def __call__(self, *args, **kwargs):
15-
return super().__call__(*args, **kwargs)
13+
try:
14+
from unittest.mock import AsyncMock # Python 3.8+
15+
except ImportError:
16+
from mock import AsyncMock
1617

1718

1819
class TstServer(base_async.BaseAsyncSubscriptionServer):
@@ -26,75 +27,78 @@ def server():
2627

2728

2829
async def test_terminate(server: TstServer):
29-
context = AsyncMock()
30+
context = AsyncMock(spec=base_async.BaseAsyncConnectionContext)
3031
await server.on_connection_terminate(connection_context=context, op_id=1)
3132
context.close.assert_called_with(1011)
3233

3334

3435
async def test_send_error(server: TstServer):
35-
context = AsyncMock()
36-
context.has_operation = mock.Mock()
36+
context = AsyncMock(spec=base_async.BaseAsyncConnectionContext)
3737
await server.send_error(connection_context=context, op_id=1, error="test error")
3838
context.send.assert_called_with(
3939
{"id": 1, "type": "error", "payload": {"message": "test error"}}
4040
)
4141

4242

43-
async def test_message(server):
43+
async def test_message(server: TstServer):
4444
server.process_message = AsyncMock()
45-
context = AsyncMock()
45+
context = AsyncMock(spec=base_async.BaseAsyncConnectionContext)
4646
msg = {"id": 1, "type": base.GQL_CONNECTION_INIT, "payload": ""}
4747
await server.on_message(context, msg)
4848
server.process_message.assert_called_with(context, msg)
4949

5050

51-
async def test_message_str(server):
51+
async def test_message_str(server: TstServer):
5252
server.process_message = AsyncMock()
53-
context = AsyncMock()
53+
context = AsyncMock(spec=base_async.BaseAsyncConnectionContext)
5454
msg = {"id": 1, "type": base.GQL_CONNECTION_INIT, "payload": ""}
5555
await server.on_message(context, json.dumps(msg))
5656
server.process_message.assert_called_with(context, msg)
5757

5858

59-
async def test_message_invalid(server):
59+
async def test_message_invalid(server: TstServer):
6060
server.send_error = AsyncMock()
61-
await server.on_message(connection_context=None, message="'not-json")
61+
context = AsyncMock(spec=base_async.BaseAsyncConnectionContext)
62+
await server.on_message(context, message="'not-json")
6263
assert server.send_error.called
6364

6465

65-
async def test_resolver(server):
66+
async def test_resolver(server: TstServer):
6667
server.send_message = AsyncMock()
68+
context = AsyncMock(spec=base_async.BaseAsyncConnectionContext)
6769
result = mock.Mock()
6870
result.data = {"test": [1, 2]}
6971
result.errors = None
7072
await server.send_execution_result(
71-
connection_context=None, op_id=1, execution_result=result
73+
context, op_id=1, execution_result=result
7274
)
7375
assert server.send_message.called
7476

7577

7678
@pytest.mark.asyncio
77-
async def test_resolver_with_promise(server):
79+
async def test_resolver_with_promise(server: TstServer):
7880
server.send_message = AsyncMock()
81+
context = AsyncMock(spec=base_async.BaseAsyncConnectionContext)
7982
result = mock.Mock()
8083
result.data = {"test": [1, promise.Promise(lambda resolve, reject: resolve(2))]}
8184
result.errors = None
8285
await server.send_execution_result(
83-
connection_context=None, op_id=1, execution_result=result
86+
context, op_id=1, execution_result=result
8487
)
8588
assert server.send_message.called
8689
assert result.data == {'test': [1, 2]}
8790

8891

89-
async def test_resolver_with_nested_promise(server):
92+
async def test_resolver_with_nested_promise(server: TstServer):
9093
server.send_message = AsyncMock()
94+
context = AsyncMock(spec=base_async.BaseAsyncConnectionContext)
9195
result = mock.Mock()
9296
inner = promise.Promise(lambda resolve, reject: resolve(2))
9397
outer = promise.Promise(lambda resolve, reject: resolve({'in': inner}))
9498
result.data = {"test": [1, outer]}
9599
result.errors = None
96100
await server.send_execution_result(
97-
connection_context=None, op_id=1, execution_result=result
101+
context, op_id=1, execution_result=result
98102
)
99103
assert server.send_message.called
100104
assert result.data == {'test': [1, {'in': 2}]}

tests/test_graphql_ws.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ def test_build_message_partial(ss):
165165
ss.build_message(id=None, op_type=None, payload=None)
166166

167167

168-
def test_send_execution_result(ss):
168+
def test_send_execution_result(ss, cc):
169169
ss.execution_result_to_dict = mock.Mock()
170170
ss.execution_result_to_dict.return_value = {"res": "ult"}
171171
ss.send_message = mock.Mock()

0 commit comments

Comments
 (0)