Skip to content

Commit 3201c05

Browse files
committed
Allow aliases as keyword arguments for dsl_gql + new operation_name argument
1 parent dbfc221 commit 3201c05

File tree

2 files changed

+74
-14
lines changed

2 files changed

+74
-14
lines changed

gql/dsl.py

Lines changed: 42 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import logging
2-
from collections.abc import Iterable
3-
from typing import List, Optional, Union
2+
from typing import Dict, Iterable, List, Optional, Tuple, Union
43

54
from graphql import (
65
ArgumentNode,
@@ -26,8 +25,12 @@
2625
log = logging.getLogger(__name__)
2726

2827

29-
def dsl_gql(*fields: "DSLField") -> DocumentNode:
30-
"""Given arguments of type :class:`DSLField` containing GraphQL requests,
28+
def dsl_gql(
29+
*fields: "DSLField",
30+
operation_name: Optional[str] = None,
31+
**fields_with_alias: "DSLField",
32+
) -> DocumentNode:
33+
r"""Given arguments of type :class:`DSLField` containing GraphQL requests,
3134
generate a Document which can be executed later in a
3235
gql client or a gql session.
3336
@@ -43,8 +46,12 @@ def dsl_gql(*fields: "DSLField") -> DocumentNode:
4346
They should all have the same root type
4447
(you can't mix queries with mutations for example).
4548
46-
:param fields: root instances of the dynamically generated requests
47-
:type fields: DSLField
49+
:param \*fields: root instances of the dynamically generated requests
50+
:type \*fields: DSLField
51+
:param \**fields_with_alias: root instances fields with alias as key
52+
:type \**fields_with_alias: DSLField
53+
:param operation_name: optional operation name
54+
:type operation_name: str
4855
:return: a Document which can be later executed or subscribed by a
4956
:class:`Client <gql.client.Client>`, by an
5057
:class:`async session <gql.client.AsyncClientSession>` or by a
@@ -54,9 +61,13 @@ def dsl_gql(*fields: "DSLField") -> DocumentNode:
5461
:raises AssertionError: if an argument is not a field of a root type
5562
"""
5663

64+
all_fields: Tuple["DSLField", ...] = DSLField.get_aliased_fields(
65+
fields, fields_with_alias
66+
)
67+
5768
# Check that we receive only arguments of type DSLField
5869
# And that they are a root type
59-
for field in fields:
70+
for field in all_fields:
6071
if not isinstance(field, DSLField):
6172
raise TypeError(
6273
f"fields must be instances of DSLField. Received type: {type(field)}"
@@ -68,15 +79,16 @@ def dsl_gql(*fields: "DSLField") -> DocumentNode:
6879

6980
# Get the operation from the first field
7081
# All the fields must have the same operation
71-
operation = fields[0].type_name.lower()
82+
operation = all_fields[0].type_name.lower()
7283

7384
return DocumentNode(
7485
definitions=[
7586
OperationDefinitionNode(
7687
operation=OperationType(operation),
7788
selection_set=SelectionSetNode(
78-
selections=FrozenList(DSLField.get_ast_fields(fields))
89+
selections=FrozenList(DSLField.get_ast_fields(all_fields))
7990
),
91+
**({"name": NameNode(value=operation_name)} if operation_name else {}),
8092
)
8193
]
8294
)
@@ -203,7 +215,7 @@ def __init__(
203215
log.debug(f"Creating {self!r}")
204216

205217
@staticmethod
206-
def get_ast_fields(fields: Iterable) -> List[FieldNode]:
218+
def get_ast_fields(fields: Iterable["DSLField"]) -> List[FieldNode]:
207219
"""
208220
:meta private:
209221
@@ -222,6 +234,23 @@ def get_ast_fields(fields: Iterable) -> List[FieldNode]:
222234

223235
return ast_fields
224236

237+
@staticmethod
238+
def get_aliased_fields(
239+
fields: Iterable["DSLField"], fields_with_alias: Dict[str, "DSLField"]
240+
) -> Tuple["DSLField", ...]:
241+
"""
242+
:meta private:
243+
244+
Concatenate all the fields (with or without alias) in a Tuple.
245+
246+
Set the requested alias for the fields with alias.
247+
"""
248+
249+
return (
250+
*fields,
251+
*(field.alias(alias) for alias, field in fields_with_alias.items()),
252+
)
253+
225254
def select(
226255
self, *fields: "DSLField", **fields_with_alias: "DSLField"
227256
) -> "DSLField":
@@ -241,9 +270,9 @@ def select(
241270
of the :class:`DSLField` class.
242271
"""
243272

244-
added_fields: List["DSLField"] = list(fields) + [
245-
field.alias(alias) for alias, field in fields_with_alias.items()
246-
]
273+
added_fields: Tuple["DSLField", ...] = self.get_aliased_fields(
274+
fields, fields_with_alias
275+
)
247276

248277
added_selections: List[FieldNode] = self.get_ast_fields(added_fields)
249278

tests/starwars/test_dsl.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ def test_invalid_arg(ds):
182182
ds.Query.hero.args(invalid_arg=5).select(ds.Character.name)
183183

184184

185-
def test_multiple_queries(ds, client):
185+
def test_multiple_root_fields(ds, client):
186186
query = dsl_gql(
187187
ds.Query.hero.select(ds.Character.name),
188188
ds.Query.hero(episode=5).alias("hero_of_episode_5").select(ds.Character.name),
@@ -195,6 +195,37 @@ def test_multiple_queries(ds, client):
195195
assert result == expected
196196

197197

198+
def test_root_fields_aliased(ds, client):
199+
query = dsl_gql(
200+
ds.Query.hero.select(ds.Character.name),
201+
hero_of_episode_5=ds.Query.hero(episode=5).select(ds.Character.name),
202+
)
203+
result = client.execute(query)
204+
expected = {
205+
"hero": {"name": "R2-D2"},
206+
"hero_of_episode_5": {"name": "Luke Skywalker"},
207+
}
208+
assert result == expected
209+
210+
211+
def test_operation_name(ds):
212+
query = dsl_gql(
213+
ds.Query.hero.select(ds.Character.name), operation_name="GetHeroName",
214+
)
215+
216+
from graphql import print_ast
217+
218+
assert (
219+
print_ast(query)
220+
== """query GetHeroName {
221+
hero {
222+
name
223+
}
224+
}
225+
"""
226+
)
227+
228+
198229
def test_dsl_gql_all_fields_should_be_instances_of_DSLField(ds, client):
199230
with pytest.raises(
200231
TypeError, match="fields must be instances of DSLField. Received type:"

0 commit comments

Comments
 (0)