Skip to content

Commit 269d2e9

Browse files
committed
Improved GraphQL DSL 😊. Fixed #12
1 parent 8257777 commit 269d2e9

File tree

3 files changed

+111
-52
lines changed

3 files changed

+111
-52
lines changed

‎gql/dsl.py

Lines changed: 43 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,44 @@
55
import six
66
from graphql.language import ast
77
from graphql.language.printer import print_ast
8-
from graphql.type import (GraphQLField, GraphQLFieldDefinition, GraphQLList,
8+
from graphql.type import (GraphQLField, GraphQLList,
99
GraphQLNonNull, GraphQLEnumType)
1010

11+
from .utils import to_camel_case
12+
13+
14+
class DSLSchema(object):
15+
def __init__(self, client):
16+
self.client = client
17+
18+
@property
19+
def schema(self):
20+
return self.client.schema
21+
22+
def __getattr__(self, name):
23+
type_def = self.schema.get_type(name)
24+
return DSLType(type_def)
25+
26+
27+
class DSLType(object):
28+
def __init__(self, type):
29+
self.type = type
30+
31+
def __getattr__(self, name):
32+
formatted_name, field_def = self.get_field(name)
33+
return DSLField(formatted_name, field_def)
34+
35+
def get_field(self, name):
36+
camel_cased_name = to_camel_case(name)
37+
38+
if name in self.type.fields:
39+
return name, self.type.fields[name]
40+
41+
if camel_cased_name in self.type.fields:
42+
return camel_cased_name, self.type.fields[camel_cased_name]
43+
44+
raise KeyError('Field {} doesnt exist in type {}.'.format(name, self.type.name))
45+
1146

1247
def selections(*fields):
1348
for _field in fields:
@@ -30,9 +65,9 @@ def get_ast_value(value):
3065

3166
class DSLField(object):
3267

33-
def __init__(self, field):
68+
def __init__(self, name, field):
3469
self.field = field
35-
self.ast_field = ast.Field(name=ast.Name(value=field.name), arguments=[])
70+
self.ast_field = ast.Field(name=ast.Name(value=name), arguments=[])
3671
self.selection_set = None
3772

3873
def get(self, *fields):
@@ -41,21 +76,16 @@ def get(self, *fields):
4176
self.ast_field.selection_set.selections.extend(selections(*fields))
4277
return self
4378

79+
def __call__(self, *args, **kwargs):
80+
return self.get(*args, **kwargs)
81+
4482
def alias(self, alias):
4583
self.ast_field.alias = ast.Name(value=alias)
4684
return self
4785

48-
def get_field_args(self):
49-
if isinstance(self.field, GraphQLFieldDefinition):
50-
# The args will be an array
51-
return {
52-
arg.name: arg for arg in self.field.args
53-
}
54-
return self.field.args
55-
5686
def args(self, **args):
5787
for name, value in args.items():
58-
arg = self.get_field_args().get(name)
88+
arg = self.field.args.get(name)
5989
arg_type_serializer = get_arg_serializer(arg.type)
6090
value = arg_type_serializer(value)
6191
self.ast_field.arguments.append(
@@ -75,7 +105,7 @@ def __str__(self):
75105

76106

77107
def field(field, **args):
78-
if isinstance(field, (GraphQLField, GraphQLFieldDefinition)):
108+
if isinstance(field, GraphQLField):
79109
return DSLField(field).args(**args)
80110
elif isinstance(field, DSLField):
81111
return field

‎gql/utils.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import re
2+
3+
4+
# From this response in Stackoverflow
5+
# http://stackoverflow.com/a/19053800/1072990
6+
def to_camel_case(snake_str):
7+
components = snake_str.split('_')
8+
# We capitalize the first letter of each component except the first one
9+
# with the 'title' method and join them together.
10+
return components[0] + "".join(x.title() if x else '_' for x in components[1:])
11+
12+
13+
# From this response in Stackoverflow
14+
# http://stackoverflow.com/a/1176023/1072990
15+
def to_snake_case(name):
16+
s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name)
17+
return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower()
18+
19+
20+
def to_const(string):
21+
return re.sub('[\W|^]+', '_', string).upper()

‎tests/starwars/test_dsl.py

Lines changed: 47 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,52 @@
1-
from gql import dsl
1+
import pytest
2+
3+
from gql import Client
4+
from gql.dsl import DSLSchema
25

36
from .schema import characterInterface, humanType, queryType
47

58

69
# We construct a Simple DSL objects for easy field referencing
710

8-
class Query(object):
9-
hero = queryType.get_fields()['hero']
10-
human = queryType.get_fields()['human']
11+
# class Query(object):
12+
# hero = queryType.fields['hero']
13+
# human = queryType.fields['human']
14+
15+
16+
# class Character(object):
17+
# id = characterInterface.fields['id']
18+
# name = characterInterface.fields['name']
19+
# friends = characterInterface.fields['friends']
20+
# appears_in = characterInterface.fields['appearsIn']
1121

1222

13-
class Character(object):
14-
id = characterInterface.get_fields()['id']
15-
name = characterInterface.get_fields()['name']
16-
friends = characterInterface.get_fields()['friends']
17-
appears_in = characterInterface.get_fields()['appearsIn']
23+
# class Human(object):
24+
# name = humanType.fields['name']
1825

1926

20-
class Human(object):
21-
name = humanType.get_fields()['name']
27+
from .schema import StarWarsSchema
2228

2329

24-
def test_hero_name_query():
30+
@pytest.fixture
31+
def ds():
32+
client = Client(schema=StarWarsSchema)
33+
ds = DSLSchema(client)
34+
return ds
35+
36+
37+
def test_hero_name_query(ds):
2538
query = '''
2639
hero {
2740
name
2841
}
2942
'''.strip()
30-
query_dsl = dsl.field(Query.hero).get(
31-
Character.name
43+
query_dsl = ds.Query.hero(
44+
ds.Character.name
3245
)
3346
assert query == str(query_dsl)
3447

3548

36-
def test_hero_name_and_friends_query():
49+
def test_hero_name_and_friends_query(ds):
3750
query = '''
3851
hero {
3952
id
@@ -43,17 +56,17 @@ def test_hero_name_and_friends_query():
4356
}
4457
}
4558
'''.strip()
46-
query_dsl = dsl.field(Query.hero).get(
47-
Character.id,
48-
Character.name,
49-
dsl.field(Character.friends).get(
50-
Character.name,
59+
query_dsl = ds.Query.hero(
60+
ds.Character.id,
61+
ds.Character.name,
62+
ds.Character.friends(
63+
ds.Character.name,
5164
)
5265
)
5366
assert query == str(query_dsl)
5467

5568

56-
def test_nested_query():
69+
def test_nested_query(ds):
5770
query = '''
5871
hero {
5972
name
@@ -66,27 +79,27 @@ def test_nested_query():
6679
}
6780
}
6881
'''.strip()
69-
query_dsl = dsl.field(Query.hero).get(
70-
Character.name,
71-
dsl.field(Character.friends).get(
72-
Character.name,
73-
Character.appears_in,
74-
dsl.field(Character.friends).get(
75-
Character.name
82+
query_dsl = ds.Query.hero(
83+
ds.Character.name,
84+
ds.Character.friends(
85+
ds.Character.name,
86+
ds.Character.appears_in,
87+
ds.Character.friends(
88+
ds.Character.name
7689
)
7790
)
7891
)
7992
assert query == str(query_dsl)
8093

8194

82-
def test_fetch_luke_query():
95+
def test_fetch_luke_query(ds):
8396
query = '''
8497
human(id: "1000") {
8598
name
8699
}
87100
'''.strip()
88-
query_dsl = dsl.field(Query.human, id="1000").get(
89-
Human.name,
101+
query_dsl = ds.Query.human.args(id="1000").get(
102+
ds.Human.name,
90103
)
91104

92105
assert query == str(query_dsl)
@@ -153,19 +166,14 @@ def test_fetch_luke_query():
153166
# assert result.data == expected
154167

155168

156-
def test_fetch_luke_aliased():
169+
def test_fetch_luke_aliased(ds):
157170
query = '''
158171
luke: human(id: "1000") {
159172
name
160173
}
161174
'''.strip()
162-
expected = {
163-
'luke': {
164-
'name': 'Luke Skywalker',
165-
}
166-
}
167-
query_dsl = dsl.field(Query.human, id=1000).alias('luke').get(
168-
Character.name,
175+
query_dsl = ds.Query.human.args(id=1000).alias('luke').get(
176+
ds.Character.name,
169177
)
170178
assert query == str(query_dsl)
171179

0 commit comments

Comments
 (0)