diff --git a/graphene_django/fields.py b/graphene_django/fields.py index eb1215ead..e6daa889e 100644 --- a/graphene_django/fields.py +++ b/graphene_django/fields.py @@ -1,13 +1,12 @@ from functools import partial from django.db.models.query import QuerySet -from graphene import NonNull - +from graphql_relay.connection.arrayconnection import connection_from_list_slice from promise import Promise -from graphene.types import Field, List +from graphene import NonNull from graphene.relay import ConnectionField, PageInfo -from graphql_relay.connection.arrayconnection import connection_from_list_slice +from graphene.types import Field, List from .settings import graphene_settings from .utils import maybe_queryset @@ -15,19 +14,43 @@ class DjangoListField(Field): def __init__(self, _type, *args, **kwargs): + from .types import DjangoObjectType + + if isinstance(_type, NonNull): + _type = _type.of_type + + assert issubclass( + _type, DjangoObjectType + ), "DjangoListField only accepts DjangoObjectType types" + # Django would never return a Set of None vvvvvvv super(DjangoListField, self).__init__(List(NonNull(_type)), *args, **kwargs) @property def model(self): - return self.type.of_type._meta.node._meta.model + _type = self.type.of_type + if isinstance(_type, NonNull): + _type = _type.of_type + return _type._meta.model @staticmethod - def list_resolver(resolver, root, info, **args): - return maybe_queryset(resolver(root, info, **args)) + def list_resolver(django_object_type, resolver, root, info, **args): + queryset = maybe_queryset(resolver(root, info, **args)) + if queryset is None: + # Default to Django Model queryset + # N.B. This happens if DjangoListField is used in the top level Query object + model = django_object_type._meta.model + queryset = maybe_queryset( + django_object_type.get_queryset(model.objects, info) + ) + return queryset def get_resolver(self, parent_resolver): - return partial(self.list_resolver, parent_resolver) + _type = self.type + if isinstance(_type, NonNull): + _type = _type.of_type + django_object_type = _type.of_type.of_type + return partial(self.list_resolver, django_object_type, parent_resolver) class DjangoConnectionField(ConnectionField): diff --git a/graphene_django/tests/test_fields.py b/graphene_django/tests/test_fields.py new file mode 100644 index 000000000..f6abf000c --- /dev/null +++ b/graphene_django/tests/test_fields.py @@ -0,0 +1,199 @@ +import datetime + +import pytest + +from graphene import List, NonNull, ObjectType, Schema, String + +from ..fields import DjangoListField +from ..types import DjangoObjectType +from .models import Article as ArticleModel +from .models import Reporter as ReporterModel + + +@pytest.mark.django_db +class TestDjangoListField: + def test_only_django_object_types(self): + class TestType(ObjectType): + foo = String() + + with pytest.raises(AssertionError): + list_field = DjangoListField(TestType) + + def test_non_null_type(self): + class Reporter(DjangoObjectType): + class Meta: + model = ReporterModel + fields = ("first_name",) + + list_field = DjangoListField(NonNull(Reporter)) + + assert isinstance(list_field.type, List) + assert isinstance(list_field.type.of_type, NonNull) + assert list_field.type.of_type.of_type is Reporter + + def test_get_django_model(self): + class Reporter(DjangoObjectType): + class Meta: + model = ReporterModel + fields = ("first_name",) + + list_field = DjangoListField(Reporter) + assert list_field.model is ReporterModel + + def test_list_field_default_queryset(self): + class Reporter(DjangoObjectType): + class Meta: + model = ReporterModel + fields = ("first_name",) + + class Query(ObjectType): + reporters = DjangoListField(Reporter) + + schema = Schema(query=Query) + + query = """ + query { + reporters { + firstName + } + } + """ + + ReporterModel.objects.create(first_name="Tara", last_name="West") + ReporterModel.objects.create(first_name="Debra", last_name="Payne") + + result = schema.execute(query) + + assert not result.errors + assert result.data == { + "reporters": [{"firstName": "Tara"}, {"firstName": "Debra"}] + } + + def test_override_resolver(self): + class Reporter(DjangoObjectType): + class Meta: + model = ReporterModel + fields = ("first_name",) + + class Query(ObjectType): + reporters = DjangoListField(Reporter) + + def resolve_reporters(_, info): + return ReporterModel.objects.filter(first_name="Tara") + + schema = Schema(query=Query) + + query = """ + query { + reporters { + firstName + } + } + """ + + ReporterModel.objects.create(first_name="Tara", last_name="West") + ReporterModel.objects.create(first_name="Debra", last_name="Payne") + + result = schema.execute(query) + + assert not result.errors + assert result.data == {"reporters": [{"firstName": "Tara"}]} + + def test_nested_list_field(self): + class Article(DjangoObjectType): + class Meta: + model = ArticleModel + fields = ("headline",) + + class Reporter(DjangoObjectType): + class Meta: + model = ReporterModel + fields = ("first_name", "articles") + + class Query(ObjectType): + reporters = DjangoListField(Reporter) + + schema = Schema(query=Query) + + query = """ + query { + reporters { + firstName + articles { + headline + } + } + } + """ + + r1 = ReporterModel.objects.create(first_name="Tara", last_name="West") + ReporterModel.objects.create(first_name="Debra", last_name="Payne") + + ArticleModel.objects.create( + headline="Amazing news", + reporter=r1, + pub_date=datetime.date.today(), + pub_date_time=datetime.datetime.now(), + editor=r1, + ) + + result = schema.execute(query) + + assert not result.errors + assert result.data == { + "reporters": [ + {"firstName": "Tara", "articles": [{"headline": "Amazing news"}]}, + {"firstName": "Debra", "articles": []}, + ] + } + + def test_override_resolver_nested_list_field(self): + class Article(DjangoObjectType): + class Meta: + model = ArticleModel + fields = ("headline",) + + class Reporter(DjangoObjectType): + class Meta: + model = ReporterModel + fields = ("first_name", "articles") + + def resolve_reporters(reporter, info): + return reporter.articles.all() + + class Query(ObjectType): + reporters = DjangoListField(Reporter) + + schema = Schema(query=Query) + + query = """ + query { + reporters { + firstName + articles { + headline + } + } + } + """ + + r1 = ReporterModel.objects.create(first_name="Tara", last_name="West") + ReporterModel.objects.create(first_name="Debra", last_name="Payne") + + ArticleModel.objects.create( + headline="Amazing news", + reporter=r1, + pub_date=datetime.date.today(), + pub_date_time=datetime.datetime.now(), + editor=r1, + ) + + result = schema.execute(query) + + assert not result.errors + assert result.data == { + "reporters": [ + {"firstName": "Tara", "articles": [{"headline": "Amazing news"}]}, + {"firstName": "Debra", "articles": []}, + ] + }