diff --git a/graphene_django/fields.py b/graphene_django/fields.py index c6dcd26af..1ab5e534b 100644 --- a/graphene_django/fields.py +++ b/graphene_django/fields.py @@ -8,6 +8,7 @@ from graphene.relay import ConnectionField, PageInfo from graphql_relay.connection.arrayconnection import connection_from_list_slice +from .optimization import optimize_queryset from .settings import graphene_settings from .utils import DJANGO_FILTER_INSTALLED, maybe_queryset @@ -23,7 +24,12 @@ def model(self): @staticmethod def list_resolver(resolver, root, args, context, info): - return maybe_queryset(resolver(root, args, context, info)) + qs = maybe_queryset(resolver(root, args, context, info)) + + if isinstance(qs, QuerySet): + qs = optimize_queryset(qs, info) + + return qs def get_resolver(self, parent_resolver): return partial(self.list_resolver, parent_resolver) @@ -62,7 +68,7 @@ def merge_querysets(cls, default_queryset, queryset): return queryset & default_queryset @classmethod - def resolve_connection(cls, connection, default_manager, args, iterable): + def resolve_connection(cls, connection, default_manager, args, info, iterable): if iterable is None: iterable = default_manager iterable = maybe_queryset(iterable) @@ -70,6 +76,7 @@ def resolve_connection(cls, connection, default_manager, args, iterable): if iterable is not default_manager: default_queryset = maybe_queryset(default_manager) iterable = cls.merge_querysets(default_queryset, iterable) + iterable = optimize_queryset(iterable, info) _len = iterable.count() else: _len = len(iterable) @@ -112,7 +119,7 @@ def connection_resolver(cls, resolver, connection, default_manager, max_limit, args['last'] = min(last, max_limit) iterable = resolver(root, args, context, info) - on_resolve = partial(cls.resolve_connection, connection, default_manager, args) + on_resolve = partial(cls.resolve_connection, connection, default_manager, args, info) if Promise.is_thenable(iterable): return Promise.resolve(iterable).then(on_resolve) diff --git a/graphene_django/filter/fields.py b/graphene_django/filter/fields.py index fc414bf51..ce8e6a290 100644 --- a/graphene_django/filter/fields.py +++ b/graphene_django/filter/fields.py @@ -4,6 +4,7 @@ # from graphene.relay import is_node from graphene.types.argument import to_arguments from ..fields import DjangoConnectionField +from ..optimization import optimize_queryset from .utils import get_filtering_args_from_filterset, get_filterset_class @@ -75,6 +76,7 @@ def connection_resolver(cls, resolver, connection, default_manager, max_limit, data=filter_kwargs, queryset=default_manager.get_queryset() ).qs + qs = optimize_queryset(qs, info) return super(DjangoFilterConnectionField, cls).connection_resolver( resolver, diff --git a/graphene_django/optimization.py b/graphene_django/optimization.py new file mode 100644 index 000000000..21c3bca38 --- /dev/null +++ b/graphene_django/optimization.py @@ -0,0 +1,100 @@ +from collections import namedtuple + +try: + from django.db.models.fields.reverse_related import ForeignObjectRel +except ImportError: + # Django 1.7 doesn't have the reverse_related distinction + from django.db.models.fields.related import ForeignObjectRel + +from django.db.models import ForeignKey +from graphene.utils.str_converters import to_snake_case + +from .registry import get_global_registry +from .utils import get_related_model + +REGISTRY = get_global_registry() +SELECT = 'select' +PREFETCH = 'prefetch' +RelatedSelection = namedtuple('RelatedSelection', ['name', 'fetch_type']) + + +def model_fields_as_dict(model): + return dict((f.name, f) for f in model._meta.get_fields()) + + +def find_model_selections(ast): + selections = ast.selection_set.selections + + for selection in selections: + if selection.name.value == 'edges': + for sub_selection in selection.selection_set.selections: + if sub_selection.name.value == 'node': + return sub_selection.selection_set.selections + + return selections + + +def get_related_fetches_for_model(model, graphql_ast): + model_fields = model_fields_as_dict(model) + selections = find_model_selections(graphql_ast) + + graphene_obj_type = REGISTRY.get_type_for_model(model) + optimizations = {} + if graphene_obj_type and graphene_obj_type._meta.optimizations: + optimizations = graphene_obj_type._meta.optimizations + + relateds = [] + + for selection in selections: + selection_name = to_snake_case(selection.name.value) + selection_field = model_fields.get(selection_name, None) + + try: + related_model = get_related_model(selection_field) + except: + # This is not a ForeignKey or Relation, check manual optimizations + manual_optimizations = optimizations.get(selection_name) + if manual_optimizations: + for manual_select in manual_optimizations.get(SELECT, []): + relateds.append(RelatedSelection(manual_select, SELECT)) + for manual_prefetch in manual_optimizations.get(PREFETCH, []): + relateds.append(RelatedSelection(manual_prefetch, PREFETCH)) + + continue + + query_name = selection_field.name + if isinstance(selection_field, ForeignObjectRel): + query_name = selection_field.field.related_query_name() + + nested_relateds = get_related_fetches_for_model(related_model, selection) + + related_type = PREFETCH # default to prefetch, it's safer + if isinstance(selection_field, ForeignKey): + related_type = SELECT # we can only select for ForeignKeys + + if nested_relateds: + for related in nested_relateds: + full_name = '{0}__{1}'.format(query_name, related.name) + + nested_related_type = PREFETCH + if related_type == SELECT and related.fetch_type == SELECT: + nested_related_type = related_type + + relateds.append(RelatedSelection(full_name, nested_related_type)) + else: + relateds.append(RelatedSelection(query_name, related_type)) + + return relateds + + +def optimize_queryset(queryset, graphql_info): + base_ast = graphql_info.field_asts[0] + relateds = get_related_fetches_for_model(queryset.model, base_ast) + + for related in relateds: + if related.fetch_type == SELECT: + queryset = queryset.select_related(related.name) + else: + queryset = queryset.prefetch_related(related.name) + + return queryset diff --git a/graphene_django/tests/test_optimization.py b/graphene_django/tests/test_optimization.py new file mode 100644 index 000000000..afe58d6e8 --- /dev/null +++ b/graphene_django/tests/test_optimization.py @@ -0,0 +1,153 @@ +from datetime import date +from django.db import connection +from django.test import TestCase +from django.test.utils import CaptureQueriesContext +import graphene + +from ..fields import DjangoConnectionField, DjangoListField +from ..optimization import optimize_queryset +from ..types import DjangoObjectType +from .models import ( + Article as ArticleModel, + Reporter as ReporterModel +) + + +class Article(DjangoObjectType): + class Meta: + model = ArticleModel + interfaces = (graphene.relay.Node,) + + +class Reporter(DjangoObjectType): + favorite_pet = graphene.Field(lambda: Reporter) + + class Meta: + model = ReporterModel + #interfaces = (graphene.relay.Node,) + optimizations = { + 'favorite_pet': { + 'prefetch': ['pets'] + } + } + + def resolve_favorite_pet(self, *args): + for pet in self.pets.all(): + if pet.last_name == 'Kent': + return pet + + +class RootQuery(graphene.ObjectType): + article = graphene.Field(Article, id=graphene.ID()) + articles = DjangoConnectionField(Article) + reporters = DjangoListField(Reporter) + + def resolve_article(self, args, context, info): + qs = ArticleModel.objects + qs = optimize_queryset(qs, info) + return qs.get(**args) + + def resolve_reporters(self, args, context, info): + return ReporterModel.objects + + +schema = graphene.Schema(query=RootQuery) + + +class TestOptimization(TestCase): + + @classmethod + def setUpTestData(cls): + cls.reporter = ReporterModel.objects.create( + first_name='Clark', last_name='Kent', + email='ckent@dailyplanet.com', a_choice='this' + ) + cls.editor = ReporterModel.objects.create( + first_name='Perry', last_name='White', + email='pwhite@dailyplanet.com', a_choice='this' + ) + cls.article = ArticleModel.objects.create( + headline='Superman Saves the Day', + pub_date=date.today(), + reporter=cls.reporter, + editor=cls.editor + ) + cls.other_article = ArticleModel.objects.create( + headline='Lex Luthor is SO Rich', + pub_date=date.today(), + reporter=cls.reporter, + editor=cls.editor + ) + cls.editor.pets.add(cls.reporter) + + def test_select_related(self): + query = """query GetArticle($articleId: ID!){ + article(id: $articleId) { + headline + reporter { + email + } + editor { + email + } + } + }""" + + variables = {'articleId': str(self.article.id)} + + with CaptureQueriesContext(connection) as query_context: + results = schema.execute(query, variable_values=variables) + + returned_article = results.data['article'] + assert returned_article['headline'] == self.article.headline + assert returned_article['reporter']['email'] == self.reporter.email + assert returned_article['editor']['email'] == self.editor.email + + self.assertEqual(len(query_context.captured_queries), 1) + + def test_prefetch_related(self): + query = """query { + articles { + edges { + node { + headline + editor { + email + pets { + email + } + } + } + } + } + }""" + + with CaptureQueriesContext(connection) as query_context: + results = schema.execute(query) + + returned_articles = results.data['articles']['edges'] + assert len(returned_articles) == 2 + + self.assertEqual(len(query_context.captured_queries), 4) + + def test_manual(self): + query = """query { + reporters { + email + favoritePet { + email + } + } + }""" + + with CaptureQueriesContext(connection) as query_context: + results = schema.execute(query) + + returned_reporters = results.data['reporters'] + assert len(returned_reporters) == 2 + + returned_editor = [reporter for reporter in returned_reporters + if reporter['email'] == self.editor.email][0] + assert returned_editor['favoritePet']['email'] == self.reporter.email + + self.assertEqual(len(query_context.captured_queries), 2) diff --git a/graphene_django/tests/test_types.py b/graphene_django/tests/test_types.py index 0ae12c02f..1172cc961 100644 --- a/graphene_django/tests/test_types.py +++ b/graphene_django/tests/test_types.py @@ -1,4 +1,4 @@ -from mock import patch +from mock import Mock, patch from graphene import Interface, ObjectType, Schema from graphene.relay import Node @@ -38,7 +38,11 @@ def test_django_interface(): @patch('graphene_django.tests.models.Article.objects.get', return_value=Article(id=1)) def test_django_get_node(get): - article = Article.get_node(1, None, None) + ast_mock = Mock() + ast_mock.selection_set.selections = [] + info_mock = Mock(field_asts=[ast_mock]) + + article = Article.get_node(1, None, info_mock) get.assert_called_with(pk=1) assert article.id == 1 diff --git a/graphene_django/types.py b/graphene_django/types.py index bb0a2f1f1..970e6fa38 100644 --- a/graphene_django/types.py +++ b/graphene_django/types.py @@ -10,6 +10,7 @@ from graphene.utils.is_base_type import is_base_type from .converter import convert_django_field_with_choices +from .optimization import optimize_queryset from .registry import Registry, get_global_registry from .utils import (DJANGO_FILTER_INSTALLED, get_model_fields, is_valid_django_model) @@ -55,6 +56,7 @@ def __new__(cls, name, bases, attrs): only_fields=(), exclude_fields=(), interfaces=(), + optimizations=None, skip_registry=False, registry=None ) @@ -118,7 +120,9 @@ def is_type_of(cls, root, context, info): @classmethod def get_node(cls, id, context, info): + query = cls._meta.model._default_manager + query = optimize_queryset(query, info) try: - return cls._meta.model.objects.get(pk=id) + return query.get(pk=id) except cls._meta.model.DoesNotExist: return None