diff --git a/README.md b/README.md index b071e54..de1a74f 100644 --- a/README.md +++ b/README.md @@ -5,6 +5,7 @@ Websocket server for GraphQL subscriptions. Currently supports: * [aiohttp](https://github.com/graphql-python/graphql-ws#aiohttp) * [Gevent](https://github.com/graphql-python/graphql-ws#gevent) +* Sanic (uses [websockets](https://github.com/aaugustin/websockets/) library) # Installation instructions @@ -40,6 +41,29 @@ app.router.add_get('/subscriptions', subscriptions) web.run_app(app, port=8000) ``` +### Sanic + +Works with any framework that uses the websockets library for +it's websocket implementation. For this example, plug in +your Sanic server. + +```python +from graphql_ws.websockets_lib import WsLibSubscriptionServer + + +app = Sanic(__name__) + +subscription_server = WsLibSubscriptionServer(schema) + +@app.websocket('/subscriptions', subprotocols=['graphql-ws']) +async def subscriptions(request, ws): + await subscription_server.handle(ws) + return ws + + +app.run(host="0.0.0.0", port=8000) +``` + And then, plug into a subscribable schema: ```python diff --git a/examples/websockets_lib/__init__.py b/examples/websockets_lib/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/examples/websockets_lib/app.py b/examples/websockets_lib/app.py new file mode 100644 index 0000000..0de6988 --- /dev/null +++ b/examples/websockets_lib/app.py @@ -0,0 +1,31 @@ +from graphql_ws.websockets_lib import WsLibSubscriptionServer +from graphql.execution.executors.asyncio import AsyncioExecutor +from sanic import Sanic, response +from sanic_graphql import GraphQLView +from schema import schema +from template import render_graphiql + +app = Sanic(__name__) + + +@app.listener('before_server_start') +def init_graphql(app, loop): + app.add_route(GraphQLView.as_view(schema=schema, + executor=AsyncioExecutor(loop=loop)), + '/graphql') + + +@app.route('/graphiql') +async def graphiql_view(request): + return response.html(render_graphiql()) + +subscription_server = WsLibSubscriptionServer(schema) + + +@app.websocket('/subscriptions', subprotocols=['graphql-ws']) +async def subscriptions(request, ws): + await subscription_server.handle(ws) + return ws + + +app.run(host="0.0.0.0", port=8000) diff --git a/examples/websockets_lib/requirements.txt b/examples/websockets_lib/requirements.txt new file mode 100644 index 0000000..6d3e723 --- /dev/null +++ b/examples/websockets_lib/requirements.txt @@ -0,0 +1,4 @@ +graphql_ws +sanic>=0.7.0 +graphene>=2.0 +sanic-graphql>=1.1.0 diff --git a/examples/websockets_lib/schema.py b/examples/websockets_lib/schema.py new file mode 100644 index 0000000..3c23d00 --- /dev/null +++ b/examples/websockets_lib/schema.py @@ -0,0 +1,34 @@ +import random +import asyncio +import graphene + + +class Query(graphene.ObjectType): + base = graphene.String() + + +class RandomType(graphene.ObjectType): + seconds = graphene.Int() + random_int = graphene.Int() + + +class Subscription(graphene.ObjectType): + count_seconds = graphene.Float(up_to=graphene.Int()) + random_int = graphene.Field(RandomType) + + async def resolve_count_seconds(root, info, up_to=5): + for i in range(up_to): + print("YIELD SECOND", i) + yield i + await asyncio.sleep(1.) + yield up_to + + async def resolve_random_int(root, info): + i = 0 + while True: + yield RandomType(seconds=i, random_int=random.randint(0, 500)) + await asyncio.sleep(1.) + i += 1 + + +schema = graphene.Schema(query=Query, subscription=Subscription) diff --git a/examples/websockets_lib/template.py b/examples/websockets_lib/template.py new file mode 100644 index 0000000..03587bb --- /dev/null +++ b/examples/websockets_lib/template.py @@ -0,0 +1,124 @@ + +from string import Template + + +def render_graphiql(): + return Template(''' + + + + + GraphiQL + + + + + + + + + + + + + +''').substitute( + GRAPHIQL_VERSION='0.10.2', + SUBSCRIPTIONS_TRANSPORT_VERSION='0.7.0', + subscriptionsEndpoint='ws://localhost:8000/subscriptions', + endpointURL='/graphql', + ) diff --git a/graphql_ws/websockets_lib.py b/graphql_ws/websockets_lib.py new file mode 100644 index 0000000..f41a1bb --- /dev/null +++ b/graphql_ws/websockets_lib.py @@ -0,0 +1,99 @@ +from inspect import isawaitable, isasyncgen + +from asyncio import ensure_future +from websockets import ConnectionClosed +from graphql.execution.executors.asyncio import AsyncioExecutor + +from .base import ConnectionClosedException, BaseConnectionContext, BaseSubscriptionServer +from .observable_aiter import setup_observable_extension + +from .constants import ( + GQL_CONNECTION_ACK, + GQL_CONNECTION_ERROR, + GQL_COMPLETE +) + +setup_observable_extension() + + +class WsLibConnectionContext(BaseConnectionContext): + async def receive(self): + try: + msg = await self.ws.recv() + return msg + except ConnectionClosed: + raise ConnectionClosedException() + + async def send(self, data): + if self.closed: + return + await self.ws.send(data) + + @property + def closed(self): + return self.ws.open is False + + async def close(self, code): + await self.ws.close(code) + + +class WsLibSubscriptionServer(BaseSubscriptionServer): + + def get_graphql_params(self, *args, **kwargs): + params = super(WsLibSubscriptionServer, + self).get_graphql_params(*args, **kwargs) + return dict(params, return_promise=True, executor=AsyncioExecutor()) + + async def handle(self, ws, request_context=None): + connection_context = WsLibConnectionContext(ws, request_context) + await self.on_open(connection_context) + while True: + try: + if connection_context.closed: + raise ConnectionClosedException() + message = await connection_context.receive() + except ConnectionClosedException: + self.on_close(connection_context) + return + + ensure_future(self.on_message(connection_context, message)) + + async def on_open(self, connection_context): + pass + + def on_close(self, connection_context): + remove_operations = list(connection_context.operations.keys()) + for op_id in remove_operations: + self.unsubscribe(connection_context, op_id) + + async def on_connect(self, connection_context, payload): + pass + + async def on_connection_init(self, connection_context, op_id, payload): + try: + await self.on_connect(connection_context, payload) + await self.send_message(connection_context, op_type=GQL_CONNECTION_ACK) + except Exception as e: + await self.send_error(connection_context, op_id, e, GQL_CONNECTION_ERROR) + await connection_context.close(1011) + + async def on_start(self, connection_context, op_id, params): + execution_result = self.execute( + connection_context.request_context, params) + + if isawaitable(execution_result): + execution_result = await execution_result + + if not hasattr(execution_result, '__aiter__'): + await self.send_execution_result(connection_context, op_id, execution_result) + else: + iterator = await execution_result.__aiter__() + connection_context.register_operation(op_id, iterator) + async for single_result in iterator: + if not connection_context.has_operation(op_id): + break + await self.send_execution_result(connection_context, op_id, single_result) + await self.send_message(connection_context, op_id, GQL_COMPLETE) + + async def on_stop(self, connection_context, op_id): + self.unsubscribe(connection_context, op_id)