Skip to content

Commit f740aa6

Browse files
committed
Merge branch 'master' into tests_complete_coverage
2 parents fb45bcb + 5348b0a commit f740aa6

File tree

4 files changed

+132
-26
lines changed

4 files changed

+132
-26
lines changed

gql/client.py

Lines changed: 45 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import asyncio
2-
from inspect import isawaitable
3-
from typing import Any, AsyncGenerator, Dict, Generator, Optional, Union, cast
2+
from typing import Any, AsyncGenerator, Dict, Generator, Optional, Union
43

54
from graphql import (
65
DocumentNode,
@@ -196,19 +195,22 @@ class SyncClientSession:
196195
def __init__(self, client: Client):
197196
self.client = client
198197

199-
def execute(self, document: DocumentNode, *args, **kwargs) -> Dict:
198+
def _execute(self, document: DocumentNode, *args, **kwargs) -> ExecutionResult:
200199

201200
# Validate document
202201
if self.client.schema:
203202
self.client.validate(document)
204203

205-
result = self.transport.execute(document, *args, **kwargs)
204+
return self.transport.execute(document, *args, **kwargs)
205+
206+
def execute(self, document: DocumentNode, *args, **kwargs) -> Dict:
206207

207-
assert not isawaitable(result), "Transport returned an awaitable result."
208-
result = cast(ExecutionResult, result)
208+
# Validate and execute on the transport
209+
result = self._execute(document, *args, **kwargs)
209210

211+
# Raise an error if an error is returned in the ExecutionResult object
210212
if result.errors:
211-
raise TransportQueryError(str(result.errors[0]))
213+
raise TransportQueryError(str(result.errors[0]), errors=result.errors)
212214

213215
assert (
214216
result.data is not None
@@ -250,43 +252,69 @@ async def fetch_and_validate(self, document: DocumentNode):
250252
if self.client.schema:
251253
self.client.validate(document)
252254

253-
async def subscribe(
255+
async def _subscribe(
254256
self, document: DocumentNode, *args, **kwargs
255-
) -> AsyncGenerator[Dict, None]:
257+
) -> AsyncGenerator[ExecutionResult, None]:
256258

257259
# Fetch schema from transport if needed and validate document if possible
258260
await self.fetch_and_validate(document)
259261

260-
# Subscribe to the transport and yield data or raise error
261-
self._generator: AsyncGenerator[
262+
# Subscribe to the transport
263+
inner_generator: AsyncGenerator[
262264
ExecutionResult, None
263265
] = self.transport.subscribe(document, *args, **kwargs)
264266

265-
async for result in self._generator:
267+
# Keep a reference to the inner generator to allow the user to call aclose()
268+
# before a break if python version is too old (pypy3 py 3.6.1)
269+
self._generator = inner_generator
270+
271+
async for result in inner_generator:
266272
if result.errors:
267273
# Note: we need to run generator.aclose() here or the finally block in
268274
# transport.subscribe will not be reached in pypy3 (py 3.6.1)
269-
await self._generator.aclose()
275+
await inner_generator.aclose()
276+
277+
yield result
278+
279+
async def subscribe(
280+
self, document: DocumentNode, *args, **kwargs
281+
) -> AsyncGenerator[Dict, None]:
270282

271-
raise TransportQueryError(str(result.errors[0]))
283+
# Validate and subscribe on the transport
284+
async for result in self._subscribe(document, *args, **kwargs):
285+
286+
# Raise an error if an error is returned in the ExecutionResult object
287+
if result.errors:
288+
raise TransportQueryError(str(result.errors[0]), errors=result.errors)
272289

273290
elif result.data is not None:
274291
yield result.data
275292

276-
async def execute(self, document: DocumentNode, *args, **kwargs) -> Dict:
293+
async def _execute(
294+
self, document: DocumentNode, *args, **kwargs
295+
) -> ExecutionResult:
277296

278297
# Fetch schema from transport if needed and validate document if possible
279298
await self.fetch_and_validate(document)
280299

281300
# Execute the query with the transport with a timeout
282-
result = await asyncio.wait_for(
301+
return await asyncio.wait_for(
283302
self.transport.execute(document, *args, **kwargs),
284303
self.client.execute_timeout,
285304
)
286305

306+
async def execute(self, document: DocumentNode, *args, **kwargs) -> Dict:
307+
308+
# Validate and execute on the transport
309+
result = await self._execute(document, *args, **kwargs)
310+
287311
# Raise an error if an error is returned in the ExecutionResult object
288312
if result.errors:
289-
raise TransportQueryError(str(result.errors[0]))
313+
raise TransportQueryError(str(result.errors[0]), errors=result.errors)
314+
315+
assert (
316+
result.data is not None
317+
), "Transport returned an ExecutionResult without data or errors"
290318

291319
return result.data
292320

gql/transport/exceptions.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
from typing import Any, List, Optional
2+
3+
14
class TransportError(Exception):
25
pass
36

@@ -22,9 +25,15 @@ class TransportQueryError(Exception):
2225
This exception should not close the transport connection.
2326
"""
2427

25-
def __init__(self, msg, query_id=None):
28+
def __init__(
29+
self,
30+
msg: str,
31+
query_id: Optional[int] = None,
32+
errors: Optional[List[Any]] = None,
33+
):
2634
super().__init__(msg)
2735
self.query_id = query_id
36+
self.errors = errors
2837

2938

3039
class TransportClosed(TransportError):

gql/transport/websockets.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,8 @@ def __init__(
134134
self._no_more_listeners: asyncio.Event = asyncio.Event()
135135
self._no_more_listeners.set()
136136

137+
self._connecting: bool = False
138+
137139
self.close_exception: Optional[Exception] = None
138140

139141
async def _send(self, message: str) -> None:
@@ -291,7 +293,9 @@ def _parse_answer(
291293

292294
elif answer_type == "error":
293295

294-
raise TransportQueryError(str(payload), query_id=answer_id)
296+
raise TransportQueryError(
297+
str(payload), query_id=answer_id, errors=[payload]
298+
)
295299

296300
elif answer_type == "ka":
297301
# KeepAlive message
@@ -333,6 +337,9 @@ async def _receive_data_loop(self) -> None:
333337
# ==> Add an exception to this query queue
334338
# The exception is raised for this specific query,
335339
# but the transport is not closed.
340+
assert isinstance(
341+
e.query_id, int
342+
), "TransportQueryError should have a query_id defined here"
336343
try:
337344
await self.listeners[e.query_id].set_exception(e)
338345
except KeyError:
@@ -467,7 +474,11 @@ async def connect(self) -> None:
467474

468475
GRAPHQLWS_SUBPROTOCOL: Subprotocol = cast(Subprotocol, "graphql-ws")
469476

470-
if self.websocket is None:
477+
if self.websocket is None and not self._connecting:
478+
479+
# Set connecting to True to avoid a race condition if user is trying
480+
# to connect twice using the same client at the same time
481+
self._connecting = True
471482

472483
# If the ssl parameter is not provided,
473484
# generate the ssl value depending on the url
@@ -489,9 +500,13 @@ async def connect(self) -> None:
489500

490501
# Connection to the specified url
491502
# Generate a TimeoutError if taking more than connect_timeout seconds
492-
self.websocket = await asyncio.wait_for(
493-
websockets.connect(self.url, **connect_args,), self.connect_timeout,
494-
)
503+
# Set the _connecting flag to False after in all cases
504+
try:
505+
self.websocket = await asyncio.wait_for(
506+
websockets.connect(self.url, **connect_args,), self.connect_timeout,
507+
)
508+
finally:
509+
self._connecting = False
495510

496511
self.next_query_id = 1
497512
self.close_exception = None

tests/test_websocket_exceptions.py

Lines changed: 57 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
import asyncio
22
import json
33
import types
4+
from typing import List
45

56
import pytest
67
import websockets
78

89
from gql import Client, gql
910
from gql.transport.exceptions import (
11+
TransportAlreadyConnected,
1012
TransportClosed,
1113
TransportProtocolError,
1214
TransportQueryError,
@@ -44,9 +46,17 @@ async def test_websocket_invalid_query(event_loop, client_and_server, query_str)
4446

4547
query = gql(query_str)
4648

47-
with pytest.raises(TransportQueryError):
49+
with pytest.raises(TransportQueryError) as exc_info:
4850
await session.execute(query)
4951

52+
exception = exc_info.value
53+
54+
assert isinstance(exception.errors, List)
55+
56+
error = exception.errors[0]
57+
58+
assert error["extensions"]["code"] == "INTERNAL_SERVER_ERROR"
59+
5060

5161
invalid_subscription_str = """
5262
subscription getContinents {
@@ -75,10 +85,18 @@ async def test_websocket_invalid_subscription(event_loop, client_and_server, que
7585

7686
query = gql(query_str)
7787

78-
with pytest.raises(TransportQueryError):
88+
with pytest.raises(TransportQueryError) as exc_info:
7989
async for result in session.subscribe(query):
8090
pass
8191

92+
exception = exc_info.value
93+
94+
assert isinstance(exception.errors, List)
95+
96+
error = exception.errors[0]
97+
98+
assert error["extensions"]["code"] == "INTERNAL_SERVER_ERROR"
99+
82100

83101
connection_error_server_answer = (
84102
'{"type":"connection_error","id":null,'
@@ -170,9 +188,17 @@ async def monkey_patch_send_query(
170188

171189
query = gql(query_str)
172190

173-
with pytest.raises(TransportQueryError):
191+
with pytest.raises(TransportQueryError) as exc_info:
174192
await session.execute(query)
175193

194+
exception = exc_info.value
195+
196+
assert isinstance(exception.errors, List)
197+
198+
error = exception.errors[0]
199+
200+
assert error["message"] == "Must provide document"
201+
176202

177203
not_json_answer = ["BLAHBLAH"]
178204
missing_type_answer = ["{}"]
@@ -294,3 +320,31 @@ async def test_websocket_server_sending_invalid_query_errors(event_loop, server)
294320
# Invalid server message is ignored
295321
async with Client(transport=sample_transport):
296322
await asyncio.sleep(2 * MS)
323+
324+
325+
@pytest.mark.asyncio
326+
@pytest.mark.parametrize("server", [server_sending_invalid_query_errors], indirect=True)
327+
async def test_websocket_non_regression_bug_105(event_loop, server):
328+
329+
# This test will check a fix to a race condition which happens if the user is trying
330+
# to connect using the same client twice at the same time
331+
# See bug #105
332+
333+
url = f"ws://{server.hostname}:{server.port}/graphql"
334+
print(f"url = {url}")
335+
336+
sample_transport = WebsocketsTransport(url=url)
337+
338+
client = Client(transport=sample_transport)
339+
340+
# Create a coroutine which start the connection with the transport but does nothing
341+
async def client_connect(client):
342+
async with client:
343+
await asyncio.sleep(2 * MS)
344+
345+
# Create two tasks which will try to connect using the same client (not allowed)
346+
connect_task1 = asyncio.ensure_future(client_connect(client))
347+
connect_task2 = asyncio.ensure_future(client_connect(client))
348+
349+
with pytest.raises(TransportAlreadyConnected):
350+
await asyncio.gather(connect_task1, connect_task2)

0 commit comments

Comments
 (0)