diff --git a/graphene_sqlalchemy/converter.py b/graphene_sqlalchemy/converter.py index 0d745ce3..08d3b7ab 100644 --- a/graphene_sqlalchemy/converter.py +++ b/graphene_sqlalchemy/converter.py @@ -36,7 +36,7 @@ def dynamic_type(): return Field(_type) elif direction in (interfaces.ONETOMANY, interfaces.MANYTOMANY): if _type._meta.connection: - return createConnectionField(_type) + return createConnectionField(_type._meta.connection) return Field(List(_type)) return Dynamic(dynamic_type) diff --git a/graphene_sqlalchemy/fields.py b/graphene_sqlalchemy/fields.py index bb084b3a..045ecec8 100644 --- a/graphene_sqlalchemy/fields.py +++ b/graphene_sqlalchemy/fields.py @@ -1,5 +1,5 @@ from functools import partial - +from promise import is_thenable, Promise from sqlalchemy.orm.query import Query from graphene.relay import ConnectionField @@ -19,39 +19,38 @@ def model(self): def get_query(cls, model, info, **args): return get_query(model, info.context) - @property - def type(self): - from .types import SQLAlchemyObjectType - _type = super(ConnectionField, self).type - assert issubclass(_type, SQLAlchemyObjectType), ( - "SQLAlchemyConnectionField only accepts SQLAlchemyObjectType types" - ) - assert _type._meta.connection, "The type {} doesn't have a connection".format(_type.__name__) - return _type._meta.connection - @classmethod - def connection_resolver(cls, resolver, connection, model, root, info, **args): - iterable = resolver(root, info, **args) - if iterable is None: - iterable = cls.get_query(model, info, **args) - if isinstance(iterable, Query): - _len = iterable.count() + def resolve_connection(cls, connection_type, model, info, args, resolved): + if resolved is None: + resolved = cls.get_query(model, info, **args) + if isinstance(resolved, Query): + _len = resolved.count() else: - _len = len(iterable) + _len = len(resolved) connection = connection_from_list_slice( - iterable, + resolved, args, slice_start=0, list_length=_len, list_slice_length=_len, - connection_type=connection, + connection_type=connection_type, pageinfo_type=PageInfo, - edge_type=connection.Edge, + edge_type=connection_type.Edge, ) - connection.iterable = iterable + connection.iterable = resolved connection.length = _len return connection + @classmethod + def connection_resolver(cls, resolver, connection_type, model, root, info, **args): + resolved = resolver(root, info, **args) + + on_resolve = partial(cls.resolve_connection, connection_type, model, info, args) + if is_thenable(resolved): + return Promise.resolve(resolved).then(on_resolve) + + return on_resolve(resolved) + def get_resolver(self, parent_resolver): return partial(self.connection_resolver, parent_resolver, self.type, self.model) diff --git a/graphene_sqlalchemy/tests/test_connectionfactory.py b/graphene_sqlalchemy/tests/test_connectionfactory.py index 867c5261..6222a431 100644 --- a/graphene_sqlalchemy/tests/test_connectionfactory.py +++ b/graphene_sqlalchemy/tests/test_connectionfactory.py @@ -22,7 +22,10 @@ def LXResolver(root, args, context, info): return SQLAlchemyConnectionField.connection_resolver(LXResolver, connection, model, root, args, context, info) def createLXConnectionField(table): - return LXConnectionField(table, filter=table.filter(), order_by=graphene.List(of_type=table.order_by)) + class LXConnection(graphene.relay.Connection): + class Meta: + node = table + return LXConnectionField(LXConnection, filter=table.filter(), order_by=graphene.List(of_type=table.order_by)) registerConnectionFieldFactory(createLXConnectionField) unregisterConnectionFieldFactory() diff --git a/graphene_sqlalchemy/tests/test_query.py b/graphene_sqlalchemy/tests/test_query.py index 12dd1fad..5654a1e0 100644 --- a/graphene_sqlalchemy/tests/test_query.py +++ b/graphene_sqlalchemy/tests/test_query.py @@ -139,7 +139,7 @@ class Meta: interfaces = (Node, ) @classmethod - def get_node(cls, id, info): + def get_node(cls, info, id): return Reporter(id=2, first_name='Cookie Monster') class ArticleNode(SQLAlchemyObjectType): @@ -152,11 +152,15 @@ class Meta: # def get_node(cls, id, info): # return Article(id=1, headline='Article node') + class ArticleConnection(graphene.relay.Connection): + class Meta: + node = ArticleNode + class Query(graphene.ObjectType): node = Node.Field() reporter = graphene.Field(ReporterNode) article = graphene.Field(ArticleNode) - all_articles = SQLAlchemyConnectionField(ArticleNode) + all_articles = SQLAlchemyConnectionField(ArticleConnection) def resolve_reporter(self, *args, **kwargs): return session.query(Reporter).first() @@ -238,9 +242,13 @@ class Meta: model = Editor interfaces = (Node, ) + class EditorConnection(graphene.relay.Connection): + class Meta: + node = EditorNode + class Query(graphene.ObjectType): node = Node.Field() - all_editors = SQLAlchemyConnectionField(EditorNode) + all_editors = SQLAlchemyConnectionField(EditorConnection) query = ''' query EditorQuery { diff --git a/graphene_sqlalchemy/tests/test_types.py b/graphene_sqlalchemy/tests/test_types.py index 2766c2ab..282e321f 100644 --- a/graphene_sqlalchemy/tests/test_types.py +++ b/graphene_sqlalchemy/tests/test_types.py @@ -1,11 +1,13 @@ from collections import OrderedDict from graphene import Field, Int, Interface, ObjectType -from graphene.relay import Node, is_node +from graphene.relay import Node, is_node, Connection import six +from promise import Promise from ..registry import Registry from ..types import SQLAlchemyObjectType, SQLAlchemyObjectTypeOptions from .models import Article, Reporter +from ..fields import SQLAlchemyConnectionField registry = Registry() @@ -158,3 +160,13 @@ def test_objecttype_with_custom_options(): 'favorite_article'] assert ReporterWithCustomOptions._meta.custom_option == 'custom_option' assert isinstance(ReporterWithCustomOptions._meta.fields['custom_field'].type, Int) + + +def test_promise_connection_resolver(): + class TestConnection(Connection): + class Meta: + node = ReporterWithCustomOptions + + resolver = lambda *args, **kwargs: Promise.resolve([]) + result = SQLAlchemyConnectionField.connection_resolver(resolver, TestConnection, ReporterWithCustomOptions, None, None) + assert result is not None