Skip to content

Commit f6ffd5b

Browse files
authored
Merge pull request #6 from hballard/websocketslib
Add websockets library context and server classes in order to use with Sanic
2 parents 5f5e85e + e262ee6 commit f6ffd5b

File tree

7 files changed

+316
-0
lines changed

7 files changed

+316
-0
lines changed

README.md

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ Websocket server for GraphQL subscriptions.
55
Currently supports:
66
* [aiohttp](https://github.com/graphql-python/graphql-ws#aiohttp)
77
* [Gevent](https://github.com/graphql-python/graphql-ws#gevent)
8+
* Sanic (uses [websockets](https://github.com/aaugustin/websockets/) library)
89

910
# Installation instructions
1011

@@ -40,6 +41,29 @@ app.router.add_get('/subscriptions', subscriptions)
4041
web.run_app(app, port=8000)
4142
```
4243

44+
### Sanic
45+
46+
Works with any framework that uses the websockets library for
47+
it's websocket implementation. For this example, plug in
48+
your Sanic server.
49+
50+
```python
51+
from graphql_ws.websockets_lib import WsLibSubscriptionServer
52+
53+
54+
app = Sanic(__name__)
55+
56+
subscription_server = WsLibSubscriptionServer(schema)
57+
58+
@app.websocket('/subscriptions', subprotocols=['graphql-ws'])
59+
async def subscriptions(request, ws):
60+
await subscription_server.handle(ws)
61+
return ws
62+
63+
64+
app.run(host="0.0.0.0", port=8000)
65+
```
66+
4367
And then, plug into a subscribable schema:
4468

4569
```python

examples/websockets_lib/__init__.py

Whitespace-only changes.

examples/websockets_lib/app.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from graphql_ws.websockets_lib import WsLibSubscriptionServer
2+
from graphql.execution.executors.asyncio import AsyncioExecutor
3+
from sanic import Sanic, response
4+
from sanic_graphql import GraphQLView
5+
from schema import schema
6+
from template import render_graphiql
7+
8+
app = Sanic(__name__)
9+
10+
11+
@app.listener('before_server_start')
12+
def init_graphql(app, loop):
13+
app.add_route(GraphQLView.as_view(schema=schema,
14+
executor=AsyncioExecutor(loop=loop)),
15+
'/graphql')
16+
17+
18+
@app.route('/graphiql')
19+
async def graphiql_view(request):
20+
return response.html(render_graphiql())
21+
22+
subscription_server = WsLibSubscriptionServer(schema)
23+
24+
25+
@app.websocket('/subscriptions', subprotocols=['graphql-ws'])
26+
async def subscriptions(request, ws):
27+
await subscription_server.handle(ws)
28+
return ws
29+
30+
31+
app.run(host="0.0.0.0", port=8000)
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
graphql_ws
2+
sanic>=0.7.0
3+
graphene>=2.0
4+
sanic-graphql>=1.1.0

examples/websockets_lib/schema.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import random
2+
import asyncio
3+
import graphene
4+
5+
6+
class Query(graphene.ObjectType):
7+
base = graphene.String()
8+
9+
10+
class RandomType(graphene.ObjectType):
11+
seconds = graphene.Int()
12+
random_int = graphene.Int()
13+
14+
15+
class Subscription(graphene.ObjectType):
16+
count_seconds = graphene.Float(up_to=graphene.Int())
17+
random_int = graphene.Field(RandomType)
18+
19+
async def resolve_count_seconds(root, info, up_to=5):
20+
for i in range(up_to):
21+
print("YIELD SECOND", i)
22+
yield i
23+
await asyncio.sleep(1.)
24+
yield up_to
25+
26+
async def resolve_random_int(root, info):
27+
i = 0
28+
while True:
29+
yield RandomType(seconds=i, random_int=random.randint(0, 500))
30+
await asyncio.sleep(1.)
31+
i += 1
32+
33+
34+
schema = graphene.Schema(query=Query, subscription=Subscription)

examples/websockets_lib/template.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
2+
from string import Template
3+
4+
5+
def render_graphiql():
6+
return Template('''
7+
<!DOCTYPE html>
8+
<html>
9+
<head>
10+
<meta charset="utf-8" />
11+
<title>GraphiQL</title>
12+
<meta name="robots" content="noindex" />
13+
<style>
14+
html, body {
15+
height: 100%;
16+
margin: 0;
17+
overflow: hidden;
18+
width: 100%;
19+
}
20+
</style>
21+
<link href="//cdn.jsdelivr.net/graphiql/${GRAPHIQL_VERSION}/graphiql.css" rel="stylesheet" />
22+
<script src="//cdn.jsdelivr.net/fetch/0.9.0/fetch.min.js"></script>
23+
<script src="//cdn.jsdelivr.net/react/15.0.0/react.min.js"></script>
24+
<script src="//cdn.jsdelivr.net/react/15.0.0/react-dom.min.js"></script>
25+
<script src="//cdn.jsdelivr.net/graphiql/${GRAPHIQL_VERSION}/graphiql.min.js"></script>
26+
<script src="//unpkg.com/subscriptions-transport-ws@${SUBSCRIPTIONS_TRANSPORT_VERSION}/browser/client.js"></script>
27+
<script src="//unpkg.com/graphiql-subscriptions-fetcher@0.0.2/browser/client.js"></script>
28+
</head>
29+
<body>
30+
<script>
31+
// Collect the URL parameters
32+
var parameters = {};
33+
window.location.search.substr(1).split('&').forEach(function (entry) {
34+
var eq = entry.indexOf('=');
35+
if (eq >= 0) {
36+
parameters[decodeURIComponent(entry.slice(0, eq))] =
37+
decodeURIComponent(entry.slice(eq + 1));
38+
}
39+
});
40+
// Produce a Location query string from a parameter object.
41+
function locationQuery(params, location) {
42+
return (location ? location: '') + '?' + Object.keys(params).map(function (key) {
43+
return encodeURIComponent(key) + '=' +
44+
encodeURIComponent(params[key]);
45+
}).join('&');
46+
}
47+
// Derive a fetch URL from the current URL, sans the GraphQL parameters.
48+
var graphqlParamNames = {
49+
query: true,
50+
variables: true,
51+
operationName: true
52+
};
53+
var otherParams = {};
54+
for (var k in parameters) {
55+
if (parameters.hasOwnProperty(k) && graphqlParamNames[k] !== true) {
56+
otherParams[k] = parameters[k];
57+
}
58+
}
59+
var fetcher;
60+
if (true) {
61+
var subscriptionsClient = new window.SubscriptionsTransportWs.SubscriptionClient('${subscriptionsEndpoint}', {
62+
reconnect: true
63+
});
64+
fetcher = window.GraphiQLSubscriptionsFetcher.graphQLFetcher(subscriptionsClient, graphQLFetcher);
65+
} else {
66+
fetcher = graphQLFetcher;
67+
}
68+
// We don't use safe-serialize for location, because it's not client input.
69+
var fetchURL = locationQuery(otherParams, '${endpointURL}');
70+
// Defines a GraphQL fetcher using the fetch API.
71+
function graphQLFetcher(graphQLParams) {
72+
return fetch(fetchURL, {
73+
method: 'post',
74+
headers: {
75+
'Accept': 'application/json',
76+
'Content-Type': 'application/json',
77+
},
78+
body: JSON.stringify(graphQLParams),
79+
credentials: 'include',
80+
}).then(function (response) {
81+
return response.text();
82+
}).then(function (responseBody) {
83+
try {
84+
return JSON.parse(responseBody);
85+
} catch (error) {
86+
return responseBody;
87+
}
88+
});
89+
}
90+
// When the query and variables string is edited, update the URL bar so
91+
// that it can be easily shared.
92+
function onEditQuery(newQuery) {
93+
parameters.query = newQuery;
94+
updateURL();
95+
}
96+
function onEditVariables(newVariables) {
97+
parameters.variables = newVariables;
98+
updateURL();
99+
}
100+
function onEditOperationName(newOperationName) {
101+
parameters.operationName = newOperationName;
102+
updateURL();
103+
}
104+
function updateURL() {
105+
history.replaceState(null, null, locationQuery(parameters) + window.location.hash);
106+
}
107+
// Render <GraphiQL /> into the body.
108+
ReactDOM.render(
109+
React.createElement(GraphiQL, {
110+
fetcher: fetcher,
111+
onEditQuery: onEditQuery,
112+
onEditVariables: onEditVariables,
113+
onEditOperationName: onEditOperationName,
114+
}),
115+
document.body
116+
);
117+
</script>
118+
</body>
119+
</html>''').substitute(
120+
GRAPHIQL_VERSION='0.10.2',
121+
SUBSCRIPTIONS_TRANSPORT_VERSION='0.7.0',
122+
subscriptionsEndpoint='ws://localhost:8000/subscriptions',
123+
endpointURL='/graphql',
124+
)

graphql_ws/websockets_lib.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
from inspect import isawaitable, isasyncgen
2+
3+
from asyncio import ensure_future
4+
from websockets import ConnectionClosed
5+
from graphql.execution.executors.asyncio import AsyncioExecutor
6+
7+
from .base import ConnectionClosedException, BaseConnectionContext, BaseSubscriptionServer
8+
from .observable_aiter import setup_observable_extension
9+
10+
from .constants import (
11+
GQL_CONNECTION_ACK,
12+
GQL_CONNECTION_ERROR,
13+
GQL_COMPLETE
14+
)
15+
16+
setup_observable_extension()
17+
18+
19+
class WsLibConnectionContext(BaseConnectionContext):
20+
async def receive(self):
21+
try:
22+
msg = await self.ws.recv()
23+
return msg
24+
except ConnectionClosed:
25+
raise ConnectionClosedException()
26+
27+
async def send(self, data):
28+
if self.closed:
29+
return
30+
await self.ws.send(data)
31+
32+
@property
33+
def closed(self):
34+
return self.ws.open is False
35+
36+
async def close(self, code):
37+
await self.ws.close(code)
38+
39+
40+
class WsLibSubscriptionServer(BaseSubscriptionServer):
41+
42+
def get_graphql_params(self, *args, **kwargs):
43+
params = super(WsLibSubscriptionServer,
44+
self).get_graphql_params(*args, **kwargs)
45+
return dict(params, return_promise=True, executor=AsyncioExecutor())
46+
47+
async def handle(self, ws, request_context=None):
48+
connection_context = WsLibConnectionContext(ws, request_context)
49+
await self.on_open(connection_context)
50+
while True:
51+
try:
52+
if connection_context.closed:
53+
raise ConnectionClosedException()
54+
message = await connection_context.receive()
55+
except ConnectionClosedException:
56+
self.on_close(connection_context)
57+
return
58+
59+
ensure_future(self.on_message(connection_context, message))
60+
61+
async def on_open(self, connection_context):
62+
pass
63+
64+
def on_close(self, connection_context):
65+
remove_operations = list(connection_context.operations.keys())
66+
for op_id in remove_operations:
67+
self.unsubscribe(connection_context, op_id)
68+
69+
async def on_connect(self, connection_context, payload):
70+
pass
71+
72+
async def on_connection_init(self, connection_context, op_id, payload):
73+
try:
74+
await self.on_connect(connection_context, payload)
75+
await self.send_message(connection_context, op_type=GQL_CONNECTION_ACK)
76+
except Exception as e:
77+
await self.send_error(connection_context, op_id, e, GQL_CONNECTION_ERROR)
78+
await connection_context.close(1011)
79+
80+
async def on_start(self, connection_context, op_id, params):
81+
execution_result = self.execute(
82+
connection_context.request_context, params)
83+
84+
if isawaitable(execution_result):
85+
execution_result = await execution_result
86+
87+
if not hasattr(execution_result, '__aiter__'):
88+
await self.send_execution_result(connection_context, op_id, execution_result)
89+
else:
90+
iterator = await execution_result.__aiter__()
91+
connection_context.register_operation(op_id, iterator)
92+
async for single_result in iterator:
93+
if not connection_context.has_operation(op_id):
94+
break
95+
await self.send_execution_result(connection_context, op_id, single_result)
96+
await self.send_message(connection_context, op_id, GQL_COMPLETE)
97+
98+
async def on_stop(self, connection_context, op_id):
99+
self.unsubscribe(connection_context, op_id)

0 commit comments

Comments
 (0)