Skip to content

Commit dc86b24

Browse files
committed
Merge branch 'graphql-transport-ws'
2 parents 7ef25ec + 2f0eb21 commit dc86b24

File tree

6 files changed

+102
-38
lines changed

6 files changed

+102
-38
lines changed

graphql_ws/base.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,17 @@
44
from graphql import format_error, graphql
55

66
from .constants import (
7+
GQL_COMPLETE,
78
GQL_CONNECTION_ERROR,
89
GQL_CONNECTION_INIT,
910
GQL_CONNECTION_TERMINATE,
1011
GQL_DATA,
1112
GQL_ERROR,
13+
GQL_NEXT,
1214
GQL_START,
1315
GQL_STOP,
16+
GQL_SUBSCRIBE,
17+
TRANSPORT_WS_PROTOCOL,
1418
)
1519

1620

@@ -19,10 +23,15 @@ class ConnectionClosedException(Exception):
1923

2024

2125
class BaseConnectionContext(object):
26+
transport_ws_protocol = False
27+
2228
def __init__(self, ws, request_context=None):
2329
self.ws = ws
2430
self.operations = {}
2531
self.request_context = request_context
32+
self.transport_ws_protocol = request_context and TRANSPORT_WS_PROTOCOL in (
33+
request_context.get("subprotocols") or []
34+
)
2635

2736
def has_operation(self, op_id):
2837
return op_id in self.operations
@@ -84,12 +93,16 @@ def process_message(self, connection_context, parsed_message):
8493
elif op_type == GQL_CONNECTION_TERMINATE:
8594
return self.on_connection_terminate(connection_context, op_id)
8695

87-
elif op_type == GQL_START:
96+
elif op_type == (
97+
GQL_SUBSCRIBE if connection_context.transport_ws_protocol else GQL_START
98+
):
8899
assert isinstance(payload, dict), "The payload must be a dict"
89100
params = self.get_graphql_params(connection_context, payload)
90101
return self.on_start(connection_context, op_id, params)
91102

92-
elif op_type == GQL_STOP:
103+
elif op_type == (
104+
GQL_COMPLETE if connection_context.transport_ws_protocol else GQL_STOP
105+
):
93106
return self.on_stop(connection_context, op_id)
94107

95108
else:
@@ -142,7 +155,12 @@ def build_message(self, id, op_type, payload):
142155

143156
def send_execution_result(self, connection_context, op_id, execution_result):
144157
result = self.execution_result_to_dict(execution_result)
145-
return self.send_message(connection_context, op_id, GQL_DATA, result)
158+
return self.send_message(
159+
connection_context,
160+
op_id,
161+
GQL_NEXT if connection_context.transport_ws_protocol else GQL_DATA,
162+
result,
163+
)
146164

147165
def execution_result_to_dict(self, execution_result):
148166
result = OrderedDict()

graphql_ws/constants.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
GRAPHQL_WS = "graphql-ws"
22
WS_PROTOCOL = GRAPHQL_WS
3+
TRANSPORT_WS_PROTOCOL = "graphql-transport-ws"
34

45
GQL_CONNECTION_INIT = "connection_init" # Client -> Server
56
GQL_CONNECTION_ACK = "connection_ack" # Server -> Client
@@ -8,8 +9,11 @@
89
# NOTE: This one here don't follow the standard due to connection optimization
910
GQL_CONNECTION_TERMINATE = "connection_terminate" # Client -> Server
1011
GQL_CONNECTION_KEEP_ALIVE = "ka" # Server -> Client
11-
GQL_START = "start" # Client -> Server
12-
GQL_DATA = "data" # Server -> Client
12+
GQL_START = "start" # Client -> Server (graphql-ws)
13+
GQL_SUBSCRIBE = "subscribe" # Client -> Server (graphql-transport-ws START equivalent)
14+
GQL_DATA = "data" # Server -> Client (graphql-ws)
15+
GQL_NEXT = "next" # Server -> Client (graphql-transport-ws DATA equivalent)
1316
GQL_ERROR = "error" # Server -> Client
1417
GQL_COMPLETE = "complete" # Server -> Client
15-
GQL_STOP = "stop" # Client -> Server
18+
# (and Client -> Server for graphql-transport-ws STOP equivalent)
19+
GQL_STOP = "stop" # Client -> Server (graphql-ws only)

graphql_ws/django/consumers.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,25 @@
22

33
from channels.generic.websocket import AsyncJsonWebsocketConsumer
44

5-
from ..constants import WS_PROTOCOL
5+
from ..constants import TRANSPORT_WS_PROTOCOL, WS_PROTOCOL
66
from .subscriptions import subscription_server
77

88

