Skip to content

Adding select_related and prefetch_related optimizations #220

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions graphene_django/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from graphene.relay import ConnectionField, PageInfo
from graphql_relay.connection.arrayconnection import connection_from_list_slice

from .optimization import optimize_queryset
from .settings import graphene_settings
from .utils import DJANGO_FILTER_INSTALLED, maybe_queryset

Expand All @@ -23,7 +24,12 @@ def model(self):

@staticmethod
def list_resolver(resolver, root, args, context, info):
return maybe_queryset(resolver(root, args, context, info))
qs = maybe_queryset(resolver(root, args, context, info))

if isinstance(qs, QuerySet):
qs = optimize_queryset(qs, info)

return qs

def get_resolver(self, parent_resolver):
return partial(self.list_resolver, parent_resolver)
Expand Down Expand Up @@ -62,14 +68,15 @@ def merge_querysets(cls, default_queryset, queryset):
return queryset & default_queryset

@classmethod
def resolve_connection(cls, connection, default_manager, args, iterable):
def resolve_connection(cls, connection, default_manager, args, info, iterable):
if iterable is None:
iterable = default_manager
iterable = maybe_queryset(iterable)
if isinstance(iterable, QuerySet):
if iterable is not default_manager:
default_queryset = maybe_queryset(default_manager)
iterable = cls.merge_querysets(default_queryset, iterable)
iterable = optimize_queryset(iterable, info)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interestingly, this optimisation does work because it's come after the call to merge_querysets!

