Skip to content

Make DjangoConnectionField compatible with Promise-based iterables. #180

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 24, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 34 additions & 23 deletions graphene_django/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from django.db.models.query import QuerySet

from promise import Promise

from graphene.types import Field, List
from graphene.relay import ConnectionField, PageInfo
from graphql_relay.connection.arrayconnection import connection_from_list_slice
Expand Down Expand Up @@ -59,6 +61,32 @@ def get_manager(self):
def merge_querysets(cls, default_queryset, queryset):
return default_queryset & queryset

@classmethod
def resolve_connection(cls, connection, default_manager, args, iterable):
if iterable is None:
iterable = default_manager
iterable = maybe_queryset(iterable)
if isinstance(iterable, QuerySet):
if iterable is not default_manager:
default_queryset = maybe_queryset(default_manager)
iterable = cls.merge_querysets(default_queryset, iterable)
_len = iterable.count()
else:
_len = len(iterable)
connection = connection_from_list_slice(
iterable,
args,
slice_start=0,
list_length=_len,
list_slice_length=_len,
connection_type=connection,
edge_type=connection.Edge,
pageinfo_type=PageInfo,
)
connection.iterable = iterable
connection.length = _len
return connection

@classmethod
def connection_resolver(cls, resolver, connection, default_manager, max_limit,
enforce_first_or_last, root, args, context, info):
Expand All @@ -84,29 +112,12 @@ def connection_resolver(cls, resolver, connection, default_manager, max_limit,
args['last'] = min(last, max_limit)

iterable = resolver(root, args, context, info)
if iterable is None:
iterable = default_manager
iterable = maybe_queryset(iterable)
if isinstance(iterable, QuerySet):
if iterable is not default_manager:
default_queryset = maybe_queryset(default_manager)
iterable = cls.merge_querysets(default_queryset, iterable)
_len = iterable.count()
else:
_len = len(iterable)
connection = connection_from_list_slice(
iterable,
args,
slice_start=0,
list_length=_len,
list_slice_length=_len,
connection_type=connection,
edge_type=connection.Edge,
pageinfo_type=PageInfo,
)
connection.iterable = iterable
connection.length = _len
return connection
on_resolve = partial(cls.resolve_connection, connection, default_manager, args)

if Promise.is_thenable(iterable):
return Promise.resolve(iterable).then(on_resolve)

return on_resolve(iterable)

def get_resolver(self, parent_resolver):
return partial(
Expand Down
143 changes: 143 additions & 0 deletions graphene_django/tests/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,3 +545,146 @@ class Query(graphene.ObjectType):
assert result.data == expected

graphene_settings.RELAY_CONNECTION_ENFORCE_FIRST_OR_LAST = False


def test_should_query_promise_connectionfields():
from promise import Promise

class ReporterType(DjangoObjectType):

class Meta:
model = Reporter
interfaces = (Node, )

class Query(graphene.ObjectType):
all_reporters = DjangoConnectionField(ReporterType)

def resolve_all_reporters(self, *args, **kwargs):
return Promise.resolve([Reporter(id=1)])

schema = graphene.Schema(query=Query)
query = '''
query ReporterPromiseConnectionQuery {
allReporters(first: 1) {
edges {
node {
id
}
}
}
}
'''

expected = {
'allReporters': {
'edges': [{
'node': {
'id': 'UmVwb3J0ZXJUeXBlOjE='
}
}]
}
}

result = schema.execute(query)
assert not result.errors
assert result.data == expected


def test_should_query_dataloader_fields():
from promise import Promise
from promise.dataloader import DataLoader

def article_batch_load_fn(keys):
queryset = Article.objects.filter(reporter_id__in=keys)
return Promise.resolve([
[article for article in queryset if article.reporter_id == id]
for id in keys
])

article_loader = DataLoader(article_batch_load_fn)

class ArticleType(DjangoObjectType):

class Meta:
model = Article
interfaces = (Node, )

class ReporterType(DjangoObjectType):

class Meta:
model = Reporter
interfaces = (Node, )

articles = DjangoConnectionField(ArticleType)

def resolve_articles(self, *args, **kwargs):
return article_loader.load(self.id)

class Query(graphene.ObjectType):
all_reporters = DjangoConnectionField(ReporterType)

r = Reporter.objects.create(
first_name='John',
last_name='Doe',
email='johndoe@example.com',
a_choice=1
)
Article.objects.create(
headline='Article Node 1',
pub_date=datetime.date.today(),
reporter=r,
editor=r,
lang='es'
)
Article.objects.create(
headline='Article Node 2',
pub_date=datetime.date.today(),
reporter=r,
editor=r,
lang='en'
)

schema = graphene.Schema(query=Query)
query = '''
query ReporterPromiseConnectionQuery {
allReporters(first: 1) {
edges {
node {
id
articles(first: 2) {
edges {
node {
headline
}
}
}
}
}
}
}
'''

expected = {
'allReporters': {
'edges': [{
'node': {
'id': 'UmVwb3J0ZXJUeXBlOjE=',
'articles': {
'edges': [{
'node': {
'headline': 'Article Node 1',
}
}, {
'node': {
'headline': 'Article Node 2'
}
}]
}
}
}]
}
}

result = schema.execute(query)
assert not result.errors
assert result.data == expected
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
'Django>=1.6.0',
'iso8601',
'singledispatch>=3.4.0.3',
'promise>=2.0',
],
setup_requires=[
'pytest-runner',
Expand Down