Skip to content

Commit 598acd8

Browse files
committed
Fix type_adapter for subscriptions and async transports
1 parent 5b19de5 commit 598acd8

File tree

5 files changed

+204
-36
lines changed

5 files changed

+204
-36
lines changed

gql/client.py

Lines changed: 28 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -67,16 +67,21 @@ def __init__(
6767
# Enforced timeout of the execute function
6868
self.execute_timeout = execute_timeout
6969

70+
# Fetch schema from transport directly if we are using a sync transport
71+
if isinstance(transport, Transport) and fetch_schema_from_transport:
72+
with self as session:
73+
session.fetch_schema()
74+
7075
# Dictionary where the name of the custom scalar type is the key and the
7176
# value is a class which has a `parse_value()` function
77+
self.custom_types = custom_types
78+
79+
# Create a type_adapter instance directly here if we received the schema
80+
# locally or from a sync transport
7281
self.type_adapter = (
7382
TypeAdapter(schema, custom_types) if custom_types and schema else None
7483
)
7584

76-
if isinstance(transport, Transport) and fetch_schema_from_transport:
77-
with self as session:
78-
session.fetch_schema()
79-
8085
def validate(self, document):
8186
assert (
8287
self.schema
@@ -211,9 +216,7 @@ def _execute(self, document: DocumentNode, *args, **kwargs) -> ExecutionResult:
211216

212217
return self.transport.execute(document, *args, **kwargs)
213218

214-
def execute(
215-
self, document: DocumentNode, *args, **kwargs
216-
) -> Optional[Dict[str, Any]]:
219+
def execute(self, document: DocumentNode, *args, **kwargs) -> Dict[str, Any]:
217220

218221
# Validate and execute on the transport
219222
result = self._execute(document, *args, **kwargs)
@@ -227,11 +230,9 @@ def execute(
227230
), "Transport returned an ExecutionResult without data or errors"
228231

229232
if self.client.type_adapter:
230-
result = result._replace(
231-
data=self.client.type_adapter.convert_scalars(result.data)
232-
)
233-
234-
return result.data
233+
return self.client.type_adapter.convert_scalars(result.data)
234+
else:
235+
return result.data
235236

236237
def fetch_schema(self) -> None:
237238
execution_result = self.transport.execute(parse(get_introspection_query()))
@@ -263,6 +264,13 @@ async def fetch_and_validate(self, document: DocumentNode):
263264
if self.client.fetch_schema_from_transport and not self.client.schema:
264265
await self.fetch_schema()
265266

267+
# Once we have received the schema from the async transport,
268+
# we can create a TypeAdapter instance if the user provided custom types
269+
if self.client.custom_types and self.client.schema:
270+
self.client.type_adapter = TypeAdapter(
271+
self.client.schema, self.client.custom_types
272+
)
273+
266274
# Validate document
267275
if self.client.schema:
268276
self.client.validate(document)
@@ -293,7 +301,7 @@ async def _subscribe(
293301

294302
async def subscribe(
295303
self, document: DocumentNode, *args, **kwargs
296-
) -> AsyncGenerator[Optional[Dict[str, Any]], None]:
304+
) -> AsyncGenerator[Dict[str, Any], None]:
297305

298306
# Validate and subscribe on the transport
299307
async for result in self._subscribe(document, *args, **kwargs):
@@ -304,10 +312,9 @@ async def subscribe(
304312

305313
elif result.data is not None:
306314
if self.client.type_adapter:
307-
result = result._replace(
308-
data=self.client.type_adapter.convert_scalars(result.data)
309-
)
310-
yield result.data
315+
yield self.client.type_adapter.convert_scalars(result.data)
316+
else:
317+
yield result.data
311318

312319
async def _execute(
313320
self, document: DocumentNode, *args, **kwargs
@@ -322,9 +329,7 @@ async def _execute(
322329
self.client.execute_timeout,
323330
)
324331

325-
async def execute(
326-
self, document: DocumentNode, *args, **kwargs
327-
) -> Optional[Dict[str, Any]]:
332+
async def execute(self, document: DocumentNode, *args, **kwargs) -> Dict[str, Any]:
328333

329334
# Validate and execute on the transport
330335
result = await self._execute(document, *args, **kwargs)
@@ -338,11 +343,9 @@ async def execute(
338343
), "Transport returned an ExecutionResult without data or errors"
339344

340345
if self.client.type_adapter:
341-
result = result._replace(
342-
data=self.client.type_adapter.convert_scalars(result.data)
343-
)
344-
345-
return result.data
346+
return self.client.type_adapter.convert_scalars(result.data)
347+
else:
348+
return result.data
346349

347350
async def fetch_schema(self) -> None:
348351
execution_result = await self.transport.execute(

gql/type_adapter.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,11 @@ def traverse_schema(
8080
schema_root = self.schema.query_type
8181
elif self.schema.mutation_type and keys[0] in self.schema.mutation_type.fields:
8282
schema_root = self.schema.mutation_type
83+
elif (
84+
self.schema.subscription_type
85+
and keys[0] in self.schema.subscription_type.fields
86+
):
87+
schema_root = self.schema.subscription_type
8388
else:
8489
return None
8590

@@ -95,6 +100,7 @@ def _get_decoded_scalar_type(self, keys: List[str], value: str) -> str:
95100
96101
If it is a custom scalar, return the deserialized value, as
97102
output by `<CustomScalarType>.parse_value()`"""
103+
98104
scalar_type = self._lookup_scalar_type(keys)
99105
if scalar_type and scalar_type in self.custom_types:
100106
return self.custom_types[scalar_type].parse_value(value)

tests/test_async_client_validation.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,3 +267,47 @@ async def test_async_client_validation_fetch_schema_from_server_with_client_argu
267267

268268
with pytest.raises(graphql.error.GraphQLError):
269269
await session.execute(query)
270+
271+
272+
class ToLowercase:
273+
@staticmethod
274+
def parse_value(value: str):
275+
return value.lower()
276+
277+
278+
@pytest.mark.asyncio
279+
@pytest.mark.parametrize("server", [hero_server_answers], indirect=True)
280+
async def test_async_client_validation_fetch_schema_from_server_with_custom_types(
281+
event_loop, server
282+
):
283+
284+
url = f"ws://{server.hostname}:{server.port}/graphql"
285+
286+
sample_transport = WebsocketsTransport(url=url)
287+
288+
custom_types = {"String": ToLowercase}
289+
290+
async with Client(
291+
transport=sample_transport,
292+
fetch_schema_from_transport=True,
293+
custom_types=custom_types,
294+
) as session:
295+
296+
query = gql(
297+
"""
298+
query HeroNameQuery {
299+
hero {
300+
name
301+
}
302+
}
303+
"""
304+
)
305+
306+
result = await session.execute(query)
307+
308+
print("Client received:", result)
309+
310+
# The expected hero name is now in lowercase
311+
expected = {"hero": {"name": "r2-d2"}}
312+
313+
assert result == expected

tests/test_requests.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
)
1414
from gql.transport.requests import RequestsHTTPTransport
1515

16+
from .test_type_adapter import Capitalize
17+
1618
query1_str = """
1719
query getContinents {
1820
continents {
@@ -201,3 +203,57 @@ def test_code():
201203
sample_transport.execute(query)
202204

203205
await run_sync_test(event_loop, server, test_code)
206+
207+
208+
partial_schema = """
209+
210+
type Continent {
211+
code: ID!
212+
name: String!
213+
}
214+
215+
type Query {
216+
continents: [Continent!]!
217+
}
218+
219+
"""
220+
221+
222+
@pytest.mark.asyncio
223+
async def test_requests_query_with_custom_types(event_loop, aiohttp_server):
224+
async def handler(request):
225+
return web.Response(text=query1_server_answer, content_type="application/json")
226+
227+
app = web.Application()
228+
app.router.add_route("POST", "/", handler)
229+
server = await aiohttp_server(app)
230+
231+
url = server.make_url("/")
232+
233+
def test_code():
234+
sample_transport = RequestsHTTPTransport(url=url)
235+
236+
custom_types = {"String": Capitalize}
237+
238+
# Instanciate a client which will capitalize all the String scalars
239+
with Client(
240+
transport=sample_transport,
241+
type_def=partial_schema,
242+
custom_types=custom_types,
243+
) as session:
244+
245+
query = gql(query1_str)
246+
247+
# Execute query synchronously
248+
result = session.execute(query)
249+
250+
continents = result["continents"]
251+
252+
africa = continents[0]
253+
254+
assert africa["code"] == "AF"
255+
256+
# Check that the string is capitalized
257+
assert africa["name"] == "AFRICA"
258+
259+
await run_sync_test(event_loop, server, test_code)

tests/test_websocket_subscription.py

Lines changed: 70 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,24 @@
1010

1111
from .conftest import MS, WebSocketServer
1212

13-
countdown_server_answer = (
14-
'{{"type":"data","id":"{query_id}","payload":{{"data":{{"number":{number}}}}}}}'
15-
)
13+
countdown_schema = """
14+
15+
type Number {
16+
number: Int!
17+
}
18+
19+
type Query {
20+
currentCount: Number!
21+
}
22+
23+
type Subscription {
24+
countdown(count: Int!): Number!
25+
}
26+
27+
"""
28+
29+
countdown_server_answer = '{{"type":"data","id":"{query_id}","payload":\
30+
{{"data":{{"countdown":{{"number":{number}}}}}}}}}'
1631

1732
WITH_KEEPALIVE = False
1833

@@ -111,7 +126,7 @@ async def test_websocket_subscription(event_loop, client_and_server, subscriptio
111126

112127
async for result in session.subscribe(subscription):
113128

114-
number = result["number"]
129+
number = result["countdown"]["number"]
115130
print(f"Number received: {number}")
116131

117132
assert number == count
@@ -134,7 +149,7 @@ async def test_websocket_subscription_break(
134149

135150
async for result in session.subscribe(subscription):
136151

137-
number = result["number"]
152+
number = result["countdown"]["number"]
138153
print(f"Number received: {number}")
139154

140155
assert number == count
@@ -165,7 +180,7 @@ async def task_coro():
165180
nonlocal count
166181
async for result in session.subscribe(subscription):
167182

168-
number = result["number"]
183+
number = result["countdown"]["number"]
169184
print(f"Number received: {number}")
170185

171186
assert number == count
@@ -204,7 +219,7 @@ async def task_coro():
204219
nonlocal count
205220
async for result in session.subscribe(subscription):
206221

207-
number = result["number"]
222+
number = result["countdown"]["number"]
208223
print(f"Number received: {number}")
209224

210225
assert number == count
@@ -269,7 +284,7 @@ async def test_websocket_subscription_server_connection_closed(
269284

270285
async for result in session.subscribe(subscription):
271286

272-
number = result["number"]
287+
number = result["countdown"]["number"]
273288
print(f"Number received: {number}")
274289

275290
assert number == count
@@ -292,7 +307,7 @@ async def test_websocket_subscription_slow_consumer(
292307
async for result in session.subscribe(subscription):
293308
await asyncio.sleep(10 * MS)
294309

295-
number = result["number"]
310+
number = result["countdown"]["number"]
296311
print(f"Number received: {number}")
297312

298313
assert number == count
@@ -319,7 +334,7 @@ async def test_websocket_subscription_with_keepalive(
319334

320335
async for result in session.subscribe(subscription):
321336

322-
number = result["number"]
337+
number = result["countdown"]["number"]
323338
print(f"Number received: {number}")
324339

325340
assert number == count
@@ -344,10 +359,54 @@ def test_websocket_subscription_sync(server, subscription_str):
344359

345360
for result in client.subscribe(subscription):
346361

347-
number = result["number"]
362+
number = result["countdown"]["number"]
348363
print(f"Number received: {number}")
349364

350365
assert number == count
351366
count -= 1
352367

353368
assert count == -1
369+
370+
371+
class NumberAddParser:
372+
""" Class with a parse_value method used to increment a number """
373+
374+
def __init__(self, increment: int):
375+
self.increment: int = increment
376+
377+
def parse_value(self, value: int) -> int:
378+
return value + self.increment
379+
380+
381+
@pytest.mark.asyncio
382+
@pytest.mark.parametrize("server", [server_countdown], indirect=True)
383+
@pytest.mark.parametrize("subscription_str", [countdown_subscription_str])
384+
async def test_websocket_subscription_with_custom_types(
385+
event_loop, server, subscription_str
386+
):
387+
388+
url = f"ws://{server.hostname}:{server.port}/graphql"
389+
390+
sample_transport = WebsocketsTransport(url=url)
391+
392+
count = 10
393+
subscription = gql(subscription_str.format(count=count))
394+
395+
add_10 = NumberAddParser(10)
396+
397+
custom_types = {"Int": add_10}
398+
399+
# Instanciate a client which will add 10 to all the scalars of type Int received
400+
async with Client(
401+
transport=sample_transport, custom_types=custom_types, type_def=countdown_schema
402+
) as session:
403+
async for result in session.subscribe(subscription):
404+
405+
number = result["countdown"]["number"]
406+
print(f"Number received: {number}")
407+
408+
# We check here that the Int scalar has been correctly incremented by 10
409+
assert number == count + 10
410+
count -= 1
411+
412+
assert count == -1

0 commit comments

Comments
 (0)