99
class GraphQLSubscriptionConsumer(AsyncJsonWebsocketConsumer):
1010
async def connect(self):
1111
self.connection_context = None
12-
if WS_PROTOCOL in self.scope["subprotocols"]:
13-
self.connection_context = await subscription_server.handle(
14-
ws=self, request_context=self.scope
15-
)
16-
await self.accept(subprotocol=WS_PROTOCOL)
17-
else:
12+
found_protocol = None
13+
for protocol in [WS_PROTOCOL, TRANSPORT_WS_PROTOCOL]:
14+
if protocol in self.scope["subprotocols"]:
15+
found_protocol = protocol
16+
break
17+
if not found_protocol:
1818
await self.close()
19+
return
20+
self.connection_context = await subscription_server.handle(
21+
ws=self, request_context=self.scope
22+
)
23+
await self.accept(subprotocol=found_protocol)
1924

2025
async def disconnect(self, code):
2126
if self.connection_context:

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ test =
5252
graphene>=2.0,<3
5353
gevent
5454
graphene_django
55-
mock; python_version<"3"
55+
mock; python_version<"3.8"
5656
django==1.11.*; python_version<"3"
5757
channels==1.*; python_version<"3"
5858
django==3.*; python_version>="3"

tests/test_base_async.py

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,10 @@
1010
pytestmark = pytest.mark.asyncio
1111

1212

13-
class AsyncMock(mock.MagicMock):
14-
async def __call__(self, *args, **kwargs):
15-
return super().__call__(*args, **kwargs)
13+
try:
14+
from unittest.mock import AsyncMock # Python 3.8+
15+
except ImportError:
16+
from mock import AsyncMock
1617

1718

1819
class TstServer(base_async.BaseAsyncSubscriptionServer):
@@ -26,75 +27,78 @@ def server():
2627

2728

2829
async def test_terminate(server: TstServer):
29-
context = AsyncMock()
30+
context = AsyncMock(spec=base_async.BaseAsyncConnectionContext)
3031
await server.on_connection_terminate(connection_context=context, op_id=1)
3132
context.close.assert_called_with(1011)
3233

3334

3435
async def test_send_error(server: TstServer):
35-
context = AsyncMock()
36-
context.has_operation = mock.Mock()
36+
context = AsyncMock(spec=base_async.BaseAsyncConnectionContext)
3737
await server.send_error(connection_context=context, op_id=1, error="test error")
3838
context.send.assert_called_with(
3939
{"id": 1, "type": "error", "payload": {"message": "test error"}}
4040
)
4141

4242

43-
async def test_message(server):
43+
async def test_message(server: TstServer):
4444
server.process_message = AsyncMock()
45-
context = AsyncMock()
45+
context = AsyncMock(spec=base_async.BaseAsyncConnectionContext)
4646
msg = {"id": 1, "type": base.GQL_CONNECTION_INIT, "payload": ""}
4747
await server.on_message(context, msg)
4848
server.process_message.assert_called_with(context, msg)
4949

5050

51-
async def test_message_str(server):
51+
async def test_message_str(server: TstServer):
5252
server.process_message = AsyncMock()
53-
context = AsyncMock()
53+
context = AsyncMock(spec=base_async.BaseAsyncConnectionContext)
5454
msg = {"id": 1, "type": base.GQL_CONNECTION_INIT, "payload": ""}
5555
await server.on_message(context, json.dumps(msg))
5656
server.process_message.assert_called_with(context, msg)
5757

5858

59-
async def test_message_invalid(server):
59+
async def test_message_invalid(server: TstServer):
6060
server.send_error = AsyncMock()
61-
await server.on_message(connection_context=None, message="'not-json")
61+
context = AsyncMock(spec=base_async.BaseAsyncConnectionContext)
62+
await server.on_message(context, message="'not-json")
6263
assert server.send_error.called
6364

6465

65-
async def test_resolver(server):
66+
async def test_resolver(server: TstServer):
6667
server.send_message = AsyncMock()
68+
context = AsyncMock(spec=base_async.BaseAsyncConnectionContext)
6769
result = mock.Mock()
6870
result.data = {"test": [1, 2]}
6971
result.errors = None
7072
await server.send_execution_result(
71-
connection_context=None, op_id=1, execution_result=result
73+
context, op_id=1, execution_result=result
7274
)
7375
assert server.send_message.called
7476

7577

7678
@pytest.mark.asyncio
77-
async def test_resolver_with_promise(server):
79+
async def test_resolver_with_promise(server: TstServer):
7880
server.send_message = AsyncMock()
81+
context = AsyncMock(spec=base_async.BaseAsyncConnectionContext)
7982
result = mock.Mock()
8083
result.data = {"test": [1, promise.Promise(lambda resolve, reject: resolve(2))]}
8184
result.errors = None
8285
await server.send_execution_result(
83-
connection_context=None, op_id=1, execution_result=result
86+
context, op_id=1, execution_result=result
8487
)
8588
assert server.send_message.called
8689
assert result.data == {"test": [1, 2]}
8790

