From ce5741f31e8077f9db5abc1c9e065ebae3d1cc7f Mon Sep 17 00:00:00 2001 From: nickharris Date: Wed, 21 Mar 2018 17:35:43 -0600 Subject: [PATCH 1/5] Updated SQLAlchemyConnectionField to be 2.0 compliant --- graphene_sqlalchemy/converter.py | 2 +- graphene_sqlalchemy/fields.py | 43 +++++++++---------- .../tests/test_connectionfactory.py | 5 ++- graphene_sqlalchemy/tests/test_query.py | 14 ++++-- graphene_sqlalchemy/types.py | 3 ++ 5 files changed, 40 insertions(+), 27 deletions(-) diff --git a/graphene_sqlalchemy/converter.py b/graphene_sqlalchemy/converter.py index 0d745ce3..08d3b7ab 100644 --- a/graphene_sqlalchemy/converter.py +++ b/graphene_sqlalchemy/converter.py @@ -36,7 +36,7 @@ def dynamic_type(): return Field(_type) elif direction in (interfaces.ONETOMANY, interfaces.MANYTOMANY): if _type._meta.connection: - return createConnectionField(_type) + return createConnectionField(_type._meta.connection) return Field(List(_type)) return Dynamic(dynamic_type) diff --git a/graphene_sqlalchemy/fields.py b/graphene_sqlalchemy/fields.py index bb084b3a..045ecec8 100644 --- a/graphene_sqlalchemy/fields.py +++ b/graphene_sqlalchemy/fields.py @@ -1,5 +1,5 @@ from functools import partial - +from promise import is_thenable, Promise from sqlalchemy.orm.query import Query from graphene.relay import ConnectionField @@ -19,39 +19,38 @@ def model(self): def get_query(cls, model, info, **args): return get_query(model, info.context) - @property - def type(self): - from .types import SQLAlchemyObjectType - _type = super(ConnectionField, self).type - assert issubclass(_type, SQLAlchemyObjectType), ( - "SQLAlchemyConnectionField only accepts SQLAlchemyObjectType types" - ) - assert _type._meta.connection, "The type {} doesn't have a connection".format(_type.__name__) - return _type._meta.connection - @classmethod - def connection_resolver(cls, resolver, connection, model, root, info, **args): - iterable = resolver(root, info, **args) - if iterable is None: - iterable = cls.get_query(model, info, **args) - if isinstance(iterable, Query): - _len = iterable.count() + def resolve_connection(cls, connection_type, model, info, args, resolved): + if resolved is None: + resolved = cls.get_query(model, info, **args) + if isinstance(resolved, Query): + _len = resolved.count() else: - _len = len(iterable) + _len = len(resolved) connection = connection_from_list_slice( - iterable, + resolved, args, slice_start=0, list_length=_len, list_slice_length=_len, - connection_type=connection, + connection_type=connection_type, pageinfo_type=PageInfo, - edge_type=connection.Edge, + edge_type=connection_type.Edge, ) - connection.iterable = iterable + connection.iterable = resolved connection.length = _len return connection + @classmethod + def connection_resolver(cls, resolver, connection_type, model, root, info, **args): + resolved = resolver(root, info, **args) + + on_resolve = partial(cls.resolve_connection, connection_type, model, info, args) + if is_thenable(resolved): + return Promise.resolve(resolved).then(on_resolve) + + return on_resolve(resolved) + def get_resolver(self, parent_resolver): return partial(self.connection_resolver, parent_resolver, self.type, self.model) diff --git a/graphene_sqlalchemy/tests/test_connectionfactory.py b/graphene_sqlalchemy/tests/test_connectionfactory.py index 867c5261..2d9109a5 100644 --- a/graphene_sqlalchemy/tests/test_connectionfactory.py +++ b/graphene_sqlalchemy/tests/test_connectionfactory.py @@ -22,7 +22,10 @@ def LXResolver(root, args, context, info): return SQLAlchemyConnectionField.connection_resolver(LXResolver, connection, model, root, args, context, info) def createLXConnectionField(table): - return LXConnectionField(table, filter=table.filter(), order_by=graphene.List(of_type=table.order_by)) + class LXConnection(graphene.relay.Connection, node=table): + pass + + return LXConnectionField(LXConnection, filter=table.filter(), order_by=graphene.List(of_type=table.order_by)) registerConnectionFieldFactory(createLXConnectionField) unregisterConnectionFieldFactory() diff --git a/graphene_sqlalchemy/tests/test_query.py b/graphene_sqlalchemy/tests/test_query.py index 12dd1fad..5654a1e0 100644 --- a/graphene_sqlalchemy/tests/test_query.py +++ b/graphene_sqlalchemy/tests/test_query.py @@ -139,7 +139,7 @@ class Meta: interfaces = (Node, ) @classmethod - def get_node(cls, id, info): + def get_node(cls, info, id): return Reporter(id=2, first_name='Cookie Monster') class ArticleNode(SQLAlchemyObjectType): @@ -152,11 +152,15 @@ class Meta: # def get_node(cls, id, info): # return Article(id=1, headline='Article node') + class ArticleConnection(graphene.relay.Connection): + class Meta: + node = ArticleNode + class Query(graphene.ObjectType): node = Node.Field() reporter = graphene.Field(ReporterNode) article = graphene.Field(ArticleNode) - all_articles = SQLAlchemyConnectionField(ArticleNode) + all_articles = SQLAlchemyConnectionField(ArticleConnection) def resolve_reporter(self, *args, **kwargs): return session.query(Reporter).first() @@ -238,9 +242,13 @@ class Meta: model = Editor interfaces = (Node, ) + class EditorConnection(graphene.relay.Connection): + class Meta: + node = EditorNode + class Query(graphene.ObjectType): node = Node.Field() - all_editors = SQLAlchemyConnectionField(EditorNode) + all_editors = SQLAlchemyConnectionField(EditorConnection) query = ''' query EditorQuery { diff --git a/graphene_sqlalchemy/types.py b/graphene_sqlalchemy/types.py index 69bf310c..5afcc8ea 100644 --- a/graphene_sqlalchemy/types.py +++ b/graphene_sqlalchemy/types.py @@ -117,6 +117,9 @@ def __init_subclass_with_meta__(cls, model=None, registry=None, skip_registry=Fa connection = Connection.create_type('{}Connection'.format(cls.__name__), node=cls) if connection is not None: + if not issubclass(connection, Connection) and hasattr(connection, '__call__'): + connection = connection() + assert issubclass(connection, Connection), ( "The connection must be a Connection. Received {}" ).format(connection.__name__) From d0609698da1df58d2d1e6d2cbc71eece65d77bed Mon Sep 17 00:00:00 2001 From: nickharris Date: Wed, 21 Mar 2018 17:53:05 -0600 Subject: [PATCH 2/5] fixed Connection class definition syntax error --- graphene_sqlalchemy/tests/test_connectionfactory.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/graphene_sqlalchemy/tests/test_connectionfactory.py b/graphene_sqlalchemy/tests/test_connectionfactory.py index 2d9109a5..6222a431 100644 --- a/graphene_sqlalchemy/tests/test_connectionfactory.py +++ b/graphene_sqlalchemy/tests/test_connectionfactory.py @@ -22,9 +22,9 @@ def LXResolver(root, args, context, info): return SQLAlchemyConnectionField.connection_resolver(LXResolver, connection, model, root, args, context, info) def createLXConnectionField(table): - class LXConnection(graphene.relay.Connection, node=table): - pass - + class LXConnection(graphene.relay.Connection): + class Meta: + node = table return LXConnectionField(LXConnection, filter=table.filter(), order_by=graphene.List(of_type=table.order_by)) registerConnectionFieldFactory(createLXConnectionField) From 80a5b0a711e597741eee54747481e01705cd8a3a Mon Sep 17 00:00:00 2001 From: nickharris Date: Wed, 21 Mar 2018 18:05:22 -0600 Subject: [PATCH 3/5] fixed indentation error --- graphene_sqlalchemy/types.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graphene_sqlalchemy/types.py b/graphene_sqlalchemy/types.py index 5afcc8ea..b9777b99 100644 --- a/graphene_sqlalchemy/types.py +++ b/graphene_sqlalchemy/types.py @@ -118,7 +118,7 @@ def __init_subclass_with_meta__(cls, model=None, registry=None, skip_registry=Fa if connection is not None: if not issubclass(connection, Connection) and hasattr(connection, '__call__'): - connection = connection() + connection = connection() assert issubclass(connection, Connection), ( "The connection must be a Connection. Received {}" From 164fd72f8cffc3340422ecfeedbb9597cc9a58df Mon Sep 17 00:00:00 2001 From: nickharris Date: Wed, 21 Mar 2018 18:14:07 -0600 Subject: [PATCH 4/5] removed callable connection meta field feature --- graphene_sqlalchemy/types.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/graphene_sqlalchemy/types.py b/graphene_sqlalchemy/types.py index b9777b99..69bf310c 100644 --- a/graphene_sqlalchemy/types.py +++ b/graphene_sqlalchemy/types.py @@ -117,9 +117,6 @@ def __init_subclass_with_meta__(cls, model=None, registry=None, skip_registry=Fa connection = Connection.create_type('{}Connection'.format(cls.__name__), node=cls) if connection is not None: - if not issubclass(connection, Connection) and hasattr(connection, '__call__'): - connection = connection() - assert issubclass(connection, Connection), ( "The connection must be a Connection. Received {}" ).format(connection.__name__) From 65e13737189e272c7b08ddffc03ec600169dbe63 Mon Sep 17 00:00:00 2001 From: nickharris Date: Wed, 21 Mar 2018 18:33:15 -0600 Subject: [PATCH 5/5] added promise test for connection resolver --- graphene_sqlalchemy/tests/test_types.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/graphene_sqlalchemy/tests/test_types.py b/graphene_sqlalchemy/tests/test_types.py index 2766c2ab..282e321f 100644 --- a/graphene_sqlalchemy/tests/test_types.py +++ b/graphene_sqlalchemy/tests/test_types.py @@ -1,11 +1,13 @@ from collections import OrderedDict from graphene import Field, Int, Interface, ObjectType -from graphene.relay import Node, is_node +from graphene.relay import Node, is_node, Connection import six +from promise import Promise from ..registry import Registry from ..types import SQLAlchemyObjectType, SQLAlchemyObjectTypeOptions from .models import Article, Reporter +from ..fields import SQLAlchemyConnectionField registry = Registry() @@ -158,3 +160,13 @@ def test_objecttype_with_custom_options(): 'favorite_article'] assert ReporterWithCustomOptions._meta.custom_option == 'custom_option' assert isinstance(ReporterWithCustomOptions._meta.fields['custom_field'].type, Int) + + +def test_promise_connection_resolver(): + class TestConnection(Connection): + class Meta: + node = ReporterWithCustomOptions + + resolver = lambda *args, **kwargs: Promise.resolve([]) + result = SQLAlchemyConnectionField.connection_resolver(resolver, TestConnection, ReporterWithCustomOptions, None, None) + assert result is not None