diff --git a/graphene_django/fields.py b/graphene_django/fields.py index c2a2a8fd4..f82e4b258 100644 --- a/graphene_django/fields.py +++ b/graphene_django/fields.py @@ -2,6 +2,8 @@ from django.db.models.query import QuerySet +from promise import Promise + from graphene.types import Field, List from graphene.relay import ConnectionField, PageInfo from graphql_relay.connection.arrayconnection import connection_from_list_slice @@ -59,6 +61,32 @@ def get_manager(self): def merge_querysets(cls, default_queryset, queryset): return default_queryset & queryset + @classmethod + def resolve_connection(cls, connection, default_manager, args, iterable): + if iterable is None: + iterable = default_manager + iterable = maybe_queryset(iterable) + if isinstance(iterable, QuerySet): + if iterable is not default_manager: + default_queryset = maybe_queryset(default_manager) + iterable = cls.merge_querysets(default_queryset, iterable) + _len = iterable.count() + else: + _len = len(iterable) + connection = connection_from_list_slice( + iterable, + args, + slice_start=0, + list_length=_len, + list_slice_length=_len, + connection_type=connection, + edge_type=connection.Edge, + pageinfo_type=PageInfo, + ) + connection.iterable = iterable + connection.length = _len + return connection + @classmethod def connection_resolver(cls, resolver, connection, default_manager, max_limit, enforce_first_or_last, root, args, context, info): @@ -84,29 +112,12 @@ def connection_resolver(cls, resolver, connection, default_manager, max_limit, args['last'] = min(last, max_limit) iterable = resolver(root, args, context, info) - if iterable is None: - iterable = default_manager - iterable = maybe_queryset(iterable) - if isinstance(iterable, QuerySet): - if iterable is not default_manager: - default_queryset = maybe_queryset(default_manager) - iterable = cls.merge_querysets(default_queryset, iterable) - _len = iterable.count() - else: - _len = len(iterable) - connection = connection_from_list_slice( - iterable, - args, - slice_start=0, - list_length=_len, - list_slice_length=_len, - connection_type=connection, - edge_type=connection.Edge, - pageinfo_type=PageInfo, - ) - connection.iterable = iterable - connection.length = _len - return connection + on_resolve = partial(cls.resolve_connection, connection, default_manager, args) + + if Promise.is_thenable(iterable): + return Promise.resolve(iterable).then(on_resolve) + + return on_resolve(iterable) def get_resolver(self, parent_resolver): return partial( diff --git a/graphene_django/tests/test_query.py b/graphene_django/tests/test_query.py index c1deebbcd..1041f7e50 100644 --- a/graphene_django/tests/test_query.py +++ b/graphene_django/tests/test_query.py @@ -545,3 +545,146 @@ class Query(graphene.ObjectType): assert result.data == expected graphene_settings.RELAY_CONNECTION_ENFORCE_FIRST_OR_LAST = False + + +def test_should_query_promise_connectionfields(): + from promise import Promise + + class ReporterType(DjangoObjectType): + + class Meta: + model = Reporter + interfaces = (Node, ) + + class Query(graphene.ObjectType): + all_reporters = DjangoConnectionField(ReporterType) + + def resolve_all_reporters(self, *args, **kwargs): + return Promise.resolve([Reporter(id=1)]) + + schema = graphene.Schema(query=Query) + query = ''' + query ReporterPromiseConnectionQuery { + allReporters(first: 1) { + edges { + node { + id + } + } + } + } + ''' + + expected = { + 'allReporters': { + 'edges': [{ + 'node': { + 'id': 'UmVwb3J0ZXJUeXBlOjE=' + } + }] + } + } + + result = schema.execute(query) + assert not result.errors + assert result.data == expected + + +def test_should_query_dataloader_fields(): + from promise import Promise + from promise.dataloader import DataLoader + + def article_batch_load_fn(keys): + queryset = Article.objects.filter(reporter_id__in=keys) + return Promise.resolve([ + [article for article in queryset if article.reporter_id == id] + for id in keys + ]) + + article_loader = DataLoader(article_batch_load_fn) + + class ArticleType(DjangoObjectType): + + class Meta: + model = Article + interfaces = (Node, ) + + class ReporterType(DjangoObjectType): + + class Meta: + model = Reporter + interfaces = (Node, ) + + articles = DjangoConnectionField(ArticleType) + + def resolve_articles(self, *args, **kwargs): + return article_loader.load(self.id) + + class Query(graphene.ObjectType): + all_reporters = DjangoConnectionField(ReporterType) + + r = Reporter.objects.create( + first_name='John', + last_name='Doe', + email='johndoe@example.com', + a_choice=1 + ) + Article.objects.create( + headline='Article Node 1', + pub_date=datetime.date.today(), + reporter=r, + editor=r, + lang='es' + ) + Article.objects.create( + headline='Article Node 2', + pub_date=datetime.date.today(), + reporter=r, + editor=r, + lang='en' + ) + + schema = graphene.Schema(query=Query) + query = ''' + query ReporterPromiseConnectionQuery { + allReporters(first: 1) { + edges { + node { + id + articles(first: 2) { + edges { + node { + headline + } + } + } + } + } + } + } + ''' + + expected = { + 'allReporters': { + 'edges': [{ + 'node': { + 'id': 'UmVwb3J0ZXJUeXBlOjE=', + 'articles': { + 'edges': [{ + 'node': { + 'headline': 'Article Node 1', + } + }, { + 'node': { + 'headline': 'Article Node 2' + } + }] + } + } + }] + } + } + + result = schema.execute(query) + assert not result.errors + assert result.data == expected diff --git a/setup.py b/setup.py index 2d2e57880..0be5297e9 100644 --- a/setup.py +++ b/setup.py @@ -47,6 +47,7 @@ 'Django>=1.6.0', 'iso8601', 'singledispatch>=3.4.0.3', + 'promise>=2.0', ], setup_requires=[ 'pytest-runner',