diff --git a/graphene_django/tests/models.py b/graphene_django/tests/models.py index 735f23648..e7298383f 100644 --- a/graphene_django/tests/models.py +++ b/graphene_django/tests/models.py @@ -46,6 +46,7 @@ class Reporter(models.Model): a_choice = models.IntegerField(choices=CHOICES, null=True, blank=True) objects = models.Manager() doe_objects = DoeReporterManager() + fans = models.ManyToManyField(Person) reporter_type = models.IntegerField( "Reporter Type", @@ -90,6 +91,16 @@ class Meta: objects = CNNReporterManager() +class APNewsReporter(Reporter): + """ + This class only inherits from Reporter for testing multi table inheritence + similar to what you'd see in django-polymorphic + """ + + alias = models.CharField(max_length=30) + objects = models.Manager() + + class Article(models.Model): headline = models.CharField(max_length=100) pub_date = models.DateField(auto_now_add=True) diff --git a/graphene_django/tests/test_query.py b/graphene_django/tests/test_query.py index 68bdc7d94..91bacbdf3 100644 --- a/graphene_django/tests/test_query.py +++ b/graphene_django/tests/test_query.py @@ -15,7 +15,16 @@ from ..fields import DjangoConnectionField from ..types import DjangoObjectType from ..utils import DJANGO_FILTER_INSTALLED -from .models import Article, CNNReporter, Film, FilmDetails, Person, Pet, Reporter +from .models import ( + Article, + CNNReporter, + Film, + FilmDetails, + Person, + Pet, + Reporter, + APNewsReporter, +) def test_should_query_only_fields(): @@ -1064,6 +1073,301 @@ class Query(graphene.ObjectType): assert result.data == expected +def test_model_inheritance_support_reverse_relationships(): + """ + This test asserts that we can query reverse relationships for all Reporters and proxied Reporters and multi table Reporters. + """ + + class FilmType(DjangoObjectType): + class Meta: + model = Film + fields = "__all__" + + class ReporterType(DjangoObjectType): + class Meta: + model = Reporter + interfaces = (Node,) + use_connection = True + fields = "__all__" + + class CNNReporterType(DjangoObjectType): + class Meta: + model = CNNReporter + interfaces = (Node,) + use_connection = True + fields = "__all__" + + class APNewsReporterType(DjangoObjectType): + class Meta: + model = APNewsReporter + interfaces = (Node,) + use_connection = True + fields = "__all__" + + film = Film.objects.create(genre="do") + + reporter = Reporter.objects.create( + first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1 + ) + + cnn_reporter = CNNReporter.objects.create( + first_name="Some", + last_name="Guy", + email="someguy@cnn.com", + a_choice=1, + reporter_type=2, # set this guy to be CNN + ) + + ap_news_reporter = APNewsReporter.objects.create( + first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1 + ) + + film.reporters.add(cnn_reporter, ap_news_reporter) + film.save() + + class Query(graphene.ObjectType): + all_reporters = DjangoConnectionField(ReporterType) + cnn_reporters = DjangoConnectionField(CNNReporterType) + ap_news_reporters = DjangoConnectionField(APNewsReporterType) + + schema = graphene.Schema(query=Query) + query = """ + query ProxyModelQuery { + allReporters { + edges { + node { + id + films { + id + } + } + } + } + cnnReporters { + edges { + node { + id + films { + id + } + } + } + } + apNewsReporters { + edges { + node { + id + films { + id + } + } + } + } + } + """ + + expected = { + "allReporters": { + "edges": [ + { + "node": { + "id": to_global_id("ReporterType", reporter.id), + "films": [], + }, + }, + { + "node": { + "id": to_global_id("ReporterType", cnn_reporter.id), + "films": [{"id": f"{film.id}"}], + }, + }, + { + "node": { + "id": to_global_id("ReporterType", ap_news_reporter.id), + "films": [{"id": f"{film.id}"}], + }, + }, + ] + }, + "cnnReporters": { + "edges": [ + { + "node": { + "id": to_global_id("CNNReporterType", cnn_reporter.id), + "films": [{"id": f"{film.id}"}], + } + } + ] + }, + "apNewsReporters": { + "edges": [ + { + "node": { + "id": to_global_id("APNewsReporterType", ap_news_reporter.id), + "films": [{"id": f"{film.id}"}], + } + } + ] + }, + } + + result = schema.execute(query) + assert result.data == expected + + +def test_model_inheritance_support_local_relationships(): + """ + This test asserts that we can query local relationships for all Reporters and proxied Reporters and multi table Reporters. + """ + + class PersonType(DjangoObjectType): + class Meta: + model = Person + fields = "__all__" + + class ReporterType(DjangoObjectType): + class Meta: + model = Reporter + interfaces = (Node,) + use_connection = True + fields = "__all__" + + class CNNReporterType(DjangoObjectType): + class Meta: + model = CNNReporter + interfaces = (Node,) + use_connection = True + fields = "__all__" + + class APNewsReporterType(DjangoObjectType): + class Meta: + model = APNewsReporter + interfaces = (Node,) + use_connection = True + fields = "__all__" + + film = Film.objects.create(genre="do") + + reporter = Reporter.objects.create( + first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1 + ) + + reporter_fan = Person.objects.create(name="Reporter Fan") + + reporter.fans.add(reporter_fan) + reporter.save() + + cnn_reporter = CNNReporter.objects.create( + first_name="Some", + last_name="Guy", + email="someguy@cnn.com", + a_choice=1, + reporter_type=2, # set this guy to be CNN + ) + cnn_fan = Person.objects.create(name="CNN Fan") + cnn_reporter.fans.add(cnn_fan) + cnn_reporter.save() + + ap_news_reporter = APNewsReporter.objects.create( + first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1 + ) + ap_news_fan = Person.objects.create(name="AP News Fan") + ap_news_reporter.fans.add(ap_news_fan) + ap_news_reporter.save() + + film.reporters.add(cnn_reporter, ap_news_reporter) + film.save() + + class Query(graphene.ObjectType): + all_reporters = DjangoConnectionField(ReporterType) + cnn_reporters = DjangoConnectionField(CNNReporterType) + ap_news_reporters = DjangoConnectionField(APNewsReporterType) + + schema = graphene.Schema(query=Query) + query = """ + query ProxyModelQuery { + allReporters { + edges { + node { + id + fans { + name + } + } + } + } + cnnReporters { + edges { + node { + id + fans { + name + } + } + } + } + apNewsReporters { + edges { + node { + id + fans { + name + } + } + } + } + } + """ + + expected = { + "allReporters": { + "edges": [ + { + "node": { + "id": to_global_id("ReporterType", reporter.id), + "fans": [{"name": f"{reporter_fan.name}"}], + }, + }, + { + "node": { + "id": to_global_id("ReporterType", cnn_reporter.id), + "fans": [{"name": f"{cnn_fan.name}"}], + }, + }, + { + "node": { + "id": to_global_id("ReporterType", ap_news_reporter.id), + "fans": [{"name": f"{ap_news_fan.name}"}], + }, + }, + ] + }, + "cnnReporters": { + "edges": [ + { + "node": { + "id": to_global_id("CNNReporterType", cnn_reporter.id), + "fans": [{"name": f"{cnn_fan.name}"}], + } + } + ] + }, + "apNewsReporters": { + "edges": [ + { + "node": { + "id": to_global_id("APNewsReporterType", ap_news_reporter.id), + "fans": [{"name": f"{ap_news_fan.name}"}], + } + } + ] + }, + } + + result = schema.execute(query) + assert result.data == expected + + def test_should_resolve_get_queryset_connectionfields(): reporter_1 = Reporter.objects.create( first_name="John", last_name="Doe", email="johndoe@example.com", a_choice=1 diff --git a/graphene_django/tests/test_schema.py b/graphene_django/tests/test_schema.py index ff2d8a668..93cbd9f05 100644 --- a/graphene_django/tests/test_schema.py +++ b/graphene_django/tests/test_schema.py @@ -33,17 +33,18 @@ class Meta: fields = "__all__" fields = list(ReporterType2._meta.fields.keys()) - assert fields[:-2] == [ + assert fields[:-3] == [ "id", "first_name", "last_name", "email", "pets", "a_choice", + "fans", "reporter_type", ] - assert sorted(fields[-2:]) == ["articles", "films"] + assert sorted(fields[-3:]) == ["apnewsreporter", "articles", "films"] def test_should_map_only_few_fields(): diff --git a/graphene_django/tests/test_types.py b/graphene_django/tests/test_types.py index fad26e2ab..fd85ef140 100644 --- a/graphene_django/tests/test_types.py +++ b/graphene_django/tests/test_types.py @@ -67,16 +67,17 @@ def test_django_get_node(get): def test_django_objecttype_map_correct_fields(): fields = Reporter._meta.fields fields = list(fields.keys()) - assert fields[:-2] == [ + assert fields[:-3] == [ "id", "first_name", "last_name", "email", "pets", "a_choice", + "fans", "reporter_type", ] - assert sorted(fields[-2:]) == ["articles", "films"] + assert sorted(fields[-3:]) == ["apnewsreporter", "articles", "films"] def test_django_objecttype_with_node_have_correct_fields(): diff --git a/graphene_django/tests/test_utils.py b/graphene_django/tests/test_utils.py index fa269b44c..4e6861eda 100644 --- a/graphene_django/tests/test_utils.py +++ b/graphene_django/tests/test_utils.py @@ -4,8 +4,8 @@ from django.utils.translation import gettext_lazy from unittest.mock import patch -from ..utils import camelize, get_model_fields, GraphQLTestCase -from .models import Film, Reporter +from ..utils import camelize, get_model_fields, get_reverse_fields, GraphQLTestCase +from .models import Film, Reporter, CNNReporter, APNewsReporter from ..utils.testing import graphql_query @@ -19,6 +19,18 @@ def test_get_model_fields_no_duplication(): assert len(film_fields) == len(film_name_set) +def test_get_reverse_fields_includes_proxied_models(): + reporter_fields = get_reverse_fields(Reporter, []) + cnn_reporter_fields = get_reverse_fields(CNNReporter, []) + ap_news_reporter_fields = get_reverse_fields(APNewsReporter, []) + + assert ( + len(list(reporter_fields)) + == len(list(cnn_reporter_fields)) + == len(list(ap_news_reporter_fields)) + ) + + def test_camelize(): assert camelize({}) == {} assert camelize("value_a") == "value_a" diff --git a/graphene_django/utils/utils.py b/graphene_django/utils/utils.py index e0b8b5f9a..d7993e7b2 100644 --- a/graphene_django/utils/utils.py +++ b/graphene_django/utils/utils.py @@ -37,18 +37,52 @@ def camelize(data): return data +def _get_model_ancestry(model): + model_ancestry = [model] + + for base in model.__bases__: + if is_valid_django_model(base) and getattr(base, "_meta", False): + model_ancestry.append(base) + return model_ancestry + + def get_reverse_fields(model, local_field_names): - for name, attr in model.__dict__.items(): - # Don't duplicate any local fields - if name in local_field_names: - continue + """ + Searches through the model's ancestry and gets reverse relationships the models + Yields a tuple of (field.name, field) + """ + model_ancestry = _get_model_ancestry(model) - # "rel" for FK and M2M relations and "related" for O2O Relations - related = getattr(attr, "rel", None) or getattr(attr, "related", None) - if isinstance(related, models.ManyToOneRel): - yield (name, related) - elif isinstance(related, models.ManyToManyRel) and not related.symmetrical: - yield (name, related) + for _model in model_ancestry: + for name, attr in _model.__dict__.items(): + # Don't duplicate any local fields + if name in local_field_names: + continue + + # "rel" for FK and M2M relations and "related" for O2O Relations + related = getattr(attr, "rel", None) or getattr(attr, "related", None) + if isinstance(related, models.ManyToOneRel): + yield (name, related) + elif isinstance(related, models.ManyToManyRel) and not related.symmetrical: + yield (name, related) + + +def get_local_fields(model): + """ + Searches through the model's ancestry and gets the fields on the models + Returns a dict of {field.name: field} + """ + model_ancestry = _get_model_ancestry(model) + + local_fields_dict = {} + for _model in model_ancestry: + for field in sorted( + list(_model._meta.fields) + list(_model._meta.local_many_to_many) + ): + if field.name not in local_fields_dict: + local_fields_dict[field.name] = field + + return list(local_fields_dict.items()) def maybe_queryset(value): @@ -58,17 +92,14 @@ def maybe_queryset(value): def get_model_fields(model): - local_fields = [ - (field.name, field) - for field in sorted( - list(model._meta.fields) + list(model._meta.local_many_to_many) - ) - ] - - # Make sure we don't duplicate local fields with "reverse" version - local_field_names = [field[0] for field in local_fields] + """ + Gets all the fields and relationships on the Django model and its ancestry. + Prioritizes local fields and relationships over the reverse relationships of the same name + Returns a tuple of (field.name, field) + """ + local_fields = get_local_fields(model) + local_field_names = {field[0] for field in local_fields} reverse_fields = get_reverse_fields(model, local_field_names) - all_fields = local_fields + list(reverse_fields) return all_fields