diff --git a/graphene_sqlalchemy/converter.py b/graphene_sqlalchemy/converter.py index 05690697..9b110cfb 100644 --- a/graphene_sqlalchemy/converter.py +++ b/graphene_sqlalchemy/converter.py @@ -8,7 +8,7 @@ from graphene.relay import is_node from graphene.types.json import JSONString -from .fields import SQLAlchemyConnectionField +from .fields import createConnectionField try: from sqlalchemy_utils import ChoiceType, JSONType, ScalarListType @@ -36,7 +36,7 @@ def dynamic_type(): elif (direction == interfaces.ONETOMANY or direction == interfaces.MANYTOMANY): if is_node(_type): - return SQLAlchemyConnectionField(_type) + return createConnectionField(_type) return Field(List(_type)) return Dynamic(dynamic_type) diff --git a/graphene_sqlalchemy/fields.py b/graphene_sqlalchemy/fields.py index de1d301f..8ac27942 100644 --- a/graphene_sqlalchemy/fields.py +++ b/graphene_sqlalchemy/fields.py @@ -41,3 +41,20 @@ def connection_resolver(cls, resolver, connection, model, root, args, context, i def get_resolver(self, parent_resolver): return partial(self.connection_resolver, parent_resolver, self.type, self.model) + + +__connectionFactory = SQLAlchemyConnectionField + + +def createConnectionField(_type): + return __connectionFactory(_type) + + +def registerConnectionFieldFactory(factoryMethod): + global __connectionFactory + __connectionFactory = factoryMethod + + +def unregisterConnectionFieldFactory(): + global __connectionFactory + __connectionFactory = SQLAlchemyConnectionField diff --git a/graphene_sqlalchemy/tests/test_connectionfactory.py b/graphene_sqlalchemy/tests/test_connectionfactory.py new file mode 100644 index 00000000..867c5261 --- /dev/null +++ b/graphene_sqlalchemy/tests/test_connectionfactory.py @@ -0,0 +1,28 @@ +from graphene_sqlalchemy.fields import SQLAlchemyConnectionField, registerConnectionFieldFactory, unregisterConnectionFieldFactory +import graphene + +def test_register(): + class LXConnectionField(SQLAlchemyConnectionField): + @classmethod + def _applyQueryArgs(cls, model, q, args): + return q + + @classmethod + def connection_resolver(cls, resolver, connection, model, root, args, context, info): + + def LXResolver(root, args, context, info): + iterable = resolver(root, args, context, info) + if iterable is None: + iterable = cls.get_query(model, context, info, args) + + # We accept always a query here. All LX-queries can be filtered and sorted + iterable = cls._applyQueryArgs(model, iterable, args) + return iterable + + return SQLAlchemyConnectionField.connection_resolver(LXResolver, connection, model, root, args, context, info) + + def createLXConnectionField(table): + return LXConnectionField(table, filter=table.filter(), order_by=graphene.List(of_type=table.order_by)) + + registerConnectionFieldFactory(createLXConnectionField) + unregisterConnectionFieldFactory()