diff --git a/gql/client.py b/gql/client.py index c4dca409..7b042a68 100644 --- a/gql/client.py +++ b/gql/client.py @@ -17,6 +17,7 @@ from .transport.exceptions import TransportQueryError from .transport.local_schema import LocalSchemaTransport from .transport.transport import Transport +from .type_adapter import TypeAdapter class Client: @@ -28,6 +29,7 @@ def __init__( transport: Optional[Union[Transport, AsyncTransport]] = None, fetch_schema_from_transport: bool = False, execute_timeout: Optional[int] = 10, + custom_types: Optional[Dict[str, Any]] = None, ): assert not ( type_def and introspection @@ -77,10 +79,21 @@ def __init__( # Enforced timeout of the execute function self.execute_timeout = execute_timeout + # Fetch schema from transport directly if we are using a sync transport if isinstance(transport, Transport) and fetch_schema_from_transport: with self as session: session.fetch_schema() + # Dictionary where the name of the custom scalar type is the key and the + # value is a class which has a `parse_value()` function + self.custom_types = custom_types + + # Create a type_adapter instance directly here if we received the schema + # locally or from a sync transport + self.type_adapter = ( + TypeAdapter(schema, custom_types) if custom_types and schema else None + ) + def validate(self, document): assert ( self.schema @@ -233,7 +246,7 @@ def _execute(self, document: DocumentNode, *args, **kwargs) -> ExecutionResult: return self.transport.execute(document, *args, **kwargs) - def execute(self, document: DocumentNode, *args, **kwargs) -> Dict: + def execute(self, document: DocumentNode, *args, **kwargs) -> Dict[str, Any]: # Validate and execute on the transport result = self._execute(document, *args, **kwargs) @@ -248,7 +261,10 @@ def execute(self, document: DocumentNode, *args, **kwargs) -> Dict: result.data is not None ), "Transport returned an ExecutionResult without data or errors" - return result.data + if self.client.type_adapter: + return self.client.type_adapter.convert_scalars(result.data) + else: + return result.data def fetch_schema(self) -> None: execution_result = self.transport.execute(parse(get_introspection_query())) @@ -280,6 +296,13 @@ async def fetch_and_validate(self, document: DocumentNode): if self.client.fetch_schema_from_transport and not self.client.schema: await self.fetch_schema() + # Once we have received the schema from the async transport, + # we can create a TypeAdapter instance if the user provided custom types + if self.client.custom_types and self.client.schema: + self.client.type_adapter = TypeAdapter( + self.client.schema, self.client.custom_types + ) + # Validate document if self.client.schema: self.client.validate(document) @@ -310,7 +333,7 @@ async def _subscribe( async def subscribe( self, document: DocumentNode, *args, **kwargs - ) -> AsyncGenerator[Dict, None]: + ) -> AsyncGenerator[Dict[str, Any], None]: # Validate and subscribe on the transport async for result in self._subscribe(document, *args, **kwargs): @@ -322,7 +345,10 @@ async def subscribe( ) elif result.data is not None: - yield result.data + if self.client.type_adapter: + yield self.client.type_adapter.convert_scalars(result.data) + else: + yield result.data async def _execute( self, document: DocumentNode, *args, **kwargs @@ -337,7 +363,7 @@ async def _execute( self.client.execute_timeout, ) - async def execute(self, document: DocumentNode, *args, **kwargs) -> Dict: + async def execute(self, document: DocumentNode, *args, **kwargs) -> Dict[str, Any]: # Validate and execute on the transport result = await self._execute(document, *args, **kwargs) @@ -352,7 +378,10 @@ async def execute(self, document: DocumentNode, *args, **kwargs) -> Dict: result.data is not None ), "Transport returned an ExecutionResult without data or errors" - return result.data + if self.client.type_adapter: + return self.client.type_adapter.convert_scalars(result.data) + else: + return result.data async def fetch_schema(self) -> None: execution_result = await self.transport.execute( diff --git a/gql/type_adapter.py b/gql/type_adapter.py new file mode 100644 index 00000000..6a987f27 --- /dev/null +++ b/gql/type_adapter.py @@ -0,0 +1,137 @@ +from typing import Any, Dict, List, Optional, Union, cast + +from graphql import GraphQLSchema +from graphql.type.definition import GraphQLField, GraphQLObjectType, GraphQLScalarType + + +class TypeAdapter: + """Substitute custom scalars in a GQL response with their decoded counterparts. + + GQL custom scalar types are defined on the GQL schema and are used to represent + fields which have special behaviour. To define custom scalar type, you need + the type name, and a class which has a class method called `parse_value()` - + this is the function which will be used to deserialize the custom scalar field. + + We first need iterate over all the fields in the response (which is done in + the `_traverse()` function). + + Each time we find a field which is a custom scalar (it's type name appears + as a key in self.custom_types), we replace the value of that field with the + decoded value. All of this logic happens in `_substitute()`. + + Public Interface: + apply(): pass in a GQL response to replace all instances of custom + scalar strings with their deserialized representation.""" + + def __init__( + self, schema: GraphQLSchema, custom_types: Optional[Dict[str, Any]] = None + ): + self.schema = schema + self.custom_types = custom_types or {} + + @staticmethod + def _follow_type_chain(node): + """ Get the type of the schema node in question. + + In the GraphQL schema, GraphQLFields have a "type" property. However, often + that dict has an "of_type" property itself. In order to get to the actual + type, we need to indefinitely follow the chain of "of_type" fields to get + to the last one, which is the one we care about.""" + if isinstance(node, GraphQLObjectType): + return node + + field_type = node.type + while hasattr(field_type, "of_type"): + field_type = field_type.of_type + + return field_type + + def _get_scalar_type_name(self, field): + """Returns the name of the type if the type is a scalar type. + Returns `None` otherwise""" + node = self._follow_type_chain(field) + if isinstance(node, GraphQLScalarType): + return node.name + return None + + def _lookup_scalar_type(self, keys: List[str]): + """Search through the GQL schema and return the type identified by 'keys'. + + If keys (e.g. ['film', 'release_date']) points to a scalar type, then + this function returns the name of that type. (e.g. 'DateTime') + + If it is not a scalar type (e..g a GraphQLObject), then this + function returns `None`. + + `keys` is a breadcrumb trail telling us where to look in the GraphQL schema. + By default the root level is `schema.query`, if that fails, then we check + `schema.mutation`.""" + + def traverse_schema( + node: Optional[Union[GraphQLObjectType, GraphQLField]], lookup + ): + if not lookup: + return self._get_scalar_type_name(node) + + final_node = self._follow_type_chain(node) + return traverse_schema(final_node.fields[lookup[0]], lookup[1:]) + + if self.schema.query_type and keys[0] in self.schema.query_type.fields: + schema_root = self.schema.query_type + elif self.schema.mutation_type and keys[0] in self.schema.mutation_type.fields: + schema_root = self.schema.mutation_type + elif ( + self.schema.subscription_type + and keys[0] in self.schema.subscription_type.fields + ): + schema_root = self.schema.subscription_type + else: + return None + + try: + return traverse_schema(schema_root, keys) + except (KeyError, AttributeError): + return None + + def _get_decoded_scalar_type(self, keys: List[str], value: Any) -> Any: + """Get the decoded value of the type identified by `keys`. + + If the type is not a custom scalar, then return the original value. + + If it is a custom scalar, return the deserialized value, as + output by `.parse_value()`""" + + scalar_type = self._lookup_scalar_type(keys) + if scalar_type and scalar_type in self.custom_types: + return self.custom_types[scalar_type].parse_value(value) + return value + + def convert_scalars(self, response: Dict[str, Any]) -> Dict[str, Any]: + """Recursively traverse the GQL response + + Recursively traverses the GQL response and calls _get_decoded_scalar_type() + for all leaf nodes. + + The function is called with 2 arguments: + keys: is a breadcrumb trail telling us where we are in the + response, and therefore, where to look in the GQL Schema. + value: is the value at that node in the response + + Builds a new tree with the substituted values so old `response` is not + modified.""" + + def iterate( + node: Union[List, Dict, str], keys: List[str] = None + ) -> Union[Dict[str, Any], List, Any]: + if keys is None: + keys = [] + if isinstance(node, dict): + return { + _key: iterate(value, keys + [_key]) for _key, value in node.items() + } + elif isinstance(node, list): + return [(iterate(item, keys)) for item in node] + else: + return self._get_decoded_scalar_type(keys, node) + + return cast(Dict, iterate(response)) diff --git a/tests/test_async_client_validation.py b/tests/test_async_client_validation.py index 55239f9e..d1e9a450 100644 --- a/tests/test_async_client_validation.py +++ b/tests/test_async_client_validation.py @@ -269,3 +269,47 @@ async def test_async_client_validation_fetch_schema_from_server_with_client_argu with pytest.raises(graphql.error.GraphQLError): await session.execute(query) + + +class ToLowercase: + @staticmethod + def parse_value(value: str): + return value.lower() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [hero_server_answers], indirect=True) +async def test_async_client_validation_fetch_schema_from_server_with_custom_types( + event_loop, server +): + + url = f"ws://{server.hostname}:{server.port}/graphql" + + sample_transport = WebsocketsTransport(url=url) + + custom_types = {"String": ToLowercase} + + async with Client( + transport=sample_transport, + fetch_schema_from_transport=True, + custom_types=custom_types, + ) as session: + + query = gql( + """ + query HeroNameQuery { + hero { + name + } + } + """ + ) + + result = await session.execute(query) + + print("Client received:", result) + + # The expected hero name is now in lowercase + expected = {"hero": {"name": "r2-d2"}} + + assert result == expected diff --git a/tests/test_requests.py b/tests/test_requests.py index 24fab2d2..b31f0834 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -13,6 +13,8 @@ ) from gql.transport.requests import RequestsHTTPTransport +from .test_type_adapter import Capitalize + query1_str = """ query getContinents { continents { @@ -201,3 +203,57 @@ def test_code(): sample_transport.execute(query) await run_sync_test(event_loop, server, test_code) + + +partial_schema = """ + +type Continent { + code: ID! + name: String! +} + +type Query { + continents: [Continent!]! +} + +""" + + +@pytest.mark.asyncio +async def test_requests_query_with_custom_types(event_loop, aiohttp_server): + async def handler(request): + return web.Response(text=query1_server_answer, content_type="application/json") + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + def test_code(): + sample_transport = RequestsHTTPTransport(url=url) + + custom_types = {"String": Capitalize} + + # Instanciate a client which will capitalize all the String scalars + with Client( + transport=sample_transport, + type_def=partial_schema, + custom_types=custom_types, + ) as session: + + query = gql(query1_str) + + # Execute query synchronously + result = session.execute(query) + + continents = result["continents"] + + africa = continents[0] + + assert africa["code"] == "AF" + + # Check that the string is capitalized + assert africa["name"] == "AFRICA" + + await run_sync_test(event_loop, server, test_code) diff --git a/tests/test_type_adapter.py b/tests/test_type_adapter.py new file mode 100644 index 00000000..fd8dc2e9 --- /dev/null +++ b/tests/test_type_adapter.py @@ -0,0 +1,135 @@ +"""Tests for the GraphQL Response Parser. + +At the moment we use the Star Wars schema which is fetched each time from the +server endpoint. In future it would be better to store this schema in a file +locally. +""" +import copy +import os + +import pytest +import requests +import vcr +from graphql import GraphQLSchema + +from gql import Client +from gql.transport.requests import RequestsHTTPTransport +from gql.type_adapter import TypeAdapter + +# We serve https://github.com/graphql-python/swapi-graphene locally: +URL = "http://127.0.0.1:8000/graphql" + + +query_vcr = vcr.VCR( + cassette_library_dir=os.path.join( + os.path.dirname(__file__), "fixtures", "vcr_cassettes" + ), + record_mode="new_episodes", + match_on=["uri", "method", "body"], +) + + +def use_cassette(name): + return query_vcr.use_cassette(name + ".yaml") + + +@pytest.fixture +def client(): + with use_cassette("client"): + response = requests.get( + URL, headers={"Host": "swapi.graphene-python.org", "Accept": "text/html"} + ) + response.raise_for_status() + csrf = response.cookies["csrftoken"] + + return Client( + transport=RequestsHTTPTransport( + url=URL, cookies={"csrftoken": csrf}, headers={"x-csrftoken": csrf} + ), + fetch_schema_from_transport=True, + ) + + +class Capitalize: + @classmethod + def parse_value(cls, value: str): + return value.upper() + + +@pytest.fixture() +def schema(client): + return client.schema + + +def test_scalar_type_name_for_scalar_field_returns_name(schema: GraphQLSchema): + type_adapter = TypeAdapter(schema) + schema_obj = schema.query_type.fields["film"] if schema.query_type else None + + assert ( + type_adapter._get_scalar_type_name(schema_obj.type.fields["releaseDate"]) + == "Date" + ) + + +def test_scalar_type_name_for_non_scalar_field_returns_none(schema: GraphQLSchema): + type_adapter = TypeAdapter(schema) + schema_obj = schema.query_type.fields["film"] if schema.query_type else None + + assert type_adapter._get_scalar_type_name(schema_obj.type.fields["species"]) is None + + +def test_lookup_scalar_type(schema): + type_adapter = TypeAdapter(schema) + + assert type_adapter._lookup_scalar_type(["film"]) is None + assert type_adapter._lookup_scalar_type(["film", "releaseDate"]) == "Date" + assert type_adapter._lookup_scalar_type(["film", "species"]) is None + + +def test_lookup_scalar_type_in_mutation(schema: GraphQLSchema): + type_adapter = TypeAdapter(schema) + + assert type_adapter._lookup_scalar_type(["createHero"]) is None + assert type_adapter._lookup_scalar_type(["createHero", "hero"]) is None + assert type_adapter._lookup_scalar_type(["createHero", "ok"]) == "Boolean" + + +def test_parse_response(schema: GraphQLSchema): + custom_types = {"Date": Capitalize} + type_adapter = TypeAdapter(schema, custom_types) + + response = {"film": {"id": "some_id", "releaseDate": "some_datetime"}} + + expected = {"film": {"id": "some_id", "releaseDate": "SOME_DATETIME"}} + + assert type_adapter.convert_scalars(response) == expected + # ensure original response is not changed + assert response["film"]["releaseDate"] == "some_datetime" + + +def test_parse_response_containing_list(schema: GraphQLSchema): + custom_types = {"Date": Capitalize} + type_adapter = TypeAdapter(schema, custom_types) + + response = { + "allFilms": { + "edges": [ + {"node": {"id": "some_id", "releaseDate": "some_datetime"}}, + {"node": {"id": "some_id", "releaseDate": "some_other_datetime"}}, + ] + } + } + + expected = copy.deepcopy(response) + expected["allFilms"]["edges"][0]["node"]["releaseDate"] = "SOME_DATETIME" + expected["allFilms"]["edges"][1]["node"]["releaseDate"] = "SOME_OTHER_DATETIME" + + result = type_adapter.convert_scalars(response) + assert result == expected + + # ensure original response is not changed + assert response["allFilms"]["edges"][0]["node"]["releaseDate"] == "some_datetime" + # ensure original response is not changed + assert ( + response["allFilms"]["edges"][1]["node"]["releaseDate"] == "some_other_datetime" + ) diff --git a/tests/test_websocket_subscription.py b/tests/test_websocket_subscription.py index 2e42cb1c..5634cb1b 100644 --- a/tests/test_websocket_subscription.py +++ b/tests/test_websocket_subscription.py @@ -12,9 +12,24 @@ from .conftest import MS, WebSocketServer -countdown_server_answer = ( - '{{"type":"data","id":"{query_id}","payload":{{"data":{{"number":{number}}}}}}}' -) +countdown_schema = """ + +type Number { + number: Int! +} + +type Query { + currentCount: Number! +} + +type Subscription { + countdown(count: Int!): Number! +} + +""" + +countdown_server_answer = '{{"type":"data","id":"{query_id}","payload":\ +{{"data":{{"countdown":{{"number":{number}}}}}}}}}' WITH_KEEPALIVE = False @@ -129,7 +144,7 @@ async def test_websocket_subscription(event_loop, client_and_server, subscriptio async for result in session.subscribe(subscription): - number = result["number"] + number = result["countdown"]["number"] print(f"Number received: {number}") assert number == count @@ -152,7 +167,7 @@ async def test_websocket_subscription_break( async for result in session.subscribe(subscription): - number = result["number"] + number = result["countdown"]["number"] print(f"Number received: {number}") assert number == count @@ -183,7 +198,7 @@ async def task_coro(): nonlocal count async for result in session.subscribe(subscription): - number = result["number"] + number = result["countdown"]["number"] print(f"Number received: {number}") assert number == count @@ -222,7 +237,7 @@ async def task_coro(): nonlocal count async for result in session.subscribe(subscription): - number = result["number"] + number = result["countdown"]["number"] print(f"Number received: {number}") assert number == count @@ -287,7 +302,7 @@ async def test_websocket_subscription_server_connection_closed( async for result in session.subscribe(subscription): - number = result["number"] + number = result["countdown"]["number"] print(f"Number received: {number}") assert number == count @@ -310,7 +325,7 @@ async def test_websocket_subscription_slow_consumer( async for result in session.subscribe(subscription): await asyncio.sleep(10 * MS) - number = result["number"] + number = result["countdown"]["number"] print(f"Number received: {number}") assert number == count @@ -365,7 +380,7 @@ async def test_websocket_subscription_with_keepalive( async for result in session.subscribe(subscription): - number = result["number"] + number = result["countdown"]["number"] print(f"Number received: {number}") assert number == count @@ -390,7 +405,7 @@ def test_websocket_subscription_sync(server, subscription_str): for result in client.subscribe(subscription): - number = result["number"] + number = result["countdown"]["number"] print(f"Number received: {number}") assert number == count @@ -446,3 +461,47 @@ def test_websocket_subscription_sync_graceful_shutdown(server, subscription_str) # Check that the server received a connection_terminate message last assert logged_messages.pop() == '{"type": "connection_terminate"}' + + +class NumberAddParser: + """ Class with a parse_value method used to increment a number """ + + def __init__(self, increment: int): + self.increment: int = increment + + def parse_value(self, value: int) -> int: + return value + self.increment + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_websocket_subscription_with_custom_types( + event_loop, server, subscription_str +): + + url = f"ws://{server.hostname}:{server.port}/graphql" + + sample_transport = WebsocketsTransport(url=url) + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + add_10 = NumberAddParser(10) + + custom_types = {"Int": add_10} + + # Instanciate a client which will add 10 to all the scalars of type Int received + async with Client( + transport=sample_transport, custom_types=custom_types, type_def=countdown_schema + ) as session: + async for result in session.subscribe(subscription): + + number = result["countdown"]["number"] + print(f"Number received: {number}") + + # We check here that the Int scalar has been correctly incremented by 10 + assert number == count + 10 + count -= 1 + + assert count == -1