Skip to content

Commit ca463c9

Browse files
committed
Split base classes into sync and async
Deduplicate code
1 parent 923fc08 commit ca463c9

File tree

9 files changed

+312
-456
lines changed

9 files changed

+312
-456
lines changed

graphql_ws/__init__.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,3 @@
55
__author__ = """Syrus Akbary"""
66
__email__ = 'me@syrusakbary.com'
77
__version__ = '0.3.1'
8-
9-
10-
from .base import BaseConnectionContext, BaseSubscriptionServer # noqa: F401

graphql_ws/aiohttp.py

Lines changed: 13 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,13 @@
1-
from inspect import isawaitable
2-
from asyncio import ensure_future, wait, shield
1+
import json
2+
from asyncio import ensure_future, shield
33

44
from aiohttp import WSMsgType
5-
from graphql.execution.executors.asyncio import AsyncioExecutor
65

7-
from .base import (
8-
ConnectionClosedException, BaseConnectionContext, BaseSubscriptionServer)
9-
from .observable_aiter import setup_observable_extension
6+
from .base import ConnectionClosedException
7+
from .base_async import BaseAsyncConnectionContext, BaseAsyncSubscriptionServer
108

11-
from .constants import (
12-
GQL_CONNECTION_ACK,
13-
GQL_CONNECTION_ERROR,
14-
GQL_COMPLETE
15-
)
169

17-
setup_observable_extension()
18-
19-
20-
class AiohttpConnectionContext(BaseConnectionContext):
10+
class AiohttpConnectionContext(BaseAsyncConnectionContext):
2111
async def receive(self):
2212
msg = await self.ws.receive()
2313
if msg.type == WSMsgType.TEXT:
@@ -32,7 +22,7 @@ async def receive(self):
3222
async def send(self, data):
3323
if self.closed:
3424
return
35-
await self.ws.send_str(data)
25+
await self.ws.send_str(json.dumps(data))
3626

3727
@property
3828
def closed(self):
@@ -42,81 +32,24 @@ async def close(self, code):
4232
await self.ws.close(code=code)
4333

4434

45-
class AiohttpSubscriptionServer(BaseSubscriptionServer):
46-
def __init__(self, schema, keep_alive=True, loop=None):
47-
self.loop = loop
48-
super().__init__(schema, keep_alive)
49-
50-
def get_graphql_params(self, *args, **kwargs):
51-
params = super(AiohttpSubscriptionServer,
52-
self).get_graphql_params(*args, **kwargs)
53-
return dict(params, return_promise=True,
54-
executor=AsyncioExecutor(loop=self.loop))
55-
35+
class AiohttpSubscriptionServer(BaseAsyncSubscriptionServer):
5636
async def _handle(self, ws, request_context=None):
5737
connection_context = AiohttpConnectionContext(ws, request_context)
5838
await self.on_open(connection_context)
59-
pending = set()
6039
while True:
6140
try:
6241
if connection_context.closed:
6342
raise ConnectionClosedException()
6443
message = await connection_context.receive()
6544
except ConnectionClosedException:
6645
break
67-
finally:
68-
if pending:
69-
(_, pending) = await wait(pending, timeout=0, loop=self.loop)
7046

71-
task = ensure_future(
72-
self.on_message(connection_context, message), loop=self.loop)
73-
pending.add(task)
74-
75-
self.on_close(connection_context)
76-
for task in pending:
77-
task.cancel()
47+
connection_context.remember_task(
48+
ensure_future(
49+
self.on_message(connection_context, message), loop=self.loop
50+
)
51+
)
52+
await self.on_close(connection_context)
7853

