diff --git a/graphene_sqlalchemy/fields.py b/graphene_sqlalchemy/fields.py index 4a283ad8..bf5ed3a5 100644 --- a/graphene_sqlalchemy/fields.py +++ b/graphene_sqlalchemy/fields.py @@ -11,6 +11,18 @@ class UnsortedSQLAlchemyConnectionField(ConnectionField): + @property + def type(self): + from .types import SQLAlchemyObjectType + _type = super(ConnectionField, self).type + if issubclass(_type, Connection): + return _type + assert issubclass(_type, SQLAlchemyObjectType), ( + "SQLALchemyConnectionField only accepts SQLAlchemyObjectType types, not {}" + ).format(_type.__name__) + assert _type._meta.connection, "The type {} doesn't have a connection".format(_type.__name__) + return _type._meta.connection + @property def model(self): return self.type._meta.node._meta.model diff --git a/graphene_sqlalchemy/types.py b/graphene_sqlalchemy/types.py index 69bf310c..5d17747d 100644 --- a/graphene_sqlalchemy/types.py +++ b/graphene_sqlalchemy/types.py @@ -90,7 +90,8 @@ class SQLAlchemyObjectType(ObjectType): @classmethod def __init_subclass_with_meta__(cls, model=None, registry=None, skip_registry=False, only_fields=(), exclude_fields=(), connection=None, - use_connection=None, interfaces=(), id=None, _meta=None, **options): + connection_class=None, use_connection=None, interfaces=(), + id=None, _meta=None, **options): assert is_mapped_class(model), ( 'You need to pass a valid SQLAlchemy Model in ' '{}.Meta, received "{}".' @@ -114,7 +115,11 @@ def __init_subclass_with_meta__(cls, model=None, registry=None, skip_registry=Fa if use_connection and not connection: # We create the connection automatically - connection = Connection.create_type('{}Connection'.format(cls.__name__), node=cls) + if not connection_class: + connection_class = Connection + + connection = connection_class.create_type( + '{}Connection'.format(cls.__name__), node=cls) if connection is not None: assert issubclass(connection, Connection), (