_len = iterable.count()
else:
_len = len(iterable)
Expand Down Expand Up @@ -112,7 +119,7 @@ def connection_resolver(cls, resolver, connection, default_manager, max_limit,
args['last'] = min(last, max_limit)

iterable = resolver(root, args, context, info)
on_resolve = partial(cls.resolve_connection, connection, default_manager, args)
on_resolve = partial(cls.resolve_connection, connection, default_manager, args, info)

if Promise.is_thenable(iterable):
return Promise.resolve(iterable).then(on_resolve)
Expand Down
2 changes: 2 additions & 0 deletions graphene_django/filter/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# from graphene.relay import is_node
from graphene.types.argument import to_arguments
from ..fields import DjangoConnectionField
from ..optimization import optimize_queryset
from .utils import get_filtering_args_from_filterset, get_filterset_class


Expand Down Expand Up @@ -75,6 +76,7 @@ def connection_resolver(cls, resolver, connection, default_manager, max_limit,
data=filter_kwargs,
queryset=default_manager.get_queryset()
).qs
qs = optimize_queryset(qs, info)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that currently these optimisations will be lost when line 66 in this file runs. That intersects the filtered queryset here with the resolved one, and loses the optimisations in the process.

We've "solved" this by instead resolving in this method, then passing the resolved queryset to the filterset (defaulting to the default_manager.get_queryset() if the resolver doesn't give us a value), and then dropping the merge behaviour above. This may or may not work in the general case, I really don't know, but it's working for us (we've implemented a similar function to your optimize_queryset).


return super(DjangoFilterConnectionField, cls).connection_resolver(
resolver,
Expand Down
100 changes: 100 additions & 0 deletions graphene_django/optimization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from collections import namedtuple

try:
from django.db.models.fields.reverse_related import ForeignObjectRel
except ImportError:
# Django 1.7 doesn't have the reverse_related distinction
from django.db.models.fields.related import ForeignObjectRel

from django.db.models import ForeignKey
from graphene.utils.str_converters import to_snake_case

from .registry import get_global_registry
from .utils import get_related_model

REGISTRY = get_global_registry()
SELECT = 'select'
PREFETCH = 'prefetch'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could make this an Enum instead.

RelatedSelection = namedtuple('RelatedSelection', ['name', 'fetch_type'])


def model_fields_as_dict(model):
return dict((f.name, f) for f in model._meta.get_fields())


def find_model_selections(ast):
selections = ast.selection_set.selections

for selection in selections:
if selection.name.value == 'edges':
for sub_selection in selection.selection_set.selections:
if sub_selection.name.value == 'node':
return sub_selection.selection_set.selections

return selections


def get_related_fetches_for_model(model, graphql_ast):
model_fields = model_fields_as_dict(model)
selections = find_model_selections(graphql_ast)

graphene_obj_type = REGISTRY.get_type_for_model(model)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the behaviour of this when there are multiple types mapped to the same model? As far as I remember that case isn't prohibited.

optimizations = {}
if graphene_obj_type and graphene_obj_type._meta.optimizations:
optimizations = graphene_obj_type._meta.optimizations

relateds = []

for selection in selections:
selection_name = to_snake_case(selection.name.value)
selection_field = model_fields.get(selection_name, None)

try:
related_model = get_related_model(selection_field)
except:
# This is not a ForeignKey or Relation, check manual optimizations
manual_optimizations = optimizations.get(selection_name)
if manual_optimizations:
for manual_select in manual_optimizations.get(SELECT, []):
relateds.append(RelatedSelection(manual_select, SELECT))
for manual_prefetch in manual_optimizations.get(PREFETCH, []):
relateds.append(RelatedSelection(manual_prefetch, PREFETCH))

continue

query_name = selection_field.name
if isinstance(selection_field, ForeignObjectRel):
query_name = selection_field.field.related_query_name()

nested_relateds = get_related_fetches_for_model(related_model, selection)

related_type = PREFETCH # default to prefetch, it's safer
if isinstance(selection_field, ForeignKey):
related_type = SELECT # we can only select for ForeignKeys

if nested_relateds:
for related in nested_relateds:
full_name = '{0}__{1}'.format(query_name, related.name)

nested_related_type = PREFETCH
if related_type == SELECT and related.fetch_type == SELECT:
nested_related_type = related_type

relateds.append(RelatedSelection(full_name, nested_related_type))
else:
relateds.append(RelatedSelection(query_name, related_type))

return relateds


def optimize_queryset(queryset, graphql_info):
base_ast = graphql_info.field_asts[0]
relateds = get_related_fetches_for_model(queryset.model, base_ast)

for related in relateds:
if related.fetch_type == SELECT:
queryset = queryset.select_related(related.name)
else:
queryset = queryset.prefetch_related(related.name)

return queryset
153 changes: 153 additions & 0 deletions graphene_django/tests/test_optimization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
from datetime import date
from django.db import connection
from django.test import TestCase
from django.test.utils import CaptureQueriesContext
import graphene

from ..fields import DjangoConnectionField, DjangoListField
from ..optimization import optimize_queryset
from ..types import DjangoObjectType
from .models import (
Article as ArticleModel,
Reporter as ReporterModel
)


class Article(DjangoObjectType):
class Meta:
model = ArticleModel
interfaces = (graphene.relay.Node,)


class Reporter(DjangoObjectType):
favorite_pet = graphene.Field(lambda: Reporter)

class Meta:
model = ReporterModel
#interfaces = (graphene.relay.Node,)
optimizations = {
'favorite_pet': {
'prefetch': ['pets']
}
}

def resolve_favorite_pet(self, *args):
for pet in self.pets.all():
if pet.last_name == 'Kent':
return pet


class RootQuery(graphene.ObjectType):
article = graphene.Field(Article, id=graphene.ID())
articles = DjangoConnectionField(Article)
reporters = DjangoListField(Reporter)

def resolve_article(self, args, context, info):
qs = ArticleModel.objects
qs = optimize_queryset(qs, info)
return qs.get(**args)

def resolve_reporters(self, args, context, info):
return ReporterModel.objects


schema = graphene.Schema(query=RootQuery)


class TestOptimization(TestCase):

@classmethod
def setUpTestData(cls):
cls.reporter = ReporterModel.objects.create(
first_name='Clark', last_name='Kent',
email='ckent@dailyplanet.com', a_choice='this'
)
cls.editor = ReporterModel.objects.create(
first_name='Perry', last_name='White',
email='pwhite@dailyplanet.com', a_choice='this'
)
cls.article = ArticleModel.objects.create(
headline='Superman Saves the Day',
pub_date=date.today(),
reporter=cls.reporter,
editor=cls.editor
)
cls.other_article = ArticleModel.objects.create(
headline='Lex Luthor is SO Rich',
pub_date=date.today(),
reporter=cls.reporter,
editor=cls.editor
)
cls.editor.pets.add(cls.reporter)

def test_select_related(self):
query = """query GetArticle($articleId: ID!){
article(id: $articleId) {
headline
reporter {
email
}
editor {
email
}
}
}"""

variables = {'articleId': str(self.article.id)}

with CaptureQueriesContext(connection) as query_context:
results = schema.execute(query, variable_values=variables)

returned_article = results.data['article']
assert returned_article['headline'] == self.article.headline
assert returned_article['reporter']['email'] == self.reporter.email
assert returned_article['editor']['email'] == self.editor.email

self.assertEqual(len(query_context.captured_queries), 1)

def test_prefetch_related(self):
query = """query {
articles {
edges {
node {
headline
editor {
email
pets {
email
}
}
}
}
}
}"""

with CaptureQueriesContext(connection) as query_context:
results = schema.execute(query)

returned_articles = results.data['articles']['edges']
assert len(returned_articles) == 2

self.assertEqual(len(query_context.captured_queries), 4)

def test_manual(self):
query = """query {
reporters {
email
favoritePet {
email
}
}
}"""

with CaptureQueriesContext(connection) as query_context:
results = schema.execute(query)

returned_reporters = results.data['reporters']
assert len(returned_reporters) == 2

returned_editor = [reporter for reporter in returned_reporters
if reporter['email'] == self.editor.email][0]
assert returned_editor['favoritePet']['email'] == self.reporter.email

self.assertEqual(len(query_context.captured_queries), 2)
8 changes: 6 additions & 2 deletions graphene_django/tests/test_types.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from mock import patch
from mock import Mock, patch

from graphene import Interface, ObjectType, Schema
from graphene.relay import Node
Expand Down Expand Up @@ -38,7 +38,11 @@ def test_django_interface():

@patch('graphene_django.tests.models.Article.objects.get', return_value=Article(id=1))
def test_django_get_node(get):
article = Article.get_node(1, None, None)
ast_mock = Mock()
ast_mock.selection_set.selections = []
info_mock = Mock(field_asts=[ast_mock])

article = Article.get_node(1, None, info_mock)
get.assert_called_with(pk=1)
assert article.id == 1

Expand Down
6 changes: 5 additions & 1 deletion graphene_django/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from graphene.utils.is_base_type import is_base_type

from .converter import convert_django_field_with_choices
from .optimization import optimize_queryset
from .registry import Registry, get_global_registry
from .utils import (DJANGO_FILTER_INSTALLED, get_model_fields,
is_valid_django_model)
Expand Down Expand Up @@ -55,6 +56,7 @@ def __new__(cls, name, bases, attrs):
only_fields=(),
exclude_fields=(),
interfaces=(),
optimizations=None,
skip_registry=False,
registry=None
)
Expand Down Expand Up @@ -118,7 +120,9 @@ def is_type_of(cls, root, context, info):

@classmethod
def get_node(cls, id, context, info):
query = cls._meta.model._default_manager
query = optimize_queryset(query, info)
try:
return cls._meta.model.objects.get(pk=id)
return query.get(pk=id)
except cls._meta.model.DoesNotExist:
return None