diff --git a/README.md b/README.md index aa5aa0ac..8fefeb2f 100644 --- a/README.md +++ b/README.md @@ -40,6 +40,7 @@ The main features of GQL are: * Supports [sync or async usage](https://gql.readthedocs.io/en/latest/async/index.html), [allowing concurrent requests](https://gql.readthedocs.io/en/latest/advanced/async_advanced_usage.html#async-advanced-usage) * Supports [File uploads](https://gql.readthedocs.io/en/latest/usage/file_upload.html) * [gql-cli script](https://gql.readthedocs.io/en/latest/gql-cli/intro.html) to execute GraphQL queries from the command line +* [DSL module](https://gql.readthedocs.io/en/latest/advanced/dsl_module.html) to compose GraphQL queries dynamically ## Installation diff --git a/docs/advanced/dsl_module.rst b/docs/advanced/dsl_module.rst index aa6638df..2e60f045 100644 --- a/docs/advanced/dsl_module.rst +++ b/docs/advanced/dsl_module.rst @@ -2,33 +2,225 @@ Compose queries dynamically =========================== Instead of providing the GraphQL queries as a Python String, it is also possible to create GraphQL queries dynamically. -Using the DSL module, we can create a query using a Domain Specific Language which is created from the schema. +Using the :mod:`DSL module `, we can create a query using a Domain Specific Language which is created from the schema. + +The following code: .. code-block:: python - from gql.dsl import DSLSchema + ds = DSLSchema(StarWarsSchema) + + query = dsl_gql( + DSLQuery( + ds.Query.hero.select( + ds.Character.id, + ds.Character.name, + ds.Character.friends.select(ds.Character.name), + ) + ) + ) + +will generate a query equivalent to: + +.. code-block:: python + + query = gql(""" + query { + hero { + id + name + friends { + name + } + } + } + """) + +How to use +---------- + +First generate the root using the :class:`DSLSchema `:: + + ds = DSLSchema(client.schema) + +Then use auto-generated attributes of the :code:`ds` instance +to get a root type (Query, Mutation or Subscription). +This will generate a :class:`DSLType ` instance:: + + ds.Query + +From this root type, you use auto-generated attributes to get a field. +This will generate a :class:`DSLField ` instance:: + + ds.Query.hero + +hero is a GraphQL object type and needs children fields. By default, +there is no children fields selected. To select the fields that you want +in your query, you use the :meth:`select ` method. + +To generate the children fields, we use the same method as above to auto-generate the fields +from the :code:`ds` instance +(ie :code:`ds.Character.name` is the field `name` of the type `Character`):: + + ds.Query.hero.select(ds.Character.name) - client = Client(schema=StarWarsSchema) - ds = DSLSchema(client) +The select method return the same instance, so it is possible to chain the calls:: - query_dsl = ds.Query.hero.select( + ds.Query.hero.select(ds.Character.name).select(ds.Character.id) + +Or do it sequencially:: + + hero_query = ds.Query.hero + + hero_query.select(ds.Character.name) + hero_query.select(ds.Character.id) + +As you can select children fields of any object type, you can construct your complete query tree:: + + ds.Query.hero.select( ds.Character.id, ds.Character.name, - ds.Character.friends.select(ds.Character.name,), + ds.Character.friends.select(ds.Character.name), ) -will create a query equivalent to: +Once your root query fields are defined, you can put them in an operation using +:class:`DSLQuery `, +:class:`DSLMutation ` or +:class:`DSLSubscription `:: -.. code-block:: python + DSLQuery( + ds.Query.hero.select( + ds.Character.id, + ds.Character.name, + ds.Character.friends.select(ds.Character.name), + ) + ) + + +Once your operations are defined, +use the :func:`dsl_gql ` function to convert your operations into +a document which will be able to get executed in the client or a session:: + + query = dsl_gql( + DSLQuery( + ds.Query.hero.select( + ds.Character.id, + ds.Character.name, + ds.Character.friends.select(ds.Character.name), + ) + ) + ) + + result = client.execute(query) + +Arguments +^^^^^^^^^ + +It is possible to add arguments to any field simply by calling it +with the required arguments:: + + ds.Query.human(id="1000").select(ds.Human.name) + +It can also be done using the :meth:`args ` method:: + + ds.Query.human.args(id="1000").select(ds.Human.name) - hero { - id - name - friends { - name - } +Aliases +^^^^^^^ + +You can set an alias of a field using the :meth:`alias ` method:: + + ds.Query.human.args(id=1000).alias("luke").select(ds.Character.name) + +It is also possible to set the alias directly using keyword arguments of an operation:: + + DSLQuery( + luke=ds.Query.human.args(id=1000).select(ds.Character.name) + ) + +Or using keyword arguments in the :meth:`select ` method:: + + ds.Query.hero.select( + my_name=ds.Character.name + ) + +Mutations +^^^^^^^^^ + +For the mutations, you need to start from root fields starting from :code:`ds.Mutation` +then you need to create the GraphQL operation using the class +:class:`DSLMutation `. Example:: + + query = dsl_gql( + DSLMutation( + ds.Mutation.createReview.args( + episode=6, review={"stars": 5, "commentary": "This is a great movie!"} + ).select(ds.Review.stars, ds.Review.commentary) + ) + ) + +Subscriptions +^^^^^^^^^^^^^ + +For the subscriptions, you need to start from root fields starting from :code:`ds.Subscription` +then you need to create the GraphQL operation using the class +:class:`DSLSubscription `. Example:: + + query = dsl_gql( + DSLSubscription( + ds.Subscription.reviewAdded(episode=6).select(ds.Review.stars, ds.Review.commentary) + ) + ) + +Multiple fields in an operation +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +It is possible to create an operation with multiple fields:: + + DSLQuery( + ds.Query.hero.select(ds.Character.name), + hero_of_episode_5=ds.Query.hero(episode=5).select(ds.Character.name), + ) + +Operation name +^^^^^^^^^^^^^^ + +You can set the operation name of an operation using a keyword argument +to :func:`dsl_gql `:: + + query = dsl_gql( + GetHeroName=DSLQuery(ds.Query.hero.select(ds.Character.name)) + ) + +will generate the request:: + + query GetHeroName { + hero { + name + } } -.. warning:: +Multiple operations in a document +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +It is possible to create an Document with multiple operations:: + + query = dsl_gql( + operation_name_1=DSLQuery( ... ), + operation_name_2=DSLQuery( ... ), + operation_name_3=DSLMutation( ... ), + ) + +Executable examples +------------------- + +Async example +^^^^^^^^^^^^^ + +.. literalinclude:: ../code_examples/aiohttp_async_dsl.py + +Sync example +^^^^^^^^^^^^^ + +.. literalinclude:: ../code_examples/requests_sync_dsl.py - Please note that the DSL module is still considered experimental in GQL 3 and is subject to changes diff --git a/docs/code_examples/aiohttp_async_dsl.py b/docs/code_examples/aiohttp_async_dsl.py new file mode 100644 index 00000000..d558ef6d --- /dev/null +++ b/docs/code_examples/aiohttp_async_dsl.py @@ -0,0 +1,55 @@ +import asyncio + +from gql import Client +from gql.dsl import DSLQuery, DSLSchema, dsl_gql +from gql.transport.aiohttp import AIOHTTPTransport + + +async def main(): + + transport = AIOHTTPTransport(url="https://countries.trevorblades.com/graphql") + + client = Client(transport=transport, fetch_schema_from_transport=True) + + # Using `async with` on the client will start a connection on the transport + # and provide a `session` variable to execute queries on this connection. + # Because we requested to fetch the schema from the transport, + # GQL will fetch the schema just after the establishment of the first session + async with client as session: + + # Instanciate the root of the DSL Schema as ds + ds = DSLSchema(client.schema) + + # Create the query using dynamically generated attributes from ds + query = dsl_gql( + DSLQuery( + ds.Query.continents(filter={"code": {"eq": "EU"}}).select( + ds.Continent.code, ds.Continent.name + ) + ) + ) + + result = await session.execute(query) + print(result) + + # This can also be written as: + + # I want to query the continents + query_continents = ds.Query.continents + + # I want to get only the continents with code equal to "EU" + query_continents(filter={"code": {"eq": "EU"}}) + + # I want this query to return the code and name fields + query_continents.select(ds.Continent.code) + query_continents.select(ds.Continent.name) + + # I generate a document from my query to be able to execute it + query = dsl_gql(DSLQuery(query_continents)) + + # Execute the query + result = await session.execute(query) + print(result) + + +asyncio.run(main()) diff --git a/docs/code_examples/requests_sync.py b/docs/code_examples/requests_sync.py index d1054844..53b1e2a7 100644 --- a/docs/code_examples/requests_sync.py +++ b/docs/code_examples/requests_sync.py @@ -1,11 +1,11 @@ from gql import Client, gql from gql.transport.requests import RequestsHTTPTransport -sample_transport = RequestsHTTPTransport( +transport = RequestsHTTPTransport( url="https://countries.trevorblades.com/", verify=True, retries=3, ) -client = Client(transport=sample_transport, fetch_schema_from_transport=True,) +client = Client(transport=transport, fetch_schema_from_transport=True) query = gql( """ diff --git a/docs/code_examples/requests_sync_dsl.py b/docs/code_examples/requests_sync_dsl.py new file mode 100644 index 00000000..23c40e18 --- /dev/null +++ b/docs/code_examples/requests_sync_dsl.py @@ -0,0 +1,29 @@ +from gql import Client +from gql.dsl import DSLQuery, DSLSchema, dsl_gql +from gql.transport.requests import RequestsHTTPTransport + +transport = RequestsHTTPTransport( + url="https://countries.trevorblades.com/", verify=True, retries=3, +) + +client = Client(transport=transport, fetch_schema_from_transport=True) + +# Using `with` on the sync client will start a connection on the transport +# and provide a `session` variable to execute queries on this connection. +# Because we requested to fetch the schema from the transport, +# GQL will fetch the schema just after the establishment of the first session +with client as session: + + # We should have received the schema now that the session is established + assert client.schema is not None + + # Instanciate the root of the DSL Schema as ds + ds = DSLSchema(client.schema) + + # Create the query using dynamically generated attributes from ds + query = dsl_gql( + DSLQuery(ds.Query.continents.select(ds.Continent.code, ds.Continent.name)) + ) + + result = session.execute(query) + print(result) diff --git a/docs/modules/client.rst b/docs/modules/client.rst index 954b4e61..69169425 100644 --- a/docs/modules/client.rst +++ b/docs/modules/client.rst @@ -1,5 +1,5 @@ -Client -====== +gql.client +========== .. currentmodule:: gql.client diff --git a/docs/modules/dsl.rst b/docs/modules/dsl.rst new file mode 100644 index 00000000..e8487745 --- /dev/null +++ b/docs/modules/dsl.rst @@ -0,0 +1,7 @@ +gql.dsl +======= + +.. currentmodule:: gql.dsl + +.. automodule:: gql.dsl + :member-order: bysource diff --git a/docs/modules/gql.rst b/docs/modules/gql.rst index 94121ea3..aac47c86 100644 --- a/docs/modules/gql.rst +++ b/docs/modules/gql.rst @@ -20,3 +20,4 @@ Sub-Packages client transport + dsl diff --git a/docs/modules/transport.rst b/docs/modules/transport.rst index dd4627e0..9a3caf6e 100644 --- a/docs/modules/transport.rst +++ b/docs/modules/transport.rst @@ -1,5 +1,5 @@ -Transport -========= +gql.transport +============= .. currentmodule:: gql.transport diff --git a/gql/dsl.py b/gql/dsl.py index 27894fcf..72abfcb9 100644 --- a/gql/dsl.py +++ b/gql/dsl.py @@ -1,20 +1,18 @@ -from collections.abc import Iterable -from functools import partial +import logging +from abc import ABC +from typing import Dict, Iterable, List, Optional, Tuple, Union from graphql import ( ArgumentNode, DocumentNode, - EnumValueNode, FieldNode, - GraphQLEnumType, - GraphQLInputField, - GraphQLInputObjectType, - GraphQLList, - GraphQLNonNull, - ListValueNode, + GraphQLArgument, + GraphQLField, + GraphQLInterfaceType, + GraphQLNamedType, + GraphQLObjectType, + GraphQLSchema, NameNode, - ObjectFieldNode, - ObjectValueNode, OperationDefinitionNode, OperationType, SelectionSetNode, @@ -25,155 +23,414 @@ from .utils import to_camel_case +log = logging.getLogger(__name__) -class DSLSchema(object): - def __init__(self, client): - self.client = client - @property - def schema(self): - return self.client.schema +def dsl_gql( + *operations: "DSLOperation", **operations_with_name: "DSLOperation" +) -> DocumentNode: + r"""Given arguments instances of :class:`DSLOperation` + containing GraphQL operations, + generate a Document which can be executed later in a + gql client or a gql session. + + Similar to the :func:`gql.gql` function but instead of parsing a python + string to describe the request, we are using operations which have been generated + dynamically using instances of :class:`DSLField`, generated + by instances of :class:`DSLType` which themselves originated from + a :class:`DSLSchema` class. + + :param \*operations: the GraphQL operations + :type \*operations: DSLOperation (DSLQuery, DSLMutation, DSLSubscription) + :param \**operations_with_name: the GraphQL operations with an operation name + :type \**operations_with_name: DSLOperation (DSLQuery, DSLMutation, DSLSubscription) + + :return: a Document which can be later executed or subscribed by a + :class:`Client `, by an + :class:`async session ` or by a + :class:`sync session ` + + :raises TypeError: if an argument is not an instance of :class:`DSLOperation` + """ + + # Concatenate operations without and with name + all_operations: Tuple["DSLOperation", ...] = ( + *operations, + *(operation for operation in operations_with_name.values()), + ) + + # Set the operation name + for name, operation in operations_with_name.items(): + operation.name = name + + # Check the type + for operation in all_operations: + if not isinstance(operation, DSLOperation): + raise TypeError( + "Operations should be instances of DSLOperation " + "(DSLQuery, DSLMutation or DSLSubscription).\n" + f"Received: {type(operation)}." + ) + + return DocumentNode( + definitions=[ + OperationDefinitionNode( + operation=OperationType(operation.operation_type), + selection_set=operation.selection_set, + **({"name": NameNode(value=operation.name)} if operation.name else {}), + ) + for operation in all_operations + ] + ) + + +class DSLSchema: + """The DSLSchema is the root of the DSL code. + + Attributes of the DSLSchema class are generated automatically + with the `__getattr__` dunder method in order to generate + instances of :class:`DSLType` + """ + + def __init__(self, schema: GraphQLSchema): + """Initialize the DSLSchema with the given schema. + + :param schema: a GraphQL Schema provided locally or fetched using + an introspection query. Usually `client.schema` + :type schema: GraphQLSchema + + :raises TypeError: if the argument is not an instance of :class:`GraphQLSchema` + """ + + if not isinstance(schema, GraphQLSchema): + raise TypeError( + f"DSLSchema needs a schema as parameter. Received: {type(schema)}" + ) + + self._schema: GraphQLSchema = schema + + def __getattr__(self, name: str) -> "DSLType": + + type_def: Optional[GraphQLNamedType] = self._schema.get_type(name) + + if type_def is None: + raise AttributeError(f"Type '{name}' not found in the schema!") + + assert isinstance(type_def, GraphQLObjectType) or isinstance( + type_def, GraphQLInterfaceType + ) - def __getattr__(self, name): - type_def = self.schema.get_type(name) return DSLType(type_def) - def query(self, *args, **kwargs): - return self.execute(query(*args, **kwargs)) - def mutate(self, *args, **kwargs): - return self.query(*args, operation="mutation", **kwargs) +class DSLOperation(ABC): + """Interface for GraphQL operations. + + Inherited by + :class:`DSLQuery `, + :class:`DSLMutation ` and + :class:`DSLSubscription ` + """ + + operation_type: OperationType + + def __init__( + self, *fields: "DSLField", **fields_with_alias: "DSLField", + ): + r"""Given arguments of type :class:`DSLField` containing GraphQL requests, + generate an operation which can be converted to a Document + using the :func:`dsl_gql `. + + The fields arguments should be fields of root GraphQL types + (Query, Mutation or Subscription) and correspond to the + operation_type of this operation. + + :param \*fields: root instances of the dynamically generated requests + :type \*fields: DSLField + :param \**fields_with_alias: root instances fields with alias as key + :type \**fields_with_alias: DSLField + + :raises TypeError: if an argument is not an instance of :class:`DSLField` + :raises AssertionError: if an argument is not a field which correspond + to the operation type + """ + + self.name: Optional[str] = None + + # Concatenate fields without and with alias + all_fields: Tuple["DSLField", ...] = DSLField.get_aliased_fields( + fields, fields_with_alias + ) + + # Check that we receive only arguments of type DSLField + # And that the root type correspond to the operation + for field in all_fields: + if not isinstance(field, DSLField): + raise TypeError( + ( + "fields must be instances of DSLField. " + f"Received type: {type(field)}" + ) + ) + assert field.type_name.upper() == self.operation_type.name, ( + f"Invalid root field for operation {self.operation_type.name}.\n" + f"Received: {field.type_name}" + ) + + self.selection_set: SelectionSetNode = SelectionSetNode( + selections=FrozenList(DSLField.get_ast_fields(all_fields)) + ) + + +class DSLQuery(DSLOperation): + operation_type = OperationType.QUERY + + +class DSLMutation(DSLOperation): + operation_type = OperationType.MUTATION + + +class DSLSubscription(DSLOperation): + operation_type = OperationType.SUBSCRIPTION - def execute(self, document): - return self.client.execute(document) +class DSLType: + """The DSLType represents a GraphQL type for the DSL code. -class DSLType(object): - def __init__(self, type_): - self._type = type_ + It can be a root type (Query, Mutation or Subscription). + Or it can be any other object type (Human in the StarWars schema). + Or it can be an interface type (Character in the StarWars schema). - def __getattr__(self, name): - formatted_name, field_def = self.get_field(name) - return DSLField(formatted_name, field_def) + Instances of this class are generated for you automatically as attributes + of the :class:`DSLSchema` - def get_field(self, name): + Attributes of the DSLType class are generated automatically + with the `__getattr__` dunder method in order to generate + instances of :class:`DSLField` + """ + + def __init__(self, graphql_type: Union[GraphQLObjectType, GraphQLInterfaceType]): + """Initialize the DSLType with the GraphQL type. + + .. warning:: + Don't instantiate this class yourself. + Use attributes of the :class:`DSLSchema` instead. + + :param graphql_type: the GraphQL type definition from the schema + """ + self._type: Union[GraphQLObjectType, GraphQLInterfaceType] = graphql_type + log.debug(f"Creating {self!r})") + + def __getattr__(self, name: str) -> "DSLField": camel_cased_name = to_camel_case(name) if name in self._type.fields: - return name, self._type.fields[name] + formatted_name = name + field = self._type.fields[name] + elif camel_cased_name in self._type.fields: + formatted_name = camel_cased_name + field = self._type.fields[camel_cased_name] + else: + raise AttributeError( + f"Field {name} does not exist in type {self._type.name}." + ) - if camel_cased_name in self._type.fields: - return camel_cased_name, self._type.fields[camel_cased_name] + return DSLField(formatted_name, self._type, field) - raise KeyError(f"Field {name} does not exist in type {self._type.name}.") + def __repr__(self) -> str: + return f"<{self.__class__.__name__} {self._type!r}>" -def selections(*fields): - for _field in fields: - yield selection_field(_field).ast +class DSLField: + """The DSLField represents a GraphQL field for the DSL code. + Instances of this class are generated for you automatically as attributes + of the :class:`DSLType` -class DSLField(object): - def __init__(self, name, field): - self.field = field - self.ast_field = FieldNode(name=NameNode(value=name), arguments=FrozenList()) - self.selection_set = None + If this field contains children fields, then you need to select which ones + you want in the request using the :meth:`select ` + method. + """ - def select(self, *fields): - selection_set = self.ast_field.selection_set - added_selections = selections(*fields) - if selection_set: - selection_set.selections = FrozenList( - selection_set.selections + list(added_selections) - ) - else: + def __init__( + self, + name: str, + graphql_type: Union[GraphQLObjectType, GraphQLInterfaceType], + graphql_field: GraphQLField, + ): + """Initialize the DSLField. + + .. warning:: + Don't instantiate this class yourself. + Use attributes of the :class:`DSLType` instead. + + :param name: the name of the field + :param graphql_type: the GraphQL type definition from the schema + :param graphql_field: the GraphQL field definition from the schema + """ + self._type: Union[GraphQLObjectType, GraphQLInterfaceType] = graphql_type + self.field: GraphQLField = graphql_field + self.ast_field: FieldNode = FieldNode( + name=NameNode(value=name), arguments=FrozenList() + ) + log.debug(f"Creating {self!r}") + + @staticmethod + def get_ast_fields(fields: Iterable["DSLField"]) -> List[FieldNode]: + """ + :meta private: + + Equivalent to: :code:`[field.ast_field for field in fields]` + But with a type check for each field in the list. + + :raises TypeError: if any of the provided fields are not instances + of the :class:`DSLField` class. + """ + ast_fields = [] + for field in fields: + if isinstance(field, DSLField): + ast_fields.append(field.ast_field) + else: + raise TypeError(f'Received incompatible field: "{field}".') + + return ast_fields + + @staticmethod + def get_aliased_fields( + fields: Iterable["DSLField"], fields_with_alias: Dict[str, "DSLField"] + ) -> Tuple["DSLField", ...]: + """ + :meta private: + + Concatenate all the fields (with or without alias) in a Tuple. + + Set the requested alias for the fields with alias. + """ + + return ( + *fields, + *(field.alias(alias) for alias, field in fields_with_alias.items()), + ) + + def select( + self, *fields: "DSLField", **fields_with_alias: "DSLField" + ) -> "DSLField": + r"""Select the new children fields + that we want to receive in the request. + + If used multiple times, we will add the new children fields + to the existing children fields. + + :param \*fields: new children fields + :type \*fields: DSLField + :param \**fields_with_alias: new children fields with alias as key + :type \**fields_with_alias: DSLField + :return: itself + + :raises TypeError: if any of the provided fields are not instances + of the :class:`DSLField` class. + """ + + # Concatenate fields without and with alias + added_fields: Tuple["DSLField", ...] = self.get_aliased_fields( + fields, fields_with_alias + ) + + added_selections: List[FieldNode] = self.get_ast_fields(added_fields) + + current_selection_set: Optional[SelectionSetNode] = self.ast_field.selection_set + + if current_selection_set is None: self.ast_field.selection_set = SelectionSetNode( selections=FrozenList(added_selections) ) + else: + current_selection_set.selections = FrozenList( + current_selection_set.selections + added_selections + ) + + log.debug(f"Added fields: {fields} in {self!r}") + return self - def __call__(self, **kwargs): + def __call__(self, **kwargs) -> "DSLField": return self.args(**kwargs) - def alias(self, alias): + def alias(self, alias: str) -> "DSLField": + """Set an alias + + .. note:: + You can also pass the alias directly at the + :meth:`select ` method. + :code:`ds.Query.human.select(my_name=ds.Character.name)` is equivalent to: + :code:`ds.Query.human.select(ds.Character.name.alias("my_name"))` + + :param alias: the alias + :type alias: str + :return: itself + """ + self.ast_field.alias = NameNode(value=alias) return self - def args(self, **kwargs): - added_args = [] - for name, value in kwargs.items(): - arg = self.field.args.get(name) - if not arg: - raise KeyError(f"Argument {name} does not exist in {self.field}.") - arg_type_serializer = get_arg_serializer(arg.type, known_serializers=dict()) - serialized_value = arg_type_serializer(value) - added_args.append( - ArgumentNode(name=NameNode(value=name), value=serialized_value) - ) - ast_field = self.ast_field - ast_field.arguments = FrozenList(ast_field.arguments + added_args) - return self + def args(self, **kwargs) -> "DSLField": + r"""Set the arguments of a field - @property - def ast(self): - return self.ast_field + The arguments are parsed to be stored in the AST of this field. - def __str__(self): - return print_ast(self.ast_field) + .. note:: + You can also call the field directly with your arguments. + :code:`ds.Query.human(id=1000)` is equivalent to: + :code:`ds.Query.human.args(id=1000)` + :param \**kwargs: the arguments (keyword=value) + :return: itself -def selection_field(field): - if isinstance(field, DSLField): - return field + :raises KeyError: if any of the provided arguments does not exist + for this field. + """ - raise TypeError(f'Received incompatible query field: "{field}".') + assert self.ast_field.arguments is not None + self.ast_field.arguments = FrozenList( + self.ast_field.arguments + + [ + ArgumentNode( + name=NameNode(value=name), + value=ast_from_value(value, self._get_argument(name).type), + ) + for name, value in kwargs.items() + ] + ) -def query(*fields, **kwargs): - if "operation" not in kwargs: - kwargs["operation"] = "query" - return DocumentNode( - definitions=[ - OperationDefinitionNode( - operation=OperationType(kwargs["operation"]), - selection_set=SelectionSetNode( - selections=FrozenList(selections(*fields)) - ), - ) - ] - ) + log.debug(f"Added arguments {kwargs} in field {self!r})") + + return self + def _get_argument(self, name: str) -> GraphQLArgument: + """Method used to return the GraphQLArgument definition + of an argument from its name. -def serialize_list(serializer, list_values): - assert isinstance( - list_values, Iterable - ), f'Expected iterable, received "{list_values!r}".' - return ListValueNode(values=FrozenList(serializer(v) for v in list_values)) - - -def get_arg_serializer(arg_type, known_serializers): - if isinstance(arg_type, GraphQLNonNull): - return get_arg_serializer(arg_type.of_type, known_serializers) - if isinstance(arg_type, GraphQLInputField): - return get_arg_serializer(arg_type.type, known_serializers) - if isinstance(arg_type, GraphQLInputObjectType): - if arg_type in known_serializers: - return known_serializers[arg_type] - known_serializers[arg_type] = None - serializers = { - k: get_arg_serializer(v, known_serializers) - for k, v in arg_type.fields.items() - } - known_serializers[arg_type] = lambda value: ObjectValueNode( - fields=FrozenList( - ObjectFieldNode(name=NameNode(value=k), value=serializers[k](v)) - for k, v in value.items() - ) + :raises KeyError: if the provided argument does not exist + for this field. + """ + arg = self.field.args.get(name) + + if arg is None: + raise KeyError(f"Argument {name} does not exist in {self.field}.") + + return arg + + @property + def type_name(self): + """:meta private:""" + return self._type.name + + def __str__(self) -> str: + return print_ast(self.ast_field) + + def __repr__(self) -> str: + return ( + f"<{self.__class__.__name__} {self._type.name}" + f"::{self.ast_field.name.value}>" ) - return known_serializers[arg_type] - if isinstance(arg_type, GraphQLList): - inner_serializer = get_arg_serializer(arg_type.of_type, known_serializers) - return partial(serialize_list, inner_serializer) - if isinstance(arg_type, GraphQLEnumType): - return lambda value: EnumValueNode(value=arg_type.serialize(value)) - return lambda value: ast_from_value(arg_type.serialize(value), arg_type) diff --git a/tests/conftest.py b/tests/conftest.py index 9484beb4..1865152e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -100,7 +100,7 @@ async def go(app, *, port=None, **kwargs): # type: ignore # Adding debug logs to websocket tests -for name in ["websockets.server", "gql.transport.websockets"]: +for name in ["websockets.server", "gql.transport.websockets", "gql.dsl"]: logger = logging.getLogger(name) logger.setLevel(logging.DEBUG) diff --git a/tests/nested_input/schema.py b/tests/nested_input/schema.py index f27a94e8..bd5a0507 100644 --- a/tests/nested_input/schema.py +++ b/tests/nested_input/schema.py @@ -1,3 +1,5 @@ +import json + from graphql import ( GraphQLArgument, GraphQLField, @@ -6,6 +8,7 @@ GraphQLInt, GraphQLObjectType, GraphQLSchema, + GraphQLString, ) nestedInput = GraphQLInputObjectType( @@ -19,10 +22,10 @@ queryType = GraphQLObjectType( "Query", fields=lambda: { - "foo": GraphQLField( + "echo": GraphQLField( args={"nested": GraphQLArgument(type_=nestedInput)}, - resolve=lambda *args, **kwargs: 1, - type_=GraphQLInt, + resolve=lambda *args, **kwargs: json.dumps(kwargs["nested"]), + type_=GraphQLString, ), }, ) diff --git a/tests/nested_input/test_nested_input.py b/tests/nested_input/test_nested_input.py index 037d1518..0712be32 100644 --- a/tests/nested_input/test_nested_input.py +++ b/tests/nested_input/test_nested_input.py @@ -1,63 +1,40 @@ -from functools import partial - import pytest -from graphql import ( - EnumValueNode, - GraphQLEnumType, - GraphQLInputField, - GraphQLInputObjectType, - GraphQLList, - GraphQLNonNull, - NameNode, - ObjectFieldNode, - ObjectValueNode, - ast_from_value, -) -from graphql.pyutils import FrozenList - -import gql.dsl as dsl + from gql import Client -from gql.dsl import DSLSchema, serialize_list +from gql.dsl import DSLQuery, DSLSchema, dsl_gql from tests.nested_input.schema import NestedInputSchema -# back up the new func -new_get_arg_serializer = dsl.get_arg_serializer - - -def old_get_arg_serializer(arg_type, known_serializers=None): - if isinstance(arg_type, GraphQLNonNull): - return old_get_arg_serializer(arg_type.of_type) - if isinstance(arg_type, GraphQLInputField): - return old_get_arg_serializer(arg_type.type) - if isinstance(arg_type, GraphQLInputObjectType): - serializers = {k: old_get_arg_serializer(v) for k, v in arg_type.fields.items()} - return lambda value: ObjectValueNode( - fields=FrozenList( - ObjectFieldNode(name=NameNode(value=k), value=serializers[k](v)) - for k, v in value.items() - ) - ) - if isinstance(arg_type, GraphQLList): - inner_serializer = old_get_arg_serializer(arg_type.of_type) - return partial(serialize_list, inner_serializer) - if isinstance(arg_type, GraphQLEnumType): - return lambda value: EnumValueNode(value=arg_type.serialize(value)) - return lambda value: ast_from_value(arg_type.serialize(value), arg_type) - @pytest.fixture def ds(): - client = Client(schema=NestedInputSchema) - ds = DSLSchema(client) - return ds + return DSLSchema(NestedInputSchema) + +@pytest.fixture +def client(): + return Client(schema=NestedInputSchema) + + +def test_nested_input(ds, client): + query = dsl_gql(DSLQuery(ds.Query.echo.args(nested={"foo": 1}))) + assert client.execute(query) == {"echo": '{"foo": 1}'} -def test_nested_input_with_old_get_arg_serializer(ds): - dsl.get_arg_serializer = old_get_arg_serializer - with pytest.raises(RecursionError, match="maximum recursion depth exceeded"): - ds.query(ds.Query.foo.args(nested={"foo": 1})) +def test_nested_input_2(ds, client): + query = dsl_gql( + DSLQuery(ds.Query.echo.args(nested={"foo": 1, "child": {"foo": 2}})) + ) + assert client.execute(query) == {"echo": '{"foo": 1, "child": {"foo": 2}}'} -def test_nested_input_with_new_get_arg_serializer(ds): - dsl.get_arg_serializer = new_get_arg_serializer - assert ds.query(ds.Query.foo.args(nested={"foo": 1})) == {"foo": 1} + +def test_nested_input_3(ds, client): + query = dsl_gql( + DSLQuery( + ds.Query.echo.args( + nested={"foo": 1, "child": {"foo": 2, "child": {"foo": 3}}} + ) + ) + ) + assert client.execute(query) == { + "echo": '{"foo": 1, "child": {"foo": 2, "child": {"foo": 3}}}' + } diff --git a/tests/starwars/test_dsl.py b/tests/starwars/test_dsl.py index 12a2f76b..5807e87f 100644 --- a/tests/starwars/test_dsl.py +++ b/tests/starwars/test_dsl.py @@ -1,28 +1,32 @@ import pytest +from graphql import print_ast from gql import Client -from gql.dsl import DSLSchema +from gql.dsl import DSLMutation, DSLQuery, DSLSchema, DSLSubscription, dsl_gql from .schema import StarWarsSchema @pytest.fixture def ds(): - client = Client(schema=StarWarsSchema) - ds = DSLSchema(client) - return ds + return DSLSchema(StarWarsSchema) + + +@pytest.fixture +def client(): + return Client(schema=StarWarsSchema) def test_invalid_field_on_type_query(ds): - with pytest.raises(KeyError) as exc_info: + with pytest.raises(AttributeError) as exc_info: ds.Query.extras.select(ds.Character.name) assert "Field extras does not exist in type Query." in str(exc_info.value) -def test_incompatible_query_field(ds): +def test_incompatible_field(ds): with pytest.raises(Exception) as exc_info: - ds.query("hero") - assert "Received incompatible query field" in str(exc_info.value) + ds.Query.hero.select("not_a_DSL_FIELD") + assert "Received incompatible field" in str(exc_info.value) def test_hero_name_query(ds): @@ -110,16 +114,41 @@ def test_fetch_luke_aliased(ds): assert query == str(query_dsl) -def test_hero_name_query_result(ds): - result = ds.query(ds.Query.hero.select(ds.Character.name)) +def test_fetch_name_aliased(ds: DSLSchema): + query = """ +human(id: "1000") { + my_name: name +} + """.strip() + query_dsl = ds.Query.human.args(id=1000).select(ds.Character.name.alias("my_name")) + print(str(query_dsl)) + assert query == str(query_dsl) + + +def test_fetch_name_aliased_as_kwargs(ds: DSLSchema): + query = """ +human(id: "1000") { + my_name: name +} + """.strip() + query_dsl = ds.Query.human.args(id=1000).select(my_name=ds.Character.name,) + assert query == str(query_dsl) + + +def test_hero_name_query_result(ds, client): + query = dsl_gql(DSLQuery(ds.Query.hero.select(ds.Character.name))) + result = client.execute(query) expected = {"hero": {"name": "R2-D2"}} assert result == expected -def test_arg_serializer_list(ds): - result = ds.query( - ds.Query.characters.args(ids=[1000, 1001, 1003]).select(ds.Character.name,) +def test_arg_serializer_list(ds, client): + query = dsl_gql( + DSLQuery( + ds.Query.characters.args(ids=[1000, 1001, 1003]).select(ds.Character.name,) + ) ) + result = client.execute(query) expected = { "characters": [ {"name": "Luke Skywalker"}, @@ -130,24 +159,160 @@ def test_arg_serializer_list(ds): assert result == expected -def test_arg_serializer_enum(ds): - result = ds.query(ds.Query.hero.args(episode=5).select(ds.Character.name)) +def test_arg_serializer_enum(ds, client): + query = dsl_gql(DSLQuery(ds.Query.hero.args(episode=5).select(ds.Character.name))) + result = client.execute(query) expected = {"hero": {"name": "Luke Skywalker"}} assert result == expected -def test_create_review_mutation_result(ds): - result = ds.mutate( - ds.Mutation.createReview.args( - episode=6, review={"stars": 5, "commentary": "This is a great movie!"} - ).select(ds.Review.stars, ds.Review.commentary) +def test_create_review_mutation_result(ds, client): + + query = dsl_gql( + DSLMutation( + ds.Mutation.createReview.args( + episode=6, review={"stars": 5, "commentary": "This is a great movie!"} + ).select(ds.Review.stars, ds.Review.commentary) + ) ) + result = client.execute(query) expected = {"createReview": {"stars": 5, "commentary": "This is a great movie!"}} assert result == expected +def test_subscription(ds): + + query = dsl_gql( + DSLSubscription( + ds.Subscription.reviewAdded(episode=6).select( + ds.Review.stars, ds.Review.commentary + ) + ) + ) + assert ( + print_ast(query) + == """subscription { + reviewAdded(episode: JEDI) { + stars + commentary + } +} +""" + ) + + def test_invalid_arg(ds): with pytest.raises( KeyError, match="Argument invalid_arg does not exist in Field: Character." ): - ds.query(ds.Query.hero.args(invalid_arg=5).select(ds.Character.name)) + ds.Query.hero.args(invalid_arg=5).select(ds.Character.name) + + +def test_multiple_root_fields(ds, client): + query = dsl_gql( + DSLQuery( + ds.Query.hero.select(ds.Character.name), + ds.Query.hero(episode=5) + .alias("hero_of_episode_5") + .select(ds.Character.name), + ) + ) + result = client.execute(query) + expected = { + "hero": {"name": "R2-D2"}, + "hero_of_episode_5": {"name": "Luke Skywalker"}, + } + assert result == expected + + +def test_root_fields_aliased(ds, client): + query = dsl_gql( + DSLQuery( + ds.Query.hero.select(ds.Character.name), + hero_of_episode_5=ds.Query.hero(episode=5).select(ds.Character.name), + ) + ) + result = client.execute(query) + expected = { + "hero": {"name": "R2-D2"}, + "hero_of_episode_5": {"name": "Luke Skywalker"}, + } + assert result == expected + + +def test_operation_name(ds): + query = dsl_gql(GetHeroName=DSLQuery(ds.Query.hero.select(ds.Character.name),)) + + assert ( + print_ast(query) + == """query GetHeroName { + hero { + name + } +} +""" + ) + + +def test_multiple_operations(ds): + query = dsl_gql( + GetHeroName=DSLQuery(ds.Query.hero.select(ds.Character.name)), + CreateReviewMutation=DSLMutation( + ds.Mutation.createReview.args( + episode=6, review={"stars": 5, "commentary": "This is a great movie!"} + ).select(ds.Review.stars, ds.Review.commentary) + ), + ) + + assert ( + print_ast(query) + == """query GetHeroName { + hero { + name + } +} + +mutation CreateReviewMutation { + createReview(episode: JEDI, review: {stars: 5, \ +commentary: "This is a great movie!"}) { + stars + commentary + } +} +""" + ) + + +def test_dsl_query_all_fields_should_be_instances_of_DSLField(): + with pytest.raises( + TypeError, match="fields must be instances of DSLField. Received type:" + ): + DSLQuery("I am a string") + + +def test_dsl_query_all_fields_should_correspond_to_the_root_type(ds): + with pytest.raises(AssertionError) as excinfo: + DSLQuery(ds.Character.name) + + assert ("Invalid root field for operation QUERY.\n" "Received: Character") in str( + excinfo.value + ) + + +def test_dsl_gql_all_arguments_should_be_operations(): + with pytest.raises( + TypeError, match="Operations should be instances of DSLOperation " + ): + dsl_gql("I am a string") + + +def test_DSLSchema_requires_a_schema(client): + with pytest.raises(TypeError, match="DSLSchema needs a schema as parameter"): + DSLSchema(client) + + +def test_invalid_type(ds): + with pytest.raises( + AttributeError, match="Type 'invalid_type' not found in the schema!" + ): + ds.invalid_type