diff --git a/examples/flask_sqlalchemy/schema.py b/examples/flask_sqlalchemy/schema.py index 5d9d3b72..e851ee73 100644 --- a/examples/flask_sqlalchemy/schema.py +++ b/examples/flask_sqlalchemy/schema.py @@ -25,6 +25,8 @@ class Role(SQLAlchemyObjectType): class Meta: model = RoleModel interfaces = (relay.Node, ) + # Disable the total count on this connection + total_count = False class Query(graphene.ObjectType): diff --git a/graphene_sqlalchemy/tests/test_query.py b/graphene_sqlalchemy/tests/test_query.py index 12dd1fad..411646c8 100644 --- a/graphene_sqlalchemy/tests/test_query.py +++ b/graphene_sqlalchemy/tests/test_query.py @@ -365,3 +365,33 @@ class Mutation(graphene.ObjectType): result = schema.execute(query, context_value={'session': session}) assert not result.errors assert result.data == expected + + +def test_should_return_total_count(session): + setup_fixtures(session) + + class ReporterNode(SQLAlchemyObjectType): + + class Meta: + model = Reporter + interfaces = (Node, ) + + class Query(graphene.ObjectType): + all_article = SQLAlchemyConnectionField(ReporterNode) + + query = ''' + { + allArticle { + totalCount + } + } + ''' + expected = { + 'allArticle': { + 'totalCount': session.query(Reporter).count() + }, + } + schema = graphene.Schema(query=Query) + result = schema.execute(query, context_value={'session': session}) + assert not result.errors + assert result.data == expected diff --git a/graphene_sqlalchemy/tests/test_types.py b/graphene_sqlalchemy/tests/test_types.py index 3f017aae..6639be9a 100644 --- a/graphene_sqlalchemy/tests/test_types.py +++ b/graphene_sqlalchemy/tests/test_types.py @@ -1,10 +1,10 @@ from graphene import Field, Int, Interface, ObjectType -from graphene.relay import Node, is_node +from graphene.relay import Node, is_node, Connection import six from ..registry import Registry -from ..types import SQLAlchemyObjectType +from ..types import SQLAlchemyObjectType, ConnectionWithCount from .models import Article, Reporter registry = Registry() @@ -116,3 +116,22 @@ def test_custom_objecttype_registered(): 'pets', 'articles', 'favorite_article'] + +def test_total_count(): + class TotalCount(SQLAlchemyObjectType): + class Meta: + model = Article + interfaces = (Node, ) + registry = registry + + class NoTotalCount(SQLAlchemyObjectType): + class Meta: + model = Reporter + interfaces = (Node, ) + registry = registry + total_count = False + + assert issubclass(TotalCount._meta.connection, ConnectionWithCount) + assert not issubclass(NoTotalCount._meta.connection, ConnectionWithCount) + assert issubclass(TotalCount._meta.connection, Connection) + assert issubclass(NoTotalCount._meta.connection, Connection) diff --git a/graphene_sqlalchemy/types.py b/graphene_sqlalchemy/types.py index 04d1a8a6..88864e9f 100644 --- a/graphene_sqlalchemy/types.py +++ b/graphene_sqlalchemy/types.py @@ -4,7 +4,7 @@ from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.orm.exc import NoResultFound -from graphene import Field # , annotate, ResolveInfo +from graphene import Field, Int, NonNull from graphene.relay import Connection, Node from graphene.types.objecttype import ObjectType, ObjectTypeOptions from graphene.types.utils import yank_fields_from_attrs @@ -86,11 +86,22 @@ class SQLAlchemyObjectTypeOptions(ObjectTypeOptions): id = None # type: str +class ConnectionWithCount(Connection): + '''Class that adds `totalCount` to a connection field''' + class Meta: + abstract = True + + total_count = NonNull(Int) + + def resolve_total_count(self, info, **kwargs): + return self.length + + class SQLAlchemyObjectType(ObjectType): @classmethod def __init_subclass_with_meta__(cls, model=None, registry=None, skip_registry=False, only_fields=(), exclude_fields=(), connection=None, - use_connection=None, interfaces=(), id=None, **options): + use_connection=None, interfaces=(), id=None, total_count=True, **options): assert is_mapped_class(model), ( 'You need to pass a valid SQLAlchemy Model in ' '{}.Meta, received "{}".' @@ -114,7 +125,8 @@ def __init_subclass_with_meta__(cls, model=None, registry=None, skip_registry=Fa if use_connection and not connection: # We create the connection automatically - connection = Connection.create_type('{}Connection'.format(cls.__name__), node=cls) + connection_class = ConnectionWithCount if total_count else Connection + connection = connection_class.create_type('{}Connection'.format(cls.__name__), node=cls) if connection is not None: assert issubclass(connection, Connection), (