1
1
import logging
2
- from collections .abc import Iterable
3
- from typing import List , Optional , Union
2
+ from typing import Dict , Iterable , List , Optional , Tuple , Union
4
3
5
4
from graphql import (
6
5
ArgumentNode ,
26
25
log = logging .getLogger (__name__ )
27
26
28
27
29
- def dsl_gql (* fields : "DSLField" ) -> DocumentNode :
30
- """Given arguments of type :class:`DSLField` containing GraphQL requests,
28
+ def dsl_gql (
29
+ * fields : "DSLField" ,
30
+ operation_name : Optional [str ] = None ,
31
+ ** fields_with_alias : "DSLField" ,
32
+ ) -> DocumentNode :
33
+ r"""Given arguments of type :class:`DSLField` containing GraphQL requests,
31
34
generate a Document which can be executed later in a
32
35
gql client or a gql session.
33
36
@@ -43,8 +46,12 @@ def dsl_gql(*fields: "DSLField") -> DocumentNode:
43
46
They should all have the same root type
44
47
(you can't mix queries with mutations for example).
45
48
46
- :param fields: root instances of the dynamically generated requests
47
- :type fields: DSLField
49
+ :param \*fields: root instances of the dynamically generated requests
50
+ :type \*fields: DSLField
51
+ :param \**fields_with_alias: root instances fields with alias as key
52
+ :type \**fields_with_alias: DSLField
53
+ :param operation_name: optional operation name
54
+ :type operation_name: str
48
55
:return: a Document which can be later executed or subscribed by a
49
56
:class:`Client <gql.client.Client>`, by an
50
57
:class:`async session <gql.client.AsyncClientSession>` or by a
@@ -54,9 +61,13 @@ def dsl_gql(*fields: "DSLField") -> DocumentNode:
54
61
:raises AssertionError: if an argument is not a field of a root type
55
62
"""
56
63
64
+ all_fields : Tuple ["DSLField" , ...] = DSLField .get_aliased_fields (
65
+ fields , fields_with_alias
66
+ )
67
+
57
68
# Check that we receive only arguments of type DSLField
58
69
# And that they are a root type
59
- for field in fields :
70
+ for field in all_fields :
60
71
if not isinstance (field , DSLField ):
61
72
raise TypeError (
62
73
f"fields must be instances of DSLField. Received type: { type (field )} "
@@ -68,15 +79,16 @@ def dsl_gql(*fields: "DSLField") -> DocumentNode:
68
79
69
80
# Get the operation from the first field
70
81
# All the fields must have the same operation
71
- operation = fields [0 ].type_name .lower ()
82
+ operation = all_fields [0 ].type_name .lower ()
72
83
73
84
return DocumentNode (
74
85
definitions = [
75
86
OperationDefinitionNode (
76
87
operation = OperationType (operation ),
77
88
selection_set = SelectionSetNode (
78
- selections = FrozenList (DSLField .get_ast_fields (fields ))
89
+ selections = FrozenList (DSLField .get_ast_fields (all_fields ))
79
90
),
91
+ ** ({"name" : NameNode (value = operation_name )} if operation_name else {}),
80
92
)
81
93
]
82
94
)
@@ -203,7 +215,7 @@ def __init__(
203
215
log .debug (f"Creating { self !r} " )
204
216
205
217
@staticmethod
206
- def get_ast_fields (fields : Iterable ) -> List [FieldNode ]:
218
+ def get_ast_fields (fields : Iterable [ "DSLField" ] ) -> List [FieldNode ]:
207
219
"""
208
220
:meta private:
209
221
@@ -222,6 +234,23 @@ def get_ast_fields(fields: Iterable) -> List[FieldNode]:
222
234
223
235
return ast_fields
224
236
237
+ @staticmethod
238
+ def get_aliased_fields (
239
+ fields : Iterable ["DSLField" ], fields_with_alias : Dict [str , "DSLField" ]
240
+ ) -> Tuple ["DSLField" , ...]:
241
+ """
242
+ :meta private:
243
+
244
+ Concatenate all the fields (with or without alias) in a Tuple.
245
+
246
+ Set the requested alias for the fields with alias.
247
+ """
248
+
249
+ return (
250
+ * fields ,
251
+ * (field .alias (alias ) for alias , field in fields_with_alias .items ()),
252
+ )
253
+
225
254
def select (
226
255
self , * fields : "DSLField" , ** fields_with_alias : "DSLField"
227
256
) -> "DSLField" :
@@ -241,9 +270,9 @@ def select(
241
270
of the :class:`DSLField` class.
242
271
"""
243
272
244
- added_fields : List ["DSLField" ] = list ( fields ) + [
245
- field . alias ( alias ) for alias , field in fields_with_alias . items ()
246
- ]
273
+ added_fields : Tuple ["DSLField" , ... ] = self . get_aliased_fields (
274
+ fields , fields_with_alias
275
+ )
247
276
248
277
added_selections : List [FieldNode ] = self .get_ast_fields (added_fields )
249
278
0 commit comments