Skip to content

Commit ce5741f

Browse files
committed
Updated SQLAlchemyConnectionField to be 2.0 compliant
1 parent a2fe926 commit ce5741f

File tree

5 files changed

+40
-27
lines changed

5 files changed

+40
-27
lines changed

graphene_sqlalchemy/converter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def dynamic_type():
3636
return Field(_type)
3737
elif direction in (interfaces.ONETOMANY, interfaces.MANYTOMANY):
3838
if _type._meta.connection:
39-
return createConnectionField(_type)
39+
return createConnectionField(_type._meta.connection)
4040
return Field(List(_type))
4141

4242
return Dynamic(dynamic_type)

graphene_sqlalchemy/fields.py

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from functools import partial
2-
2+
from promise import is_thenable, Promise
33
from sqlalchemy.orm.query import Query
44

55
from graphene.relay import ConnectionField
@@ -19,39 +19,38 @@ def model(self):
1919
def get_query(cls, model, info, **args):
2020
return get_query(model, info.context)
2121

22-
@property
23-
def type(self):
24-
from .types import SQLAlchemyObjectType
25-
_type = super(ConnectionField, self).type
26-
assert issubclass(_type, SQLAlchemyObjectType), (
27-
"SQLAlchemyConnectionField only accepts SQLAlchemyObjectType types"
28-
)
29-
assert _type._meta.connection, "The type {} doesn't have a connection".format(_type.__name__)
30-
return _type._meta.connection
31-
3222
@classmethod
33-
def connection_resolver(cls, resolver, connection, model, root, info, **args):
34-
iterable = resolver(root, info, **args)
35-
if iterable is None:
36-
iterable = cls.get_query(model, info, **args)
37-
if isinstance(iterable, Query):
38-
_len = iterable.count()
23+
def resolve_connection(cls, connection_type, model, info, args, resolved):
24+
if resolved is None:
25+
resolved = cls.get_query(model, info, **args)
26+
if isinstance(resolved, Query):
27+
_len = resolved.count()
3928
else:
40-
_len = len(iterable)
29+
_len = len(resolved)
4130
connection = connection_from_list_slice(
42-
iterable,
31+
resolved,
4332
args,
4433
slice_start=0,
4534
list_length=_len,
4635
list_slice_length=_len,
47-
connection_type=connection,
36+
connection_type=connection_type,
4837
pageinfo_type=PageInfo,
49-
edge_type=connection.Edge,
38+
edge_type=connection_type.Edge,
5039
)
51-
connection.iterable = iterable
40+
connection.iterable = resolved
5241
connection.length = _len
5342
return connection
5443

44+
@classmethod
45+
def connection_resolver(cls, resolver, connection_type, model, root, info, **args):
46+
resolved = resolver(root, info, **args)
47+
48+
on_resolve = partial(cls.resolve_connection, connection_type, model, info, args)
49+
if is_thenable(resolved):
50+
return Promise.resolve(resolved).then(on_resolve)
51+
52+
return on_resolve(resolved)
53+
5554
def get_resolver(self, parent_resolver):
5655
return partial(self.connection_resolver, parent_resolver, self.type, self.model)
5756

graphene_sqlalchemy/tests/test_connectionfactory.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,10 @@ def LXResolver(root, args, context, info):
2222
return SQLAlchemyConnectionField.connection_resolver(LXResolver, connection, model, root, args, context, info)
2323

2424
def createLXConnectionField(table):
25-
return LXConnectionField(table, filter=table.filter(), order_by=graphene.List(of_type=table.order_by))
25+
class LXConnection(graphene.relay.Connection, node=table):
26+
pass
27+
28+
return LXConnectionField(LXConnection, filter=table.filter(), order_by=graphene.List(of_type=table.order_by))
2629

2730
registerConnectionFieldFactory(createLXConnectionField)
2831
unregisterConnectionFieldFactory()

graphene_sqlalchemy/tests/test_query.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ class Meta:
139139
interfaces = (Node, )
140140

141141
@classmethod
142-
def get_node(cls, id, info):
142+
def get_node(cls, info, id):
143143
return Reporter(id=2, first_name='Cookie Monster')
144144

145145
class ArticleNode(SQLAlchemyObjectType):
@@ -152,11 +152,15 @@ class Meta:
152152
# def get_node(cls, id, info):
153153
# return Article(id=1, headline='Article node')
154154

155+
class ArticleConnection(graphene.relay.Connection):
156+
class Meta:
157+
node = ArticleNode
158+
155159
class Query(graphene.ObjectType):
156160
node = Node.Field()
157161
reporter = graphene.Field(ReporterNode)
158162
article = graphene.Field(ArticleNode)
159-
all_articles = SQLAlchemyConnectionField(ArticleNode)
163+
all_articles = SQLAlchemyConnectionField(ArticleConnection)
160164

161165
def resolve_reporter(self, *args, **kwargs):
162166
return session.query(Reporter).first()
@@ -238,9 +242,13 @@ class Meta:
238242
model = Editor
239243
interfaces = (Node, )
240244

245+
class EditorConnection(graphene.relay.Connection):
246+
class Meta:
247+
node = EditorNode
248+
241249
class Query(graphene.ObjectType):
242250
node = Node.Field()
243-
all_editors = SQLAlchemyConnectionField(EditorNode)
251+
all_editors = SQLAlchemyConnectionField(EditorConnection)
244252

245253
query = '''
246254
query EditorQuery {

graphene_sqlalchemy/types.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,9 @@ def __init_subclass_with_meta__(cls, model=None, registry=None, skip_registry=Fa
117117
connection = Connection.create_type('{}Connection'.format(cls.__name__), node=cls)
118118

119119
if connection is not None:
120+
if not issubclass(connection, Connection) and hasattr(connection, '__call__'):
121+
connection = connection()
122+
120123
assert issubclass(connection, Connection), (
121124
"The connection must be a Connection. Received {}"
122125
).format(connection.__name__)

0 commit comments

Comments
 (0)