Skip to content

Commit de8ced3

Browse files
committed
Move unsubscribe logic to the connection context
1 parent 583f3f0 commit de8ced3

File tree

5 files changed

+47
-44
lines changed

5 files changed

+47
-44
lines changed

graphql_ws/base.py

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,18 @@ def get_operation(self, op_id):
3535

3636
def remove_operation(self, op_id):
3737
try:
38-
del self.operations[op_id]
38+
return self.operations.pop(op_id)
3939
except KeyError:
40-
pass
40+
return
41+
42+
def unsubscribe(self, op_id):
43+
async_iterator = self.remove_operation(op_id)
44+
if hasattr(async_iterator, 'dispose'):
45+
async_iterator.dispose()
46+
47+
def unsubscribe_all(self):
48+
for op_id in list(self.operations):
49+
self.unsubscribe(op_id)
4150

4251
def receive(self):
4352
raise NotImplementedError("receive method not implemented")
@@ -76,12 +85,6 @@ def process_message(self, connection_context, parsed_message):
7685

7786
elif op_type == GQL_START:
7887
assert isinstance(payload, dict), "The payload must be a dict"
79-
80-
# If we already have a subscription with this id, unsubscribe from
81-
# it first
82-
if connection_context.has_operation(op_id):
83-
self.unsubscribe(connection_context, op_id)
84-
8588
params = self.get_graphql_params(connection_context, payload)
8689
return self.on_start(connection_context, op_id, params)
8790

@@ -116,7 +119,10 @@ def on_open(self, connection_context):
116119
raise NotImplementedError("on_open method not implemented")
117120

118121
def on_stop(self, connection_context, op_id):
119-
raise NotImplementedError("on_stop method not implemented")
122+
return connection_context.unsubscribe(op_id)
123+
124+
def on_close(self, connection_context):
125+
return connection_context.unsubscribe_all()
120126

121127
def send_message(self, connection_context, op_id=None, op_type=None, payload=None):
122128
message = self.build_message(op_id, op_type, payload)
@@ -171,11 +177,3 @@ def on_message(self, connection_context, message):
171177
return self.send_error(connection_context, None, e)
172178

173179
return self.process_message(connection_context, parsed_message)
174-
175-
def unsubscribe(self, connection_context, op_id):
176-
if connection_context.has_operation(op_id):
177-
# Close async iterator
178-
connection_context.get_operation(op_id).dispose()
179-
# Close operation
180-
connection_context.remove_operation(op_id)
181-
return self.on_operation_complete(connection_context, op_id)

graphql_ws/base_async.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,20 @@ def remember_task(self, task):
8888
task for task in self.pending_tasks if task.done()
8989
)
9090

91+
async def unsubscribe(self, op_id):
92+
super().unsubscribe(op_id)
93+
94+
async def unsubscribe_all(self):
95+
awaitables = [self.unsubscribe(op_id) for op_id in list(self.operations)]
96+
for task in self.pending_tasks:
97+
task.cancel()
98+
awaitables.append(task)
99+
if awaitables:
100+
try:
101+
await asyncio.gather(*awaitables)
102+
except asyncio.CancelledError:
103+
pass
104+
91105

92106
class BaseAsyncSubscriptionServer(base.BaseSubscriptionServer, ABC):
93107
graphql_executor = AsyncioExecutor
@@ -125,6 +139,10 @@ async def on_connection_init(self, connection_context, op_id, payload):
125139
await connection_context.close(1011)
126140

127141
async def on_start(self, connection_context, op_id, params):
142+
# Attempt to unsubscribe first in case we already have a subscription
143+
# with this id.
144+
await connection_context.unsubscribe(op_id)
145+
128146
execution_result = self.execute(params)
129147

130148
if is_awaitable(execution_result):
@@ -153,20 +171,6 @@ async def on_start(self, connection_context, op_id, params):
153171
await self.send_message(connection_context, op_id, GQL_COMPLETE)
154172
await self.on_operation_complete(connection_context, op_id)
155173

156-
async def on_close(self, connection_context):
157-
awaitables = tuple(
158-
self.unsubscribe(connection_context, op_id)
159-
for op_id in connection_context.operations
160-
) + tuple(task.cancel() for task in connection_context.pending_tasks)
161-
if awaitables:
162-
try:
163-
await asyncio.gather(*awaitables, loop=self.loop)
164-
except asyncio.CancelledError:
165-
pass
166-
167-
async def on_stop(self, connection_context, op_id):
168-
await self.unsubscribe(connection_context, op_id)
169-
170174
async def on_operation_complete(self, connection_context, op_id):
171175
pass
172176

graphql_ws/base_sync.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,6 @@ def on_open(self, connection_context):
2020
def on_connect(self, connection_context, payload):
2121
pass
2222

23-
def on_close(self, connection_context):
24-
remove_operations = list(connection_context.operations)
25-
for op_id in remove_operations:
26-
self.unsubscribe(connection_context, op_id)
27-
2823
def on_connection_init(self, connection_context, op_id, payload):
2924
try:
3025
self.on_connect(connection_context, payload)
@@ -34,10 +29,10 @@ def on_connection_init(self, connection_context, op_id, payload):
3429
self.send_error(connection_context, op_id, e, GQL_CONNECTION_ERROR)
3530
connection_context.close(1011)
3631

37-
def on_stop(self, connection_context, op_id):
38-
self.unsubscribe(connection_context, op_id)
39-
4032
def on_start(self, connection_context, op_id, params):
33+
# Attempt to unsubscribe first in case we already have a subscription
34+
# with this id.
35+
connection_context.unsubscribe(op_id)
4136
try:
4237
execution_result = self.execute(params)
4338
assert isinstance(

tests/test_base.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,13 @@ def test_not_implemented():
1616
server.on_connection_init(connection_context=None, op_id=1, payload={})
1717
with pytest.raises(NotImplementedError):
1818
server.on_open(connection_context=None)
19-
with pytest.raises(NotImplementedError):
20-
server.on_stop(connection_context=None, op_id=1)
19+
20+
21+
def test_on_stop():
22+
server = base.BaseSubscriptionServer(schema=None)
23+
context = mock.Mock()
24+
server.on_stop(connection_context=context, op_id=1)
25+
context.unsubscribe.assert_called_with(1)
2126

2227

2328
def test_terminate():

tests/test_graphql_ws.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,12 +94,12 @@ def test_start_existing_op(self, ss, cc):
9494
ss.get_graphql_params.return_value = {"params": True}
9595
cc.has_operation = mock.Mock()
9696
cc.has_operation.return_value = True
97-
ss.unsubscribe = mock.Mock()
97+
cc.unsubscribe = mock.Mock()
9898
ss.on_start = mock.Mock()
9999
ss.process_message(
100100
cc, {"id": "1", "type": constants.GQL_START, "payload": {"a": "b"}}
101101
)
102-
assert ss.unsubscribe.called
102+
assert cc.unsubscribe.called
103103
ss.on_start.assert_called_with(cc, "1", {"params": True})
104104

105105
def test_start_bad_graphql_params(self, ss, cc):
@@ -162,7 +162,8 @@ def test_build_message_partial(ss):
162162
assert ss.build_message(id=None, op_type=None, payload="PAYLOAD") == {
163163
"payload": "PAYLOAD"
164164
}
165-
assert ss.build_message(id=None, op_type=None, payload=None) == {}
165+
with pytest.raises(AssertionError):
166+
ss.build_message(id=None, op_type=None, payload=None)
166167

167168

168169
def test_send_execution_result(ss):

0 commit comments

Comments
 (0)