Skip to content

Add websockets library context and server classes in order to use with Sanic #6

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jan 20, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ Websocket server for GraphQL subscriptions.
Currently supports:
* [aiohttp](https://github.com/graphql-python/graphql-ws#aiohttp)
* [Gevent](https://github.com/graphql-python/graphql-ws#gevent)
* Sanic (uses [websockets](https://github.com/aaugustin/websockets/) library)

# Installation instructions

Expand Down Expand Up @@ -40,6 +41,29 @@ app.router.add_get('/subscriptions', subscriptions)
web.run_app(app, port=8000)
```

### Sanic

Works with any framework that uses the websockets library for
it's websocket implementation. For this example, plug in
your Sanic server.

```python
from graphql_ws.websockets_lib import WsLibSubscriptionServer


app = Sanic(__name__)

subscription_server = WsLibSubscriptionServer(schema)

@app.websocket('/subscriptions', subprotocols=['graphql-ws'])
async def subscriptions(request, ws):
await subscription_server.handle(ws)
return ws


app.run(host="0.0.0.0", port=8000)
```

And then, plug into a subscribable schema:

```python
Expand Down
Empty file.
31 changes: 31 additions & 0 deletions examples/websockets_lib/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from graphql_ws.websockets_lib import WsLibSubscriptionServer
from graphql.execution.executors.asyncio import AsyncioExecutor
from sanic import Sanic, response
from sanic_graphql import GraphQLView
from schema import schema
from template import render_graphiql

app = Sanic(__name__)


@app.listener('before_server_start')
def init_graphql(app, loop):
app.add_route(GraphQLView.as_view(schema=schema,
executor=AsyncioExecutor(loop=loop)),
'/graphql')


@app.route('/graphiql')
async def graphiql_view(request):
return response.html(render_graphiql())

subscription_server = WsLibSubscriptionServer(schema)


@app.websocket('/subscriptions', subprotocols=['graphql-ws'])
async def subscriptions(request, ws):
await subscription_server.handle(ws)
return ws


app.run(host="0.0.0.0", port=8000)
4 changes: 4 additions & 0 deletions examples/websockets_lib/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
graphql_ws
sanic>=0.7.0
graphene>=2.0
sanic-graphql>=1.1.0
34 changes: 34 additions & 0 deletions examples/websockets_lib/schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import random
import asyncio
import graphene


class Query(graphene.ObjectType):
base = graphene.String()


class RandomType(graphene.ObjectType):
seconds = graphene.Int()
random_int = graphene.Int()


class Subscription(graphene.ObjectType):
count_seconds = graphene.Float(up_to=graphene.Int())
random_int = graphene.Field(RandomType)

async def resolve_count_seconds(root, info, up_to=5):
for i in range(up_to):
print("YIELD SECOND", i)
yield i
await asyncio.sleep(1.)
yield up_to

async def resolve_random_int(root, info):
i = 0
while True:
yield RandomType(seconds=i, random_int=random.randint(0, 500))
await asyncio.sleep(1.)
i += 1


schema = graphene.Schema(query=Query, subscription=Subscription)
124 changes: 124 additions & 0 deletions examples/websockets_lib/template.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@

from string import Template


def render_graphiql():
return Template('''
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8" />
<title>GraphiQL</title>
<meta name="robots" content="noindex" />
<style>
html, body {
height: 100%;
margin: 0;
overflow: hidden;
width: 100%;
}
</style>
<link href="//cdn.jsdelivr.net/graphiql/${GRAPHIQL_VERSION}/graphiql.css" rel="stylesheet" />
<script src="//cdn.jsdelivr.net/fetch/0.9.0/fetch.min.js"></script>
<script src="//cdn.jsdelivr.net/react/15.0.0/react.min.js"></script>
<script src="//cdn.jsdelivr.net/react/15.0.0/react-dom.min.js"></script>
<script src="//cdn.jsdelivr.net/graphiql/${GRAPHIQL_VERSION}/graphiql.min.js"></script>
<script src="//unpkg.com/subscriptions-transport-ws@${SUBSCRIPTIONS_TRANSPORT_VERSION}/browser/client.js"></script>
<script src="//unpkg.com/graphiql-subscriptions-fetcher@0.0.2/browser/client.js"></script>
</head>
<body>
<script>
// Collect the URL parameters
var parameters = {};
window.location.search.substr(1).split('&').forEach(function (entry) {
var eq = entry.indexOf('=');
if (eq >= 0) {
parameters[decodeURIComponent(entry.slice(0, eq))] =
decodeURIComponent(entry.slice(eq + 1));
}
});
// Produce a Location query string from a parameter object.
function locationQuery(params, location) {
return (location ? location: '') + '?' + Object.keys(params).map(function (key) {
return encodeURIComponent(key) + '=' +
encodeURIComponent(params[key]);
}).join('&');
}
// Derive a fetch URL from the current URL, sans the GraphQL parameters.
var graphqlParamNames = {
query: true,
variables: true,
operationName: true
};
var otherParams = {};
for (var k in parameters) {
if (parameters.hasOwnProperty(k) && graphqlParamNames[k] !== true) {
otherParams[k] = parameters[k];
}
}
var fetcher;
if (true) {
var subscriptionsClient = new window.SubscriptionsTransportWs.SubscriptionClient('${subscriptionsEndpoint}', {
reconnect: true
});
fetcher = window.GraphiQLSubscriptionsFetcher.graphQLFetcher(subscriptionsClient, graphQLFetcher);
} else {
fetcher = graphQLFetcher;
}
// We don't use safe-serialize for location, because it's not client input.
var fetchURL = locationQuery(otherParams, '${endpointURL}');
// Defines a GraphQL fetcher using the fetch API.
function graphQLFetcher(graphQLParams) {
return fetch(fetchURL, {
method: 'post',
headers: {
'Accept': 'application/json',
'Content-Type': 'application/json',
},
body: JSON.stringify(graphQLParams),
credentials: 'include',
}).then(function (response) {
return response.text();
}).then(function (responseBody) {
try {
return JSON.parse(responseBody);
} catch (error) {
return responseBody;
}
});
}
// When the query and variables string is edited, update the URL bar so
// that it can be easily shared.
function onEditQuery(newQuery) {
parameters.query = newQuery;
updateURL();
}
function onEditVariables(newVariables) {
parameters.variables = newVariables;
updateURL();
}
function onEditOperationName(newOperationName) {
parameters.operationName = newOperationName;
updateURL();
}
function updateURL() {
history.replaceState(null, null, locationQuery(parameters) + window.location.hash);
}
// Render <GraphiQL /> into the body.
ReactDOM.render(
React.createElement(GraphiQL, {
fetcher: fetcher,
onEditQuery: onEditQuery,
onEditVariables: onEditVariables,
onEditOperationName: onEditOperationName,
}),
document.body
);
</script>
</body>
</html>''').substitute(
GRAPHIQL_VERSION='0.10.2',
SUBSCRIPTIONS_TRANSPORT_VERSION='0.7.0',
subscriptionsEndpoint='ws://localhost:8000/subscriptions',
endpointURL='/graphql',
)
99 changes: 99 additions & 0 deletions graphql_ws/websockets_lib.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from inspect import isawaitable, isasyncgen

from asyncio import ensure_future
from websockets import ConnectionClosed
from graphql.execution.executors.asyncio import AsyncioExecutor

from .base import ConnectionClosedException, BaseConnectionContext, BaseSubscriptionServer
from .observable_aiter import setup_observable_extension

from .constants import (
GQL_CONNECTION_ACK,
GQL_CONNECTION_ERROR,
GQL_COMPLETE
)

setup_observable_extension()


class WsLibConnectionContext(BaseConnectionContext):
async def receive(self):
try:
msg = await self.ws.recv()
return msg
except ConnectionClosed:
raise ConnectionClosedException()

async def send(self, data):
if self.closed:
return
await self.ws.send(data)

@property
def closed(self):
return self.ws.open is False

async def close(self, code):
await self.ws.close(code)


class WsLibSubscriptionServer(BaseSubscriptionServer):

def get_graphql_params(self, *args, **kwargs):
params = super(WsLibSubscriptionServer,
self).get_graphql_params(*args, **kwargs)
return dict(params, return_promise=True, executor=AsyncioExecutor())

async def handle(self, ws, request_context=None):
connection_context = WsLibConnectionContext(ws, request_context)
await self.on_open(connection_context)
while True:
try:
if connection_context.closed:
raise ConnectionClosedException()
message = await connection_context.receive()
except ConnectionClosedException:
self.on_close(connection_context)
return

ensure_future(self.on_message(connection_context, message))

async def on_open(self, connection_context):
pass

def on_close(self, connection_context):
remove_operations = list(connection_context.operations.keys())
for op_id in remove_operations:
self.unsubscribe(connection_context, op_id)

async def on_connect(self, connection_context, payload):
pass

async def on_connection_init(self, connection_context, op_id, payload):
try:
await self.on_connect(connection_context, payload)
await self.send_message(connection_context, op_type=GQL_CONNECTION_ACK)
except Exception as e:
await self.send_error(connection_context, op_id, e, GQL_CONNECTION_ERROR)
await connection_context.close(1011)

async def on_start(self, connection_context, op_id, params):
execution_result = self.execute(
connection_context.request_context, params)

if isawaitable(execution_result):
execution_result = await execution_result

if not hasattr(execution_result, '__aiter__'):
await self.send_execution_result(connection_context, op_id, execution_result)
else:
iterator = await execution_result.__aiter__()
connection_context.register_operation(op_id, iterator)
async for single_result in iterator:
if not connection_context.has_operation(op_id):
break
await self.send_execution_result(connection_context, op_id, single_result)
await self.send_message(connection_context, op_id, GQL_COMPLETE)

async def on_stop(self, connection_context, op_id):
self.unsubscribe(connection_context, op_id)