From 246df9698253ba1b1fc7def858ba2233e46964ac Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Fri, 6 Nov 2020 22:53:29 +0100 Subject: [PATCH 01/21] DSL Refactor 1 Adding type annotations Replace selections generator by get_ast_fields function Remove ast property Replace serialize_list method by a lambda Remove obsolete test_nested_input_with_old_get_arg_serializer test --- gql/dsl.py | 142 ++++++++++++++---------- tests/nested_input/test_nested_input.py | 49 +------- 2 files changed, 83 insertions(+), 108 deletions(-) diff --git a/gql/dsl.py b/gql/dsl.py index 27894fcf..6d71e444 100644 --- a/gql/dsl.py +++ b/gql/dsl.py @@ -1,5 +1,7 @@ +from __future__ import annotations + from collections.abc import Iterable -from functools import partial +from typing import Any, Callable, Dict, List, Optional, Tuple, cast from graphql import ( ArgumentNode, @@ -7,10 +9,14 @@ EnumValueNode, FieldNode, GraphQLEnumType, + GraphQLField, GraphQLInputField, GraphQLInputObjectType, + GraphQLInputType, GraphQLList, GraphQLNonNull, + GraphQLObjectType, + GraphQLScalarType, ListValueNode, NameNode, ObjectFieldNode, @@ -18,6 +24,7 @@ OperationDefinitionNode, OperationType, SelectionSetNode, + ValueNode, ast_from_value, print_ast, ) @@ -26,7 +33,7 @@ from .utils import to_camel_case -class DSLSchema(object): +class DSLSchema: def __init__(self, client): self.client = client @@ -34,8 +41,8 @@ def __init__(self, client): def schema(self): return self.client.schema - def __getattr__(self, name): - type_def = self.schema.get_type(name) + def __getattr__(self, name: str) -> DSLType: + type_def: GraphQLObjectType = self.schema.get_type(name) return DSLType(type_def) def query(self, *args, **kwargs): @@ -48,15 +55,15 @@ def execute(self, document): return self.client.execute(document) -class DSLType(object): - def __init__(self, type_): - self._type = type_ +class DSLType: + def __init__(self, type_: GraphQLObjectType): + self._type: GraphQLObjectType = type_ - def __getattr__(self, name): + def __getattr__(self, name: str) -> DSLField: formatted_name, field_def = self.get_field(name) return DSLField(formatted_name, field_def) - def get_field(self, name): + def get_field(self, name: str) -> Tuple[str, GraphQLField]: camel_cased_name = to_camel_case(name) if name in self._type.fields: @@ -68,38 +75,60 @@ def get_field(self, name): raise KeyError(f"Field {name} does not exist in type {self._type.name}.") -def selections(*fields): - for _field in fields: - yield selection_field(_field).ast +def get_ast_fields(fields: Iterable[DSLField]) -> List[FieldNode]: + """ + Roughtly equivalent to: [field.ast_field for field in fields] + But with a type check for each field in the list + + Raises a TypeError if any of the provided fields are not of the DSLField type + """ + ast_fields = [] + for field in fields: + if isinstance(field, DSLField): + ast_fields.append(field.ast_field) + else: + raise TypeError(f'Received incompatible query field: "{field}".') + + return ast_fields + + +class DSLField: + # Definition of field from the schema + field: GraphQLField -class DSLField(object): - def __init__(self, name, field): + # Current selection in the query + ast_field: FieldNode + + def __init__(self, name: str, field: GraphQLField): self.field = field self.ast_field = FieldNode(name=NameNode(value=name), arguments=FrozenList()) - self.selection_set = None - - def select(self, *fields): - selection_set = self.ast_field.selection_set - added_selections = selections(*fields) - if selection_set: - selection_set.selections = FrozenList( - selection_set.selections + list(added_selections) - ) - else: + + def select(self, *fields: DSLField) -> DSLField: + + added_selections: List[FieldNode] = get_ast_fields(fields) + + current_selection_set: Optional[SelectionSetNode] = self.ast_field.selection_set + + if current_selection_set is None: self.ast_field.selection_set = SelectionSetNode( selections=FrozenList(added_selections) ) + else: + current_selection_set.selections = FrozenList( + current_selection_set.selections + added_selections + ) + return self - def __call__(self, **kwargs): + def __call__(self, **kwargs) -> DSLField: return self.args(**kwargs) - def alias(self, alias): + def alias(self, alias: str) -> DSLField: self.ast_field.alias = NameNode(value=alias) return self - def args(self, **kwargs): + def args(self, **kwargs) -> DSLField: added_args = [] for name, value in kwargs.items(): arg = self.field.args.get(name) @@ -110,55 +139,41 @@ def args(self, **kwargs): added_args.append( ArgumentNode(name=NameNode(value=name), value=serialized_value) ) - ast_field = self.ast_field - ast_field.arguments = FrozenList(ast_field.arguments + added_args) + self.ast_field.arguments = FrozenList(self.ast_field.arguments + added_args) return self - @property - def ast(self): - return self.ast_field - - def __str__(self): + def __str__(self) -> str: return print_ast(self.ast_field) -def selection_field(field): - if isinstance(field, DSLField): - return field +def query(*fields: DSLField, operation: str = "query") -> DocumentNode: - raise TypeError(f'Received incompatible query field: "{field}".') - - -def query(*fields, **kwargs): - if "operation" not in kwargs: - kwargs["operation"] = "query" return DocumentNode( definitions=[ OperationDefinitionNode( - operation=OperationType(kwargs["operation"]), + operation=OperationType(operation), selection_set=SelectionSetNode( - selections=FrozenList(selections(*fields)) + selections=FrozenList(get_ast_fields(fields)) ), ) ] ) -def serialize_list(serializer, list_values): - assert isinstance( - list_values, Iterable - ), f'Expected iterable, received "{list_values!r}".' - return ListValueNode(values=FrozenList(serializer(v) for v in list_values)) +Serializer = Callable[[Any], Optional[ValueNode]] -def get_arg_serializer(arg_type, known_serializers): +def get_arg_serializer( + arg_type: GraphQLInputType, + known_serializers: Dict[GraphQLInputType, Optional[Serializer]], +) -> Serializer: if isinstance(arg_type, GraphQLNonNull): return get_arg_serializer(arg_type.of_type, known_serializers) - if isinstance(arg_type, GraphQLInputField): + elif isinstance(arg_type, GraphQLInputField): return get_arg_serializer(arg_type.type, known_serializers) - if isinstance(arg_type, GraphQLInputObjectType): + elif isinstance(arg_type, GraphQLInputObjectType): if arg_type in known_serializers: - return known_serializers[arg_type] + return cast(Serializer, known_serializers[arg_type]) known_serializers[arg_type] = None serializers = { k: get_arg_serializer(v, known_serializers) @@ -170,10 +185,17 @@ def get_arg_serializer(arg_type, known_serializers): for k, v in value.items() ) ) - return known_serializers[arg_type] - if isinstance(arg_type, GraphQLList): + return cast(Serializer, known_serializers[arg_type]) + elif isinstance(arg_type, GraphQLList): inner_serializer = get_arg_serializer(arg_type.of_type, known_serializers) - return partial(serialize_list, inner_serializer) - if isinstance(arg_type, GraphQLEnumType): - return lambda value: EnumValueNode(value=arg_type.serialize(value)) - return lambda value: ast_from_value(arg_type.serialize(value), arg_type) + return lambda list_values: ListValueNode( + values=FrozenList(inner_serializer(v) for v in list_values) + ) + elif isinstance(arg_type, GraphQLEnumType): + return lambda value: EnumValueNode( + value=cast(GraphQLEnumType, arg_type).serialize(value) + ) + + return lambda value: ast_from_value( + cast(GraphQLScalarType, arg_type).serialize(value), arg_type + ) diff --git a/tests/nested_input/test_nested_input.py b/tests/nested_input/test_nested_input.py index 037d1518..e8a0ca17 100644 --- a/tests/nested_input/test_nested_input.py +++ b/tests/nested_input/test_nested_input.py @@ -1,49 +1,9 @@ -from functools import partial - import pytest -from graphql import ( - EnumValueNode, - GraphQLEnumType, - GraphQLInputField, - GraphQLInputObjectType, - GraphQLList, - GraphQLNonNull, - NameNode, - ObjectFieldNode, - ObjectValueNode, - ast_from_value, -) -from graphql.pyutils import FrozenList -import gql.dsl as dsl from gql import Client -from gql.dsl import DSLSchema, serialize_list +from gql.dsl import DSLSchema from tests.nested_input.schema import NestedInputSchema -# back up the new func -new_get_arg_serializer = dsl.get_arg_serializer - - -def old_get_arg_serializer(arg_type, known_serializers=None): - if isinstance(arg_type, GraphQLNonNull): - return old_get_arg_serializer(arg_type.of_type) - if isinstance(arg_type, GraphQLInputField): - return old_get_arg_serializer(arg_type.type) - if isinstance(arg_type, GraphQLInputObjectType): - serializers = {k: old_get_arg_serializer(v) for k, v in arg_type.fields.items()} - return lambda value: ObjectValueNode( - fields=FrozenList( - ObjectFieldNode(name=NameNode(value=k), value=serializers[k](v)) - for k, v in value.items() - ) - ) - if isinstance(arg_type, GraphQLList): - inner_serializer = old_get_arg_serializer(arg_type.of_type) - return partial(serialize_list, inner_serializer) - if isinstance(arg_type, GraphQLEnumType): - return lambda value: EnumValueNode(value=arg_type.serialize(value)) - return lambda value: ast_from_value(arg_type.serialize(value), arg_type) - @pytest.fixture def ds(): @@ -52,12 +12,5 @@ def ds(): return ds -def test_nested_input_with_old_get_arg_serializer(ds): - dsl.get_arg_serializer = old_get_arg_serializer - with pytest.raises(RecursionError, match="maximum recursion depth exceeded"): - ds.query(ds.Query.foo.args(nested={"foo": 1})) - - def test_nested_input_with_new_get_arg_serializer(ds): - dsl.get_arg_serializer = new_get_arg_serializer assert ds.query(ds.Query.foo.args(nested={"foo": 1})) == {"foo": 1} From 9d501a3d1be619b6f542aa8972a46137ac855e4d Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Fri, 6 Nov 2020 23:08:02 +0100 Subject: [PATCH 02/21] DSL Refactor 2 Move get_arg_serializer function inside the DSLField class Rename get_field method of DSLType to _get_field --- gql/dsl.py | 87 +++++++++++++++++++++++++++--------------------------- 1 file changed, 43 insertions(+), 44 deletions(-) diff --git a/gql/dsl.py b/gql/dsl.py index 6d71e444..6d8106f1 100644 --- a/gql/dsl.py +++ b/gql/dsl.py @@ -60,10 +60,10 @@ def __init__(self, type_: GraphQLObjectType): self._type: GraphQLObjectType = type_ def __getattr__(self, name: str) -> DSLField: - formatted_name, field_def = self.get_field(name) + formatted_name, field_def = self._get_field(name) return DSLField(formatted_name, field_def) - def get_field(self, name: str) -> Tuple[str, GraphQLField]: + def _get_field(self, name: str) -> Tuple[str, GraphQLField]: camel_cased_name = to_camel_case(name) if name in self._type.fields: @@ -92,6 +92,9 @@ def get_ast_fields(fields: Iterable[DSLField]) -> List[FieldNode]: return ast_fields +Serializer = Callable[[Any], Optional[ValueNode]] + + class DSLField: # Definition of field from the schema @@ -100,9 +103,13 @@ class DSLField: # Current selection in the query ast_field: FieldNode + # Known serializers + known_serializers: Dict[GraphQLInputType, Optional[Serializer]] + def __init__(self, name: str, field: GraphQLField): self.field = field self.ast_field = FieldNode(name=NameNode(value=name), arguments=FrozenList()) + self.known_serializers = dict() def select(self, *fields: DSLField) -> DSLField: @@ -134,7 +141,7 @@ def args(self, **kwargs) -> DSLField: arg = self.field.args.get(name) if not arg: raise KeyError(f"Argument {name} does not exist in {self.field}.") - arg_type_serializer = get_arg_serializer(arg.type, known_serializers=dict()) + arg_type_serializer = self._get_arg_serializer(arg.type) serialized_value = arg_type_serializer(value) added_args.append( ArgumentNode(name=NameNode(value=name), value=serialized_value) @@ -142,6 +149,39 @@ def args(self, **kwargs) -> DSLField: self.ast_field.arguments = FrozenList(self.ast_field.arguments + added_args) return self + def _get_arg_serializer(self, arg_type: GraphQLInputType,) -> Serializer: + if isinstance(arg_type, GraphQLNonNull): + return self._get_arg_serializer(arg_type.of_type) + elif isinstance(arg_type, GraphQLInputField): + return self._get_arg_serializer(arg_type.type) + elif isinstance(arg_type, GraphQLInputObjectType): + if arg_type in self.known_serializers: + return cast(Serializer, self.known_serializers[arg_type]) + self.known_serializers[arg_type] = None + serializers = { + k: self._get_arg_serializer(v) for k, v in arg_type.fields.items() + } + self.known_serializers[arg_type] = lambda value: ObjectValueNode( + fields=FrozenList( + ObjectFieldNode(name=NameNode(value=k), value=serializers[k](v)) + for k, v in value.items() + ) + ) + return cast(Serializer, self.known_serializers[arg_type]) + elif isinstance(arg_type, GraphQLList): + inner_serializer = self._get_arg_serializer(arg_type.of_type) + return lambda list_values: ListValueNode( + values=FrozenList(inner_serializer(v) for v in list_values) + ) + elif isinstance(arg_type, GraphQLEnumType): + return lambda value: EnumValueNode( + value=cast(GraphQLEnumType, arg_type).serialize(value) + ) + + return lambda value: ast_from_value( + cast(GraphQLScalarType, arg_type).serialize(value), arg_type + ) + def __str__(self) -> str: return print_ast(self.ast_field) @@ -158,44 +198,3 @@ def query(*fields: DSLField, operation: str = "query") -> DocumentNode: ) ] ) - - -Serializer = Callable[[Any], Optional[ValueNode]] - - -def get_arg_serializer( - arg_type: GraphQLInputType, - known_serializers: Dict[GraphQLInputType, Optional[Serializer]], -) -> Serializer: - if isinstance(arg_type, GraphQLNonNull): - return get_arg_serializer(arg_type.of_type, known_serializers) - elif isinstance(arg_type, GraphQLInputField): - return get_arg_serializer(arg_type.type, known_serializers) - elif isinstance(arg_type, GraphQLInputObjectType): - if arg_type in known_serializers: - return cast(Serializer, known_serializers[arg_type]) - known_serializers[arg_type] = None - serializers = { - k: get_arg_serializer(v, known_serializers) - for k, v in arg_type.fields.items() - } - known_serializers[arg_type] = lambda value: ObjectValueNode( - fields=FrozenList( - ObjectFieldNode(name=NameNode(value=k), value=serializers[k](v)) - for k, v in value.items() - ) - ) - return cast(Serializer, known_serializers[arg_type]) - elif isinstance(arg_type, GraphQLList): - inner_serializer = get_arg_serializer(arg_type.of_type, known_serializers) - return lambda list_values: ListValueNode( - values=FrozenList(inner_serializer(v) for v in list_values) - ) - elif isinstance(arg_type, GraphQLEnumType): - return lambda value: EnumValueNode( - value=cast(GraphQLEnumType, arg_type).serialize(value) - ) - - return lambda value: ast_from_value( - cast(GraphQLScalarType, arg_type).serialize(value), arg_type - ) From 921ea693a78b37a36bf74d04bb4176d632f45e74 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Sat, 7 Nov 2020 01:24:40 +0100 Subject: [PATCH 03/21] DSL Refactor 3 DSLSchema now requires a schema instead of a client query and mutate methods of DSLSchema have been replaced by a gql method which will return a DocumentNode --- gql/dsl.py | 66 +++++++++++++------------ tests/nested_input/test_nested_input.py | 14 ++++-- tests/starwars/test_dsl.py | 41 +++++++++------ 3 files changed, 68 insertions(+), 53 deletions(-) diff --git a/gql/dsl.py b/gql/dsl.py index 6d8106f1..97e17f89 100644 --- a/gql/dsl.py +++ b/gql/dsl.py @@ -1,7 +1,7 @@ from __future__ import annotations from collections.abc import Iterable -from typing import Any, Callable, Dict, List, Optional, Tuple, cast +from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast from graphql import ( ArgumentNode, @@ -13,10 +13,13 @@ GraphQLInputField, GraphQLInputObjectType, GraphQLInputType, + GraphQLInterfaceType, GraphQLList, + GraphQLNamedType, GraphQLNonNull, GraphQLObjectType, GraphQLScalarType, + GraphQLSchema, ListValueNode, NameNode, ObjectFieldNode, @@ -32,32 +35,45 @@ from .utils import to_camel_case +GraphQLTypeWithFields = Union[GraphQLObjectType, GraphQLInterfaceType] + class DSLSchema: - def __init__(self, client): - self.client = client + def __init__(self, schema: GraphQLSchema): + + assert isinstance( + schema, GraphQLSchema + ), "DSLSchema needs a schema as parameter" - @property - def schema(self): - return self.client.schema + self._schema: GraphQLSchema = schema def __getattr__(self, name: str) -> DSLType: - type_def: GraphQLObjectType = self.schema.get_type(name) - return DSLType(type_def) - def query(self, *args, **kwargs): - return self.execute(query(*args, **kwargs)) + type_def: Optional[GraphQLNamedType] = self._schema.get_type(name) - def mutate(self, *args, **kwargs): - return self.query(*args, operation="mutation", **kwargs) + assert isinstance(type_def, GraphQLObjectType) or isinstance( + type_def, GraphQLInterfaceType + ) - def execute(self, document): - return self.client.execute(document) + return DSLType(type_def) + + def gql(self, *fields: DSLField, operation: str = "query") -> DocumentNode: + + return DocumentNode( + definitions=[ + OperationDefinitionNode( + operation=OperationType(operation), + selection_set=SelectionSetNode( + selections=FrozenList(get_ast_fields(fields)) + ), + ) + ] + ) class DSLType: - def __init__(self, type_: GraphQLObjectType): - self._type: GraphQLObjectType = type_ + def __init__(self, type_: GraphQLTypeWithFields): + self._type: GraphQLTypeWithFields = type_ def __getattr__(self, name: str) -> DSLField: formatted_name, field_def = self._get_field(name) @@ -77,7 +93,7 @@ def _get_field(self, name: str) -> Tuple[str, GraphQLField]: def get_ast_fields(fields: Iterable[DSLField]) -> List[FieldNode]: """ - Roughtly equivalent to: [field.ast_field for field in fields] + Equivalent to: [field.ast_field for field in fields] But with a type check for each field in the list Raises a TypeError if any of the provided fields are not of the DSLField type @@ -87,7 +103,7 @@ def get_ast_fields(fields: Iterable[DSLField]) -> List[FieldNode]: if isinstance(field, DSLField): ast_fields.append(field.ast_field) else: - raise TypeError(f'Received incompatible query field: "{field}".') + raise TypeError(f'Received incompatible field: "{field}".') return ast_fields @@ -184,17 +200,3 @@ def _get_arg_serializer(self, arg_type: GraphQLInputType,) -> Serializer: def __str__(self) -> str: return print_ast(self.ast_field) - - -def query(*fields: DSLField, operation: str = "query") -> DocumentNode: - - return DocumentNode( - definitions=[ - OperationDefinitionNode( - operation=OperationType(operation), - selection_set=SelectionSetNode( - selections=FrozenList(get_ast_fields(fields)) - ), - ) - ] - ) diff --git a/tests/nested_input/test_nested_input.py b/tests/nested_input/test_nested_input.py index e8a0ca17..4c69c47c 100644 --- a/tests/nested_input/test_nested_input.py +++ b/tests/nested_input/test_nested_input.py @@ -7,10 +7,14 @@ @pytest.fixture def ds(): - client = Client(schema=NestedInputSchema) - ds = DSLSchema(client) - return ds + return DSLSchema(NestedInputSchema) -def test_nested_input_with_new_get_arg_serializer(ds): - assert ds.query(ds.Query.foo.args(nested={"foo": 1})) == {"foo": 1} +@pytest.fixture +def client(): + return Client(schema=NestedInputSchema) + + +def test_nested_input_with_new_get_arg_serializer(ds, client): + query = ds.gql(ds.Query.foo.args(nested={"foo": 1})) + assert client.execute(query) == {"foo": 1} diff --git a/tests/starwars/test_dsl.py b/tests/starwars/test_dsl.py index 12a2f76b..42f08a69 100644 --- a/tests/starwars/test_dsl.py +++ b/tests/starwars/test_dsl.py @@ -8,9 +8,12 @@ @pytest.fixture def ds(): - client = Client(schema=StarWarsSchema) - ds = DSLSchema(client) - return ds + return DSLSchema(StarWarsSchema) + + +@pytest.fixture +def client(): + return Client(schema=StarWarsSchema) def test_invalid_field_on_type_query(ds): @@ -19,10 +22,10 @@ def test_invalid_field_on_type_query(ds): assert "Field extras does not exist in type Query." in str(exc_info.value) -def test_incompatible_query_field(ds): +def test_incompatible_field(ds): with pytest.raises(Exception) as exc_info: - ds.query("hero") - assert "Received incompatible query field" in str(exc_info.value) + ds.gql("hero") + assert "Received incompatible field" in str(exc_info.value) def test_hero_name_query(ds): @@ -110,16 +113,18 @@ def test_fetch_luke_aliased(ds): assert query == str(query_dsl) -def test_hero_name_query_result(ds): - result = ds.query(ds.Query.hero.select(ds.Character.name)) +def test_hero_name_query_result(ds, client): + query = ds.gql(ds.Query.hero.select(ds.Character.name)) + result = client.execute(query) expected = {"hero": {"name": "R2-D2"}} assert result == expected -def test_arg_serializer_list(ds): - result = ds.query( +def test_arg_serializer_list(ds, client): + query = ds.gql( ds.Query.characters.args(ids=[1000, 1001, 1003]).select(ds.Character.name,) ) + result = client.execute(query) expected = { "characters": [ {"name": "Luke Skywalker"}, @@ -130,18 +135,22 @@ def test_arg_serializer_list(ds): assert result == expected -def test_arg_serializer_enum(ds): - result = ds.query(ds.Query.hero.args(episode=5).select(ds.Character.name)) +def test_arg_serializer_enum(ds, client): + query = ds.gql(ds.Query.hero.args(episode=5).select(ds.Character.name)) + result = client.execute(query) expected = {"hero": {"name": "Luke Skywalker"}} assert result == expected -def test_create_review_mutation_result(ds): - result = ds.mutate( +def test_create_review_mutation_result(ds, client): + + query = ds.gql( ds.Mutation.createReview.args( episode=6, review={"stars": 5, "commentary": "This is a great movie!"} - ).select(ds.Review.stars, ds.Review.commentary) + ).select(ds.Review.stars, ds.Review.commentary), + operation="mutation", ) + result = client.execute(query) expected = {"createReview": {"stars": 5, "commentary": "This is a great movie!"}} assert result == expected @@ -150,4 +159,4 @@ def test_invalid_arg(ds): with pytest.raises( KeyError, match="Argument invalid_arg does not exist in Field: Character." ): - ds.query(ds.Query.hero.args(invalid_arg=5).select(ds.Character.name)) + ds.Query.hero.args(invalid_arg=5).select(ds.Character.name) From 8d34d318eac1b2ae3542fdf0578b20a1352d0c85 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Sat, 7 Nov 2020 01:30:41 +0100 Subject: [PATCH 04/21] DSL Refactor 4 Set get_ast_field as a staticmethod of the DSLField class --- gql/dsl.py | 42 ++++++++++++++++++++---------------------- 1 file changed, 20 insertions(+), 22 deletions(-) diff --git a/gql/dsl.py b/gql/dsl.py index 97e17f89..e8375c8d 100644 --- a/gql/dsl.py +++ b/gql/dsl.py @@ -36,6 +36,7 @@ from .utils import to_camel_case GraphQLTypeWithFields = Union[GraphQLObjectType, GraphQLInterfaceType] +Serializer = Callable[[Any], Optional[ValueNode]] class DSLSchema: @@ -64,7 +65,7 @@ def gql(self, *fields: DSLField, operation: str = "query") -> DocumentNode: OperationDefinitionNode( operation=OperationType(operation), selection_set=SelectionSetNode( - selections=FrozenList(get_ast_fields(fields)) + selections=FrozenList(DSLField.get_ast_fields(fields)) ), ) ] @@ -91,26 +92,6 @@ def _get_field(self, name: str) -> Tuple[str, GraphQLField]: raise KeyError(f"Field {name} does not exist in type {self._type.name}.") -def get_ast_fields(fields: Iterable[DSLField]) -> List[FieldNode]: - """ - Equivalent to: [field.ast_field for field in fields] - But with a type check for each field in the list - - Raises a TypeError if any of the provided fields are not of the DSLField type - """ - ast_fields = [] - for field in fields: - if isinstance(field, DSLField): - ast_fields.append(field.ast_field) - else: - raise TypeError(f'Received incompatible field: "{field}".') - - return ast_fields - - -Serializer = Callable[[Any], Optional[ValueNode]] - - class DSLField: # Definition of field from the schema @@ -127,9 +108,26 @@ def __init__(self, name: str, field: GraphQLField): self.ast_field = FieldNode(name=NameNode(value=name), arguments=FrozenList()) self.known_serializers = dict() + @staticmethod + def get_ast_fields(fields: Iterable[DSLField]) -> List[FieldNode]: + """ + Equivalent to: [field.ast_field for field in fields] + But with a type check for each field in the list + + Raises a TypeError if any of the provided fields are not of the DSLField type + """ + ast_fields = [] + for field in fields: + if isinstance(field, DSLField): + ast_fields.append(field.ast_field) + else: + raise TypeError(f'Received incompatible field: "{field}".') + + return ast_fields + def select(self, *fields: DSLField) -> DSLField: - added_selections: List[FieldNode] = get_ast_fields(fields) + added_selections: List[FieldNode] = self.get_ast_fields(fields) current_selection_set: Optional[SelectionSetNode] = self.ast_field.selection_set From 2b80179d0fd039300e90252fe9c7188c1fbeaa17 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Sat, 7 Nov 2020 01:41:29 +0100 Subject: [PATCH 05/21] Fix annotations for python 3.6 --- gql/dsl.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/gql/dsl.py b/gql/dsl.py index e8375c8d..84261e42 100644 --- a/gql/dsl.py +++ b/gql/dsl.py @@ -1,5 +1,3 @@ -from __future__ import annotations - from collections.abc import Iterable from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast @@ -48,7 +46,7 @@ def __init__(self, schema: GraphQLSchema): self._schema: GraphQLSchema = schema - def __getattr__(self, name: str) -> DSLType: + def __getattr__(self, name: str) -> "DSLType": type_def: Optional[GraphQLNamedType] = self._schema.get_type(name) @@ -58,7 +56,7 @@ def __getattr__(self, name: str) -> DSLType: return DSLType(type_def) - def gql(self, *fields: DSLField, operation: str = "query") -> DocumentNode: + def gql(self, *fields: "DSLField", operation: str = "query") -> DocumentNode: return DocumentNode( definitions=[ @@ -76,7 +74,7 @@ class DSLType: def __init__(self, type_: GraphQLTypeWithFields): self._type: GraphQLTypeWithFields = type_ - def __getattr__(self, name: str) -> DSLField: + def __getattr__(self, name: str) -> "DSLField": formatted_name, field_def = self._get_field(name) return DSLField(formatted_name, field_def) @@ -109,7 +107,7 @@ def __init__(self, name: str, field: GraphQLField): self.known_serializers = dict() @staticmethod - def get_ast_fields(fields: Iterable[DSLField]) -> List[FieldNode]: + def get_ast_fields(fields: Iterable) -> List[FieldNode]: """ Equivalent to: [field.ast_field for field in fields] But with a type check for each field in the list @@ -125,7 +123,7 @@ def get_ast_fields(fields: Iterable[DSLField]) -> List[FieldNode]: return ast_fields - def select(self, *fields: DSLField) -> DSLField: + def select(self, *fields: "DSLField") -> "DSLField": added_selections: List[FieldNode] = self.get_ast_fields(fields) @@ -142,14 +140,14 @@ def select(self, *fields: DSLField) -> DSLField: return self - def __call__(self, **kwargs) -> DSLField: + def __call__(self, **kwargs) -> "DSLField": return self.args(**kwargs) - def alias(self, alias: str) -> DSLField: + def alias(self, alias: str) -> "DSLField": self.ast_field.alias = NameNode(value=alias) return self - def args(self, **kwargs) -> DSLField: + def args(self, **kwargs) -> "DSLField": added_args = [] for name, value in kwargs.items(): arg = self.field.args.get(name) From 79b082b916096c65f851a79f345337bc772289cb Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Sat, 7 Nov 2020 10:38:56 +0100 Subject: [PATCH 06/21] DSL Refactor 5 Move DSLSchema.gql method out of the class to the dsl_gql funciton Add a test_multiple_queries test --- gql/dsl.py | 27 +++++++++++++------------ tests/nested_input/test_nested_input.py | 4 ++-- tests/starwars/test_dsl.py | 25 +++++++++++++++++------ 3 files changed, 35 insertions(+), 21 deletions(-) diff --git a/gql/dsl.py b/gql/dsl.py index 84261e42..e2ae1c5f 100644 --- a/gql/dsl.py +++ b/gql/dsl.py @@ -37,6 +37,20 @@ Serializer = Callable[[Any], Optional[ValueNode]] +def dsl_gql(*fields: "DSLField", operation: str = "query") -> DocumentNode: + + return DocumentNode( + definitions=[ + OperationDefinitionNode( + operation=OperationType(operation), + selection_set=SelectionSetNode( + selections=FrozenList(DSLField.get_ast_fields(fields)) + ), + ) + ] + ) + + class DSLSchema: def __init__(self, schema: GraphQLSchema): @@ -56,19 +70,6 @@ def __getattr__(self, name: str) -> "DSLType": return DSLType(type_def) - def gql(self, *fields: "DSLField", operation: str = "query") -> DocumentNode: - - return DocumentNode( - definitions=[ - OperationDefinitionNode( - operation=OperationType(operation), - selection_set=SelectionSetNode( - selections=FrozenList(DSLField.get_ast_fields(fields)) - ), - ) - ] - ) - class DSLType: def __init__(self, type_: GraphQLTypeWithFields): diff --git a/tests/nested_input/test_nested_input.py b/tests/nested_input/test_nested_input.py index 4c69c47c..2abe5252 100644 --- a/tests/nested_input/test_nested_input.py +++ b/tests/nested_input/test_nested_input.py @@ -1,7 +1,7 @@ import pytest from gql import Client -from gql.dsl import DSLSchema +from gql.dsl import DSLSchema, dsl_gql from tests.nested_input.schema import NestedInputSchema @@ -16,5 +16,5 @@ def client(): def test_nested_input_with_new_get_arg_serializer(ds, client): - query = ds.gql(ds.Query.foo.args(nested={"foo": 1})) + query = dsl_gql(ds.Query.foo.args(nested={"foo": 1})) assert client.execute(query) == {"foo": 1} diff --git a/tests/starwars/test_dsl.py b/tests/starwars/test_dsl.py index 42f08a69..a1861111 100644 --- a/tests/starwars/test_dsl.py +++ b/tests/starwars/test_dsl.py @@ -1,7 +1,7 @@ import pytest from gql import Client -from gql.dsl import DSLSchema +from gql.dsl import DSLSchema, dsl_gql from .schema import StarWarsSchema @@ -24,7 +24,7 @@ def test_invalid_field_on_type_query(ds): def test_incompatible_field(ds): with pytest.raises(Exception) as exc_info: - ds.gql("hero") + dsl_gql("hero") assert "Received incompatible field" in str(exc_info.value) @@ -114,14 +114,14 @@ def test_fetch_luke_aliased(ds): def test_hero_name_query_result(ds, client): - query = ds.gql(ds.Query.hero.select(ds.Character.name)) + query = dsl_gql(ds.Query.hero.select(ds.Character.name)) result = client.execute(query) expected = {"hero": {"name": "R2-D2"}} assert result == expected def test_arg_serializer_list(ds, client): - query = ds.gql( + query = dsl_gql( ds.Query.characters.args(ids=[1000, 1001, 1003]).select(ds.Character.name,) ) result = client.execute(query) @@ -136,7 +136,7 @@ def test_arg_serializer_list(ds, client): def test_arg_serializer_enum(ds, client): - query = ds.gql(ds.Query.hero.args(episode=5).select(ds.Character.name)) + query = dsl_gql(ds.Query.hero.args(episode=5).select(ds.Character.name)) result = client.execute(query) expected = {"hero": {"name": "Luke Skywalker"}} assert result == expected @@ -144,7 +144,7 @@ def test_arg_serializer_enum(ds, client): def test_create_review_mutation_result(ds, client): - query = ds.gql( + query = dsl_gql( ds.Mutation.createReview.args( episode=6, review={"stars": 5, "commentary": "This is a great movie!"} ).select(ds.Review.stars, ds.Review.commentary), @@ -160,3 +160,16 @@ def test_invalid_arg(ds): KeyError, match="Argument invalid_arg does not exist in Field: Character." ): ds.Query.hero.args(invalid_arg=5).select(ds.Character.name) + + +def test_multiple_queries(ds, client): + query = dsl_gql( + ds.Query.hero.select(ds.Character.name), + ds.Query.hero(episode=5).alias("hero_of_episode_5").select(ds.Character.name), + ) + result = client.execute(query) + expected = { + "hero": {"name": "R2-D2"}, + "hero_of_episode_5": {"name": "Luke Skywalker"}, + } + assert result == expected From c926a3dd0ad006484760f2e9f7347dc232cc9f59 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Sat, 7 Nov 2020 10:52:29 +0100 Subject: [PATCH 07/21] DSL Refactor 6 Put the _get_field method of DSLType inside __getattr__ --- gql/dsl.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/gql/dsl.py b/gql/dsl.py index e2ae1c5f..c0133ee3 100644 --- a/gql/dsl.py +++ b/gql/dsl.py @@ -1,5 +1,5 @@ from collections.abc import Iterable -from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast +from typing import Any, Callable, Dict, List, Optional, Union, cast from graphql import ( ArgumentNode, @@ -76,19 +76,18 @@ def __init__(self, type_: GraphQLTypeWithFields): self._type: GraphQLTypeWithFields = type_ def __getattr__(self, name: str) -> "DSLField": - formatted_name, field_def = self._get_field(name) - return DSLField(formatted_name, field_def) - - def _get_field(self, name: str) -> Tuple[str, GraphQLField]: camel_cased_name = to_camel_case(name) if name in self._type.fields: - return name, self._type.fields[name] - - if camel_cased_name in self._type.fields: - return camel_cased_name, self._type.fields[camel_cased_name] + formatted_name = name + field = self._type.fields[name] + elif camel_cased_name in self._type.fields: + formatted_name = camel_cased_name + field = self._type.fields[camel_cased_name] + else: + raise KeyError(f"Field {name} does not exist in type {self._type.name}.") - raise KeyError(f"Field {name} does not exist in type {self._type.name}.") + return DSLField(formatted_name, field) class DSLField: From e0c7e34941cc2510aa0c2d87841b167f829867b8 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Sat, 7 Nov 2020 18:39:47 +0100 Subject: [PATCH 08/21] DSL Refactor 7 The GraphQL type is saved in each DSLField This allows us to compute the operation directly from the fields (no need for the operation argument anymore) In dsl_gql, we check that all the fields have the correct type and that they are root fields (Query, Mutation and Subscription) Add new AttributeError for types which are not in the schema Replacing the KeyError by an AttributeError in the __getattr__ method of DSLType Adding debug log messages when creating fields and types --- gql/dsl.py | 44 ++++++++++++++++++++++++++++------ tests/conftest.py | 2 +- tests/starwars/test_dsl.py | 48 ++++++++++++++++++++++++++++++++++---- 3 files changed, 82 insertions(+), 12 deletions(-) diff --git a/gql/dsl.py b/gql/dsl.py index c0133ee3..49633e15 100644 --- a/gql/dsl.py +++ b/gql/dsl.py @@ -1,3 +1,4 @@ +import logging from collections.abc import Iterable from typing import Any, Callable, Dict, List, Optional, Union, cast @@ -33,11 +34,29 @@ from .utils import to_camel_case +log = logging.getLogger(__name__) + GraphQLTypeWithFields = Union[GraphQLObjectType, GraphQLInterfaceType] Serializer = Callable[[Any], Optional[ValueNode]] -def dsl_gql(*fields: "DSLField", operation: str = "query") -> DocumentNode: +def dsl_gql(*fields: "DSLField") -> DocumentNode: + + # Check that we receive only arguments of type DSLField + # And that they are a root type + for field in fields: + if not isinstance(field, DSLField): + raise TypeError( + f"fields must be instances of DSLField. Received type: {type(field)}" + ) + assert field.type_name in ["Query", "Mutation", "Subscription"], ( + "fields should be root types (Query, Mutation or Subscription)\n" + f"Received: {field.type_name}" + ) + + # Get the operation from the first field + # All the fields must have the same operation + operation = fields[0].type_name.lower() return DocumentNode( definitions=[ @@ -54,9 +73,8 @@ def dsl_gql(*fields: "DSLField", operation: str = "query") -> DocumentNode: class DSLSchema: def __init__(self, schema: GraphQLSchema): - assert isinstance( - schema, GraphQLSchema - ), "DSLSchema needs a schema as parameter" + if not isinstance(schema, GraphQLSchema): + raise TypeError("DSLSchema needs a schema as parameter") self._schema: GraphQLSchema = schema @@ -64,6 +82,9 @@ def __getattr__(self, name: str) -> "DSLType": type_def: Optional[GraphQLNamedType] = self._schema.get_type(name) + if type_def is None: + raise AttributeError(f"Type '{name}' not found in the schema!") + assert isinstance(type_def, GraphQLObjectType) or isinstance( type_def, GraphQLInterfaceType ) @@ -74,6 +95,7 @@ def __getattr__(self, name: str) -> "DSLType": class DSLType: def __init__(self, type_: GraphQLTypeWithFields): self._type: GraphQLTypeWithFields = type_ + log.debug(f"DSLType({type_!r})") def __getattr__(self, name: str) -> "DSLField": camel_cased_name = to_camel_case(name) @@ -85,9 +107,11 @@ def __getattr__(self, name: str) -> "DSLField": formatted_name = camel_cased_name field = self._type.fields[camel_cased_name] else: - raise KeyError(f"Field {name} does not exist in type {self._type.name}.") + raise AttributeError( + f"Field {name} does not exist in type {self._type.name}." + ) - return DSLField(formatted_name, field) + return DSLField(formatted_name, self._type, field) class DSLField: @@ -101,10 +125,12 @@ class DSLField: # Known serializers known_serializers: Dict[GraphQLInputType, Optional[Serializer]] - def __init__(self, name: str, field: GraphQLField): + def __init__(self, name: str, type_: GraphQLTypeWithFields, field: GraphQLField): + self._type: GraphQLTypeWithFields = type_ self.field = field self.ast_field = FieldNode(name=NameNode(value=name), arguments=FrozenList()) self.known_serializers = dict() + log.debug(f"DSLField('{name}',{field!r})") @staticmethod def get_ast_fields(fields: Iterable) -> List[FieldNode]: @@ -194,5 +220,9 @@ def _get_arg_serializer(self, arg_type: GraphQLInputType,) -> Serializer: cast(GraphQLScalarType, arg_type).serialize(value), arg_type ) + @property + def type_name(self): + return self._type.name + def __str__(self) -> str: return print_ast(self.ast_field) diff --git a/tests/conftest.py b/tests/conftest.py index 9484beb4..1865152e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -100,7 +100,7 @@ async def go(app, *, port=None, **kwargs): # type: ignore # Adding debug logs to websocket tests -for name in ["websockets.server", "gql.transport.websockets"]: +for name in ["websockets.server", "gql.transport.websockets", "gql.dsl"]: logger = logging.getLogger(name) logger.setLevel(logging.DEBUG) diff --git a/tests/starwars/test_dsl.py b/tests/starwars/test_dsl.py index a1861111..d83280d7 100644 --- a/tests/starwars/test_dsl.py +++ b/tests/starwars/test_dsl.py @@ -17,14 +17,14 @@ def client(): def test_invalid_field_on_type_query(ds): - with pytest.raises(KeyError) as exc_info: + with pytest.raises(AttributeError) as exc_info: ds.Query.extras.select(ds.Character.name) assert "Field extras does not exist in type Query." in str(exc_info.value) def test_incompatible_field(ds): with pytest.raises(Exception) as exc_info: - dsl_gql("hero") + ds.Query.hero.select("not_a_DSL_FIELD") assert "Received incompatible field" in str(exc_info.value) @@ -147,8 +147,7 @@ def test_create_review_mutation_result(ds, client): query = dsl_gql( ds.Mutation.createReview.args( episode=6, review={"stars": 5, "commentary": "This is a great movie!"} - ).select(ds.Review.stars, ds.Review.commentary), - operation="mutation", + ).select(ds.Review.stars, ds.Review.commentary) ) result = client.execute(query) expected = {"createReview": {"stars": 5, "commentary": "This is a great movie!"}} @@ -173,3 +172,44 @@ def test_multiple_queries(ds, client): "hero_of_episode_5": {"name": "Luke Skywalker"}, } assert result == expected + + +def test_dsl_gql_all_fields_should_be_instances_of_DSLField(ds, client): + with pytest.raises( + TypeError, match="fields must be instances of DSLField. Received type:" + ): + dsl_gql( + ds.Query.hero.select(ds.Character.name), + ds.Query.hero(episode=5) + .alias("hero_of_episode_5") + .select(ds.Character.name), + "I am a string", + ) + + +def test_dsl_gql_all_fields_should_be_a_root_type(ds, client): + with pytest.raises(AssertionError,) as excinfo: + dsl_gql( + ds.Query.hero.select(ds.Character.name), + ds.Query.hero(episode=5) + .alias("hero_of_episode_5") + .select(ds.Character.name), + ds.Character.name, + ) + + assert ( + "fields should be root types (Query, Mutation or Subscription)\n" + "Received: Character" + ) in str(excinfo.value) + + +def test_DSLSchema_requires_a_schema(client): + with pytest.raises(TypeError, match="DSLSchema needs a schema as parameter"): + DSLSchema(client) + + +def test_invalid_type(ds): + with pytest.raises( + AttributeError, match="Type 'invalid_type' not found in the schema!" + ): + ds.invalid_type From 5e130cf5011c3626a8db2ad5f7598798fbd15768 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Sun, 8 Nov 2020 12:58:04 +0100 Subject: [PATCH 09/21] DOC document the dsl code Rename known_serializers to known_arg_serializers --- docs/modules/client.rst | 4 +- docs/modules/dsl.rst | 7 ++ docs/modules/gql.rst | 1 + docs/modules/transport.rst | 4 +- gql/dsl.py | 176 +++++++++++++++++++++++++++++++------ 5 files changed, 161 insertions(+), 31 deletions(-) create mode 100644 docs/modules/dsl.rst diff --git a/docs/modules/client.rst b/docs/modules/client.rst index 954b4e61..69169425 100644 --- a/docs/modules/client.rst +++ b/docs/modules/client.rst @@ -1,5 +1,5 @@ -Client -====== +gql.client +========== .. currentmodule:: gql.client diff --git a/docs/modules/dsl.rst b/docs/modules/dsl.rst new file mode 100644 index 00000000..e8487745 --- /dev/null +++ b/docs/modules/dsl.rst @@ -0,0 +1,7 @@ +gql.dsl +======= + +.. currentmodule:: gql.dsl + +.. automodule:: gql.dsl + :member-order: bysource diff --git a/docs/modules/gql.rst b/docs/modules/gql.rst index 94121ea3..aac47c86 100644 --- a/docs/modules/gql.rst +++ b/docs/modules/gql.rst @@ -20,3 +20,4 @@ Sub-Packages client transport + dsl diff --git a/docs/modules/transport.rst b/docs/modules/transport.rst index dd4627e0..9a3caf6e 100644 --- a/docs/modules/transport.rst +++ b/docs/modules/transport.rst @@ -1,5 +1,5 @@ -Transport -========= +gql.transport +============= .. currentmodule:: gql.transport diff --git a/gql/dsl.py b/gql/dsl.py index 49633e15..46275f1f 100644 --- a/gql/dsl.py +++ b/gql/dsl.py @@ -41,6 +41,32 @@ def dsl_gql(*fields: "DSLField") -> DocumentNode: + """Given arguments of type :class:`DSLField` containing GraphQL requests, + generate a Document which can be executed later in a + gql client or a gql session. + + Similar to the :func:`gql.gql` function but instead of parsing a python + string to describe the request, we are using requests which have been generated + dynamically using instances of :class:`DSLField` which have been generated + by instances of :class:`DSLType` which themselves have been generated from + a :class:`DSLSchema` class. + + The fields arguments should be fields of root GraphQL types + (Query, Mutation or Subscription). + + They should all have the same root type + (you can't mix queries with mutations for example). + + :param fields: root instances of the dynamically generated requests + :type fields: DSLField + :return: a Document which can be later executed or subscribed by a + :class:`Client `, by an + :class:`async session ` or by a + :class:`sync session ` + + :raises TypeError: if an argument is not an instance of :class:`DSLField` + :raises AssertionError: if an argument is not a field of a root type + """ # Check that we receive only arguments of type DSLField # And that they are a root type @@ -71,7 +97,22 @@ def dsl_gql(*fields: "DSLField") -> DocumentNode: class DSLSchema: + """The DSLSchema is the root of the DSL code. + + Attributes of the DSLSchema class are generated automatically + with the `__getattr__` dunder method in order to generate + instances of :class:`DSLType` + """ + def __init__(self, schema: GraphQLSchema): + """Initialize the DSLSchema with the given schema. + + :param schema: a GraphQL Schema provided locally or fetched using + an introspection query. Usually `client.schema` + :type schema: GraphQLSchema + + :raises TypeError: if the argument is not an instance of :class:`GraphQLSchema` + """ if not isinstance(schema, GraphQLSchema): raise TypeError("DSLSchema needs a schema as parameter") @@ -93,9 +134,31 @@ def __getattr__(self, name: str) -> "DSLType": class DSLType: - def __init__(self, type_: GraphQLTypeWithFields): - self._type: GraphQLTypeWithFields = type_ - log.debug(f"DSLType({type_!r})") + """The DSLType represents a GraphQL type for the DSL code. + + It can be a root type (Query, Mutation or Subscription). + Or it can be an interface type (Character in the StarWars schema). + Or it can be an object type (Human in the StarWars schema). + + Instances of this class are generated for you automatically as attributes + of the :class:`DSLSchema` + + Attributes of the DSLType class are generated automatically + with the `__getattr__` dunder method in order to generate + instances of :class:`DSLField` + """ + + def __init__(self, graphql_type: GraphQLTypeWithFields): + """Initialize the DSLType with the GraphQL type. + + .. warning:: + Don't instanciate this class yourself. + Use attributes of the :class:`DSLSchema` instead. + + :param graphql_type: a GraphQL type + """ + self._type: GraphQLTypeWithFields = graphql_type + log.debug(f"DSLType({self._type!r})") def __getattr__(self, name: str) -> "DSLField": camel_cased_name = to_camel_case(name) @@ -115,30 +178,52 @@ def __getattr__(self, name: str) -> "DSLField": class DSLField: - - # Definition of field from the schema - field: GraphQLField - - # Current selection in the query - ast_field: FieldNode - - # Known serializers - known_serializers: Dict[GraphQLInputType, Optional[Serializer]] - - def __init__(self, name: str, type_: GraphQLTypeWithFields, field: GraphQLField): - self._type: GraphQLTypeWithFields = type_ - self.field = field - self.ast_field = FieldNode(name=NameNode(value=name), arguments=FrozenList()) - self.known_serializers = dict() - log.debug(f"DSLField('{name}',{field!r})") + """The DSLField represents a GraphQL field for the DSL code. + + Instances of this class are generated for you automatically as attributes + of the :class:`DSLType` + + If this field contains children fields, then you need to select which ones + you want in the request using the :meth:`select ` + method. + """ + + def __init__( + self, + name: str, + graphql_type: GraphQLTypeWithFields, + graphql_field: GraphQLField, + ): + """Initialize the DSLField. + + .. warning:: + Don't instanciate this class yourself. + Use attributes of the :class:`DSLType` instead. + + :param name: the name of the field + :param graphql_type: the GraphQL type + :param graphql_field: the GraphQL field + """ + self._type: GraphQLTypeWithFields = graphql_type + self.field: GraphQLField = graphql_field + self.ast_field: FieldNode = FieldNode( + name=NameNode(value=name), arguments=FrozenList() + ) + self.known_arg_serializers: Dict[ + GraphQLInputType, Optional[Serializer] + ] = dict() + log.debug(f"DSLField('{name}',{self.field!r})") @staticmethod def get_ast_fields(fields: Iterable) -> List[FieldNode]: """ - Equivalent to: [field.ast_field for field in fields] - But with a type check for each field in the list + :meta private: + + Equivalent to: :code:`[field.ast_field for field in fields]` + But with a type check for each field in the list. - Raises a TypeError if any of the provided fields are not of the DSLField type + :raises TypeError: if any of the provided fields are not instances + of the :class:`DSLField` class. """ ast_fields = [] for field in fields: @@ -150,6 +235,19 @@ def get_ast_fields(fields: Iterable) -> List[FieldNode]: return ast_fields def select(self, *fields: "DSLField") -> "DSLField": + """Select the new children fields + that we want to receive in the request. + + If used multiple times, we will add the new children fields + to the existing children fields + + :param fields: new children fields + :type fields: DSLField + :return: itself + + :raises TypeError: if any of the provided fields are not instances + of the :class:`DSLField` class. + """ added_selections: List[FieldNode] = self.get_ast_fields(fields) @@ -170,10 +268,33 @@ def __call__(self, **kwargs) -> "DSLField": return self.args(**kwargs) def alias(self, alias: str) -> "DSLField": + """Set an alias + + :param alias: the alias + :type alias: str + :return: itself + """ + self.ast_field.alias = NameNode(value=alias) return self def args(self, **kwargs) -> "DSLField": + r"""Set the arguments of a field + + The arguments are parsed to be stored in the AST of this field. + + .. note:: + you can also call the field directly with your arguments. + :code:`ds.Query.human(id=1000)` is equivalent to: + :code:`ds.Query.human.args(id=1000)` + + :param \**kwargs: the arguments (keyword=value) + :return: itself + + :raises KeyError: if any of the provided arguments does not exist + for this field. + """ + added_args = [] for name, value in kwargs.items(): arg = self.field.args.get(name) @@ -193,19 +314,19 @@ def _get_arg_serializer(self, arg_type: GraphQLInputType,) -> Serializer: elif isinstance(arg_type, GraphQLInputField): return self._get_arg_serializer(arg_type.type) elif isinstance(arg_type, GraphQLInputObjectType): - if arg_type in self.known_serializers: - return cast(Serializer, self.known_serializers[arg_type]) - self.known_serializers[arg_type] = None + if arg_type in self.known_arg_serializers: + return cast(Serializer, self.known_arg_serializers[arg_type]) + self.known_arg_serializers[arg_type] = None serializers = { k: self._get_arg_serializer(v) for k, v in arg_type.fields.items() } - self.known_serializers[arg_type] = lambda value: ObjectValueNode( + self.known_arg_serializers[arg_type] = lambda value: ObjectValueNode( fields=FrozenList( ObjectFieldNode(name=NameNode(value=k), value=serializers[k](v)) for k, v in value.items() ) ) - return cast(Serializer, self.known_serializers[arg_type]) + return cast(Serializer, self.known_arg_serializers[arg_type]) elif isinstance(arg_type, GraphQLList): inner_serializer = self._get_arg_serializer(arg_type.of_type) return lambda list_values: ListValueNode( @@ -222,6 +343,7 @@ def _get_arg_serializer(self, arg_type: GraphQLInputType,) -> Serializer: @property def type_name(self): + """:meta private:""" return self._type.name def __str__(self) -> str: From 0ffe5b983b5ce06bd1846cdf92cc346e9e144bdb Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Sun, 8 Nov 2020 15:22:44 +0100 Subject: [PATCH 10/21] DSL Refactor 8 Added __repr__ dunder methods to DSLType and DSLField Added debug logs --- gql/dsl.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/gql/dsl.py b/gql/dsl.py index 46275f1f..3c4bd2ea 100644 --- a/gql/dsl.py +++ b/gql/dsl.py @@ -155,10 +155,10 @@ def __init__(self, graphql_type: GraphQLTypeWithFields): Don't instanciate this class yourself. Use attributes of the :class:`DSLSchema` instead. - :param graphql_type: a GraphQL type + :param graphql_type: ther GraphQL type definition from the schema """ self._type: GraphQLTypeWithFields = graphql_type - log.debug(f"DSLType({self._type!r})") + log.debug(f"Creating {self!r})") def __getattr__(self, name: str) -> "DSLField": camel_cased_name = to_camel_case(name) @@ -176,6 +176,9 @@ def __getattr__(self, name: str) -> "DSLField": return DSLField(formatted_name, self._type, field) + def __repr__(self) -> str: + return f"<{self.__class__.__name__} {self._type!r}>" + class DSLField: """The DSLField represents a GraphQL field for the DSL code. @@ -201,8 +204,8 @@ def __init__( Use attributes of the :class:`DSLType` instead. :param name: the name of the field - :param graphql_type: the GraphQL type - :param graphql_field: the GraphQL field + :param graphql_type: the GraphQL type definition from the schema + :param graphql_field: the GraphQL field definition from the schema """ self._type: GraphQLTypeWithFields = graphql_type self.field: GraphQLField = graphql_field @@ -212,7 +215,7 @@ def __init__( self.known_arg_serializers: Dict[ GraphQLInputType, Optional[Serializer] ] = dict() - log.debug(f"DSLField('{name}',{self.field!r})") + log.debug(f"Creating {self!r}") @staticmethod def get_ast_fields(fields: Iterable) -> List[FieldNode]: @@ -262,6 +265,8 @@ def select(self, *fields: "DSLField") -> "DSLField": current_selection_set.selections + added_selections ) + log.debug(f"Added fields: {fields} in {self!r}") + return self def __call__(self, **kwargs) -> "DSLField": @@ -306,6 +311,7 @@ def args(self, **kwargs) -> "DSLField": ArgumentNode(name=NameNode(value=name), value=serialized_value) ) self.ast_field.arguments = FrozenList(self.ast_field.arguments + added_args) + log.debug(f"Added arguments {kwargs} in field {self!r})") return self def _get_arg_serializer(self, arg_type: GraphQLInputType,) -> Serializer: @@ -348,3 +354,9 @@ def type_name(self): def __str__(self) -> str: return print_ast(self.ast_field) + + def __repr__(self) -> str: + return ( + f"<{self.__class__.__name__} {self._type.name}" + f"::{self.ast_field.name.value}>" + ) From dca2832f5a2c75e3b5df46e3161c7f32b61bb9bb Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Sun, 8 Nov 2020 17:39:31 +0100 Subject: [PATCH 11/21] DOCS adding sphinx docs for the DSL module Adding executable examples for async and sync clients Added to README.md features --- README.md | 1 + docs/advanced/dsl_module.rst | 153 +++++++++++++++++++++--- docs/code_examples/aiohttp_async_dsl.py | 53 ++++++++ docs/code_examples/requests_sync.py | 4 +- docs/code_examples/requests_sync_dsl.py | 27 +++++ gql/dsl.py | 4 +- 6 files changed, 222 insertions(+), 20 deletions(-) create mode 100644 docs/code_examples/aiohttp_async_dsl.py create mode 100644 docs/code_examples/requests_sync_dsl.py diff --git a/README.md b/README.md index aa5aa0ac..8fefeb2f 100644 --- a/README.md +++ b/README.md @@ -40,6 +40,7 @@ The main features of GQL are: * Supports [sync or async usage](https://gql.readthedocs.io/en/latest/async/index.html), [allowing concurrent requests](https://gql.readthedocs.io/en/latest/advanced/async_advanced_usage.html#async-advanced-usage) * Supports [File uploads](https://gql.readthedocs.io/en/latest/usage/file_upload.html) * [gql-cli script](https://gql.readthedocs.io/en/latest/gql-cli/intro.html) to execute GraphQL queries from the command line +* [DSL module](https://gql.readthedocs.io/en/latest/advanced/dsl_module.html) to compose GraphQL queries dynamically ## Installation diff --git a/docs/advanced/dsl_module.rst b/docs/advanced/dsl_module.rst index aa6638df..1c461649 100644 --- a/docs/advanced/dsl_module.rst +++ b/docs/advanced/dsl_module.rst @@ -2,33 +2,152 @@ Compose queries dynamically =========================== Instead of providing the GraphQL queries as a Python String, it is also possible to create GraphQL queries dynamically. -Using the DSL module, we can create a query using a Domain Specific Language which is created from the schema. +Using the :mod:`DSL module `, we can create a query using a Domain Specific Language which is created from the schema. + +The following code: + +.. code-block:: python + + ds = DSLSchema(StarWarsSchema) + + query = dsl_gql( + ds.Query.hero.select( + ds.Character.id, + ds.Character.name, + ds.Character.friends.select(ds.Character.name), + ) + ) + +will generate a query equivalent to: .. code-block:: python - from gql.dsl import DSLSchema + query = gql(""" + query { + hero { + id + name + friends { + name + } + } + } + """) + +How to use +---------- + +First generate the root using the :class:`DSLSchema `:: + + ds = DSLSchema(client.schema) + +Then use auto-generated attributes of the :code:`ds` instance +to get a root type (Query, Mutation or Subscription). +This will generate a :class:`DSLType ` instance:: + + ds.Query - client = Client(schema=StarWarsSchema) - ds = DSLSchema(client) +From this root type, you use auto-generated attributes to get a field. +This will generate a :class:`DSLField ` instance:: - query_dsl = ds.Query.hero.select( + ds.Query.hero + +hero is a GraphQL object type and needs children fields. By default, +there is no children fields selected. To select the fields that you want +in your query, you use the :meth:`select ` method. + +To generate the children fields, we use the same method as above to auto-generate the fields +from the :code:`ds` instance +(ie :code:`ds.Character.name` is the field `name` of the type `Character`):: + + ds.Query.hero.select(ds.Character.name) + +The select method return the same instance, so it is possible to chain the calls:: + + ds.Query.hero.select(ds.Character.name).select(ds.Character.id) + +Or do it sequencially:: + + hero_query = ds.Query.hero + + hero_query.select(ds.Character.name) + hero_query.select(ds.Character.id) + +As you can select children fields of any object type, you can construct your complete query tree:: + + ds.Query.hero.select( ds.Character.id, ds.Character.name, - ds.Character.friends.select(ds.Character.name,), + ds.Character.friends.select(ds.Character.name), ) -will create a query equivalent to: +Once your query is completed and you have selected all the fields you want, +use the :func:`dsl_gql ` function to convert your query into +a document which will be able to get executed in the client or a session:: -.. code-block:: python + query = dsl_gql( + ds.Query.hero.select( + ds.Character.id, + ds.Character.name, + ds.Character.friends.select(ds.Character.name), + ) + ) + + result = client.execute(query) + +Arguments +^^^^^^^^^ + +It is possible to add arguments to any field simply by calling it +with the required arguments:: + + ds.Query.human(id="1000").select(ds.Human.name) + +It can also be done using the :meth:`args ` method:: + + ds.Query.human.args(id="1000").select(ds.Human.name) + +Alias +^^^^^ + +You can set an alias of a field using the :meth:`alias ` method:: + + ds.Query.human.args(id=1000).alias("luke").select(ds.Character.name) + +Mutations +^^^^^^^^^ + +It works the same way for mutations. Example:: + + query = dsl_gql( + ds.Mutation.createReview.args( + episode=6, review={"stars": 5, "commentary": "This is a great movie!"} + ).select(ds.Review.stars, ds.Review.commentary) + ) + +Multiple requests +^^^^^^^^^^^^^^^^^ + +It is possible to create a document with multiple requests:: + + query = dsl_gql( + ds.Query.hero.select(ds.Character.name), + ds.Query.hero(episode=5).alias("hero_of_episode_5").select(ds.Character.name), + ) + +But you have to take care that the root type is always the same. It is not possible +to mix queries and mutations for example. + +Executable examples +------------------- + +Async example +^^^^^^^^^^^^^ + +.. literalinclude:: ../code_examples/aiohttp_async_dsl.py - hero { - id - name - friends { - name - } - } +Sync example +^^^^^^^^^^^^^ -.. warning:: +.. literalinclude:: ../code_examples/requests_sync_dsl.py - Please note that the DSL module is still considered experimental in GQL 3 and is subject to changes diff --git a/docs/code_examples/aiohttp_async_dsl.py b/docs/code_examples/aiohttp_async_dsl.py new file mode 100644 index 00000000..48352d03 --- /dev/null +++ b/docs/code_examples/aiohttp_async_dsl.py @@ -0,0 +1,53 @@ +import asyncio + +from gql import Client +from gql.dsl import DSLSchema, dsl_gql +from gql.transport.aiohttp import AIOHTTPTransport + + +async def main(): + + transport = AIOHTTPTransport(url="https://countries.trevorblades.com/graphql") + + client = Client(transport=transport, fetch_schema_from_transport=True) + + # Using `async with` on the client will start a connection on the transport + # and provide a `session` variable to execute queries on this connection. + # Because we requested to fetch the schema from the transport, + # GQL will fetch the schema just after the establishment of the first session + async with client as session: + + # Instanciate the root of the DSL Schema as ds + ds = DSLSchema(client.schema) + + # Create the query using dynamically generated attributes from ds + query = dsl_gql( + ds.Query.continents(filter={"code": {"eq": "EU"}}).select( + ds.Continent.code, ds.Continent.name + ) + ) + + result = await session.execute(query) + print(result) + + # This can also be written as: + + # I want to query the continents + query_continents = ds.Query.continents + + # I want to get only the continents with code equal to "EU" + query_continents(filter={"code": {"eq": "EU"}}) + + # I want this query to return the code and name fields + query_continents.select(ds.Continent.code) + query_continents.select(ds.Continent.name) + + # I generate a document from my query to be able to execute it + query = dsl_gql(query_continents) + + # Execute the query + result = await session.execute(query) + print(result) + + +asyncio.run(main()) diff --git a/docs/code_examples/requests_sync.py b/docs/code_examples/requests_sync.py index d1054844..53b1e2a7 100644 --- a/docs/code_examples/requests_sync.py +++ b/docs/code_examples/requests_sync.py @@ -1,11 +1,11 @@ from gql import Client, gql from gql.transport.requests import RequestsHTTPTransport -sample_transport = RequestsHTTPTransport( +transport = RequestsHTTPTransport( url="https://countries.trevorblades.com/", verify=True, retries=3, ) -client = Client(transport=sample_transport, fetch_schema_from_transport=True,) +client = Client(transport=transport, fetch_schema_from_transport=True) query = gql( """ diff --git a/docs/code_examples/requests_sync_dsl.py b/docs/code_examples/requests_sync_dsl.py new file mode 100644 index 00000000..7cec520a --- /dev/null +++ b/docs/code_examples/requests_sync_dsl.py @@ -0,0 +1,27 @@ +from gql import Client +from gql.dsl import DSLSchema, dsl_gql +from gql.transport.requests import RequestsHTTPTransport + +transport = RequestsHTTPTransport( + url="https://countries.trevorblades.com/", verify=True, retries=3, +) + +client = Client(transport=transport, fetch_schema_from_transport=True) + +# Using `with` on the sync client will start a connection on the transport +# and provide a `session` variable to execute queries on this connection. +# Because we requested to fetch the schema from the transport, +# GQL will fetch the schema just after the establishment of the first session +with client as session: + + # We should have received the schema now that the session is established + assert client.schema is not None + + # Instanciate the root of the DSL Schema as ds + ds = DSLSchema(client.schema) + + # Create the query using dynamically generated attributes from ds + query = dsl_gql(ds.Query.continents.select(ds.Continent.code, ds.Continent.name)) + + result = session.execute(query) + print(result) diff --git a/gql/dsl.py b/gql/dsl.py index 3c4bd2ea..83da37c7 100644 --- a/gql/dsl.py +++ b/gql/dsl.py @@ -115,7 +115,9 @@ def __init__(self, schema: GraphQLSchema): """ if not isinstance(schema, GraphQLSchema): - raise TypeError("DSLSchema needs a schema as parameter") + raise TypeError( + f"DSLSchema needs a schema as parameter. Received: {type(schema)}" + ) self._schema: GraphQLSchema = schema From 7501c172cb79bfdf7a7dc07704a21af079acc63d Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Sun, 8 Nov 2020 17:51:02 +0100 Subject: [PATCH 12/21] Fix typo and modify doc following @KingDarBoja advice --- gql/dsl.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/gql/dsl.py b/gql/dsl.py index 83da37c7..44193024 100644 --- a/gql/dsl.py +++ b/gql/dsl.py @@ -47,8 +47,8 @@ def dsl_gql(*fields: "DSLField") -> DocumentNode: Similar to the :func:`gql.gql` function but instead of parsing a python string to describe the request, we are using requests which have been generated - dynamically using instances of :class:`DSLField` which have been generated - by instances of :class:`DSLType` which themselves have been generated from + dynamically using instances of :class:`DSLField`, generated + by instances of :class:`DSLType` which themselves originated from a :class:`DSLSchema` class. The fields arguments should be fields of root GraphQL types @@ -157,7 +157,7 @@ def __init__(self, graphql_type: GraphQLTypeWithFields): Don't instanciate this class yourself. Use attributes of the :class:`DSLSchema` instead. - :param graphql_type: ther GraphQL type definition from the schema + :param graphql_type: the GraphQL type definition from the schema """ self._type: GraphQLTypeWithFields = graphql_type log.debug(f"Creating {self!r})") From 9974899bd7247684a4420d08be46278b6fc1c455 Mon Sep 17 00:00:00 2001 From: KingDarBoja Date: Sun, 8 Nov 2020 15:00:26 -0500 Subject: [PATCH 13/21] Add alias to selection as keyword argument at DSL --- gql/dsl.py | 29 +++++++++++++++++++++-------- tests/starwars/test_dsl.py | 21 +++++++++++++++++++++ 2 files changed, 42 insertions(+), 8 deletions(-) diff --git a/gql/dsl.py b/gql/dsl.py index 44193024..45e40210 100644 --- a/gql/dsl.py +++ b/gql/dsl.py @@ -154,7 +154,7 @@ def __init__(self, graphql_type: GraphQLTypeWithFields): """Initialize the DSLType with the GraphQL type. .. warning:: - Don't instanciate this class yourself. + Don't instantiate this class yourself. Use attributes of the :class:`DSLSchema` instead. :param graphql_type: the GraphQL type definition from the schema @@ -198,11 +198,12 @@ def __init__( name: str, graphql_type: GraphQLTypeWithFields, graphql_field: GraphQLField, + alias: str = None, ): """Initialize the DSLField. .. warning:: - Don't instanciate this class yourself. + Don't instantiate this class yourself. Use attributes of the :class:`DSLType` instead. :param name: the name of the field @@ -212,7 +213,7 @@ def __init__( self._type: GraphQLTypeWithFields = graphql_type self.field: GraphQLField = graphql_field self.ast_field: FieldNode = FieldNode( - name=NameNode(value=name), arguments=FrozenList() + name=NameNode(value=name), arguments=FrozenList(), alias=alias ) self.known_arg_serializers: Dict[ GraphQLInputType, Optional[Serializer] @@ -239,7 +240,9 @@ def get_ast_fields(fields: Iterable) -> List[FieldNode]: return ast_fields - def select(self, *fields: "DSLField") -> "DSLField": + def select( + self, *fields: "DSLField", **fields_with_alias: "DSLField" + ) -> "DSLField": """Select the new children fields that we want to receive in the request. @@ -255,16 +258,20 @@ def select(self, *fields: "DSLField") -> "DSLField": """ added_selections: List[FieldNode] = self.get_ast_fields(fields) - + added_selections_with_alias: List[FieldNode] = self.get_ast_fields( + [field.alias(alias) for alias, field in fields_with_alias.items()] + ) current_selection_set: Optional[SelectionSetNode] = self.ast_field.selection_set if current_selection_set is None: self.ast_field.selection_set = SelectionSetNode( - selections=FrozenList(added_selections) + selections=FrozenList(added_selections + added_selections_with_alias) ) else: current_selection_set.selections = FrozenList( - current_selection_set.selections + added_selections + current_selection_set.selections + + added_selections + + added_selections_with_alias ) log.debug(f"Added fields: {fields} in {self!r}") @@ -277,6 +284,12 @@ def __call__(self, **kwargs) -> "DSLField": def alias(self, alias: str) -> "DSLField": """Set an alias + .. note:: + You can also pass the alias directly at the + :meth:`select ` method. + :code:`ds.Query.human.select(my_name=ds.Character.name)` is equivalent to: + :code:`ds.Query.human.select(ds.Character.name.alias("my_name"))` + :param alias: the alias :type alias: str :return: itself @@ -291,7 +304,7 @@ def args(self, **kwargs) -> "DSLField": The arguments are parsed to be stored in the AST of this field. .. note:: - you can also call the field directly with your arguments. + You can also call the field directly with your arguments. :code:`ds.Query.human(id=1000)` is equivalent to: :code:`ds.Query.human.args(id=1000)` diff --git a/tests/starwars/test_dsl.py b/tests/starwars/test_dsl.py index d83280d7..5fa85398 100644 --- a/tests/starwars/test_dsl.py +++ b/tests/starwars/test_dsl.py @@ -113,6 +113,27 @@ def test_fetch_luke_aliased(ds): assert query == str(query_dsl) +def test_fetch_name_aliased(ds: DSLSchema): + query = """ +human(id: "1000") { + my_name: name +} + """.strip() + query_dsl = ds.Query.human.args(id=1000).select(ds.Character.name.alias("my_name")) + print(str(query_dsl)) + assert query == str(query_dsl) + + +def test_fetch_name_aliased_as_kwargs(ds: DSLSchema): + query = """ +human(id: "1000") { + my_name: name +} + """.strip() + query_dsl = ds.Query.human.args(id=1000).select(my_name=ds.Character.name,) + assert query == str(query_dsl) + + def test_hero_name_query_result(ds, client): query = dsl_gql(ds.Query.hero.select(ds.Character.name)) result = client.execute(query) From db654f94b4f522bc67c0b19809c25b73f1d2ac53 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Wed, 11 Nov 2020 12:32:22 +0100 Subject: [PATCH 14/21] small cleaning --- gql/dsl.py | 29 +++++++++++++++-------------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/gql/dsl.py b/gql/dsl.py index 45e40210..7a3834dc 100644 --- a/gql/dsl.py +++ b/gql/dsl.py @@ -198,7 +198,6 @@ def __init__( name: str, graphql_type: GraphQLTypeWithFields, graphql_field: GraphQLField, - alias: str = None, ): """Initialize the DSLField. @@ -213,7 +212,7 @@ def __init__( self._type: GraphQLTypeWithFields = graphql_type self.field: GraphQLField = graphql_field self.ast_field: FieldNode = FieldNode( - name=NameNode(value=name), arguments=FrozenList(), alias=alias + name=NameNode(value=name), arguments=FrozenList() ) self.known_arg_serializers: Dict[ GraphQLInputType, Optional[Serializer] @@ -243,35 +242,37 @@ def get_ast_fields(fields: Iterable) -> List[FieldNode]: def select( self, *fields: "DSLField", **fields_with_alias: "DSLField" ) -> "DSLField": - """Select the new children fields + r"""Select the new children fields that we want to receive in the request. If used multiple times, we will add the new children fields - to the existing children fields + to the existing children fields. - :param fields: new children fields - :type fields: DSLField + :param \*fields: new children fields + :type \*fields: DSLField + :param \**fields_with_alias: new children fields with alias as key + :type \**fields_with_alias: DSLField :return: itself :raises TypeError: if any of the provided fields are not instances of the :class:`DSLField` class. """ - added_selections: List[FieldNode] = self.get_ast_fields(fields) - added_selections_with_alias: List[FieldNode] = self.get_ast_fields( - [field.alias(alias) for alias, field in fields_with_alias.items()] - ) + added_fields: List["DSLField"] = list(fields) + [ + field.alias(alias) for alias, field in fields_with_alias.items() + ] + + added_selections: List[FieldNode] = self.get_ast_fields(added_fields) + current_selection_set: Optional[SelectionSetNode] = self.ast_field.selection_set if current_selection_set is None: self.ast_field.selection_set = SelectionSetNode( - selections=FrozenList(added_selections + added_selections_with_alias) + selections=FrozenList(added_selections) ) else: current_selection_set.selections = FrozenList( - current_selection_set.selections - + added_selections - + added_selections_with_alias + current_selection_set.selections + added_selections ) log.debug(f"Added fields: {fields} in {self!r}") From fa8e0feaeeee05be1392eb2b9bc91b8f9b96ee9d Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Thu, 12 Nov 2020 21:23:30 +0100 Subject: [PATCH 15/21] Remove GraphQLTypeWithFields alias --- gql/dsl.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/gql/dsl.py b/gql/dsl.py index 7a3834dc..6e786a57 100644 --- a/gql/dsl.py +++ b/gql/dsl.py @@ -36,7 +36,6 @@ log = logging.getLogger(__name__) -GraphQLTypeWithFields = Union[GraphQLObjectType, GraphQLInterfaceType] Serializer = Callable[[Any], Optional[ValueNode]] @@ -150,7 +149,7 @@ class DSLType: instances of :class:`DSLField` """ - def __init__(self, graphql_type: GraphQLTypeWithFields): + def __init__(self, graphql_type: Union[GraphQLObjectType, GraphQLInterfaceType]): """Initialize the DSLType with the GraphQL type. .. warning:: @@ -159,7 +158,7 @@ def __init__(self, graphql_type: GraphQLTypeWithFields): :param graphql_type: the GraphQL type definition from the schema """ - self._type: GraphQLTypeWithFields = graphql_type + self._type: Union[GraphQLObjectType, GraphQLInterfaceType] = graphql_type log.debug(f"Creating {self!r})") def __getattr__(self, name: str) -> "DSLField": @@ -196,7 +195,7 @@ class DSLField: def __init__( self, name: str, - graphql_type: GraphQLTypeWithFields, + graphql_type: Union[GraphQLObjectType, GraphQLInterfaceType], graphql_field: GraphQLField, ): """Initialize the DSLField. @@ -209,7 +208,7 @@ def __init__( :param graphql_type: the GraphQL type definition from the schema :param graphql_field: the GraphQL field definition from the schema """ - self._type: GraphQLTypeWithFields = graphql_type + self._type: Union[GraphQLObjectType, GraphQLInterfaceType] = graphql_type self.field: GraphQLField = graphql_field self.ast_field: FieldNode = FieldNode( name=NameNode(value=name), arguments=FrozenList() @@ -330,7 +329,7 @@ def args(self, **kwargs) -> "DSLField": log.debug(f"Added arguments {kwargs} in field {self!r})") return self - def _get_arg_serializer(self, arg_type: GraphQLInputType,) -> Serializer: + def _get_arg_serializer(self, arg_type: GraphQLInputType) -> Serializer: if isinstance(arg_type, GraphQLNonNull): return self._get_arg_serializer(arg_type.of_type) elif isinstance(arg_type, GraphQLInputField): From ec57e4615931c17395ac11cd928be76d6057078d Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Thu, 12 Nov 2020 21:25:40 +0100 Subject: [PATCH 16/21] Rephrase DSLType comment --- gql/dsl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gql/dsl.py b/gql/dsl.py index 6e786a57..3cb6839b 100644 --- a/gql/dsl.py +++ b/gql/dsl.py @@ -138,8 +138,8 @@ class DSLType: """The DSLType represents a GraphQL type for the DSL code. It can be a root type (Query, Mutation or Subscription). + Or it can be any other object type (Human in the StarWars schema). Or it can be an interface type (Character in the StarWars schema). - Or it can be an object type (Human in the StarWars schema). Instances of this class are generated for you automatically as attributes of the :class:`DSLSchema` From 2b8f985a9e9054e95157a3acf88aeb18208e3a5c Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Thu, 12 Nov 2020 22:42:55 +0100 Subject: [PATCH 17/21] fix _get_arg_serializer typing --- gql/dsl.py | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/gql/dsl.py b/gql/dsl.py index 3cb6839b..4e082b55 100644 --- a/gql/dsl.py +++ b/gql/dsl.py @@ -9,7 +9,6 @@ FieldNode, GraphQLEnumType, GraphQLField, - GraphQLInputField, GraphQLInputObjectType, GraphQLInputType, GraphQLInterfaceType, @@ -330,16 +329,25 @@ def args(self, **kwargs) -> "DSLField": return self def _get_arg_serializer(self, arg_type: GraphQLInputType) -> Serializer: + """Recursive function used to get a argument serializer function + for a specific GraphQL input type. + + The only possible sort of types are: + GraphQLScalarType, GraphQLEnumType, GraphQLInputObjectType, GraphQLWrappingType + GraphQLWrappingType can be GraphQLList or GraphQLNonNull + """ + + log.debug(f"_get_arg_serializer({arg_type!r})") + if isinstance(arg_type, GraphQLNonNull): return self._get_arg_serializer(arg_type.of_type) - elif isinstance(arg_type, GraphQLInputField): - return self._get_arg_serializer(arg_type.type) + elif isinstance(arg_type, GraphQLInputObjectType): if arg_type in self.known_arg_serializers: return cast(Serializer, self.known_arg_serializers[arg_type]) self.known_arg_serializers[arg_type] = None serializers = { - k: self._get_arg_serializer(v) for k, v in arg_type.fields.items() + k: self._get_arg_serializer(v.type) for k, v in arg_type.fields.items() } self.known_arg_serializers[arg_type] = lambda value: ObjectValueNode( fields=FrozenList( @@ -348,18 +356,24 @@ def _get_arg_serializer(self, arg_type: GraphQLInputType) -> Serializer: ) ) return cast(Serializer, self.known_arg_serializers[arg_type]) + elif isinstance(arg_type, GraphQLList): inner_serializer = self._get_arg_serializer(arg_type.of_type) return lambda list_values: ListValueNode( values=FrozenList(inner_serializer(v) for v in list_values) ) + elif isinstance(arg_type, GraphQLEnumType): return lambda value: EnumValueNode( value=cast(GraphQLEnumType, arg_type).serialize(value) ) + # Impossible to be another type here + assert isinstance(arg_type, GraphQLScalarType) + return lambda value: ast_from_value( - cast(GraphQLScalarType, arg_type).serialize(value), arg_type + cast(GraphQLScalarType, arg_type).serialize(value), + cast(GraphQLScalarType, arg_type), ) @property From 49eecd3eba0e86477093c3b77a08143c2f822903 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Fri, 13 Nov 2020 00:14:01 +0100 Subject: [PATCH 18/21] Really Fix nested input arguments this time --- gql/dsl.py | 23 ++++++++++------------- tests/nested_input/schema.py | 9 ++++++--- tests/nested_input/test_nested_input.py | 20 +++++++++++++++++--- 3 files changed, 33 insertions(+), 19 deletions(-) diff --git a/gql/dsl.py b/gql/dsl.py index 4e082b55..f0c95956 100644 --- a/gql/dsl.py +++ b/gql/dsl.py @@ -1,6 +1,6 @@ import logging from collections.abc import Iterable -from typing import Any, Callable, Dict, List, Optional, Union, cast +from typing import Any, Callable, List, Optional, Union, cast from graphql import ( ArgumentNode, @@ -212,9 +212,6 @@ def __init__( self.ast_field: FieldNode = FieldNode( name=NameNode(value=name), arguments=FrozenList() ) - self.known_arg_serializers: Dict[ - GraphQLInputType, Optional[Serializer] - ] = dict() log.debug(f"Creating {self!r}") @staticmethod @@ -343,19 +340,19 @@ def _get_arg_serializer(self, arg_type: GraphQLInputType) -> Serializer: return self._get_arg_serializer(arg_type.of_type) elif isinstance(arg_type, GraphQLInputObjectType): - if arg_type in self.known_arg_serializers: - return cast(Serializer, self.known_arg_serializers[arg_type]) - self.known_arg_serializers[arg_type] = None - serializers = { - k: self._get_arg_serializer(v.type) for k, v in arg_type.fields.items() - } - self.known_arg_serializers[arg_type] = lambda value: ObjectValueNode( + return lambda value: ObjectValueNode( fields=FrozenList( - ObjectFieldNode(name=NameNode(value=k), value=serializers[k](v)) + ObjectFieldNode( + name=NameNode(value=k), + value=( + self._get_arg_serializer( + cast(GraphQLInputObjectType, arg_type).fields[k].type + ) + )(v), + ) for k, v in value.items() ) ) - return cast(Serializer, self.known_arg_serializers[arg_type]) elif isinstance(arg_type, GraphQLList): inner_serializer = self._get_arg_serializer(arg_type.of_type) diff --git a/tests/nested_input/schema.py b/tests/nested_input/schema.py index f27a94e8..bd5a0507 100644 --- a/tests/nested_input/schema.py +++ b/tests/nested_input/schema.py @@ -1,3 +1,5 @@ +import json + from graphql import ( GraphQLArgument, GraphQLField, @@ -6,6 +8,7 @@ GraphQLInt, GraphQLObjectType, GraphQLSchema, + GraphQLString, ) nestedInput = GraphQLInputObjectType( @@ -19,10 +22,10 @@ queryType = GraphQLObjectType( "Query", fields=lambda: { - "foo": GraphQLField( + "echo": GraphQLField( args={"nested": GraphQLArgument(type_=nestedInput)}, - resolve=lambda *args, **kwargs: 1, - type_=GraphQLInt, + resolve=lambda *args, **kwargs: json.dumps(kwargs["nested"]), + type_=GraphQLString, ), }, ) diff --git a/tests/nested_input/test_nested_input.py b/tests/nested_input/test_nested_input.py index 2abe5252..b8997e95 100644 --- a/tests/nested_input/test_nested_input.py +++ b/tests/nested_input/test_nested_input.py @@ -15,6 +15,20 @@ def client(): return Client(schema=NestedInputSchema) -def test_nested_input_with_new_get_arg_serializer(ds, client): - query = dsl_gql(ds.Query.foo.args(nested={"foo": 1})) - assert client.execute(query) == {"foo": 1} +def test_nested_input(ds, client): + query = dsl_gql(ds.Query.echo.args(nested={"foo": 1})) + assert client.execute(query) == {"echo": '{"foo": 1}'} + + +def test_nested_input_2(ds, client): + query = dsl_gql(ds.Query.echo.args(nested={"foo": 1, "child": {"foo": 2}})) + assert client.execute(query) == {"echo": '{"foo": 1, "child": {"foo": 2}}'} + + +def test_nested_input_3(ds, client): + query = dsl_gql( + ds.Query.echo.args(nested={"foo": 1, "child": {"foo": 2, "child": {"foo": 3}}}) + ) + assert client.execute(query) == { + "echo": '{"foo": 1, "child": {"foo": 2, "child": {"foo": 3}}}' + } From dbfc221df591932182eadc401375b10f4953ed70 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Fri, 13 Nov 2020 15:14:16 +0100 Subject: [PATCH 19/21] The _get_arg_serializer method is not needed... --- gql/dsl.py | 93 +++++++++++++++--------------------------------------- 1 file changed, 25 insertions(+), 68 deletions(-) diff --git a/gql/dsl.py b/gql/dsl.py index f0c95956..e60f4752 100644 --- a/gql/dsl.py +++ b/gql/dsl.py @@ -1,31 +1,21 @@ import logging from collections.abc import Iterable -from typing import Any, Callable, List, Optional, Union, cast +from typing import List, Optional, Union from graphql import ( ArgumentNode, DocumentNode, - EnumValueNode, FieldNode, - GraphQLEnumType, + GraphQLArgument, GraphQLField, - GraphQLInputObjectType, - GraphQLInputType, GraphQLInterfaceType, - GraphQLList, GraphQLNamedType, - GraphQLNonNull, GraphQLObjectType, - GraphQLScalarType, GraphQLSchema, - ListValueNode, NameNode, - ObjectFieldNode, - ObjectValueNode, OperationDefinitionNode, OperationType, SelectionSetNode, - ValueNode, ast_from_value, print_ast, ) @@ -35,8 +25,6 @@ log = logging.getLogger(__name__) -Serializer = Callable[[Any], Optional[ValueNode]] - def dsl_gql(*fields: "DSLField") -> DocumentNode: """Given arguments of type :class:`DSLField` containing GraphQL requests, @@ -311,67 +299,36 @@ def args(self, **kwargs) -> "DSLField": for this field. """ - added_args = [] - for name, value in kwargs.items(): - arg = self.field.args.get(name) - if not arg: - raise KeyError(f"Argument {name} does not exist in {self.field}.") - arg_type_serializer = self._get_arg_serializer(arg.type) - serialized_value = arg_type_serializer(value) - added_args.append( - ArgumentNode(name=NameNode(value=name), value=serialized_value) - ) - self.ast_field.arguments = FrozenList(self.ast_field.arguments + added_args) - log.debug(f"Added arguments {kwargs} in field {self!r})") - return self + assert self.ast_field.arguments is not None - def _get_arg_serializer(self, arg_type: GraphQLInputType) -> Serializer: - """Recursive function used to get a argument serializer function - for a specific GraphQL input type. + self.ast_field.arguments = FrozenList( + self.ast_field.arguments + + [ + ArgumentNode( + name=NameNode(value=name), + value=ast_from_value(value, self._get_argument(name).type), + ) + for name, value in kwargs.items() + ] + ) - The only possible sort of types are: - GraphQLScalarType, GraphQLEnumType, GraphQLInputObjectType, GraphQLWrappingType - GraphQLWrappingType can be GraphQLList or GraphQLNonNull - """ + log.debug(f"Added arguments {kwargs} in field {self!r})") - log.debug(f"_get_arg_serializer({arg_type!r})") - - if isinstance(arg_type, GraphQLNonNull): - return self._get_arg_serializer(arg_type.of_type) - - elif isinstance(arg_type, GraphQLInputObjectType): - return lambda value: ObjectValueNode( - fields=FrozenList( - ObjectFieldNode( - name=NameNode(value=k), - value=( - self._get_arg_serializer( - cast(GraphQLInputObjectType, arg_type).fields[k].type - ) - )(v), - ) - for k, v in value.items() - ) - ) + return self - elif isinstance(arg_type, GraphQLList): - inner_serializer = self._get_arg_serializer(arg_type.of_type) - return lambda list_values: ListValueNode( - values=FrozenList(inner_serializer(v) for v in list_values) - ) + def _get_argument(self, name: str) -> GraphQLArgument: + """Method used to return the GraphQLArgument definition + of an argument from its name. - elif isinstance(arg_type, GraphQLEnumType): - return lambda value: EnumValueNode( - value=cast(GraphQLEnumType, arg_type).serialize(value) - ) + :raises KeyError: if the provided argument does not exist + for this field. + """ + arg = self.field.args.get(name) - # Impossible to be another type here - assert isinstance(arg_type, GraphQLScalarType) + if arg is None: + raise KeyError(f"Argument {name} does not exist in {self.field}.") - return lambda value: ast_from_value( - cast(GraphQLScalarType, arg_type).serialize(value), - cast(GraphQLScalarType, arg_type), - ) + return arg @property def type_name(self): From 3201c057ee82196797404f2e63ea405a67f79d4d Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Sat, 14 Nov 2020 16:53:23 +0100 Subject: [PATCH 20/21] Allow aliases as keyword arguments for dsl_gql + new operation_name argument --- gql/dsl.py | 55 +++++++++++++++++++++++++++++--------- tests/starwars/test_dsl.py | 33 ++++++++++++++++++++++- 2 files changed, 74 insertions(+), 14 deletions(-) diff --git a/gql/dsl.py b/gql/dsl.py index e60f4752..a197e0bf 100644 --- a/gql/dsl.py +++ b/gql/dsl.py @@ -1,6 +1,5 @@ import logging -from collections.abc import Iterable -from typing import List, Optional, Union +from typing import Dict, Iterable, List, Optional, Tuple, Union from graphql import ( ArgumentNode, @@ -26,8 +25,12 @@ log = logging.getLogger(__name__) -def dsl_gql(*fields: "DSLField") -> DocumentNode: - """Given arguments of type :class:`DSLField` containing GraphQL requests, +def dsl_gql( + *fields: "DSLField", + operation_name: Optional[str] = None, + **fields_with_alias: "DSLField", +) -> DocumentNode: + r"""Given arguments of type :class:`DSLField` containing GraphQL requests, generate a Document which can be executed later in a gql client or a gql session. @@ -43,8 +46,12 @@ def dsl_gql(*fields: "DSLField") -> DocumentNode: They should all have the same root type (you can't mix queries with mutations for example). - :param fields: root instances of the dynamically generated requests - :type fields: DSLField + :param \*fields: root instances of the dynamically generated requests + :type \*fields: DSLField + :param \**fields_with_alias: root instances fields with alias as key + :type \**fields_with_alias: DSLField + :param operation_name: optional operation name + :type operation_name: str :return: a Document which can be later executed or subscribed by a :class:`Client `, by an :class:`async session ` or by a @@ -54,9 +61,13 @@ def dsl_gql(*fields: "DSLField") -> DocumentNode: :raises AssertionError: if an argument is not a field of a root type """ + all_fields: Tuple["DSLField", ...] = DSLField.get_aliased_fields( + fields, fields_with_alias + ) + # Check that we receive only arguments of type DSLField # And that they are a root type - for field in fields: + for field in all_fields: if not isinstance(field, DSLField): raise TypeError( f"fields must be instances of DSLField. Received type: {type(field)}" @@ -68,15 +79,16 @@ def dsl_gql(*fields: "DSLField") -> DocumentNode: # Get the operation from the first field # All the fields must have the same operation - operation = fields[0].type_name.lower() + operation = all_fields[0].type_name.lower() return DocumentNode( definitions=[ OperationDefinitionNode( operation=OperationType(operation), selection_set=SelectionSetNode( - selections=FrozenList(DSLField.get_ast_fields(fields)) + selections=FrozenList(DSLField.get_ast_fields(all_fields)) ), + **({"name": NameNode(value=operation_name)} if operation_name else {}), ) ] ) @@ -203,7 +215,7 @@ def __init__( log.debug(f"Creating {self!r}") @staticmethod - def get_ast_fields(fields: Iterable) -> List[FieldNode]: + def get_ast_fields(fields: Iterable["DSLField"]) -> List[FieldNode]: """ :meta private: @@ -222,6 +234,23 @@ def get_ast_fields(fields: Iterable) -> List[FieldNode]: return ast_fields + @staticmethod + def get_aliased_fields( + fields: Iterable["DSLField"], fields_with_alias: Dict[str, "DSLField"] + ) -> Tuple["DSLField", ...]: + """ + :meta private: + + Concatenate all the fields (with or without alias) in a Tuple. + + Set the requested alias for the fields with alias. + """ + + return ( + *fields, + *(field.alias(alias) for alias, field in fields_with_alias.items()), + ) + def select( self, *fields: "DSLField", **fields_with_alias: "DSLField" ) -> "DSLField": @@ -241,9 +270,9 @@ def select( of the :class:`DSLField` class. """ - added_fields: List["DSLField"] = list(fields) + [ - field.alias(alias) for alias, field in fields_with_alias.items() - ] + added_fields: Tuple["DSLField", ...] = self.get_aliased_fields( + fields, fields_with_alias + ) added_selections: List[FieldNode] = self.get_ast_fields(added_fields) diff --git a/tests/starwars/test_dsl.py b/tests/starwars/test_dsl.py index 5fa85398..49a366fe 100644 --- a/tests/starwars/test_dsl.py +++ b/tests/starwars/test_dsl.py @@ -182,7 +182,7 @@ def test_invalid_arg(ds): ds.Query.hero.args(invalid_arg=5).select(ds.Character.name) -def test_multiple_queries(ds, client): +def test_multiple_root_fields(ds, client): query = dsl_gql( ds.Query.hero.select(ds.Character.name), ds.Query.hero(episode=5).alias("hero_of_episode_5").select(ds.Character.name), @@ -195,6 +195,37 @@ def test_multiple_queries(ds, client): assert result == expected +def test_root_fields_aliased(ds, client): + query = dsl_gql( + ds.Query.hero.select(ds.Character.name), + hero_of_episode_5=ds.Query.hero(episode=5).select(ds.Character.name), + ) + result = client.execute(query) + expected = { + "hero": {"name": "R2-D2"}, + "hero_of_episode_5": {"name": "Luke Skywalker"}, + } + assert result == expected + + +def test_operation_name(ds): + query = dsl_gql( + ds.Query.hero.select(ds.Character.name), operation_name="GetHeroName", + ) + + from graphql import print_ast + + assert ( + print_ast(query) + == """query GetHeroName { + hero { + name + } +} +""" + ) + + def test_dsl_gql_all_fields_should_be_instances_of_DSLField(ds, client): with pytest.raises( TypeError, match="fields must be instances of DSLField. Received type:" From 27e0eb7cca4fd25b447a6ce2834c1413f7c061ff Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Sat, 14 Nov 2020 23:06:01 +0100 Subject: [PATCH 21/21] Another refactor to allow multiple operations in documents --- docs/advanced/dsl_module.rst | 113 +++++++++++++++---- docs/code_examples/aiohttp_async_dsl.py | 10 +- docs/code_examples/requests_sync_dsl.py | 6 +- gql/dsl.py | 138 +++++++++++++++++------- tests/nested_input/test_nested_input.py | 14 ++- tests/starwars/test_dsl.py | 121 +++++++++++++++------ 6 files changed, 299 insertions(+), 103 deletions(-) diff --git a/docs/advanced/dsl_module.rst b/docs/advanced/dsl_module.rst index 1c461649..2e60f045 100644 --- a/docs/advanced/dsl_module.rst +++ b/docs/advanced/dsl_module.rst @@ -11,10 +11,12 @@ The following code: ds = DSLSchema(StarWarsSchema) query = dsl_gql( - ds.Query.hero.select( - ds.Character.id, - ds.Character.name, - ds.Character.friends.select(ds.Character.name), + DSLQuery( + ds.Query.hero.select( + ds.Character.id, + ds.Character.name, + ds.Character.friends.select(ds.Character.name), + ) ) ) @@ -81,11 +83,12 @@ As you can select children fields of any object type, you can construct your com ds.Character.friends.select(ds.Character.name), ) -Once your query is completed and you have selected all the fields you want, -use the :func:`dsl_gql ` function to convert your query into -a document which will be able to get executed in the client or a session:: +Once your root query fields are defined, you can put them in an operation using +:class:`DSLQuery `, +:class:`DSLMutation ` or +:class:`DSLSubscription `:: - query = dsl_gql( + DSLQuery( ds.Query.hero.select( ds.Character.id, ds.Character.name, @@ -93,6 +96,21 @@ a document which will be able to get executed in the client or a session:: ) ) + +Once your operations are defined, +use the :func:`dsl_gql ` function to convert your operations into +a document which will be able to get executed in the client or a session:: + + query = dsl_gql( + DSLQuery( + ds.Query.hero.select( + ds.Character.id, + ds.Character.name, + ds.Character.friends.select(ds.Character.name), + ) + ) + ) + result = client.execute(query) Arguments @@ -107,36 +125,91 @@ It can also be done using the :meth:`args ` method:: ds.Query.human.args(id="1000").select(ds.Human.name) -Alias -^^^^^ +Aliases +^^^^^^^ You can set an alias of a field using the :meth:`alias ` method:: ds.Query.human.args(id=1000).alias("luke").select(ds.Character.name) +It is also possible to set the alias directly using keyword arguments of an operation:: + + DSLQuery( + luke=ds.Query.human.args(id=1000).select(ds.Character.name) + ) + +Or using keyword arguments in the :meth:`select ` method:: + + ds.Query.hero.select( + my_name=ds.Character.name + ) + Mutations ^^^^^^^^^ -It works the same way for mutations. Example:: +For the mutations, you need to start from root fields starting from :code:`ds.Mutation` +then you need to create the GraphQL operation using the class +:class:`DSLMutation `. Example:: query = dsl_gql( - ds.Mutation.createReview.args( - episode=6, review={"stars": 5, "commentary": "This is a great movie!"} - ).select(ds.Review.stars, ds.Review.commentary) + DSLMutation( + ds.Mutation.createReview.args( + episode=6, review={"stars": 5, "commentary": "This is a great movie!"} + ).select(ds.Review.stars, ds.Review.commentary) + ) ) -Multiple requests -^^^^^^^^^^^^^^^^^ +Subscriptions +^^^^^^^^^^^^^ -It is possible to create a document with multiple requests:: +For the subscriptions, you need to start from root fields starting from :code:`ds.Subscription` +then you need to create the GraphQL operation using the class +:class:`DSLSubscription `. Example:: query = dsl_gql( + DSLSubscription( + ds.Subscription.reviewAdded(episode=6).select(ds.Review.stars, ds.Review.commentary) + ) + ) + +Multiple fields in an operation +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +It is possible to create an operation with multiple fields:: + + DSLQuery( ds.Query.hero.select(ds.Character.name), - ds.Query.hero(episode=5).alias("hero_of_episode_5").select(ds.Character.name), + hero_of_episode_5=ds.Query.hero(episode=5).select(ds.Character.name), + ) + +Operation name +^^^^^^^^^^^^^^ + +You can set the operation name of an operation using a keyword argument +to :func:`dsl_gql `:: + + query = dsl_gql( + GetHeroName=DSLQuery(ds.Query.hero.select(ds.Character.name)) ) -But you have to take care that the root type is always the same. It is not possible -to mix queries and mutations for example. +will generate the request:: + + query GetHeroName { + hero { + name + } + } + +Multiple operations in a document +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +It is possible to create an Document with multiple operations:: + + query = dsl_gql( + operation_name_1=DSLQuery( ... ), + operation_name_2=DSLQuery( ... ), + operation_name_3=DSLMutation( ... ), + ) Executable examples ------------------- diff --git a/docs/code_examples/aiohttp_async_dsl.py b/docs/code_examples/aiohttp_async_dsl.py index 48352d03..d558ef6d 100644 --- a/docs/code_examples/aiohttp_async_dsl.py +++ b/docs/code_examples/aiohttp_async_dsl.py @@ -1,7 +1,7 @@ import asyncio from gql import Client -from gql.dsl import DSLSchema, dsl_gql +from gql.dsl import DSLQuery, DSLSchema, dsl_gql from gql.transport.aiohttp import AIOHTTPTransport @@ -22,8 +22,10 @@ async def main(): # Create the query using dynamically generated attributes from ds query = dsl_gql( - ds.Query.continents(filter={"code": {"eq": "EU"}}).select( - ds.Continent.code, ds.Continent.name + DSLQuery( + ds.Query.continents(filter={"code": {"eq": "EU"}}).select( + ds.Continent.code, ds.Continent.name + ) ) ) @@ -43,7 +45,7 @@ async def main(): query_continents.select(ds.Continent.name) # I generate a document from my query to be able to execute it - query = dsl_gql(query_continents) + query = dsl_gql(DSLQuery(query_continents)) # Execute the query result = await session.execute(query) diff --git a/docs/code_examples/requests_sync_dsl.py b/docs/code_examples/requests_sync_dsl.py index 7cec520a..23c40e18 100644 --- a/docs/code_examples/requests_sync_dsl.py +++ b/docs/code_examples/requests_sync_dsl.py @@ -1,5 +1,5 @@ from gql import Client -from gql.dsl import DSLSchema, dsl_gql +from gql.dsl import DSLQuery, DSLSchema, dsl_gql from gql.transport.requests import RequestsHTTPTransport transport = RequestsHTTPTransport( @@ -21,7 +21,9 @@ ds = DSLSchema(client.schema) # Create the query using dynamically generated attributes from ds - query = dsl_gql(ds.Query.continents.select(ds.Continent.code, ds.Continent.name)) + query = dsl_gql( + DSLQuery(ds.Query.continents.select(ds.Continent.code, ds.Continent.name)) + ) result = session.execute(query) print(result) diff --git a/gql/dsl.py b/gql/dsl.py index a197e0bf..72abfcb9 100644 --- a/gql/dsl.py +++ b/gql/dsl.py @@ -1,4 +1,5 @@ import logging +from abc import ABC from typing import Dict, Iterable, List, Optional, Tuple, Union from graphql import ( @@ -26,70 +27,59 @@ def dsl_gql( - *fields: "DSLField", - operation_name: Optional[str] = None, - **fields_with_alias: "DSLField", + *operations: "DSLOperation", **operations_with_name: "DSLOperation" ) -> DocumentNode: - r"""Given arguments of type :class:`DSLField` containing GraphQL requests, + r"""Given arguments instances of :class:`DSLOperation` + containing GraphQL operations, generate a Document which can be executed later in a gql client or a gql session. Similar to the :func:`gql.gql` function but instead of parsing a python - string to describe the request, we are using requests which have been generated + string to describe the request, we are using operations which have been generated dynamically using instances of :class:`DSLField`, generated by instances of :class:`DSLType` which themselves originated from a :class:`DSLSchema` class. - The fields arguments should be fields of root GraphQL types - (Query, Mutation or Subscription). + :param \*operations: the GraphQL operations + :type \*operations: DSLOperation (DSLQuery, DSLMutation, DSLSubscription) + :param \**operations_with_name: the GraphQL operations with an operation name + :type \**operations_with_name: DSLOperation (DSLQuery, DSLMutation, DSLSubscription) - They should all have the same root type - (you can't mix queries with mutations for example). - - :param \*fields: root instances of the dynamically generated requests - :type \*fields: DSLField - :param \**fields_with_alias: root instances fields with alias as key - :type \**fields_with_alias: DSLField - :param operation_name: optional operation name - :type operation_name: str :return: a Document which can be later executed or subscribed by a :class:`Client `, by an :class:`async session ` or by a :class:`sync session ` - :raises TypeError: if an argument is not an instance of :class:`DSLField` - :raises AssertionError: if an argument is not a field of a root type + :raises TypeError: if an argument is not an instance of :class:`DSLOperation` """ - all_fields: Tuple["DSLField", ...] = DSLField.get_aliased_fields( - fields, fields_with_alias + # Concatenate operations without and with name + all_operations: Tuple["DSLOperation", ...] = ( + *operations, + *(operation for operation in operations_with_name.values()), ) - # Check that we receive only arguments of type DSLField - # And that they are a root type - for field in all_fields: - if not isinstance(field, DSLField): + # Set the operation name + for name, operation in operations_with_name.items(): + operation.name = name + + # Check the type + for operation in all_operations: + if not isinstance(operation, DSLOperation): raise TypeError( - f"fields must be instances of DSLField. Received type: {type(field)}" + "Operations should be instances of DSLOperation " + "(DSLQuery, DSLMutation or DSLSubscription).\n" + f"Received: {type(operation)}." ) - assert field.type_name in ["Query", "Mutation", "Subscription"], ( - "fields should be root types (Query, Mutation or Subscription)\n" - f"Received: {field.type_name}" - ) - - # Get the operation from the first field - # All the fields must have the same operation - operation = all_fields[0].type_name.lower() return DocumentNode( definitions=[ OperationDefinitionNode( - operation=OperationType(operation), - selection_set=SelectionSetNode( - selections=FrozenList(DSLField.get_ast_fields(all_fields)) - ), - **({"name": NameNode(value=operation_name)} if operation_name else {}), + operation=OperationType(operation.operation_type), + selection_set=operation.selection_set, + **({"name": NameNode(value=operation.name)} if operation.name else {}), ) + for operation in all_operations ] ) @@ -133,6 +123,77 @@ def __getattr__(self, name: str) -> "DSLType": return DSLType(type_def) +class DSLOperation(ABC): + """Interface for GraphQL operations. + + Inherited by + :class:`DSLQuery `, + :class:`DSLMutation ` and + :class:`DSLSubscription ` + """ + + operation_type: OperationType + + def __init__( + self, *fields: "DSLField", **fields_with_alias: "DSLField", + ): + r"""Given arguments of type :class:`DSLField` containing GraphQL requests, + generate an operation which can be converted to a Document + using the :func:`dsl_gql `. + + The fields arguments should be fields of root GraphQL types + (Query, Mutation or Subscription) and correspond to the + operation_type of this operation. + + :param \*fields: root instances of the dynamically generated requests + :type \*fields: DSLField + :param \**fields_with_alias: root instances fields with alias as key + :type \**fields_with_alias: DSLField + + :raises TypeError: if an argument is not an instance of :class:`DSLField` + :raises AssertionError: if an argument is not a field which correspond + to the operation type + """ + + self.name: Optional[str] = None + + # Concatenate fields without and with alias + all_fields: Tuple["DSLField", ...] = DSLField.get_aliased_fields( + fields, fields_with_alias + ) + + # Check that we receive only arguments of type DSLField + # And that the root type correspond to the operation + for field in all_fields: + if not isinstance(field, DSLField): + raise TypeError( + ( + "fields must be instances of DSLField. " + f"Received type: {type(field)}" + ) + ) + assert field.type_name.upper() == self.operation_type.name, ( + f"Invalid root field for operation {self.operation_type.name}.\n" + f"Received: {field.type_name}" + ) + + self.selection_set: SelectionSetNode = SelectionSetNode( + selections=FrozenList(DSLField.get_ast_fields(all_fields)) + ) + + +class DSLQuery(DSLOperation): + operation_type = OperationType.QUERY + + +class DSLMutation(DSLOperation): + operation_type = OperationType.MUTATION + + +class DSLSubscription(DSLOperation): + operation_type = OperationType.SUBSCRIPTION + + class DSLType: """The DSLType represents a GraphQL type for the DSL code. @@ -270,6 +331,7 @@ def select( of the :class:`DSLField` class. """ + # Concatenate fields without and with alias added_fields: Tuple["DSLField", ...] = self.get_aliased_fields( fields, fields_with_alias ) diff --git a/tests/nested_input/test_nested_input.py b/tests/nested_input/test_nested_input.py index b8997e95..0712be32 100644 --- a/tests/nested_input/test_nested_input.py +++ b/tests/nested_input/test_nested_input.py @@ -1,7 +1,7 @@ import pytest from gql import Client -from gql.dsl import DSLSchema, dsl_gql +from gql.dsl import DSLQuery, DSLSchema, dsl_gql from tests.nested_input.schema import NestedInputSchema @@ -16,18 +16,24 @@ def client(): def test_nested_input(ds, client): - query = dsl_gql(ds.Query.echo.args(nested={"foo": 1})) + query = dsl_gql(DSLQuery(ds.Query.echo.args(nested={"foo": 1}))) assert client.execute(query) == {"echo": '{"foo": 1}'} def test_nested_input_2(ds, client): - query = dsl_gql(ds.Query.echo.args(nested={"foo": 1, "child": {"foo": 2}})) + query = dsl_gql( + DSLQuery(ds.Query.echo.args(nested={"foo": 1, "child": {"foo": 2}})) + ) assert client.execute(query) == {"echo": '{"foo": 1, "child": {"foo": 2}}'} def test_nested_input_3(ds, client): query = dsl_gql( - ds.Query.echo.args(nested={"foo": 1, "child": {"foo": 2, "child": {"foo": 3}}}) + DSLQuery( + ds.Query.echo.args( + nested={"foo": 1, "child": {"foo": 2, "child": {"foo": 3}}} + ) + ) ) assert client.execute(query) == { "echo": '{"foo": 1, "child": {"foo": 2, "child": {"foo": 3}}}' diff --git a/tests/starwars/test_dsl.py b/tests/starwars/test_dsl.py index 49a366fe..5807e87f 100644 --- a/tests/starwars/test_dsl.py +++ b/tests/starwars/test_dsl.py @@ -1,7 +1,8 @@ import pytest +from graphql import print_ast from gql import Client -from gql.dsl import DSLSchema, dsl_gql +from gql.dsl import DSLMutation, DSLQuery, DSLSchema, DSLSubscription, dsl_gql from .schema import StarWarsSchema @@ -135,7 +136,7 @@ def test_fetch_name_aliased_as_kwargs(ds: DSLSchema): def test_hero_name_query_result(ds, client): - query = dsl_gql(ds.Query.hero.select(ds.Character.name)) + query = dsl_gql(DSLQuery(ds.Query.hero.select(ds.Character.name))) result = client.execute(query) expected = {"hero": {"name": "R2-D2"}} assert result == expected @@ -143,7 +144,9 @@ def test_hero_name_query_result(ds, client): def test_arg_serializer_list(ds, client): query = dsl_gql( - ds.Query.characters.args(ids=[1000, 1001, 1003]).select(ds.Character.name,) + DSLQuery( + ds.Query.characters.args(ids=[1000, 1001, 1003]).select(ds.Character.name,) + ) ) result = client.execute(query) expected = { @@ -157,7 +160,7 @@ def test_arg_serializer_list(ds, client): def test_arg_serializer_enum(ds, client): - query = dsl_gql(ds.Query.hero.args(episode=5).select(ds.Character.name)) + query = dsl_gql(DSLQuery(ds.Query.hero.args(episode=5).select(ds.Character.name))) result = client.execute(query) expected = {"hero": {"name": "Luke Skywalker"}} assert result == expected @@ -166,15 +169,38 @@ def test_arg_serializer_enum(ds, client): def test_create_review_mutation_result(ds, client): query = dsl_gql( - ds.Mutation.createReview.args( - episode=6, review={"stars": 5, "commentary": "This is a great movie!"} - ).select(ds.Review.stars, ds.Review.commentary) + DSLMutation( + ds.Mutation.createReview.args( + episode=6, review={"stars": 5, "commentary": "This is a great movie!"} + ).select(ds.Review.stars, ds.Review.commentary) + ) ) result = client.execute(query) expected = {"createReview": {"stars": 5, "commentary": "This is a great movie!"}} assert result == expected +def test_subscription(ds): + + query = dsl_gql( + DSLSubscription( + ds.Subscription.reviewAdded(episode=6).select( + ds.Review.stars, ds.Review.commentary + ) + ) + ) + assert ( + print_ast(query) + == """subscription { + reviewAdded(episode: JEDI) { + stars + commentary + } +} +""" + ) + + def test_invalid_arg(ds): with pytest.raises( KeyError, match="Argument invalid_arg does not exist in Field: Character." @@ -184,8 +210,12 @@ def test_invalid_arg(ds): def test_multiple_root_fields(ds, client): query = dsl_gql( - ds.Query.hero.select(ds.Character.name), - ds.Query.hero(episode=5).alias("hero_of_episode_5").select(ds.Character.name), + DSLQuery( + ds.Query.hero.select(ds.Character.name), + ds.Query.hero(episode=5) + .alias("hero_of_episode_5") + .select(ds.Character.name), + ) ) result = client.execute(query) expected = { @@ -197,8 +227,10 @@ def test_multiple_root_fields(ds, client): def test_root_fields_aliased(ds, client): query = dsl_gql( - ds.Query.hero.select(ds.Character.name), - hero_of_episode_5=ds.Query.hero(episode=5).select(ds.Character.name), + DSLQuery( + ds.Query.hero.select(ds.Character.name), + hero_of_episode_5=ds.Query.hero(episode=5).select(ds.Character.name), + ) ) result = client.execute(query) expected = { @@ -209,11 +241,28 @@ def test_root_fields_aliased(ds, client): def test_operation_name(ds): - query = dsl_gql( - ds.Query.hero.select(ds.Character.name), operation_name="GetHeroName", + query = dsl_gql(GetHeroName=DSLQuery(ds.Query.hero.select(ds.Character.name),)) + + assert ( + print_ast(query) + == """query GetHeroName { + hero { + name + } +} +""" ) - from graphql import print_ast + +def test_multiple_operations(ds): + query = dsl_gql( + GetHeroName=DSLQuery(ds.Query.hero.select(ds.Character.name)), + CreateReviewMutation=DSLMutation( + ds.Mutation.createReview.args( + episode=6, review={"stars": 5, "commentary": "This is a great movie!"} + ).select(ds.Review.stars, ds.Review.commentary) + ), + ) assert ( print_ast(query) @@ -222,37 +271,39 @@ def test_operation_name(ds): name } } + +mutation CreateReviewMutation { + createReview(episode: JEDI, review: {stars: 5, \ +commentary: "This is a great movie!"}) { + stars + commentary + } +} """ ) -def test_dsl_gql_all_fields_should_be_instances_of_DSLField(ds, client): +def test_dsl_query_all_fields_should_be_instances_of_DSLField(): with pytest.raises( TypeError, match="fields must be instances of DSLField. Received type:" ): - dsl_gql( - ds.Query.hero.select(ds.Character.name), - ds.Query.hero(episode=5) - .alias("hero_of_episode_5") - .select(ds.Character.name), - "I am a string", - ) + DSLQuery("I am a string") -def test_dsl_gql_all_fields_should_be_a_root_type(ds, client): - with pytest.raises(AssertionError,) as excinfo: - dsl_gql( - ds.Query.hero.select(ds.Character.name), - ds.Query.hero(episode=5) - .alias("hero_of_episode_5") - .select(ds.Character.name), - ds.Character.name, - ) +def test_dsl_query_all_fields_should_correspond_to_the_root_type(ds): + with pytest.raises(AssertionError) as excinfo: + DSLQuery(ds.Character.name) - assert ( - "fields should be root types (Query, Mutation or Subscription)\n" - "Received: Character" - ) in str(excinfo.value) + assert ("Invalid root field for operation QUERY.\n" "Received: Character") in str( + excinfo.value + ) + + +def test_dsl_gql_all_arguments_should_be_operations(): + with pytest.raises( + TypeError, match="Operations should be instances of DSLOperation " + ): + dsl_gql("I am a string") def test_DSLSchema_requires_a_schema(client):