8891

89-
async def test_resolver_with_nested_promise(server):
92+
async def test_resolver_with_nested_promise(server: TstServer):
9093
server.send_message = AsyncMock()
94+
context = AsyncMock(spec=base_async.BaseAsyncConnectionContext)
9195
result = mock.Mock()
9296
inner = promise.Promise(lambda resolve, reject: resolve(2))
9397
outer = promise.Promise(lambda resolve, reject: resolve({"in": inner}))
9498
result.data = {"test": [1, outer]}
9599
result.errors = None
96100
await server.send_execution_result(
97-
connection_context=None, op_id=1, execution_result=result
101+
context, op_id=1, execution_result=result
98102
)
99103
assert server.send_message.called
100104
assert result.data == {"test": [1, {"in": 2}]}

tests/test_graphql_ws.py

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -77,15 +77,20 @@ def test_terminate(self, ss, cc):
7777
ss.process_message(cc, {"id": "1", "type": constants.GQL_CONNECTION_TERMINATE})
7878
ss.on_connection_terminate.assert_called_with(cc, "1")
7979

80-
def test_start(self, ss, cc):
80+
@pytest.mark.parametrize(
81+
"transport_ws_protocol,expected_type",
82+
((False, constants.GQL_START), (True, constants.GQL_SUBSCRIBE)),
83+
)
84+
def test_start(self, ss, cc, transport_ws_protocol, expected_type):
8185
ss.get_graphql_params = mock.Mock()
8286
ss.get_graphql_params.return_value = {"params": True}
8387
cc.has_operation = mock.Mock()
8488
cc.has_operation.return_value = False
89+
cc.transport_ws_protocol = transport_ws_protocol
8590
ss.unsubscribe = mock.Mock()
8691
ss.on_start = mock.Mock()
8792
ss.process_message(
88-
cc, {"id": "1", "type": constants.GQL_START, "payload": {"a": "b"}}
93+
cc, {"id": "1", "type": expected_type, "payload": {"a": "b"}}
8994
)
9095
assert not ss.unsubscribe.called
9196
ss.on_start.assert_called_with(cc, "1", {"params": True})
@@ -117,9 +122,32 @@ def test_start_bad_graphql_params(self, ss, cc):
117122
assert isinstance(ss.send_error.call_args[0][2], Exception)
118123
assert not ss.on_start.called
119124

120-
def test_stop(self, ss, cc):
125+
@pytest.mark.parametrize(
126+
"transport_ws_protocol,stop_type,invalid_stop_type",
127+
(
128+
(False, constants.GQL_STOP, constants.GQL_COMPLETE),
129+
(True, constants.GQL_COMPLETE, constants.GQL_STOP),
130+
),
131+
)
132+
def test_stop(
133+
self,
134+
ss,
135+
cc,
136+
transport_ws_protocol,
137+
stop_type,
138+
invalid_stop_type,
139+
):
121140
ss.on_stop = mock.Mock()
122-
ss.process_message(cc, {"id": "1", "type": constants.GQL_STOP})
141+
ss.send_error = mock.Mock()
142+
cc.transport_ws_protocol = transport_ws_protocol
143+
144+
ss.process_message(cc, {"id": "1", "type": invalid_stop_type})
145+
assert ss.send_error.called
146+
assert ss.send_error.call_args[0][:2] == (cc, "1")
147+
assert isinstance(ss.send_error.call_args[0][2], Exception)
148+
assert not ss.on_stop.called
149+
150+
ss.process_message(cc, {"id": "1", "type": stop_type})
123151
ss.on_stop.assert_called_with(cc, "1")
124152

125153
def test_invalid(self, ss, cc):
@@ -165,13 +193,18 @@ def test_build_message_partial(ss):
165193
ss.build_message(id=None, op_type=None, payload=None)
166194

167195

168-
def test_send_execution_result(ss):
196+
@pytest.mark.parametrize(
197+
"transport_ws_protocol,expected_type",
198+
((False, constants.GQL_DATA), (True, constants.GQL_NEXT)),
199+
)
200+
def test_send_execution_result(ss, cc, transport_ws_protocol, expected_type):
201+
cc.transport_ws_protocol = transport_ws_protocol
169202
ss.execution_result_to_dict = mock.Mock()
170203
ss.execution_result_to_dict.return_value = {"res": "ult"}
171204
ss.send_message = mock.Mock()
172205
ss.send_message.return_value = "returned"
173206
assert "returned" == ss.send_execution_result(cc, "1", "result")
174-
ss.send_message.assert_called_with(cc, "1", constants.GQL_DATA, {"res": "ult"})
207+
ss.send_message.assert_called_with(cc, "1", expected_type, {"res": "ult"})
175208

176209

177210
def test_execution_result_to_dict(ss):

0 commit comments

Comments
 (0)