7954
async def handle(self, ws, request_context=None):
8055
await shield(self._handle(ws, request_context), loop=self.loop)
81-
82-
async def on_open(self, connection_context):
83-
pass
84-
85-
def on_close(self, connection_context):
86-
remove_operations = list(connection_context.operations.keys())
87-
for op_id in remove_operations:
88-
self.unsubscribe(connection_context, op_id)
89-
90-
async def on_connect(self, connection_context, payload):
91-
pass
92-
93-
async def on_connection_init(self, connection_context, op_id, payload):
94-
try:
95-
await self.on_connect(connection_context, payload)
96-
await self.send_message(connection_context, op_type=GQL_CONNECTION_ACK)
97-
except Exception as e:
98-
await self.send_error(connection_context, op_id, e, GQL_CONNECTION_ERROR)
99-
await connection_context.close(1011)
100-
101-
async def on_start(self, connection_context, op_id, params):
102-
execution_result = self.execute(
103-
connection_context.request_context, params)
104-
105-
if isawaitable(execution_result):
106-
execution_result = await execution_result
107-
108-
if not hasattr(execution_result, '__aiter__'):
109-
await self.send_execution_result(
110-
connection_context, op_id, execution_result)
111-
else:
112-
iterator = await execution_result.__aiter__()
113-
connection_context.register_operation(op_id, iterator)
114-
async for single_result in iterator:
115-
if not connection_context.has_operation(op_id):
116-
break
117-
await self.send_execution_result(
118-
connection_context, op_id, single_result)
119-
await self.send_message(connection_context, op_id, GQL_COMPLETE)
120-
121-
async def on_stop(self, connection_context, op_id):
122-
self.unsubscribe(connection_context, op_id)

graphql_ws/base.py

Lines changed: 61 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
import json
22
from collections import OrderedDict
33

4-
from graphql import graphql, format_error
4+
from graphql import format_error
55

66
from .constants import (
7+
GQL_CONNECTION_ERROR,
78
GQL_CONNECTION_INIT,
89
GQL_CONNECTION_TERMINATE,
10+
GQL_DATA,
11+
GQL_ERROR,
912
GQL_START,
1013
GQL_STOP,
11-
GQL_ERROR,
12-
GQL_CONNECTION_ERROR,
13-
GQL_DATA
1414
)
1515

1616

@@ -51,33 +51,16 @@ def close(self, code):
5151

5252

5353
class BaseSubscriptionServer(object):
54+
graphql_executor = None
5455

5556
def __init__(self, schema, keep_alive=True):
5657
self.schema = schema
5758
self.keep_alive = keep_alive
5859

59-
def get_graphql_params(self, connection_context, payload):
60-
return {
61-
'request_string': payload.get('query'),
62-
'variable_values': payload.get('variables'),
63-
'operation_name': payload.get('operationName'),
64-
'context_value': payload.get('context'),
65-
}
66-
67-
def build_message(self, id, op_type, payload):
68-
message = {}
69-
if id is not None:
70-
message['id'] = id
71-
if op_type is not None:
72-
message['type'] = op_type
73-
if payload is not None:
74-
message['payload'] = payload
75-
return message
76-
7760
def process_message(self, connection_context, parsed_message):
78-
op_id = parsed_message.get('id')
79-
op_type = parsed_message.get('type')
80-
payload = parsed_message.get('payload')
61+
op_id = parsed_message.get("id")
62+
op_type = parsed_message.get("type")
63+
payload = parsed_message.get("payload")
8164

8265
if op_type == GQL_CONNECTION_INIT:
8366
return self.on_connection_init(connection_context, op_id, payload)
@@ -92,22 +75,63 @@ def process_message(self, connection_context, parsed_message):
9275
if not isinstance(params, dict):
9376
error = Exception(
9477
"Invalid params returned from get_graphql_params!"
95-
" Return values must be a dict.")
78+
" Return values must be a dict."
79+
)
9680
return self.send_error(connection_context, op_id, error)
9781

9882
# If we already have a subscription with this id, unsubscribe from
9983
# it first
10084
if connection_context.has_operation(op_id):
10185
self.unsubscribe(connection_context, op_id)
10286

87+
params = self.get_graphql_params(connection_context, payload)
10388
return self.on_start(connection_context, op_id, params)
10489

10590
elif op_type == GQL_STOP:
10691
return self.on_stop(connection_context, op_id)
10792

10893
else:
109-
return self.send_error(connection_context, op_id, Exception(
110-
"Invalid message type: {}.".format(op_type)))
94+
return self.send_error(
95+
connection_context,
96+
op_id,
97+
Exception("Invalid message type: {}.".format(op_type)),
98+
)
99+
100+
def on_connection_init(self, connection_context, op_id, payload):
101+
raise NotImplementedError("on_connection_init method not implemented")
102+
103+
def on_connection_terminate(self, connection_context, op_id):
104+
return connection_context.close(1011)
105+
106+
def get_graphql_params(self, connection_context, payload):
107+
return {
108+
"request_string": payload.get("query"),
109+
"variable_values": payload.get("variables"),
110+
"operation_name": payload.get("operationName"),
111+
"context_value": payload.get("context"),
112+
"executor": self.graphql_executor(),
113+
}
114+
115+
def on_open(self, connection_context):
116+
raise NotImplementedError("on_open method not implemented")
117+
118+
def on_stop(self, connection_context, op_id):
119+
raise NotImplementedError("on_stop method not implemented")
120+
121+
def send_message(self, connection_context, op_id=None, op_type=None, payload=None):
122+
message = self.build_message(op_id, op_type, payload)
123+
assert message, "You need to send at least one thing"
124+
return connection_context.send(message)
125+
126+
def build_message(self, id, op_type, payload):
127+
message = {}
128+
if id is not None:
129+
message["id"] = id
130+
if op_type is not None:
131+
message["type"] = op_type
132+
if payload is not None:
133+
message["payload"] = payload
134+
return message
111135

