Skip to content

Added support for total count on relay connections #104

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions examples/flask_sqlalchemy/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ class Role(SQLAlchemyObjectType):
class Meta:
model = RoleModel
interfaces = (relay.Node, )
# Disable the total count on this connection
total_count = False


class Query(graphene.ObjectType):
Expand Down
30 changes: 30 additions & 0 deletions graphene_sqlalchemy/tests/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,3 +365,33 @@ class Mutation(graphene.ObjectType):
result = schema.execute(query, context_value={'session': session})
assert not result.errors
assert result.data == expected


def test_should_return_total_count(session):
setup_fixtures(session)

class ReporterNode(SQLAlchemyObjectType):

class Meta:
model = Reporter
interfaces = (Node, )

class Query(graphene.ObjectType):
all_article = SQLAlchemyConnectionField(ReporterNode)

query = '''
{
allArticle {
totalCount
}
}
'''
expected = {
'allArticle': {
'totalCount': session.query(Reporter).count()
},
}
schema = graphene.Schema(query=Query)
result = schema.execute(query, context_value={'session': session})
assert not result.errors
assert result.data == expected
23 changes: 21 additions & 2 deletions graphene_sqlalchemy/tests/test_types.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@

from graphene import Field, Int, Interface, ObjectType
from graphene.relay import Node, is_node
from graphene.relay import Node, is_node, Connection
import six

from ..registry import Registry
from ..types import SQLAlchemyObjectType
from ..types import SQLAlchemyObjectType, ConnectionWithCount
from .models import Article, Reporter

registry = Registry()
Expand Down Expand Up @@ -116,3 +116,22 @@ def test_custom_objecttype_registered():
'pets',
'articles',
'favorite_article']

def test_total_count():
class TotalCount(SQLAlchemyObjectType):
class Meta:
model = Article
interfaces = (Node, )
registry = registry

class NoTotalCount(SQLAlchemyObjectType):
class Meta:
model = Reporter
interfaces = (Node, )
registry = registry
total_count = False

assert issubclass(TotalCount._meta.connection, ConnectionWithCount)
assert not issubclass(NoTotalCount._meta.connection, ConnectionWithCount)
assert issubclass(TotalCount._meta.connection, Connection)
assert issubclass(NoTotalCount._meta.connection, Connection)
18 changes: 15 additions & 3 deletions graphene_sqlalchemy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm.exc import NoResultFound

from graphene import Field # , annotate, ResolveInfo
from graphene import Field, Int, NonNull
from graphene.relay import Connection, Node
from graphene.types.objecttype import ObjectType, ObjectTypeOptions
from graphene.types.utils import yank_fields_from_attrs
Expand Down Expand Up @@ -86,11 +86,22 @@ class SQLAlchemyObjectTypeOptions(ObjectTypeOptions):
id = None # type: str


class ConnectionWithCount(Connection):
'''Class that adds `totalCount` to a connection field'''
class Meta:
abstract = True

total_count = NonNull(Int)

def resolve_total_count(self, info, **kwargs):
return self.length


class SQLAlchemyObjectType(ObjectType):
@classmethod
def __init_subclass_with_meta__(cls, model=None, registry=None, skip_registry=False,
only_fields=(), exclude_fields=(), connection=None,
use_connection=None, interfaces=(), id=None, **options):
use_connection=None, interfaces=(), id=None, total_count=True, **options):
assert is_mapped_class(model), (
'You need to pass a valid SQLAlchemy Model in '
'{}.Meta, received "{}".'
Expand All @@ -114,7 +125,8 @@ def __init_subclass_with_meta__(cls, model=None, registry=None, skip_registry=Fa

if use_connection and not connection:
# We create the connection automatically
connection = Connection.create_type('{}Connection'.format(cls.__name__), node=cls)
connection_class = ConnectionWithCount if total_count else Connection
connection = connection_class.create_type('{}Connection'.format(cls.__name__), node=cls)

if connection is not None:
assert issubclass(connection, Connection), (
Expand Down