From 1bbfbd78cf566d4bb7f607665cfd8a32ef1d2601 Mon Sep 17 00:00:00 2001 From: Awais Hussain Date: Fri, 20 Apr 2018 17:47:42 +0000 Subject: [PATCH 01/15] Include support for GQL custom scalar types Move the GQLResponseParser into the gql library. Each time a GQL query is made by the client, the response is passed through ResponseParser.parse() where it is processed. Configure your own custom scalars as follows: ``` custom_scalars = { 'SomeCustomScalar': } gql_client = GQLClient(transport=gql_transport, schema=schema, custom_scalars=custom_scalars) gql_client.execute(...) ``` should have a .parse_value(value) function There are a few anti-patterns I have had to include in order to support some new functionality that we require: - client.execute now accepts `url` and `headers` as arguments, since we need to be able to set these on a per request basis. Previously they were supplied by the transport (which gets set only once at initialization time). - As a consequence, the url supplied in to the transport goes unused if a url is passed in to `execute()`. It is a required field so I have to pass in a string, but it's not the best. --- gql/client.py | 7 ++- gql/response_parser.py | 114 ++++++++++++++++++++++++++++++++++ gql/transport/requests.py | 2 +- tests/test_response_parser.py | 94 ++++++++++++++++++++++++++++ 4 files changed, 215 insertions(+), 2 deletions(-) create mode 100644 gql/response_parser.py create mode 100644 tests/test_response_parser.py diff --git a/gql/client.py b/gql/client.py index dcfeb7af..2ca3d0e5 100644 --- a/gql/client.py +++ b/gql/client.py @@ -4,6 +4,7 @@ from graphql.validation import validate from .transport.local_schema import LocalSchemaTransport +from .response_parser import ResponseParser log = logging.getLogger(__name__) @@ -17,7 +18,7 @@ def __init__(self, retries_count, last_exception): class Client(object): def __init__(self, schema=None, introspection=None, type_def=None, transport=None, - fetch_schema_from_transport=False, retries=0): + fetch_schema_from_transport=False, custom_scalars={}, retries=0): assert not(type_def and introspection), 'Cant provide introspection type definition at the same time' if transport and fetch_schema_from_transport: assert not schema, 'Cant fetch the schema from transport if is already provided' @@ -36,6 +37,7 @@ def __init__(self, schema=None, introspection=None, type_def=None, transport=Non self.introspection = introspection self.transport = transport self.retries = retries + self.response_parser = ResponseParser(schema, custom_scalars) if custom_scalars else None def validate(self, document): if not self.schema: @@ -52,6 +54,9 @@ def execute(self, document, *args, **kwargs): if result.errors: raise Exception(str(result.errors[0])) + if self.response_parser: + result.data = self.response_parser.parse(result.data) + return result.data def _get_result(self, document, *args, **kwargs): diff --git a/gql/response_parser.py b/gql/response_parser.py new file mode 100644 index 00000000..088bd58a --- /dev/null +++ b/gql/response_parser.py @@ -0,0 +1,114 @@ +from typing import Any, Dict, Callable, Optional, List + +from graphql.type.schema import GraphQLSchema +from graphql.type.definition import GraphQLObjectType, GraphQLField, GraphQLScalarType + + +class ResponseParser(object): + """The challenge is to substitute custom scalars in a GQL response with their + decoded counterparts. + + To solve this problem, we first need to iterate over all the fields in the + response (which is done in the `_traverse()` function). + + Each time we find a field which has type scalar and is a custom scalar, we + need to replace the value of that field with the decoded value. All of this + logic happens in `_substitute()`. + + Public Interface: + parse(): call parse with a GQL response to replace all instances of custom + scalar strings with their deserialized representation.""" + + def __init__(self, schema: GraphQLSchema, custom_scalars: Dict[str, Any] = {}) -> None: + """ schema: a graphQL schema in the GraphQLSchema format + custom_scalars: a Dict[str, Any], + where str is the name of the custom scalar type, and + Any is a class which has a `parse_value()` function""" + self.schema = schema + self.custom_scalars = custom_scalars + + def _follow_type_chain(self, node: Any) -> Any: + """In the schema GraphQL types are often listed with the format + `obj.type.of_type...` where there are 0 or more 'of_type' fields before + you get to the type you are interested in. + + This is a convenience method to help us get to these nested types.""" + if isinstance(node, GraphQLObjectType): + return node + + field_type = node.type + while hasattr(field_type, 'of_type'): + field_type = field_type.of_type + + return field_type + + def _get_scalar_type_name(self, field: GraphQLField) -> Optional[str]: + """Returns the name of the type if the type is a scalar type. + Returns None otherwise""" + node = self._follow_type_chain(field) + if isinstance(node, GraphQLScalarType): + return node.name + return None + + def _lookup_scalar_type(self, keys: List[str]) -> Optional[str]: + """ + `keys` is a breadcrumb trail telling us where to look in the GraphQL schema. + By default the root level is `schema.query`, if that fails, then we check + `schema.mutation`. + + If keys (e.g. ['wallet', 'balance']) points to a scalar type, then + this function returns the name of that type. (e.g. 'Money') + + If it is not a scalar type (e..g a GraphQLObject or list), then this + function returns None""" + + def iterate(node: Any, lookup: List[str]): + lookup = lookup.copy() + if not lookup: + return self._get_scalar_type_name(node) + + final_node = self._follow_type_chain(node) + return iterate(final_node.fields[lookup.pop(0)], lookup) + + try: + return iterate(self.schema.get_query_type(), keys) + except (KeyError, AttributeError): + try: + return iterate(self.schema.get_mutation_type(), keys) + except (KeyError, AttributeError): + return None + + def _substitute(self, keys: List[str], value: Any) -> Any: + """Looks in the GraphQL schema to find the type identified by 'keys' + + If that type is not a custom scalar, we return the original value. + If it is a custom scalar, we return the deserialized value, as + processed by `.parse_value()`""" + scalar_type = self._lookup_scalar_type(keys) + if scalar_type and scalar_type in self.custom_scalars: + return self.custom_scalars[scalar_type].parse_value(value) + return value + + def _traverse(self, response: Dict[str, Any], substitute: Callable) -> Dict[str, Any]: + """Recursively traverses the GQL response and calls the `substitute` + function on all leaf nodes. The function is called with 2 arguments: + keys: List[str] is a breadcrumb trail telling us where we are in the + response, and therefore, where to look in the GQL Schema. + value: Any is the value at that node in the tree + + Builds a new tree with the substituted values so `response` is not + modified.""" + def iterate(node: Any, keys: List[str] = []): + if isinstance(node, dict): + result = {} + for _key, value in node.items(): + result[_key] = iterate(value, keys + [_key]) + return result + elif isinstance(node, list): + return [(iterate(item, keys)) for item in node] + else: + return substitute(keys, node) + return iterate(response) + + def parse(self, response: Dict[str, Any]) -> Dict[str, Any]: + return self._traverse(response, self._substitute) diff --git a/gql/transport/requests.py b/gql/transport/requests.py index 71399a55..b748b322 100644 --- a/gql/transport/requests.py +++ b/gql/transport/requests.py @@ -35,7 +35,7 @@ def execute(self, document, variable_values=None, timeout=None): 'timeout': timeout or self.default_timeout, data_key: payload } - request = requests.post(self.url, **post_args) + request = requests.post(url or self.url, **post_args) request.raise_for_status() result = request.json() diff --git a/tests/test_response_parser.py b/tests/test_response_parser.py new file mode 100644 index 00000000..675ba9c2 --- /dev/null +++ b/tests/test_response_parser.py @@ -0,0 +1,94 @@ +"""Tests for the GraphQL Response Parser. + +These tests are worthless until I have a schema I can work with. +""" +import copy +from gql.response_parser import ResponseParser + + +class Capitalize(): + def parse_value(self, value: str): + return value.upper(); + +def test_scalar_type_name_for_scalar_field_returns_name(gql_schema): + parser = ResponseParser(gql_schema) + schema_obj = gql_schema.get_type_map().get('Wallet') + + assert parser._get_scalar_type_name(schema_obj.fields['balance']) == 'Money' + + +def test_scalar_type_name_for_non_scalar_field_returns_none(gql_schema): + parser = ResponseParser(gql_schema) + schema_obj = gql_schema.get_type_map().get('Wallet') + + assert parser._get_scalar_type_name(schema_obj.fields['user']) is None + +def test_lookup_scalar_type(gql_schema): + parser = ResponseParser(gql_schema) + + assert parser._lookup_scalar_type(["wallet"]) is None + assert parser._lookup_scalar_type(["searchWallets"]) is None + assert parser._lookup_scalar_type(["wallet", "balance"]) == 'Money' + assert parser._lookup_scalar_type(["searchWallets", "balance"]) == 'Money' + assert parser._lookup_scalar_type(["wallet", "name"]) == 'String' + assert parser._lookup_scalar_type(["wallet", "invalid"]) is None + +def test_lookup_scalar_type_in_mutation(gql_schema): + parser = ResponseParser(gql_schema) + + assert parser._lookup_scalar_type(["manualWithdraw", "agentTransaction"]) is None + assert parser._lookup_scalar_type(["manualWithdraw", "agentTransaction", "amount"]) == 'Money' + +def test_parse_response(gql_schema): + custom_scalars = { + 'Money': Capitalize + } + parser = ResponseParser(gql_schema, custom_scalars) + + response = { + 'wallet': { + 'id': 'some_id', + 'name': 'U1_test', + } + } + + expected = { + 'wallet': { + 'id': 'some_id', + 'name': 'U1_test', + } + } + + assert parser.parse(response) == expected + assert response['wallet']['balance'] == 'CFA 3850' + +def test_parse_response_containing_list(gql_schema): + custom_scalars = { + 'Money': M + } + parser = ResponseParser(gql_schema, custom_scalars) + + response = { + "searchWallets": [ + { + "id": "W_wz518BXTDJuQ", + "name": "U2_test", + "balance": "CFA 4148" + }, + { + "id": "W_uOe9fHPoKO21", + "name": "Agent_test", + "balance": "CFA 2641" + } + ] + } + + expected = copy.deepcopy(response) + expected['searchWallets'][0]['balance'] = M("CFA", "4148") + expected['searchWallets'][1]['balance'] = M("CFA", "2641") + + result = parser.parse(response) + + assert result == expected + assert response['searchWallets'][0]['balance'] == "CFA 4148" + assert response['searchWallets'][1]['balance'] == "CFA 2641" \ No newline at end of file From 9d80656b0444fd6f13dbc7cb8ddbc04caf2fdb1e Mon Sep 17 00:00:00 2001 From: Awais Hussain Date: Sat, 21 Apr 2018 12:42:28 +0000 Subject: [PATCH 02/15] Fix failing tests The tests were copied over from a different repo and so were inconsistent with the Star Wars schema we are using here. This PR should fix the tests (although we are still using a remote schema rather than a local one, so the tests run slowly becuase they need to make a http request each time). --- gql/transport/requests.py | 2 +- tests/test_response_parser.py | 118 ++++++++++++++++++++-------------- 2 files changed, 72 insertions(+), 48 deletions(-) diff --git a/gql/transport/requests.py b/gql/transport/requests.py index b748b322..71399a55 100644 --- a/gql/transport/requests.py +++ b/gql/transport/requests.py @@ -35,7 +35,7 @@ def execute(self, document, variable_values=None, timeout=None): 'timeout': timeout or self.default_timeout, data_key: payload } - request = requests.post(url or self.url, **post_args) + request = requests.post(self.url, **post_args) request.raise_for_status() result = request.json() diff --git a/tests/test_response_parser.py b/tests/test_response_parser.py index 675ba9c2..531ad9a9 100644 --- a/tests/test_response_parser.py +++ b/tests/test_response_parser.py @@ -1,94 +1,118 @@ """Tests for the GraphQL Response Parser. -These tests are worthless until I have a schema I can work with. +At the moment we use the Star Wars schema which is fetched each time from the +server endpoint. In future it would be better to store this schema in a file +locally. """ import copy from gql.response_parser import ResponseParser - +import pytest +import requests +from gql import Client +from gql.transport.requests import RequestsHTTPTransport class Capitalize(): + @classmethod def parse_value(self, value: str): return value.upper(); -def test_scalar_type_name_for_scalar_field_returns_name(gql_schema): - parser = ResponseParser(gql_schema) - schema_obj = gql_schema.get_type_map().get('Wallet') +@pytest.fixture +def schema(): + request = requests.get('http://swapi.graphene-python.org/graphql', + headers={ + 'Host': 'swapi.graphene-python.org', + 'Accept': 'text/html', + }) + request.raise_for_status() + csrf = request.cookies['csrftoken'] - assert parser._get_scalar_type_name(schema_obj.fields['balance']) == 'Money' + client = Client( + transport=RequestsHTTPTransport(url='http://swapi.graphene-python.org/graphql', + cookies={"csrftoken": csrf}, + headers={'x-csrftoken': csrf}), + fetch_schema_from_transport=True + ) + return client.schema -def test_scalar_type_name_for_non_scalar_field_returns_none(gql_schema): - parser = ResponseParser(gql_schema) - schema_obj = gql_schema.get_type_map().get('Wallet') +def test_scalar_type_name_for_scalar_field_returns_name(schema): + parser = ResponseParser(schema) + schema_obj = schema.get_query_type().fields['film'] + + assert parser._get_scalar_type_name(schema_obj.type.fields['releaseDate']) == 'DateTime' - assert parser._get_scalar_type_name(schema_obj.fields['user']) is None + +def test_scalar_type_name_for_non_scalar_field_returns_none(schema): + parser = ResponseParser(schema) + schema_obj = schema.get_query_type().fields['film'] + + assert parser._get_scalar_type_name(schema_obj.type.fields['species']) is None def test_lookup_scalar_type(gql_schema): parser = ResponseParser(gql_schema) - assert parser._lookup_scalar_type(["wallet"]) is None - assert parser._lookup_scalar_type(["searchWallets"]) is None - assert parser._lookup_scalar_type(["wallet", "balance"]) == 'Money' - assert parser._lookup_scalar_type(["searchWallets", "balance"]) == 'Money' - assert parser._lookup_scalar_type(["wallet", "name"]) == 'String' - assert parser._lookup_scalar_type(["wallet", "invalid"]) is None + assert parser._lookup_scalar_type(["film"]) is None + assert parser._lookup_scalar_type(["film", "releaseDate"]) == 'DateTime' + assert parser._lookup_scalar_type(["film", "species"]) is None -def test_lookup_scalar_type_in_mutation(gql_schema): - parser = ResponseParser(gql_schema) +def test_lookup_scalar_type_in_mutation(schema): + parser = ResponseParser(schema) - assert parser._lookup_scalar_type(["manualWithdraw", "agentTransaction"]) is None - assert parser._lookup_scalar_type(["manualWithdraw", "agentTransaction", "amount"]) == 'Money' + assert parser._lookup_scalar_type(["createHero"]) is None + assert parser._lookup_scalar_type(["createHero", "hero"]) is None + assert parser._lookup_scalar_type(["createHero", "ok"]) == 'Boolean' -def test_parse_response(gql_schema): +def test_parse_response(schema): custom_scalars = { - 'Money': Capitalize + 'DateTime': Capitalize } - parser = ResponseParser(gql_schema, custom_scalars) + parser = ResponseParser(schema, custom_scalars) response = { - 'wallet': { + 'film': { 'id': 'some_id', - 'name': 'U1_test', + 'releaseDate': 'some_datetime', } } expected = { - 'wallet': { + 'film': { 'id': 'some_id', - 'name': 'U1_test', + 'releaseDate': 'SOME_DATETIME', } } assert parser.parse(response) == expected - assert response['wallet']['balance'] == 'CFA 3850' + assert response['film']['releaseDate'] == 'some_datetime' # ensure original response is not changed -def test_parse_response_containing_list(gql_schema): +def test_parse_response_containing_list(schema): custom_scalars = { - 'Money': M + 'DateTime': Capitalize } - parser = ResponseParser(gql_schema, custom_scalars) + parser = ResponseParser(schema, custom_scalars) response = { - "searchWallets": [ - { - "id": "W_wz518BXTDJuQ", - "name": "U2_test", - "balance": "CFA 4148" - }, - { - "id": "W_uOe9fHPoKO21", - "name": "Agent_test", - "balance": "CFA 2641" - } - ] + "allFilms": { + "edges": [{ + "node": { + 'id': 'some_id', + 'releaseDate': 'some_datetime', + } + },{ + "node": { + 'id': 'some_id', + 'releaseDate': 'some_other_datetime', + } + }] + } } expected = copy.deepcopy(response) - expected['searchWallets'][0]['balance'] = M("CFA", "4148") - expected['searchWallets'][1]['balance'] = M("CFA", "2641") + expected['allFilms']['edges'][0]['node']['releaseDate'] = "SOME_DATETIME" + expected['allFilms']['edges'][1]['node']['releaseDate'] = "SOME_OTHER_DATETIME" result = parser.parse(response) assert result == expected - assert response['searchWallets'][0]['balance'] == "CFA 4148" - assert response['searchWallets'][1]['balance'] == "CFA 2641" \ No newline at end of file + expected['allFilms']['edges'][0]['node']['releaseDate'] = "some_datetime" + expected['allFilms']['edges'][1]['node']['releaseDate'] = "some_other_datetime" From a13d1667d6847182d887c2a77a198f3a95a10bd9 Mon Sep 17 00:00:00 2001 From: Awais Hussain Date: Sat, 21 Apr 2018 13:34:17 +0000 Subject: [PATCH 03/15] Rename ResponseParser to TypeAdaptor The 'ResponseParser' name was confusing because it was too broad, and we weren't really actually parsing the response. Hopefully this makes it clearer that what we are actually doing it just adding support for custom types. --- gql/client.py | 8 ++--- gql/{response_parser.py => type_adaptor.py} | 24 +++++++------ ...esponse_parser.py => test_type_adaptor.py} | 36 +++++++++---------- 3 files changed, 36 insertions(+), 32 deletions(-) rename gql/{response_parser.py => type_adaptor.py} (82%) rename tests/{test_response_parser.py => test_type_adaptor.py} (72%) diff --git a/gql/client.py b/gql/client.py index 2ca3d0e5..c35c1cc5 100644 --- a/gql/client.py +++ b/gql/client.py @@ -4,7 +4,7 @@ from graphql.validation import validate from .transport.local_schema import LocalSchemaTransport -from .response_parser import ResponseParser +from .type_adaptor import TypeAdaptor log = logging.getLogger(__name__) @@ -37,7 +37,7 @@ def __init__(self, schema=None, introspection=None, type_def=None, transport=Non self.introspection = introspection self.transport = transport self.retries = retries - self.response_parser = ResponseParser(schema, custom_scalars) if custom_scalars else None + self.type_adaptor = TypeAdaptor(schema, custom_scalars) if custom_scalars else None def validate(self, document): if not self.schema: @@ -54,8 +54,8 @@ def execute(self, document, *args, **kwargs): if result.errors: raise Exception(str(result.errors[0])) - if self.response_parser: - result.data = self.response_parser.parse(result.data) + if self.type_adaptor: + result.data = self.type_adaptor.apply(result.data) return result.data diff --git a/gql/response_parser.py b/gql/type_adaptor.py similarity index 82% rename from gql/response_parser.py rename to gql/type_adaptor.py index 088bd58a..ea4c15e5 100644 --- a/gql/response_parser.py +++ b/gql/type_adaptor.py @@ -4,19 +4,23 @@ from graphql.type.definition import GraphQLObjectType, GraphQLField, GraphQLScalarType -class ResponseParser(object): - """The challenge is to substitute custom scalars in a GQL response with their - decoded counterparts. +class TypeAdaptor(object): + """Substitute custom scalars in a GQL response with their decoded counterparts. - To solve this problem, we first need to iterate over all the fields in the - response (which is done in the `_traverse()` function). + GQL custom scalar types are defined on the GQL schema and are used to represent + fields which have special behaviour. To define custom scalar type, you need + the type name, and a class which has a class method called `parse_value()` - + this is the function which will be used to deserialize the custom scalar field. - Each time we find a field which has type scalar and is a custom scalar, we - need to replace the value of that field with the decoded value. All of this - logic happens in `_substitute()`. + We first need iterate over all the fields in the response (which is done in + the `_traverse()` function). + + Each time we find a field which is a custom scalar (it's type name appears + as a key in self.custom_scalars), we replace the value of that field with the + decoded value. All of this logic happens in `_substitute()`. Public Interface: - parse(): call parse with a GQL response to replace all instances of custom + parse(): pass in a GQL response to replace all instances of custom scalar strings with their deserialized representation.""" def __init__(self, schema: GraphQLSchema, custom_scalars: Dict[str, Any] = {}) -> None: @@ -110,5 +114,5 @@ def iterate(node: Any, keys: List[str] = []): return substitute(keys, node) return iterate(response) - def parse(self, response: Dict[str, Any]) -> Dict[str, Any]: + def apply(self, response: Dict[str, Any]) -> Dict[str, Any]: return self._traverse(response, self._substitute) diff --git a/tests/test_response_parser.py b/tests/test_type_adaptor.py similarity index 72% rename from tests/test_response_parser.py rename to tests/test_type_adaptor.py index 531ad9a9..1b344374 100644 --- a/tests/test_response_parser.py +++ b/tests/test_type_adaptor.py @@ -5,7 +5,7 @@ locally. """ import copy -from gql.response_parser import ResponseParser +from gql.type_adaptor import TypeAdaptor import pytest import requests from gql import Client @@ -36,37 +36,37 @@ def schema(): return client.schema def test_scalar_type_name_for_scalar_field_returns_name(schema): - parser = ResponseParser(schema) + type_adaptor = TypeAdaptor(schema) schema_obj = schema.get_query_type().fields['film'] - assert parser._get_scalar_type_name(schema_obj.type.fields['releaseDate']) == 'DateTime' + assert type_adaptor ._get_scalar_type_name(schema_obj.type.fields['releaseDate']) == 'DateTime' def test_scalar_type_name_for_non_scalar_field_returns_none(schema): - parser = ResponseParser(schema) + type_adaptor = TypeAdaptor(schema) schema_obj = schema.get_query_type().fields['film'] - assert parser._get_scalar_type_name(schema_obj.type.fields['species']) is None + assert type_adaptor._get_scalar_type_name(schema_obj.type.fields['species']) is None -def test_lookup_scalar_type(gql_schema): - parser = ResponseParser(gql_schema) +def test_lookup_scalar_type(schema): + type_adaptor = TypeAdaptor(schema) - assert parser._lookup_scalar_type(["film"]) is None - assert parser._lookup_scalar_type(["film", "releaseDate"]) == 'DateTime' - assert parser._lookup_scalar_type(["film", "species"]) is None + assert type_adaptor._lookup_scalar_type(["film"]) is None + assert type_adaptor._lookup_scalar_type(["film", "releaseDate"]) == 'DateTime' + assert type_adaptor._lookup_scalar_type(["film", "species"]) is None def test_lookup_scalar_type_in_mutation(schema): - parser = ResponseParser(schema) + type_adaptor = TypeAdaptor(schema) - assert parser._lookup_scalar_type(["createHero"]) is None - assert parser._lookup_scalar_type(["createHero", "hero"]) is None - assert parser._lookup_scalar_type(["createHero", "ok"]) == 'Boolean' + assert type_adaptor._lookup_scalar_type(["createHero"]) is None + assert type_adaptor._lookup_scalar_type(["createHero", "hero"]) is None + assert type_adaptor._lookup_scalar_type(["createHero", "ok"]) == 'Boolean' def test_parse_response(schema): custom_scalars = { 'DateTime': Capitalize } - parser = ResponseParser(schema, custom_scalars) + type_adaptor = TypeAdaptor(schema, custom_scalars) response = { 'film': { @@ -82,14 +82,14 @@ def test_parse_response(schema): } } - assert parser.parse(response) == expected + assert type_adaptor.apply(response) == expected assert response['film']['releaseDate'] == 'some_datetime' # ensure original response is not changed def test_parse_response_containing_list(schema): custom_scalars = { 'DateTime': Capitalize } - parser = ResponseParser(schema, custom_scalars) + type_adaptor = TypeAdaptor(schema, custom_scalars) response = { "allFilms": { @@ -111,7 +111,7 @@ def test_parse_response_containing_list(schema): expected['allFilms']['edges'][0]['node']['releaseDate'] = "SOME_DATETIME" expected['allFilms']['edges'][1]['node']['releaseDate'] = "SOME_OTHER_DATETIME" - result = parser.parse(response) + result = type_adaptor.apply(response) assert result == expected expected['allFilms']['edges'][0]['node']['releaseDate'] = "some_datetime" From 44becd00a3329e070b2450fa0a5e1816e9b187c3 Mon Sep 17 00:00:00 2001 From: Awais Hussain Date: Mon, 30 Apr 2018 08:40:37 +0000 Subject: [PATCH 04/15] Clean up docstrings in `type_adaptor` file Stick to the format of having a one line summary followed by more detailed information. --- gql/type_adaptor.py | 43 ++++++++++++++++++++++++------------------- 1 file changed, 24 insertions(+), 19 deletions(-) diff --git a/gql/type_adaptor.py b/gql/type_adaptor.py index ea4c15e5..2e14b11d 100644 --- a/gql/type_adaptor.py +++ b/gql/type_adaptor.py @@ -20,7 +20,7 @@ class TypeAdaptor(object): decoded value. All of this logic happens in `_substitute()`. Public Interface: - parse(): pass in a GQL response to replace all instances of custom + apply(): pass in a GQL response to replace all instances of custom scalar strings with their deserialized representation.""" def __init__(self, schema: GraphQLSchema, custom_scalars: Dict[str, Any] = {}) -> None: @@ -32,11 +32,12 @@ def __init__(self, schema: GraphQLSchema, custom_scalars: Dict[str, Any] = {}) - self.custom_scalars = custom_scalars def _follow_type_chain(self, node: Any) -> Any: - """In the schema GraphQL types are often listed with the format - `obj.type.of_type...` where there are 0 or more 'of_type' fields before - you get to the type you are interested in. + """ Get the type of the schema node in question. - This is a convenience method to help us get to these nested types.""" + In the GraphQL schema, GraphQLFields have a "type" property. However, often + that dict has an "of_type" property itself. In order to get to the actual + type, we need to indefinitely follow the chain of "of_type" fields to get + to the last one, which is the one we care about.""" if isinstance(node, GraphQLObjectType): return node @@ -55,16 +56,17 @@ def _get_scalar_type_name(self, field: GraphQLField) -> Optional[str]: return None def _lookup_scalar_type(self, keys: List[str]) -> Optional[str]: - """ - `keys` is a breadcrumb trail telling us where to look in the GraphQL schema. - By default the root level is `schema.query`, if that fails, then we check - `schema.mutation`. + """Search through the GQL schema and return the type identified by 'keys'. - If keys (e.g. ['wallet', 'balance']) points to a scalar type, then - this function returns the name of that type. (e.g. 'Money') + If keys (e.g. ['film', 'release_date']) points to a scalar type, then + this function returns the name of that type. (e.g. 'DateTime') If it is not a scalar type (e..g a GraphQLObject or list), then this - function returns None""" + function returns None. + + `keys` is a breadcrumb trail telling us where to look in the GraphQL schema. + By default the root level is `schema.query`, if that fails, then we check + `schema.mutation`.""" def iterate(node: Any, lookup: List[str]): lookup = lookup.copy() @@ -83,24 +85,27 @@ def iterate(node: Any, lookup: List[str]): return None def _substitute(self, keys: List[str], value: Any) -> Any: - """Looks in the GraphQL schema to find the type identified by 'keys' + """Get the decoded value of the type identified by `keys`. + + If the type is not a custom scalar, then return the original value. - If that type is not a custom scalar, we return the original value. - If it is a custom scalar, we return the deserialized value, as - processed by `.parse_value()`""" + If it is a custom scalar, return the deserialized value, as + output by `.parse_value()`""" scalar_type = self._lookup_scalar_type(keys) if scalar_type and scalar_type in self.custom_scalars: return self.custom_scalars[scalar_type].parse_value(value) return value def _traverse(self, response: Dict[str, Any], substitute: Callable) -> Dict[str, Any]: - """Recursively traverses the GQL response and calls the `substitute` + """Recursively traverse the GQL response + + Recursively traverses the GQL response and calls the `substitute` function on all leaf nodes. The function is called with 2 arguments: keys: List[str] is a breadcrumb trail telling us where we are in the response, and therefore, where to look in the GQL Schema. - value: Any is the value at that node in the tree + value: Any is the value at that node in the response - Builds a new tree with the substituted values so `response` is not + Builds a new tree with the substituted values so old `response` is not modified.""" def iterate(node: Any, keys: List[str] = []): if isinstance(node, dict): From 0457275cc13d4ebb8baac81e6996f6ccd598b52e Mon Sep 17 00:00:00 2001 From: Awais Hussain Date: Mon, 30 Apr 2018 19:20:37 +0000 Subject: [PATCH 05/15] Bugfix: Empty headers should default to empty dict Since we are now merging the headers specified in the transport with the headers specified on a per request basis, we need to make sure the old headers are of type `dict`, a `None` is not sufficient. --- gql/exceptions.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 gql/exceptions.py diff --git a/gql/exceptions.py b/gql/exceptions.py new file mode 100644 index 00000000..e69de29b From 5908011dc8478738f6bcd915ffe04fa52232b8d1 Mon Sep 17 00:00:00 2001 From: Awais Hussain Date: Mon, 30 Apr 2018 19:23:25 +0000 Subject: [PATCH 06/15] Add new exceptions (GQLSyntaxError and GQLServerError) GQLServerError maps to errors which are generated on the backend server. I.e. the server responds with a <200 OK> but an "errors" field is included within the response. These are often user facing errors and should be handled as such. GQLSyntaxError maps to cases where either the query or the schema is improperly formatted. This usually indicates a mistake in the code. --- gql/client.py | 5 +++-- gql/exceptions.py | 5 +++++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/gql/client.py b/gql/client.py index c35c1cc5..0dcb4148 100644 --- a/gql/client.py +++ b/gql/client.py @@ -5,6 +5,7 @@ from .transport.local_schema import LocalSchemaTransport from .type_adaptor import TypeAdaptor +from .exceptions import GQLServerError, GQLSyntaxError log = logging.getLogger(__name__) @@ -41,7 +42,7 @@ def __init__(self, schema=None, introspection=None, type_def=None, transport=Non def validate(self, document): if not self.schema: - raise Exception("Cannot validate locally the document, you need to pass a schema.") + raise GQLSyntaxError("Cannot validate locally the document, you need to pass a schema.") validation_errors = validate(self.schema, document) if validation_errors: raise validation_errors[0] @@ -52,7 +53,7 @@ def execute(self, document, *args, **kwargs): result = self._get_result(document, *args, **kwargs) if result.errors: - raise Exception(str(result.errors[0])) + raise GQLServerError(result.errors[0]) if self.type_adaptor: result.data = self.type_adaptor.apply(result.data) diff --git a/gql/exceptions.py b/gql/exceptions.py index e69de29b..5a45fb63 100644 --- a/gql/exceptions.py +++ b/gql/exceptions.py @@ -0,0 +1,5 @@ +class GQLSyntaxError(Exception): + """A problem with the GQL query or schema syntax""" + +class GQLServerError(Exception): + """Errors which should be explicitly handled by the calling code""" From 0fd8937a40501f6b96d0c57bb472cae625d12dac Mon Sep 17 00:00:00 2001 From: Awais Hussain Date: Tue, 1 May 2018 11:40:03 +0000 Subject: [PATCH 07/15] Address CR changes --- gql/client.py | 13 +++-- gql/{type_adaptor.py => type_adapter.py} | 55 +++++++++---------- ...t_type_adaptor.py => test_type_adapter.py} | 50 ++++++++--------- 3 files changed, 59 insertions(+), 59 deletions(-) rename gql/{type_adaptor.py => type_adapter.py} (70%) rename tests/{test_type_adaptor.py => test_type_adapter.py} (64%) diff --git a/gql/client.py b/gql/client.py index 0dcb4148..8ca630f3 100644 --- a/gql/client.py +++ b/gql/client.py @@ -4,7 +4,7 @@ from graphql.validation import validate from .transport.local_schema import LocalSchemaTransport -from .type_adaptor import TypeAdaptor +from .type_adapter import TypeAdapter from .exceptions import GQLServerError, GQLSyntaxError log = logging.getLogger(__name__) @@ -19,7 +19,10 @@ def __init__(self, retries_count, last_exception): class Client(object): def __init__(self, schema=None, introspection=None, type_def=None, transport=None, - fetch_schema_from_transport=False, custom_scalars={}, retries=0): + fetch_schema_from_transport=False, retries=0, custom_types={}): + """custom_types should be of type Dict[str, Any] + where str is the name of the custom scalar type, and + Any is a class which has a `parse_value()` function""" assert not(type_def and introspection), 'Cant provide introspection type definition at the same time' if transport and fetch_schema_from_transport: assert not schema, 'Cant fetch the schema from transport if is already provided' @@ -38,7 +41,7 @@ def __init__(self, schema=None, introspection=None, type_def=None, transport=Non self.introspection = introspection self.transport = transport self.retries = retries - self.type_adaptor = TypeAdaptor(schema, custom_scalars) if custom_scalars else None + self.type_adapter = TypeAdapter(schema, custom_types) if custom_types else None def validate(self, document): if not self.schema: @@ -55,8 +58,8 @@ def execute(self, document, *args, **kwargs): if result.errors: raise GQLServerError(result.errors[0]) - if self.type_adaptor: - result.data = self.type_adaptor.apply(result.data) + if self.type_adapter: + result.data = self.type_adapter.convert_scalars(result.data) return result.data diff --git a/gql/type_adaptor.py b/gql/type_adapter.py similarity index 70% rename from gql/type_adaptor.py rename to gql/type_adapter.py index 2e14b11d..8014ae0a 100644 --- a/gql/type_adaptor.py +++ b/gql/type_adapter.py @@ -4,7 +4,7 @@ from graphql.type.definition import GraphQLObjectType, GraphQLField, GraphQLScalarType -class TypeAdaptor(object): +class TypeAdapter(object): """Substitute custom scalars in a GQL response with their decoded counterparts. GQL custom scalar types are defined on the GQL schema and are used to represent @@ -16,20 +16,20 @@ class TypeAdaptor(object): the `_traverse()` function). Each time we find a field which is a custom scalar (it's type name appears - as a key in self.custom_scalars), we replace the value of that field with the + as a key in self.custom_types), we replace the value of that field with the decoded value. All of this logic happens in `_substitute()`. Public Interface: apply(): pass in a GQL response to replace all instances of custom scalar strings with their deserialized representation.""" - def __init__(self, schema: GraphQLSchema, custom_scalars: Dict[str, Any] = {}) -> None: + def __init__(self, schema: GraphQLSchema, custom_types: Dict[str, Any] = {}) -> None: """ schema: a graphQL schema in the GraphQLSchema format - custom_scalars: a Dict[str, Any], + custom_types: a Dict[str, Any], where str is the name of the custom scalar type, and - Any is a class which has a `parse_value()` function""" + Any is a class which has a `parse_value(str)` function""" self.schema = schema - self.custom_scalars = custom_scalars + self.custom_types = custom_types def _follow_type_chain(self, node: Any) -> Any: """ Get the type of the schema node in question. @@ -61,30 +61,33 @@ def _lookup_scalar_type(self, keys: List[str]) -> Optional[str]: If keys (e.g. ['film', 'release_date']) points to a scalar type, then this function returns the name of that type. (e.g. 'DateTime') - If it is not a scalar type (e..g a GraphQLObject or list), then this + If it is not a scalar type (e..g a GraphQLObject), then this function returns None. `keys` is a breadcrumb trail telling us where to look in the GraphQL schema. By default the root level is `schema.query`, if that fails, then we check `schema.mutation`.""" - def iterate(node: Any, lookup: List[str]): - lookup = lookup.copy() + def traverse_schema(node: Any, lookup: List[str]): if not lookup: return self._get_scalar_type_name(node) final_node = self._follow_type_chain(node) - return iterate(final_node.fields[lookup.pop(0)], lookup) + return traverse_schema(final_node.fields[lookup[0]], lookup[1:]) + + if keys[0] in self.schema.get_query_type().fields: + schema_root = self.schema.get_query_type() + elif keys[0] in self.schema.get_mutation_type().fields: + schema_root = self.schema.get_mutation_type() + else: + return None try: - return iterate(self.schema.get_query_type(), keys) + return traverse_schema(schema_root, keys) except (KeyError, AttributeError): - try: - return iterate(self.schema.get_mutation_type(), keys) - except (KeyError, AttributeError): - return None + return None - def _substitute(self, keys: List[str], value: Any) -> Any: + def _get_decoded_scalar_type(self, keys: List[str], value: Any) -> Any: """Get the decoded value of the type identified by `keys`. If the type is not a custom scalar, then return the original value. @@ -92,15 +95,15 @@ def _substitute(self, keys: List[str], value: Any) -> Any: If it is a custom scalar, return the deserialized value, as output by `.parse_value()`""" scalar_type = self._lookup_scalar_type(keys) - if scalar_type and scalar_type in self.custom_scalars: - return self.custom_scalars[scalar_type].parse_value(value) + if scalar_type and scalar_type in self.custom_types: + return self.custom_types[scalar_type].parse_value(value) return value - def _traverse(self, response: Dict[str, Any], substitute: Callable) -> Dict[str, Any]: + def convert_scalars(self, response: Dict[str, Any]) -> Dict[str, Any]: """Recursively traverse the GQL response - Recursively traverses the GQL response and calls the `substitute` - function on all leaf nodes. The function is called with 2 arguments: + Recursively traverses the GQL response and calls _get_decoded_scalar_type() + for all leaf nodes. The function is called with 2 arguments: keys: List[str] is a breadcrumb trail telling us where we are in the response, and therefore, where to look in the GQL Schema. value: Any is the value at that node in the response @@ -109,15 +112,9 @@ def _traverse(self, response: Dict[str, Any], substitute: Callable) -> Dict[str, modified.""" def iterate(node: Any, keys: List[str] = []): if isinstance(node, dict): - result = {} - for _key, value in node.items(): - result[_key] = iterate(value, keys + [_key]) - return result + return {_key: iterate(value, keys + [_key]) for _key, value in node.items()} elif isinstance(node, list): return [(iterate(item, keys)) for item in node] else: - return substitute(keys, node) + return self._get_decoded_scalar_type(keys, node) return iterate(response) - - def apply(self, response: Dict[str, Any]) -> Dict[str, Any]: - return self._traverse(response, self._substitute) diff --git a/tests/test_type_adaptor.py b/tests/test_type_adapter.py similarity index 64% rename from tests/test_type_adaptor.py rename to tests/test_type_adapter.py index 1b344374..49ce8b2e 100644 --- a/tests/test_type_adaptor.py +++ b/tests/test_type_adapter.py @@ -5,7 +5,7 @@ locally. """ import copy -from gql.type_adaptor import TypeAdaptor +from gql.type_adapter import TypeAdapter import pytest import requests from gql import Client @@ -16,7 +16,7 @@ class Capitalize(): def parse_value(self, value: str): return value.upper(); -@pytest.fixture +@pytest.fixture(scope='session') def schema(): request = requests.get('http://swapi.graphene-python.org/graphql', headers={ @@ -36,37 +36,37 @@ def schema(): return client.schema def test_scalar_type_name_for_scalar_field_returns_name(schema): - type_adaptor = TypeAdaptor(schema) + type_adapter = TypeAdapter(schema) schema_obj = schema.get_query_type().fields['film'] - assert type_adaptor ._get_scalar_type_name(schema_obj.type.fields['releaseDate']) == 'DateTime' + assert type_adapter ._get_scalar_type_name(schema_obj.type.fields['releaseDate']) == 'DateTime' def test_scalar_type_name_for_non_scalar_field_returns_none(schema): - type_adaptor = TypeAdaptor(schema) + type_adapter = TypeAdapter(schema) schema_obj = schema.get_query_type().fields['film'] - assert type_adaptor._get_scalar_type_name(schema_obj.type.fields['species']) is None + assert type_adapter._get_scalar_type_name(schema_obj.type.fields['species']) is None def test_lookup_scalar_type(schema): - type_adaptor = TypeAdaptor(schema) + type_adapter = TypeAdapter(schema) - assert type_adaptor._lookup_scalar_type(["film"]) is None - assert type_adaptor._lookup_scalar_type(["film", "releaseDate"]) == 'DateTime' - assert type_adaptor._lookup_scalar_type(["film", "species"]) is None + assert type_adapter._lookup_scalar_type(["film"]) is None + assert type_adapter._lookup_scalar_type(["film", "releaseDate"]) == 'DateTime' + assert type_adapter._lookup_scalar_type(["film", "species"]) is None def test_lookup_scalar_type_in_mutation(schema): - type_adaptor = TypeAdaptor(schema) + type_adapter = TypeAdapter(schema) - assert type_adaptor._lookup_scalar_type(["createHero"]) is None - assert type_adaptor._lookup_scalar_type(["createHero", "hero"]) is None - assert type_adaptor._lookup_scalar_type(["createHero", "ok"]) == 'Boolean' + assert type_adapter._lookup_scalar_type(["createHero"]) is None + assert type_adapter._lookup_scalar_type(["createHero", "hero"]) is None + assert type_adapter._lookup_scalar_type(["createHero", "ok"]) == 'Boolean' def test_parse_response(schema): - custom_scalars = { + custom_types = { 'DateTime': Capitalize } - type_adaptor = TypeAdaptor(schema, custom_scalars) + type_adapter = TypeAdapter(schema, custom_types) response = { 'film': { @@ -82,14 +82,14 @@ def test_parse_response(schema): } } - assert type_adaptor.apply(response) == expected + assert type_adapter.convert_scalars(response) == expected assert response['film']['releaseDate'] == 'some_datetime' # ensure original response is not changed def test_parse_response_containing_list(schema): - custom_scalars = { + custom_types = { 'DateTime': Capitalize } - type_adaptor = TypeAdaptor(schema, custom_scalars) + type_adapter = TypeAdapter(schema, custom_types) response = { "allFilms": { @@ -108,11 +108,11 @@ def test_parse_response_containing_list(schema): } expected = copy.deepcopy(response) - expected['allFilms']['edges'][0]['node']['releaseDate'] = "SOME_DATETIME" - expected['allFilms']['edges'][1]['node']['releaseDate'] = "SOME_OTHER_DATETIME" - - result = type_adaptor.apply(response) + expected['allFilms']['edges'][0]['node']['releaseDate'] = 'SOME_DATETIME' + expected['allFilms']['edges'][1]['node']['releaseDate'] = 'SOME_OTHER_DATETIME' + result = type_adapter.convert_scalars(response) assert result == expected - expected['allFilms']['edges'][0]['node']['releaseDate'] = "some_datetime" - expected['allFilms']['edges'][1]['node']['releaseDate'] = "some_other_datetime" + + assert response['allFilms']['edges'][0]['node']['releaseDate'] == 'some_datetime' # ensure original response is not changed + assert response['allFilms']['edges'][1]['node']['releaseDate'] == 'some_other_datetime' # ensure original response is not changed From 2988ed5f2c0a50267398baa3786fdb3a0e91503a Mon Sep 17 00:00:00 2001 From: Awais Hussain Date: Tue, 1 May 2018 11:45:26 +0000 Subject: [PATCH 08/15] Remove type annotations from TypeAdapter for py2.7 compatibility Previously we were using Python3.5+ type annotation syntax but this is not backwards compatible so we are removing it. --- gql/type_adapter.py | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/gql/type_adapter.py b/gql/type_adapter.py index 8014ae0a..37566331 100644 --- a/gql/type_adapter.py +++ b/gql/type_adapter.py @@ -1,7 +1,4 @@ -from typing import Any, Dict, Callable, Optional, List - -from graphql.type.schema import GraphQLSchema -from graphql.type.definition import GraphQLObjectType, GraphQLField, GraphQLScalarType +from graphql.type.definition import GraphQLObjectType, GraphQLScalarType class TypeAdapter(object): @@ -23,7 +20,7 @@ class TypeAdapter(object): apply(): pass in a GQL response to replace all instances of custom scalar strings with their deserialized representation.""" - def __init__(self, schema: GraphQLSchema, custom_types: Dict[str, Any] = {}) -> None: + def __init__(self, schema, custom_types = {}): """ schema: a graphQL schema in the GraphQLSchema format custom_types: a Dict[str, Any], where str is the name of the custom scalar type, and @@ -31,7 +28,7 @@ def __init__(self, schema: GraphQLSchema, custom_types: Dict[str, Any] = {}) -> self.schema = schema self.custom_types = custom_types - def _follow_type_chain(self, node: Any) -> Any: + def _follow_type_chain(self, node): """ Get the type of the schema node in question. In the GraphQL schema, GraphQLFields have a "type" property. However, often @@ -47,7 +44,7 @@ def _follow_type_chain(self, node: Any) -> Any: return field_type - def _get_scalar_type_name(self, field: GraphQLField) -> Optional[str]: + def _get_scalar_type_name(self, field): """Returns the name of the type if the type is a scalar type. Returns None otherwise""" node = self._follow_type_chain(field) @@ -55,7 +52,7 @@ def _get_scalar_type_name(self, field: GraphQLField) -> Optional[str]: return node.name return None - def _lookup_scalar_type(self, keys: List[str]) -> Optional[str]: + def _lookup_scalar_type(self, keys): """Search through the GQL schema and return the type identified by 'keys'. If keys (e.g. ['film', 'release_date']) points to a scalar type, then @@ -68,7 +65,7 @@ def _lookup_scalar_type(self, keys: List[str]) -> Optional[str]: By default the root level is `schema.query`, if that fails, then we check `schema.mutation`.""" - def traverse_schema(node: Any, lookup: List[str]): + def traverse_schema(node, lookup): if not lookup: return self._get_scalar_type_name(node) @@ -87,7 +84,7 @@ def traverse_schema(node: Any, lookup: List[str]): except (KeyError, AttributeError): return None - def _get_decoded_scalar_type(self, keys: List[str], value: Any) -> Any: + def _get_decoded_scalar_type(self, keys, value): """Get the decoded value of the type identified by `keys`. If the type is not a custom scalar, then return the original value. @@ -99,7 +96,7 @@ def _get_decoded_scalar_type(self, keys: List[str], value: Any) -> Any: return self.custom_types[scalar_type].parse_value(value) return value - def convert_scalars(self, response: Dict[str, Any]) -> Dict[str, Any]: + def convert_scalars(self, response): """Recursively traverse the GQL response Recursively traverses the GQL response and calls _get_decoded_scalar_type() @@ -110,7 +107,7 @@ def convert_scalars(self, response: Dict[str, Any]) -> Dict[str, Any]: Builds a new tree with the substituted values so old `response` is not modified.""" - def iterate(node: Any, keys: List[str] = []): + def iterate(node, keys = []): if isinstance(node, dict): return {_key: iterate(value, keys + [_key]) for _key, value in node.items()} elif isinstance(node, list): From 5839be8cb8431cdd371e1c9e7e227b28e3041058 Mon Sep 17 00:00:00 2001 From: KingDarBoja Date: Sun, 14 Jun 2020 16:52:19 -0500 Subject: [PATCH 09/15] styles: apply black, flake8 and isort formatting --- gql/client.py | 10 ++-- gql/type_adapter.py | 40 ++++++++------ tests/test_type_adapter.py | 103 ++++++++++++++++++------------------- 3 files changed, 77 insertions(+), 76 deletions(-) diff --git a/gql/client.py b/gql/client.py index 5713a994..d84deecf 100644 --- a/gql/client.py +++ b/gql/client.py @@ -29,7 +29,7 @@ def __init__( transport: Optional[Union[Transport, AsyncTransport]] = None, fetch_schema_from_transport: bool = False, execute_timeout: Optional[int] = 10, - custom_types: Dict[str, Any] = {} + custom_types: Optional[Dict[str, Any]] = None, ): assert not ( type_def and introspection @@ -220,8 +220,8 @@ def execute(self, document: DocumentNode, *args, **kwargs) -> Dict: result.data is not None ), "Transport returned an ExecutionResult without data or errors" - if self.type_adapter: - result.data = self.type_adapter.convert_scalars(result.data) + if self.client.type_adapter: + result.data = self.client.type_adapter.convert_scalars(result.data) return result.data @@ -297,8 +297,8 @@ async def execute(self, document: DocumentNode, *args, **kwargs) -> Dict: if result.errors: raise TransportQueryError(str(result.errors[0])) - if self.type_adapter: - result.data = self.type_adapter.convert_scalars(result.data) + if self.client.type_adapter: + result.data = self.client.type_adapter.convert_scalars(result.data) return result.data diff --git a/gql/type_adapter.py b/gql/type_adapter.py index 37566331..93e3af2e 100644 --- a/gql/type_adapter.py +++ b/gql/type_adapter.py @@ -1,4 +1,7 @@ -from graphql.type.definition import GraphQLObjectType, GraphQLScalarType +from typing import Optional, Dict, Any, List + +from graphql import GraphQLSchema +from graphql.type.definition import GraphQLObjectType, GraphQLScalarType, GraphQLField class TypeAdapter(object): @@ -20,15 +23,14 @@ class TypeAdapter(object): apply(): pass in a GQL response to replace all instances of custom scalar strings with their deserialized representation.""" - def __init__(self, schema, custom_types = {}): - """ schema: a graphQL schema in the GraphQLSchema format - custom_types: a Dict[str, Any], - where str is the name of the custom scalar type, and - Any is a class which has a `parse_value(str)` function""" + def __init__( + self, schema: GraphQLSchema, custom_types: Optional[Dict[str, Any]] = None + ): self.schema = schema self.custom_types = custom_types - def _follow_type_chain(self, node): + @staticmethod + def _follow_type_chain(node): """ Get the type of the schema node in question. In the GraphQL schema, GraphQLFields have a "type" property. However, often @@ -39,7 +41,7 @@ def _follow_type_chain(self, node): return node field_type = node.type - while hasattr(field_type, 'of_type'): + while hasattr(field_type, "of_type"): field_type = field_type.of_type return field_type @@ -52,7 +54,7 @@ def _get_scalar_type_name(self, field): return node.name return None - def _lookup_scalar_type(self, keys): + def _lookup_scalar_type(self, keys: List[str]): """Search through the GQL schema and return the type identified by 'keys'. If keys (e.g. ['film', 'release_date']) points to a scalar type, then @@ -65,17 +67,17 @@ def _lookup_scalar_type(self, keys): By default the root level is `schema.query`, if that fails, then we check `schema.mutation`.""" - def traverse_schema(node, lookup): + def traverse_schema(node: Optional[GraphQLField], lookup): if not lookup: return self._get_scalar_type_name(node) final_node = self._follow_type_chain(node) return traverse_schema(final_node.fields[lookup[0]], lookup[1:]) - if keys[0] in self.schema.get_query_type().fields: - schema_root = self.schema.get_query_type() - elif keys[0] in self.schema.get_mutation_type().fields: - schema_root = self.schema.get_mutation_type() + if keys[0] in self.schema.query_type.fields: + schema_root = self.schema.query_type + elif keys[0] in self.schema.mutation_type.fields: + schema_root = self.schema.mutation_type else: return None @@ -84,7 +86,7 @@ def traverse_schema(node, lookup): except (KeyError, AttributeError): return None - def _get_decoded_scalar_type(self, keys, value): + def _get_decoded_scalar_type(self, keys: List[str], value): """Get the decoded value of the type identified by `keys`. If the type is not a custom scalar, then return the original value. @@ -107,11 +109,15 @@ def convert_scalars(self, response): Builds a new tree with the substituted values so old `response` is not modified.""" - def iterate(node, keys = []): + + def iterate(node, keys: List[str] = None): if isinstance(node, dict): - return {_key: iterate(value, keys + [_key]) for _key, value in node.items()} + return { + _key: iterate(value, keys + [_key]) for _key, value in node.items() + } elif isinstance(node, list): return [(iterate(item, keys)) for item in node] else: return self._get_decoded_scalar_type(keys, node) + return iterate(response) diff --git a/tests/test_type_adapter.py b/tests/test_type_adapter.py index 49ce8b2e..c39e7346 100644 --- a/tests/test_type_adapter.py +++ b/tests/test_type_adapter.py @@ -11,108 +11,103 @@ from gql import Client from gql.transport.requests import RequestsHTTPTransport -class Capitalize(): + +class Capitalize: @classmethod - def parse_value(self, value: str): - return value.upper(); + def parse_value(cls, value: str): + return value.upper() + -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def schema(): - request = requests.get('http://swapi.graphene-python.org/graphql', - headers={ - 'Host': 'swapi.graphene-python.org', - 'Accept': 'text/html', - }) + request = requests.get( + "http://swapi.graphene-python.org/graphql", + headers={"Host": "swapi.graphene-python.org", "Accept": "text/html"}, + ) request.raise_for_status() - csrf = request.cookies['csrftoken'] + csrf = request.cookies["csrftoken"] client = Client( - transport=RequestsHTTPTransport(url='http://swapi.graphene-python.org/graphql', - cookies={"csrftoken": csrf}, - headers={'x-csrftoken': csrf}), - fetch_schema_from_transport=True + transport=RequestsHTTPTransport( + url="http://swapi.graphene-python.org/graphql", + cookies={"csrftoken": csrf}, + headers={"x-csrftoken": csrf}, + ), + fetch_schema_from_transport=True, ) return client.schema + def test_scalar_type_name_for_scalar_field_returns_name(schema): type_adapter = TypeAdapter(schema) - schema_obj = schema.get_query_type().fields['film'] + schema_obj = schema.get_query_type().fields["film"] - assert type_adapter ._get_scalar_type_name(schema_obj.type.fields['releaseDate']) == 'DateTime' + assert ( + type_adapter._get_scalar_type_name(schema_obj.type.fields["releaseDate"]) + == "DateTime" + ) def test_scalar_type_name_for_non_scalar_field_returns_none(schema): type_adapter = TypeAdapter(schema) - schema_obj = schema.get_query_type().fields['film'] + schema_obj = schema.get_query_type().fields["film"] + + assert type_adapter._get_scalar_type_name(schema_obj.type.fields["species"]) is None - assert type_adapter._get_scalar_type_name(schema_obj.type.fields['species']) is None def test_lookup_scalar_type(schema): type_adapter = TypeAdapter(schema) assert type_adapter._lookup_scalar_type(["film"]) is None - assert type_adapter._lookup_scalar_type(["film", "releaseDate"]) == 'DateTime' + assert type_adapter._lookup_scalar_type(["film", "releaseDate"]) == "DateTime" assert type_adapter._lookup_scalar_type(["film", "species"]) is None + def test_lookup_scalar_type_in_mutation(schema): type_adapter = TypeAdapter(schema) assert type_adapter._lookup_scalar_type(["createHero"]) is None assert type_adapter._lookup_scalar_type(["createHero", "hero"]) is None - assert type_adapter._lookup_scalar_type(["createHero", "ok"]) == 'Boolean' + assert type_adapter._lookup_scalar_type(["createHero", "ok"]) == "Boolean" + def test_parse_response(schema): - custom_types = { - 'DateTime': Capitalize - } + custom_types = {"DateTime": Capitalize} type_adapter = TypeAdapter(schema, custom_types) - response = { - 'film': { - 'id': 'some_id', - 'releaseDate': 'some_datetime', - } - } + response = {"film": {"id": "some_id", "releaseDate": "some_datetime"}} - expected = { - 'film': { - 'id': 'some_id', - 'releaseDate': 'SOME_DATETIME', - } - } + expected = {"film": {"id": "some_id", "releaseDate": "SOME_DATETIME"}} assert type_adapter.convert_scalars(response) == expected - assert response['film']['releaseDate'] == 'some_datetime' # ensure original response is not changed + # ensure original response is not changed + assert response["film"]["releaseDate"] == "some_datetime" + def test_parse_response_containing_list(schema): - custom_types = { - 'DateTime': Capitalize - } + custom_types = {"DateTime": Capitalize} type_adapter = TypeAdapter(schema, custom_types) response = { "allFilms": { - "edges": [{ - "node": { - 'id': 'some_id', - 'releaseDate': 'some_datetime', - } - },{ - "node": { - 'id': 'some_id', - 'releaseDate': 'some_other_datetime', - } - }] + "edges": [ + {"node": {"id": "some_id", "releaseDate": "some_datetime"}}, + {"node": {"id": "some_id", "releaseDate": "some_other_datetime"}}, + ] } } expected = copy.deepcopy(response) - expected['allFilms']['edges'][0]['node']['releaseDate'] = 'SOME_DATETIME' - expected['allFilms']['edges'][1]['node']['releaseDate'] = 'SOME_OTHER_DATETIME' + expected["allFilms"]["edges"][0]["node"]["releaseDate"] = "SOME_DATETIME" + expected["allFilms"]["edges"][1]["node"]["releaseDate"] = "SOME_OTHER_DATETIME" result = type_adapter.convert_scalars(response) assert result == expected - assert response['allFilms']['edges'][0]['node']['releaseDate'] == 'some_datetime' # ensure original response is not changed - assert response['allFilms']['edges'][1]['node']['releaseDate'] == 'some_other_datetime' # ensure original response is not changed + # ensure original response is not changed + assert response["allFilms"]["edges"][0]["node"]["releaseDate"] == "some_datetime" + # ensure original response is not changed + assert ( + response["allFilms"]["edges"][1]["node"]["releaseDate"] == "some_other_datetime" + ) From 7e432559d6b581001386827f6a4300c2f120cb90 Mon Sep 17 00:00:00 2001 From: KingDarBoja Date: Sun, 21 Jun 2020 20:43:52 -0500 Subject: [PATCH 10/15] tests: use vcr to reproduce original schema --- gql/type_adapter.py | 2 + tests/test_type_adapter.py | 80 ++++++++++++++++++++++++-------------- 2 files changed, 53 insertions(+), 29 deletions(-) diff --git a/gql/type_adapter.py b/gql/type_adapter.py index 93e3af2e..01f8b2f8 100644 --- a/gql/type_adapter.py +++ b/gql/type_adapter.py @@ -111,6 +111,8 @@ def convert_scalars(self, response): modified.""" def iterate(node, keys: List[str] = None): + if keys is None: + keys = [] if isinstance(node, dict): return { _key: iterate(value, keys + [_key]) for _key, value in node.items() diff --git a/tests/test_type_adapter.py b/tests/test_type_adapter.py index c39e7346..bc75b3eb 100644 --- a/tests/test_type_adapter.py +++ b/tests/test_type_adapter.py @@ -5,12 +5,50 @@ locally. """ import copy -from gql.type_adapter import TypeAdapter +import os + +import vcr import pytest import requests +from graphql import GraphQLSchema + from gql import Client +from gql.type_adapter import TypeAdapter from gql.transport.requests import RequestsHTTPTransport +# We serve https://github.com/graphql-python/swapi-graphene locally: +URL = "http://127.0.0.1:8000/graphql" + + +query_vcr = vcr.VCR( + cassette_library_dir=os.path.join( + os.path.dirname(__file__), "fixtures", "vcr_cassettes" + ), + record_mode="new_episodes", + match_on=["uri", "method", "body"], +) + + +def use_cassette(name): + return query_vcr.use_cassette(name + ".yaml") + + +@pytest.fixture +def client(): + with use_cassette("client"): + response = requests.get( + URL, headers={"Host": "swapi.graphene-python.org", "Accept": "text/html"} + ) + response.raise_for_status() + csrf = response.cookies["csrftoken"] + + return Client( + transport=RequestsHTTPTransport( + url=URL, cookies={"csrftoken": csrf}, headers={"x-csrftoken": csrf} + ), + fetch_schema_from_transport=True, + ) + class Capitalize: @classmethod @@ -18,40 +56,24 @@ def parse_value(cls, value: str): return value.upper() -@pytest.fixture(scope="session") -def schema(): - request = requests.get( - "http://swapi.graphene-python.org/graphql", - headers={"Host": "swapi.graphene-python.org", "Accept": "text/html"}, - ) - request.raise_for_status() - csrf = request.cookies["csrftoken"] - - client = Client( - transport=RequestsHTTPTransport( - url="http://swapi.graphene-python.org/graphql", - cookies={"csrftoken": csrf}, - headers={"x-csrftoken": csrf}, - ), - fetch_schema_from_transport=True, - ) - +@pytest.fixture() +def schema(client): return client.schema -def test_scalar_type_name_for_scalar_field_returns_name(schema): +def test_scalar_type_name_for_scalar_field_returns_name(schema: GraphQLSchema): type_adapter = TypeAdapter(schema) - schema_obj = schema.get_query_type().fields["film"] + schema_obj = schema.query_type.fields["film"] assert ( type_adapter._get_scalar_type_name(schema_obj.type.fields["releaseDate"]) - == "DateTime" + == "Date" ) -def test_scalar_type_name_for_non_scalar_field_returns_none(schema): +def test_scalar_type_name_for_non_scalar_field_returns_none(schema: GraphQLSchema): type_adapter = TypeAdapter(schema) - schema_obj = schema.get_query_type().fields["film"] + schema_obj = schema.query_type.fields["film"] assert type_adapter._get_scalar_type_name(schema_obj.type.fields["species"]) is None @@ -60,11 +82,11 @@ def test_lookup_scalar_type(schema): type_adapter = TypeAdapter(schema) assert type_adapter._lookup_scalar_type(["film"]) is None - assert type_adapter._lookup_scalar_type(["film", "releaseDate"]) == "DateTime" + assert type_adapter._lookup_scalar_type(["film", "releaseDate"]) == "Date" assert type_adapter._lookup_scalar_type(["film", "species"]) is None -def test_lookup_scalar_type_in_mutation(schema): +def test_lookup_scalar_type_in_mutation(schema: GraphQLSchema): type_adapter = TypeAdapter(schema) assert type_adapter._lookup_scalar_type(["createHero"]) is None @@ -72,8 +94,8 @@ def test_lookup_scalar_type_in_mutation(schema): assert type_adapter._lookup_scalar_type(["createHero", "ok"]) == "Boolean" -def test_parse_response(schema): - custom_types = {"DateTime": Capitalize} +def test_parse_response(schema: GraphQLSchema): + custom_types = {"Date": Capitalize} type_adapter = TypeAdapter(schema, custom_types) response = {"film": {"id": "some_id", "releaseDate": "some_datetime"}} @@ -85,7 +107,7 @@ def test_parse_response(schema): assert response["film"]["releaseDate"] == "some_datetime" -def test_parse_response_containing_list(schema): +def test_parse_response_containing_list(schema: GraphQLSchema): custom_types = {"DateTime": Capitalize} type_adapter = TypeAdapter(schema, custom_types) From 53fd1cbc8288f57c6baaa96a0533606dc3ad41e8 Mon Sep 17 00:00:00 2001 From: KingDarBoja Date: Mon, 22 Jun 2020 11:42:40 -0500 Subject: [PATCH 11/15] refactor: add better typings to type adapter --- gql/client.py | 5 ++++- gql/type_adapter.py | 34 +++++++++++++++++++--------------- tests/test_type_adapter.py | 10 +++++----- 3 files changed, 28 insertions(+), 21 deletions(-) diff --git a/gql/client.py b/gql/client.py index d84deecf..f8c4436e 100644 --- a/gql/client.py +++ b/gql/client.py @@ -70,7 +70,10 @@ def __init__( # Dictionary where the name of the custom scalar type is the key and the # value is a class which has a `parse_value()` function - self.type_adapter = TypeAdapter(schema, custom_types) if custom_types else None + if schema: + self.type_adapter = ( + TypeAdapter(schema, custom_types) if custom_types else None + ) if isinstance(transport, Transport) and fetch_schema_from_transport: with self as session: diff --git a/gql/type_adapter.py b/gql/type_adapter.py index 01f8b2f8..fa50722c 100644 --- a/gql/type_adapter.py +++ b/gql/type_adapter.py @@ -1,10 +1,10 @@ -from typing import Optional, Dict, Any, List +from typing import Any, Dict, List, Optional, Union from graphql import GraphQLSchema -from graphql.type.definition import GraphQLObjectType, GraphQLScalarType, GraphQLField +from graphql.type.definition import GraphQLField, GraphQLObjectType, GraphQLScalarType -class TypeAdapter(object): +class TypeAdapter: """Substitute custom scalars in a GQL response with their decoded counterparts. GQL custom scalar types are defined on the GQL schema and are used to represent @@ -27,7 +27,7 @@ def __init__( self, schema: GraphQLSchema, custom_types: Optional[Dict[str, Any]] = None ): self.schema = schema - self.custom_types = custom_types + self.custom_types = custom_types or {} @staticmethod def _follow_type_chain(node): @@ -48,7 +48,7 @@ def _follow_type_chain(node): def _get_scalar_type_name(self, field): """Returns the name of the type if the type is a scalar type. - Returns None otherwise""" + Returns `None` otherwise""" node = self._follow_type_chain(field) if isinstance(node, GraphQLScalarType): return node.name @@ -61,22 +61,24 @@ def _lookup_scalar_type(self, keys: List[str]): this function returns the name of that type. (e.g. 'DateTime') If it is not a scalar type (e..g a GraphQLObject), then this - function returns None. + function returns `None`. `keys` is a breadcrumb trail telling us where to look in the GraphQL schema. By default the root level is `schema.query`, if that fails, then we check `schema.mutation`.""" - def traverse_schema(node: Optional[GraphQLField], lookup): + def traverse_schema( + node: Optional[Union[GraphQLObjectType, GraphQLField]], lookup + ): if not lookup: return self._get_scalar_type_name(node) final_node = self._follow_type_chain(node) return traverse_schema(final_node.fields[lookup[0]], lookup[1:]) - if keys[0] in self.schema.query_type.fields: + if self.schema.query_type and keys[0] in self.schema.query_type.fields: schema_root = self.schema.query_type - elif keys[0] in self.schema.mutation_type.fields: + elif self.schema.mutation_type and keys[0] in self.schema.mutation_type.fields: schema_root = self.schema.mutation_type else: return None @@ -86,7 +88,7 @@ def traverse_schema(node: Optional[GraphQLField], lookup): except (KeyError, AttributeError): return None - def _get_decoded_scalar_type(self, keys: List[str], value): + def _get_decoded_scalar_type(self, keys: List[str], value: Any): """Get the decoded value of the type identified by `keys`. If the type is not a custom scalar, then return the original value. @@ -98,19 +100,21 @@ def _get_decoded_scalar_type(self, keys: List[str], value): return self.custom_types[scalar_type].parse_value(value) return value - def convert_scalars(self, response): + def convert_scalars(self, response: Dict[str, Any]): """Recursively traverse the GQL response Recursively traverses the GQL response and calls _get_decoded_scalar_type() - for all leaf nodes. The function is called with 2 arguments: - keys: List[str] is a breadcrumb trail telling us where we are in the + for all leaf nodes. + + The function is called with 2 arguments: + keys: is a breadcrumb trail telling us where we are in the response, and therefore, where to look in the GQL Schema. - value: Any is the value at that node in the response + value: is the value at that node in the response Builds a new tree with the substituted values so old `response` is not modified.""" - def iterate(node, keys: List[str] = None): + def iterate(node: Union[List, Dict, str], keys: List[str] = None): if keys is None: keys = [] if isinstance(node, dict): diff --git a/tests/test_type_adapter.py b/tests/test_type_adapter.py index bc75b3eb..fd8dc2e9 100644 --- a/tests/test_type_adapter.py +++ b/tests/test_type_adapter.py @@ -7,14 +7,14 @@ import copy import os -import vcr import pytest import requests +import vcr from graphql import GraphQLSchema from gql import Client -from gql.type_adapter import TypeAdapter from gql.transport.requests import RequestsHTTPTransport +from gql.type_adapter import TypeAdapter # We serve https://github.com/graphql-python/swapi-graphene locally: URL = "http://127.0.0.1:8000/graphql" @@ -63,7 +63,7 @@ def schema(client): def test_scalar_type_name_for_scalar_field_returns_name(schema: GraphQLSchema): type_adapter = TypeAdapter(schema) - schema_obj = schema.query_type.fields["film"] + schema_obj = schema.query_type.fields["film"] if schema.query_type else None assert ( type_adapter._get_scalar_type_name(schema_obj.type.fields["releaseDate"]) @@ -73,7 +73,7 @@ def test_scalar_type_name_for_scalar_field_returns_name(schema: GraphQLSchema): def test_scalar_type_name_for_non_scalar_field_returns_none(schema: GraphQLSchema): type_adapter = TypeAdapter(schema) - schema_obj = schema.query_type.fields["film"] + schema_obj = schema.query_type.fields["film"] if schema.query_type else None assert type_adapter._get_scalar_type_name(schema_obj.type.fields["species"]) is None @@ -108,7 +108,7 @@ def test_parse_response(schema: GraphQLSchema): def test_parse_response_containing_list(schema: GraphQLSchema): - custom_types = {"DateTime": Capitalize} + custom_types = {"Date": Capitalize} type_adapter = TypeAdapter(schema, custom_types) response = { From 6d83048966ad00bef832bec86beebb9db297f952 Mon Sep 17 00:00:00 2001 From: KingDarBoja Date: Mon, 22 Jun 2020 12:05:51 -0500 Subject: [PATCH 12/15] fix: correct check of type_adapter on client --- gql/client.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/gql/client.py b/gql/client.py index f8c4436e..64b32c53 100644 --- a/gql/client.py +++ b/gql/client.py @@ -70,10 +70,9 @@ def __init__( # Dictionary where the name of the custom scalar type is the key and the # value is a class which has a `parse_value()` function - if schema: - self.type_adapter = ( - TypeAdapter(schema, custom_types) if custom_types else None - ) + self.type_adapter = ( + TypeAdapter(schema, custom_types) if custom_types and schema else None + ) if isinstance(transport, Transport) and fetch_schema_from_transport: with self as session: @@ -283,6 +282,8 @@ async def subscribe( raise TransportQueryError(str(result.errors[0])) elif result.data is not None: + if self.client.type_adapter: + result.data = self.client.type_adapter.convert_scalars(result.data) yield result.data async def execute(self, document: DocumentNode, *args, **kwargs) -> Dict: From 5b19de5eb7027d5783b7b01598f5abe08936231f Mon Sep 17 00:00:00 2001 From: KingDarBoja Date: Mon, 29 Jun 2020 12:40:24 -0500 Subject: [PATCH 13/15] Fix mypy issues and bring extra type hints --- gql/client.py | 22 ++++++++++++++++------ gql/type_adapter.py | 12 +++++++----- 2 files changed, 23 insertions(+), 11 deletions(-) diff --git a/gql/client.py b/gql/client.py index 6a7eb562..9ef6240a 100644 --- a/gql/client.py +++ b/gql/client.py @@ -211,7 +211,9 @@ def _execute(self, document: DocumentNode, *args, **kwargs) -> ExecutionResult: return self.transport.execute(document, *args, **kwargs) - def execute(self, document: DocumentNode, *args, **kwargs) -> Dict: + def execute( + self, document: DocumentNode, *args, **kwargs + ) -> Optional[Dict[str, Any]]: # Validate and execute on the transport result = self._execute(document, *args, **kwargs) @@ -225,7 +227,9 @@ def execute(self, document: DocumentNode, *args, **kwargs) -> Dict: ), "Transport returned an ExecutionResult without data or errors" if self.client.type_adapter: - result.data = self.client.type_adapter.convert_scalars(result.data) + result = result._replace( + data=self.client.type_adapter.convert_scalars(result.data) + ) return result.data @@ -289,7 +293,7 @@ async def _subscribe( async def subscribe( self, document: DocumentNode, *args, **kwargs - ) -> AsyncGenerator[Dict, None]: + ) -> AsyncGenerator[Optional[Dict[str, Any]], None]: # Validate and subscribe on the transport async for result in self._subscribe(document, *args, **kwargs): @@ -300,7 +304,9 @@ async def subscribe( elif result.data is not None: if self.client.type_adapter: - result.data = self.client.type_adapter.convert_scalars(result.data) + result = result._replace( + data=self.client.type_adapter.convert_scalars(result.data) + ) yield result.data async def _execute( @@ -316,7 +322,9 @@ async def _execute( self.client.execute_timeout, ) - async def execute(self, document: DocumentNode, *args, **kwargs) -> Dict: + async def execute( + self, document: DocumentNode, *args, **kwargs + ) -> Optional[Dict[str, Any]]: # Validate and execute on the transport result = await self._execute(document, *args, **kwargs) @@ -330,7 +338,9 @@ async def execute(self, document: DocumentNode, *args, **kwargs) -> Dict: ), "Transport returned an ExecutionResult without data or errors" if self.client.type_adapter: - result.data = self.client.type_adapter.convert_scalars(result.data) + result = result._replace( + data=self.client.type_adapter.convert_scalars(result.data) + ) return result.data diff --git a/gql/type_adapter.py b/gql/type_adapter.py index fa50722c..cf7829c0 100644 --- a/gql/type_adapter.py +++ b/gql/type_adapter.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union, cast from graphql import GraphQLSchema from graphql.type.definition import GraphQLField, GraphQLObjectType, GraphQLScalarType @@ -88,7 +88,7 @@ def traverse_schema( except (KeyError, AttributeError): return None - def _get_decoded_scalar_type(self, keys: List[str], value: Any): + def _get_decoded_scalar_type(self, keys: List[str], value: str) -> str: """Get the decoded value of the type identified by `keys`. If the type is not a custom scalar, then return the original value. @@ -100,7 +100,7 @@ def _get_decoded_scalar_type(self, keys: List[str], value: Any): return self.custom_types[scalar_type].parse_value(value) return value - def convert_scalars(self, response: Dict[str, Any]): + def convert_scalars(self, response: Dict[str, Any]) -> Dict[str, Any]: """Recursively traverse the GQL response Recursively traverses the GQL response and calls _get_decoded_scalar_type() @@ -114,7 +114,9 @@ def convert_scalars(self, response: Dict[str, Any]): Builds a new tree with the substituted values so old `response` is not modified.""" - def iterate(node: Union[List, Dict, str], keys: List[str] = None): + def iterate( + node: Union[List, Dict, str], keys: List[str] = None + ) -> Union[Dict[str, Any], List, str]: if keys is None: keys = [] if isinstance(node, dict): @@ -126,4 +128,4 @@ def iterate(node: Union[List, Dict, str], keys: List[str] = None): else: return self._get_decoded_scalar_type(keys, node) - return iterate(response) + return cast(Dict, iterate(response)) From 598acd844d1121ced55a7ce91d07884b07591ae1 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Sun, 28 Jun 2020 21:14:42 +0200 Subject: [PATCH 14/15] Fix type_adapter for subscriptions and async transports --- gql/client.py | 53 +++++++++--------- gql/type_adapter.py | 6 ++ tests/test_async_client_validation.py | 44 +++++++++++++++ tests/test_requests.py | 56 ++++++++++++++++++ tests/test_websocket_subscription.py | 81 +++++++++++++++++++++++---- 5 files changed, 204 insertions(+), 36 deletions(-) diff --git a/gql/client.py b/gql/client.py index 9ef6240a..96110f54 100644 --- a/gql/client.py +++ b/gql/client.py @@ -67,16 +67,21 @@ def __init__( # Enforced timeout of the execute function self.execute_timeout = execute_timeout + # Fetch schema from transport directly if we are using a sync transport + if isinstance(transport, Transport) and fetch_schema_from_transport: + with self as session: + session.fetch_schema() + # Dictionary where the name of the custom scalar type is the key and the # value is a class which has a `parse_value()` function + self.custom_types = custom_types + + # Create a type_adapter instance directly here if we received the schema + # locally or from a sync transport self.type_adapter = ( TypeAdapter(schema, custom_types) if custom_types and schema else None ) - if isinstance(transport, Transport) and fetch_schema_from_transport: - with self as session: - session.fetch_schema() - def validate(self, document): assert ( self.schema @@ -211,9 +216,7 @@ def _execute(self, document: DocumentNode, *args, **kwargs) -> ExecutionResult: return self.transport.execute(document, *args, **kwargs) - def execute( - self, document: DocumentNode, *args, **kwargs - ) -> Optional[Dict[str, Any]]: + def execute(self, document: DocumentNode, *args, **kwargs) -> Dict[str, Any]: # Validate and execute on the transport result = self._execute(document, *args, **kwargs) @@ -227,11 +230,9 @@ def execute( ), "Transport returned an ExecutionResult without data or errors" if self.client.type_adapter: - result = result._replace( - data=self.client.type_adapter.convert_scalars(result.data) - ) - - return result.data + return self.client.type_adapter.convert_scalars(result.data) + else: + return result.data def fetch_schema(self) -> None: execution_result = self.transport.execute(parse(get_introspection_query())) @@ -263,6 +264,13 @@ async def fetch_and_validate(self, document: DocumentNode): if self.client.fetch_schema_from_transport and not self.client.schema: await self.fetch_schema() + # Once we have received the schema from the async transport, + # we can create a TypeAdapter instance if the user provided custom types + if self.client.custom_types and self.client.schema: + self.client.type_adapter = TypeAdapter( + self.client.schema, self.client.custom_types + ) + # Validate document if self.client.schema: self.client.validate(document) @@ -293,7 +301,7 @@ async def _subscribe( async def subscribe( self, document: DocumentNode, *args, **kwargs - ) -> AsyncGenerator[Optional[Dict[str, Any]], None]: + ) -> AsyncGenerator[Dict[str, Any], None]: # Validate and subscribe on the transport async for result in self._subscribe(document, *args, **kwargs): @@ -304,10 +312,9 @@ async def subscribe( elif result.data is not None: if self.client.type_adapter: - result = result._replace( - data=self.client.type_adapter.convert_scalars(result.data) - ) - yield result.data + yield self.client.type_adapter.convert_scalars(result.data) + else: + yield result.data async def _execute( self, document: DocumentNode, *args, **kwargs @@ -322,9 +329,7 @@ async def _execute( self.client.execute_timeout, ) - async def execute( - self, document: DocumentNode, *args, **kwargs - ) -> Optional[Dict[str, Any]]: + async def execute(self, document: DocumentNode, *args, **kwargs) -> Dict[str, Any]: # Validate and execute on the transport result = await self._execute(document, *args, **kwargs) @@ -338,11 +343,9 @@ async def execute( ), "Transport returned an ExecutionResult without data or errors" if self.client.type_adapter: - result = result._replace( - data=self.client.type_adapter.convert_scalars(result.data) - ) - - return result.data + return self.client.type_adapter.convert_scalars(result.data) + else: + return result.data async def fetch_schema(self) -> None: execution_result = await self.transport.execute( diff --git a/gql/type_adapter.py b/gql/type_adapter.py index cf7829c0..817b33ee 100644 --- a/gql/type_adapter.py +++ b/gql/type_adapter.py @@ -80,6 +80,11 @@ def traverse_schema( schema_root = self.schema.query_type elif self.schema.mutation_type and keys[0] in self.schema.mutation_type.fields: schema_root = self.schema.mutation_type + elif ( + self.schema.subscription_type + and keys[0] in self.schema.subscription_type.fields + ): + schema_root = self.schema.subscription_type else: return None @@ -95,6 +100,7 @@ def _get_decoded_scalar_type(self, keys: List[str], value: str) -> str: If it is a custom scalar, return the deserialized value, as output by `.parse_value()`""" + scalar_type = self._lookup_scalar_type(keys) if scalar_type and scalar_type in self.custom_types: return self.custom_types[scalar_type].parse_value(value) diff --git a/tests/test_async_client_validation.py b/tests/test_async_client_validation.py index ec651866..cd6197d0 100644 --- a/tests/test_async_client_validation.py +++ b/tests/test_async_client_validation.py @@ -267,3 +267,47 @@ async def test_async_client_validation_fetch_schema_from_server_with_client_argu with pytest.raises(graphql.error.GraphQLError): await session.execute(query) + + +class ToLowercase: + @staticmethod + def parse_value(value: str): + return value.lower() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [hero_server_answers], indirect=True) +async def test_async_client_validation_fetch_schema_from_server_with_custom_types( + event_loop, server +): + + url = f"ws://{server.hostname}:{server.port}/graphql" + + sample_transport = WebsocketsTransport(url=url) + + custom_types = {"String": ToLowercase} + + async with Client( + transport=sample_transport, + fetch_schema_from_transport=True, + custom_types=custom_types, + ) as session: + + query = gql( + """ + query HeroNameQuery { + hero { + name + } + } + """ + ) + + result = await session.execute(query) + + print("Client received:", result) + + # The expected hero name is now in lowercase + expected = {"hero": {"name": "r2-d2"}} + + assert result == expected diff --git a/tests/test_requests.py b/tests/test_requests.py index 24fab2d2..b31f0834 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -13,6 +13,8 @@ ) from gql.transport.requests import RequestsHTTPTransport +from .test_type_adapter import Capitalize + query1_str = """ query getContinents { continents { @@ -201,3 +203,57 @@ def test_code(): sample_transport.execute(query) await run_sync_test(event_loop, server, test_code) + + +partial_schema = """ + +type Continent { + code: ID! + name: String! +} + +type Query { + continents: [Continent!]! +} + +""" + + +@pytest.mark.asyncio +async def test_requests_query_with_custom_types(event_loop, aiohttp_server): + async def handler(request): + return web.Response(text=query1_server_answer, content_type="application/json") + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + def test_code(): + sample_transport = RequestsHTTPTransport(url=url) + + custom_types = {"String": Capitalize} + + # Instanciate a client which will capitalize all the String scalars + with Client( + transport=sample_transport, + type_def=partial_schema, + custom_types=custom_types, + ) as session: + + query = gql(query1_str) + + # Execute query synchronously + result = session.execute(query) + + continents = result["continents"] + + africa = continents[0] + + assert africa["code"] == "AF" + + # Check that the string is capitalized + assert africa["name"] == "AFRICA" + + await run_sync_test(event_loop, server, test_code) diff --git a/tests/test_websocket_subscription.py b/tests/test_websocket_subscription.py index 2a9942ff..ec135a34 100644 --- a/tests/test_websocket_subscription.py +++ b/tests/test_websocket_subscription.py @@ -10,9 +10,24 @@ from .conftest import MS, WebSocketServer -countdown_server_answer = ( - '{{"type":"data","id":"{query_id}","payload":{{"data":{{"number":{number}}}}}}}' -) +countdown_schema = """ + +type Number { + number: Int! +} + +type Query { + currentCount: Number! +} + +type Subscription { + countdown(count: Int!): Number! +} + +""" + +countdown_server_answer = '{{"type":"data","id":"{query_id}","payload":\ +{{"data":{{"countdown":{{"number":{number}}}}}}}}}' WITH_KEEPALIVE = False @@ -111,7 +126,7 @@ async def test_websocket_subscription(event_loop, client_and_server, subscriptio async for result in session.subscribe(subscription): - number = result["number"] + number = result["countdown"]["number"] print(f"Number received: {number}") assert number == count @@ -134,7 +149,7 @@ async def test_websocket_subscription_break( async for result in session.subscribe(subscription): - number = result["number"] + number = result["countdown"]["number"] print(f"Number received: {number}") assert number == count @@ -165,7 +180,7 @@ async def task_coro(): nonlocal count async for result in session.subscribe(subscription): - number = result["number"] + number = result["countdown"]["number"] print(f"Number received: {number}") assert number == count @@ -204,7 +219,7 @@ async def task_coro(): nonlocal count async for result in session.subscribe(subscription): - number = result["number"] + number = result["countdown"]["number"] print(f"Number received: {number}") assert number == count @@ -269,7 +284,7 @@ async def test_websocket_subscription_server_connection_closed( async for result in session.subscribe(subscription): - number = result["number"] + number = result["countdown"]["number"] print(f"Number received: {number}") assert number == count @@ -292,7 +307,7 @@ async def test_websocket_subscription_slow_consumer( async for result in session.subscribe(subscription): await asyncio.sleep(10 * MS) - number = result["number"] + number = result["countdown"]["number"] print(f"Number received: {number}") assert number == count @@ -319,7 +334,7 @@ async def test_websocket_subscription_with_keepalive( async for result in session.subscribe(subscription): - number = result["number"] + number = result["countdown"]["number"] print(f"Number received: {number}") assert number == count @@ -344,10 +359,54 @@ def test_websocket_subscription_sync(server, subscription_str): for result in client.subscribe(subscription): - number = result["number"] + number = result["countdown"]["number"] print(f"Number received: {number}") assert number == count count -= 1 assert count == -1 + + +class NumberAddParser: + """ Class with a parse_value method used to increment a number """ + + def __init__(self, increment: int): + self.increment: int = increment + + def parse_value(self, value: int) -> int: + return value + self.increment + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_websocket_subscription_with_custom_types( + event_loop, server, subscription_str +): + + url = f"ws://{server.hostname}:{server.port}/graphql" + + sample_transport = WebsocketsTransport(url=url) + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + add_10 = NumberAddParser(10) + + custom_types = {"Int": add_10} + + # Instanciate a client which will add 10 to all the scalars of type Int received + async with Client( + transport=sample_transport, custom_types=custom_types, type_def=countdown_schema + ) as session: + async for result in session.subscribe(subscription): + + number = result["countdown"]["number"] + print(f"Number received: {number}") + + # We check here that the Int scalar has been correctly incremented by 10 + assert number == count + 10 + count -= 1 + + assert count == -1 From a98b272c0015b68c4cf74ab6cb18f103551f4e9c Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Mon, 29 Jun 2020 23:56:36 +0200 Subject: [PATCH 15/15] Fix type hints - scalar values can be str or int or something else --- gql/type_adapter.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gql/type_adapter.py b/gql/type_adapter.py index 817b33ee..6a987f27 100644 --- a/gql/type_adapter.py +++ b/gql/type_adapter.py @@ -93,7 +93,7 @@ def traverse_schema( except (KeyError, AttributeError): return None - def _get_decoded_scalar_type(self, keys: List[str], value: str) -> str: + def _get_decoded_scalar_type(self, keys: List[str], value: Any) -> Any: """Get the decoded value of the type identified by `keys`. If the type is not a custom scalar, then return the original value. @@ -122,7 +122,7 @@ def convert_scalars(self, response: Dict[str, Any]) -> Dict[str, Any]: def iterate( node: Union[List, Dict, str], keys: List[str] = None - ) -> Union[Dict[str, Any], List, str]: + ) -> Union[Dict[str, Any], List, Any]: if keys is None: keys = [] if isinstance(node, dict):