diff --git a/graphene_django/filter/fields.py b/graphene_django/filter/fields.py index 62f4b1a13..338becb72 100644 --- a/graphene_django/filter/fields.py +++ b/graphene_django/filter/fields.py @@ -111,7 +111,7 @@ def get_resolver(self, parent_resolver): return partial( self.connection_resolver, parent_resolver, - self.type, + self.connection_type, self.get_manager(), self.max_limit, self.enforce_first_or_last, diff --git a/graphene_django/filter/tests/test_fields.py b/graphene_django/filter/tests/test_fields.py index 99876b6ab..5b8471518 100644 --- a/graphene_django/filter/tests/test_fields.py +++ b/graphene_django/filter/tests/test_fields.py @@ -818,3 +818,38 @@ class Query(ObjectType): } """ ) + + +def test_required_filter_field(): + class ReporterType(DjangoObjectType): + class Meta: + model = Reporter + interfaces = (Node,) + filter_fields = () + + class Query(ObjectType): + all_reporters = DjangoFilterConnectionField(ReporterType, required=True) + + def resolve_all_reporters(self, info, **args): + return Reporter.objects.all() + + Reporter.objects.create(first_name="John", last_name="Doe") + schema = Schema(query=Query) + + query = """ + query NodeFilteringQuery { + allReporters(first: 1) { + edges { + node { + firstName + } + } + } + } + """ + expected = {"allReporters": {"edges": [{"node": {"firstName": "John"}}]}} + + result = schema.execute(query) + + assert not result.errors + assert result.data == expected diff --git a/graphene_django/tests/test_fields.py b/graphene_django/tests/test_fields.py new file mode 100644 index 000000000..9a362ce2c --- /dev/null +++ b/graphene_django/tests/test_fields.py @@ -0,0 +1,42 @@ +import pytest + +from graphene import ObjectType, Schema +from graphene.relay import Node +from graphene_django import DjangoConnectionField, DjangoObjectType +from graphene_django.tests.models import Article, Pet, Reporter + +pytestmark = pytest.mark.django_db + + +def test_required_connection_field(): + class ReporterType(DjangoObjectType): + class Meta: + model = Reporter + interfaces = (Node,) + + class Query(ObjectType): + all_reporters = DjangoConnectionField(ReporterType, required=True) + + def resolve_all_reporters(self, info, **args): + return Reporter.objects.all() + + Reporter.objects.create(first_name="John", last_name="Doe") + + schema = Schema(query=Query) + query = """ + query NodeFilteringQuery { + allReporters { + edges { + node { + firstName + } + } + } + } + """ + + expected = {"allReporters": {"edges": [{"node": {"firstName": "John"}}]}} + + result = schema.execute(query) + assert not result.errors + assert result.data == expected