112136
def send_execution_result(self, connection_context, op_id, execution_result):
113137
result = self.execution_result_to_dict(execution_result)
@@ -116,86 +140,34 @@ def send_execution_result(self, connection_context, op_id, execution_result):
116140
def execution_result_to_dict(self, execution_result):
117141
result = OrderedDict()
118142
if execution_result.data:
119-
result['data'] = execution_result.data
143+
result["data"] = execution_result.data
120144
if execution_result.errors:
121-
result['errors'] = [format_error(error)
122-
for error in execution_result.errors]
145+
result["errors"] = [
146+
format_error(error) for error in execution_result.errors
147+
]
123148
return result
124149

125-
def send_message(self, connection_context, op_id=None, op_type=None, payload=None):
126-
message = self.build_message(op_id, op_type, payload)
127-
assert message, "You need to send at least one thing"
128-
json_message = json.dumps(message)
129-
return connection_context.send(json_message)
130-
131150
def send_error(self, connection_context, op_id, error, error_type=None):
132151
if error_type is None:
133152
error_type = GQL_ERROR
134153

135154
assert error_type in [GQL_CONNECTION_ERROR, GQL_ERROR], (
136-
'error_type should be one of the allowed error messages'
137-
' GQL_CONNECTION_ERROR or GQL_ERROR'
138-
)
139-
140-
error_payload = {
141-
'message': str(error)
142-
}
143-
144-
return self.send_message(
145-
connection_context,
146-
op_id,
147-
error_type,
148-
error_payload
155+
"error_type should be one of the allowed error messages"
156+
" GQL_CONNECTION_ERROR or GQL_ERROR"
149157
)
150158

151-
def unsubscribe(self, connection_context, op_id):
152-
if connection_context.has_operation(op_id):
153-
# Close async iterator
154-
connection_context.get_operation(op_id).dispose()
155-
# Close operation
156-
connection_context.remove_operation(op_id)
157-
self.on_operation_complete(connection_context, op_id)
159+
error_payload = {"message": str(error)}
158160

159-
def on_operation_complete(self, connection_context, op_id):
160-
pass
161-
162-
def on_connection_terminate(self, connection_context, op_id):
163-
return connection_context.close(1011)
164-
165-
def execute(self, request_context, params):
166-
return graphql(
167-
self.schema, **dict(params, allow_subscriptions=True))
168-
169-
def handle(self, ws, request_context=None):
170-
raise NotImplementedError("handle method not implemented")
161+
return self.send_message(connection_context, op_id, error_type, error_payload)
171162

172163
def on_message(self, connection_context, message):
173164
try:
174165
if not isinstance(message, dict):
175166
parsed_message = json.loads(message)
176-
assert isinstance(
177-
parsed_message, dict), "Payload must be an object."
167+
assert isinstance(parsed_message, dict), "Payload must be an object."
178168
else:
179169
parsed_message = message
180170
except Exception as e:
181171
return self.send_error(connection_context, None, e)
182172

183173
return self.process_message(connection_context, parsed_message)
184-
185-
def on_open(self, connection_context):
186-
raise NotImplementedError("on_open method not implemented")
187-
188-
def on_connect(self, connection_context, payload):
189-
raise NotImplementedError("on_connect method not implemented")
190-
191-
def on_close(self, connection_context):
192-
raise NotImplementedError("on_close method not implemented")
193-
194-
def on_connection_init(self, connection_context, op_id, payload):
195-
raise NotImplementedError("on_connection_init method not implemented")
196-
197-
def on_stop(self, connection_context, op_id):
198-
raise NotImplementedError("on_stop method not implemented")
199-
200-
def on_start(self, connection_context, op_id, params):
201-
raise NotImplementedError("on_start method not implemented")

0 commit comments

Comments
 (0)