Skip to content

SQLAlchemyConnectionField Graphene 2.0 + Promise Support #120

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 4, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion graphene_sqlalchemy/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def dynamic_type():
return Field(_type)
elif direction in (interfaces.ONETOMANY, interfaces.MANYTOMANY):
if _type._meta.connection:
return createConnectionField(_type)
return createConnectionField(_type._meta.connection)
return Field(List(_type))

return Dynamic(dynamic_type)
Expand Down
43 changes: 21 additions & 22 deletions graphene_sqlalchemy/fields.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from functools import partial

from promise import is_thenable, Promise
from sqlalchemy.orm.query import Query

from graphene.relay import ConnectionField
Expand All @@ -19,39 +19,38 @@ def model(self):
def get_query(cls, model, info, **args):
return get_query(model, info.context)

@property
def type(self):
from .types import SQLAlchemyObjectType
_type = super(ConnectionField, self).type
assert issubclass(_type, SQLAlchemyObjectType), (
"SQLAlchemyConnectionField only accepts SQLAlchemyObjectType types"
)
assert _type._meta.connection, "The type {} doesn't have a connection".format(_type.__name__)
return _type._meta.connection

@classmethod
def connection_resolver(cls, resolver, connection, model, root, info, **args):
iterable = resolver(root, info, **args)
if iterable is None:
iterable = cls.get_query(model, info, **args)
if isinstance(iterable, Query):
_len = iterable.count()
def resolve_connection(cls, connection_type, model, info, args, resolved):
if resolved is None:
resolved = cls.get_query(model, info, **args)
if isinstance(resolved, Query):
_len = resolved.count()
else:
_len = len(iterable)
_len = len(resolved)
connection = connection_from_list_slice(
iterable,
resolved,
args,
slice_start=0,
list_length=_len,
list_slice_length=_len,
connection_type=connection,
connection_type=connection_type,
pageinfo_type=PageInfo,
edge_type=connection.Edge,
edge_type=connection_type.Edge,
)
connection.iterable = iterable
connection.iterable = resolved
connection.length = _len
return connection

@classmethod
def connection_resolver(cls, resolver, connection_type, model, root, info, **args):
resolved = resolver(root, info, **args)

on_resolve = partial(cls.resolve_connection, connection_type, model, info, args)
if is_thenable(resolved):
return Promise.resolve(resolved).then(on_resolve)

return on_resolve(resolved)

def get_resolver(self, parent_resolver):
return partial(self.connection_resolver, parent_resolver, self.type, self.model)

Expand Down
5 changes: 4 additions & 1 deletion graphene_sqlalchemy/tests/test_connectionfactory.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@ def LXResolver(root, args, context, info):
return SQLAlchemyConnectionField.connection_resolver(LXResolver, connection, model, root, args, context, info)

def createLXConnectionField(table):
return LXConnectionField(table, filter=table.filter(), order_by=graphene.List(of_type=table.order_by))
class LXConnection(graphene.relay.Connection):
class Meta:
node = table
return LXConnectionField(LXConnection, filter=table.filter(), order_by=graphene.List(of_type=table.order_by))

registerConnectionFieldFactory(createLXConnectionField)
unregisterConnectionFieldFactory()
14 changes: 11 additions & 3 deletions graphene_sqlalchemy/tests/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ class Meta:
interfaces = (Node, )

@classmethod
def get_node(cls, id, info):
def get_node(cls, info, id):
return Reporter(id=2, first_name='Cookie Monster')

class ArticleNode(SQLAlchemyObjectType):
Expand All @@ -152,11 +152,15 @@ class Meta:
# def get_node(cls, id, info):
# return Article(id=1, headline='Article node')

class ArticleConnection(graphene.relay.Connection):
class Meta:
node = ArticleNode

class Query(graphene.ObjectType):
node = Node.Field()
reporter = graphene.Field(ReporterNode)
article = graphene.Field(ArticleNode)
all_articles = SQLAlchemyConnectionField(ArticleNode)
all_articles = SQLAlchemyConnectionField(ArticleConnection)

def resolve_reporter(self, *args, **kwargs):
return session.query(Reporter).first()
Expand Down Expand Up @@ -238,9 +242,13 @@ class Meta:
model = Editor
interfaces = (Node, )

class EditorConnection(graphene.relay.Connection):
class Meta:
node = EditorNode

class Query(graphene.ObjectType):
node = Node.Field()
all_editors = SQLAlchemyConnectionField(EditorNode)
all_editors = SQLAlchemyConnectionField(EditorConnection)

query = '''
query EditorQuery {
Expand Down
14 changes: 13 additions & 1 deletion graphene_sqlalchemy/tests/test_types.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from collections import OrderedDict
from graphene import Field, Int, Interface, ObjectType
from graphene.relay import Node, is_node
from graphene.relay import Node, is_node, Connection
import six
from promise import Promise

from ..registry import Registry
from ..types import SQLAlchemyObjectType, SQLAlchemyObjectTypeOptions
from .models import Article, Reporter
from ..fields import SQLAlchemyConnectionField

registry = Registry()

Expand Down Expand Up @@ -158,3 +160,13 @@ def test_objecttype_with_custom_options():
'favorite_article']
assert ReporterWithCustomOptions._meta.custom_option == 'custom_option'
assert isinstance(ReporterWithCustomOptions._meta.fields['custom_field'].type, Int)


def test_promise_connection_resolver():
class TestConnection(Connection):
class Meta:
node = ReporterWithCustomOptions

resolver = lambda *args, **kwargs: Promise.resolve([])
result = SQLAlchemyConnectionField.connection_resolver(resolver, TestConnection, ReporterWithCustomOptions, None, None)
assert result is not None