Skip to content

Commit 7c52aa3

Browse files
authored
Merge pull request #180 from arianon/master
Make DjangoConnectionField compatible with Promise-based iterables.
2 parents 3667157 + bfcac1d commit 7c52aa3

File tree

3 files changed

+178
-23
lines changed

3 files changed

+178
-23
lines changed

graphene_django/fields.py

Lines changed: 34 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
from django.db.models.query import QuerySet
44

5+
from promise import Promise
6+
57
from graphene.types import Field, List
68
from graphene.relay import ConnectionField, PageInfo
79
from graphql_relay.connection.arrayconnection import connection_from_list_slice
@@ -59,6 +61,32 @@ def get_manager(self):
5961
def merge_querysets(cls, default_queryset, queryset):
6062
return default_queryset & queryset
6163

64+
@classmethod
65+
def resolve_connection(cls, connection, default_manager, args, iterable):
66+
if iterable is None:
67+
iterable = default_manager
68+
iterable = maybe_queryset(iterable)
69+
if isinstance(iterable, QuerySet):
70+
if iterable is not default_manager:
71+
default_queryset = maybe_queryset(default_manager)
72+
iterable = cls.merge_querysets(default_queryset, iterable)
73+
_len = iterable.count()
74+
else:
75+
_len = len(iterable)
76+
connection = connection_from_list_slice(
77+
iterable,
78+
args,
79+
slice_start=0,
80+
list_length=_len,
81+
list_slice_length=_len,
82+
connection_type=connection,
83+
edge_type=connection.Edge,
84+
pageinfo_type=PageInfo,
85+
)
86+
connection.iterable = iterable
87+
connection.length = _len
88+
return connection
89+
6290
@classmethod
6391
def connection_resolver(cls, resolver, connection, default_manager, max_limit,
6492
enforce_first_or_last, root, args, context, info):
@@ -84,29 +112,12 @@ def connection_resolver(cls, resolver, connection, default_manager, max_limit,
84112
args['last'] = min(last, max_limit)
85113

86114
iterable = resolver(root, args, context, info)
87-
if iterable is None:
88-
iterable = default_manager
89-
iterable = maybe_queryset(iterable)
90-
if isinstance(iterable, QuerySet):
91-
if iterable is not default_manager:
92-
default_queryset = maybe_queryset(default_manager)
93-
iterable = cls.merge_querysets(default_queryset, iterable)
94-
_len = iterable.count()
95-
else:
96-
_len = len(iterable)
97-
connection = connection_from_list_slice(
98-
iterable,
99-
args,
100-
slice_start=0,
101-
list_length=_len,
102-
list_slice_length=_len,
103-
connection_type=connection,
104-
edge_type=connection.Edge,
105-
pageinfo_type=PageInfo,
106-
)
107-
connection.iterable = iterable
108-
connection.length = _len
109-
return connection
115+
on_resolve = partial(cls.resolve_connection, connection, default_manager, args)
116+
117+
if Promise.is_thenable(iterable):
118+
return Promise.resolve(iterable).then(on_resolve)
119+
120+
return on_resolve(iterable)
110121

111122
def get_resolver(self, parent_resolver):
112123
return partial(

graphene_django/tests/test_query.py

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -545,3 +545,146 @@ class Query(graphene.ObjectType):
545545
assert result.data == expected
546546

547547
graphene_settings.RELAY_CONNECTION_ENFORCE_FIRST_OR_LAST = False
548+
549+
550+
def test_should_query_promise_connectionfields():
551+
from promise import Promise
552+
553+
class ReporterType(DjangoObjectType):
554+
555+
class Meta:
556+
model = Reporter
557+
interfaces = (Node, )
558+
559+
class Query(graphene.ObjectType):
560+
all_reporters = DjangoConnectionField(ReporterType)
561+
562+
def resolve_all_reporters(self, *args, **kwargs):
563+
return Promise.resolve([Reporter(id=1)])
564+
565+
schema = graphene.Schema(query=Query)
566+
query = '''
567+
query ReporterPromiseConnectionQuery {
568+
allReporters(first: 1) {
569+
edges {
570+
node {
571+
id
572+
}
573+
}
574+
}
575+
}
576+
'''
577+
578+
expected = {
579+
'allReporters': {
580+
'edges': [{
581+
'node': {
582+
'id': 'UmVwb3J0ZXJUeXBlOjE='
583+
}
584+
}]
585+
}
586+
}
587+
588+
result = schema.execute(query)
589+
assert not result.errors
590+
assert result.data == expected
591+
592+
593+
def test_should_query_dataloader_fields():
594+
from promise import Promise
595+
from promise.dataloader import DataLoader
596+
597+
def article_batch_load_fn(keys):
598+
queryset = Article.objects.filter(reporter_id__in=keys)
599+
return Promise.resolve([
600+
[article for article in queryset if article.reporter_id == id]
601+
for id in keys
602+
])
603+
604+
article_loader = DataLoader(article_batch_load_fn)
605+
606+
class ArticleType(DjangoObjectType):
607+
608+
class Meta:
609+
model = Article
610+
interfaces = (Node, )
611+
612+
class ReporterType(DjangoObjectType):
613+
614+
class Meta:
615+
model = Reporter
616+
interfaces = (Node, )
617+
618+
articles = DjangoConnectionField(ArticleType)
619+
620+
def resolve_articles(self, *args, **kwargs):
621+
return article_loader.load(self.id)
622+
623+
class Query(graphene.ObjectType):
624+
all_reporters = DjangoConnectionField(ReporterType)
625+
626+
r = Reporter.objects.create(
627+
first_name='John',
628+
last_name='Doe',
629+
email='johndoe@example.com',
630+
a_choice=1
631+
)
632+
Article.objects.create(
633+
headline='Article Node 1',
634+
pub_date=datetime.date.today(),
635+
reporter=r,
636+
editor=r,
637+
lang='es'
638+
)
639+
Article.objects.create(
640+
headline='Article Node 2',
641+
pub_date=datetime.date.today(),
642+
reporter=r,
643+
editor=r,
644+
lang='en'
645+
)
646+
647+
schema = graphene.Schema(query=Query)
648+
query = '''
649+
query ReporterPromiseConnectionQuery {
650+
allReporters(first: 1) {
651+
edges {
652+
node {
653+
id
654+
articles(first: 2) {
655+
edges {
656+
node {
657+
headline
658+
}
659+
}
660+
}
661+
}
662+
}
663+
}
664+
}
665+
'''
666+
667+
expected = {
668+
'allReporters': {
669+
'edges': [{
670+
'node': {
671+
'id': 'UmVwb3J0ZXJUeXBlOjE=',
672+
'articles': {
673+
'edges': [{
674+
'node': {
675+
'headline': 'Article Node 1',
676+
}
677+
}, {
678+
'node': {
679+
'headline': 'Article Node 2'
680+
}
681+
}]
682+
}
683+
}
684+
}]
685+
}
686+
}
687+
688+
result = schema.execute(query)
689+
assert not result.errors
690+
assert result.data == expected

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
'Django>=1.8.0',
4848
'iso8601',
4949
'singledispatch>=3.4.0.3',
50+
'promise>=2.0',
5051
],
5152
setup_requires=[
5253
'pytest-runner',

0 commit comments

Comments
 (0)