From 18dd02a079c14c4dfd051b6ce1dabeb904cc14e5 Mon Sep 17 00:00:00 2001 From: Geert JM Vanderkelen Date: Tue, 21 Nov 2017 12:37:28 +0100 Subject: [PATCH] Allow creation of custom connections. Fixes #65. Previously, it was not possible to create custom connections without monkey patching graphene.relay.connection. Custom connections are useful when all connections need a total count for paging, for example. We change the SQLAlchemyObjectType.__init_subclass_with_meta method so that when the connection argument is a class, it is used to create the new connection type. A test case has been added to check usage of custom connections. The `conftest.py` was added to allow pytest using database related fixtures in other places than `test_query.py`. --- graphene_sqlalchemy/tests/conftest.py | 40 +++++++++++ graphene_sqlalchemy/tests/test_query.py | 89 ++----------------------- graphene_sqlalchemy/tests/test_types.py | 83 +++++++++++++++++++++-- graphene_sqlalchemy/types.py | 28 +++++--- 4 files changed, 142 insertions(+), 98 deletions(-) create mode 100644 graphene_sqlalchemy/tests/conftest.py diff --git a/graphene_sqlalchemy/tests/conftest.py b/graphene_sqlalchemy/tests/conftest.py new file mode 100644 index 00000000..05963829 --- /dev/null +++ b/graphene_sqlalchemy/tests/conftest.py @@ -0,0 +1,40 @@ +import pytest +from sqlalchemy import create_engine +from sqlalchemy.orm import scoped_session, sessionmaker + +from .models import Article, Base, Editor, Reporter +from ..registry import reset_global_registry + +db = create_engine('sqlite:///test_sqlalchemy.sqlite3') + + +@pytest.yield_fixture(scope='function') +def session(): + reset_global_registry() + connection = db.engine.connect() + transaction = connection.begin() + Base.metadata.create_all(connection) + + # options = dict(bind=connection, binds={}) + session_factory = sessionmaker(bind=connection) + session = scoped_session(session_factory) + + yield session + + # Finalize test here + transaction.rollback() + connection.close() + session.remove() + +@pytest.yield_fixture(scope='function') +def setup_fixtures(session): + reporter = Reporter(first_name='ABA', last_name='X') + session.add(reporter) + reporter2 = Reporter(first_name='ABO', last_name='Y') + session.add(reporter2) + article = Article(headline='Hi!') + article.reporter = reporter + session.add(article) + editor = Editor(name="John") + session.add(editor) + session.commit() diff --git a/graphene_sqlalchemy/tests/test_query.py b/graphene_sqlalchemy/tests/test_query.py index 12dd1fad..5c644991 100644 --- a/graphene_sqlalchemy/tests/test_query.py +++ b/graphene_sqlalchemy/tests/test_query.py @@ -1,54 +1,12 @@ -import pytest -from sqlalchemy import create_engine -from sqlalchemy.orm import scoped_session, sessionmaker - import graphene from graphene.relay import Node -from ..registry import reset_global_registry +from .models import Article, Editor, Reporter from ..fields import SQLAlchemyConnectionField from ..types import SQLAlchemyObjectType -from .models import Article, Base, Editor, Pet, Reporter - -db = create_engine('sqlite:///test_sqlalchemy.sqlite3') - - -@pytest.yield_fixture(scope='function') -def session(): - reset_global_registry() - connection = db.engine.connect() - transaction = connection.begin() - Base.metadata.create_all(connection) - - # options = dict(bind=connection, binds={}) - session_factory = sessionmaker(bind=connection) - session = scoped_session(session_factory) - - yield session - - # Finalize test here - transaction.rollback() - connection.close() - session.remove() -def setup_fixtures(session): - pet = Pet(name='Lassie', pet_kind='dog') - session.add(pet) - reporter = Reporter(first_name='ABA', last_name='X') - session.add(reporter) - reporter2 = Reporter(first_name='ABO', last_name='Y') - session.add(reporter2) - article = Article(headline='Hi!') - article.reporter = reporter - session.add(article) - editor = Editor(name="John") - session.add(editor) - session.commit() - - -def test_should_query_well(session): - setup_fixtures(session) +def test_should_query_well(session, setup_fixtures): class ReporterType(SQLAlchemyObjectType): @@ -95,42 +53,7 @@ def resolve_reporters(self, *args, **kwargs): assert result.data == expected -def test_should_query_enums(session): - setup_fixtures(session) - - class PetType(SQLAlchemyObjectType): - - class Meta: - model = Pet - - class Query(graphene.ObjectType): - pet = graphene.Field(PetType) - - def resolve_pet(self, *args, **kwargs): - return session.query(Pet).first() - - query = ''' - query PetQuery { - pet { - name, - petKind - } - } - ''' - expected = { - 'pet': { - 'name': 'Lassie', - 'petKind': 'dog' - } - } - schema = graphene.Schema(query=Query) - result = schema.execute(query) - assert not result.errors - assert result.data == expected, result.data - - -def test_should_node(session): - setup_fixtures(session) +def test_should_node(session, setup_fixtures): class ReporterNode(SQLAlchemyObjectType): @@ -229,8 +152,7 @@ def resolve_article(self, *args, **kwargs): assert result.data == expected -def test_should_custom_identifier(session): - setup_fixtures(session) +def test_should_custom_identifier(session, setup_fixtures): class EditorNode(SQLAlchemyObjectType): @@ -279,8 +201,7 @@ class Query(graphene.ObjectType): assert result.data == expected -def test_should_mutate_well(session): - setup_fixtures(session) +def test_should_mutate_well(session, setup_fixtures): class EditorNode(SQLAlchemyObjectType): diff --git a/graphene_sqlalchemy/tests/test_types.py b/graphene_sqlalchemy/tests/test_types.py index 3f017aae..5d745917 100644 --- a/graphene_sqlalchemy/tests/test_types.py +++ b/graphene_sqlalchemy/tests/test_types.py @@ -1,10 +1,10 @@ - -from graphene import Field, Int, Interface, ObjectType +from graphene import Field, Int, Interface, ObjectType, Connection, Schema from graphene.relay import Node, is_node -import six +import pytest from ..registry import Registry from ..types import SQLAlchemyObjectType +from ..fields import SQLAlchemyConnectionField from .models import Article, Reporter registry = Registry() @@ -72,8 +72,6 @@ def test_node_replacedfield(): def test_object_type(): - - class Human(SQLAlchemyObjectType): '''Human description''' @@ -90,7 +88,6 @@ class Meta: assert is_node(Human) - # Test Custom SQLAlchemyObjectType Implementation class CustomSQLAlchemyObjectType(SQLAlchemyObjectType): class Meta: @@ -116,3 +113,77 @@ def test_custom_objecttype_registered(): 'pets', 'articles', 'favorite_article'] + + +def test_custom_connection(session, setup_fixtures): + exp_counter = 123 + + class CustomConnection(Connection): + class Meta: + abstract = True + + counter = Int() + + @staticmethod + def resolve_counter(*args, **kwargs): + return exp_counter + + class ArticleType(SQLAlchemyObjectType): + class Meta: + model = Article + connection = CustomConnection + interfaces = (Node,) + registry = registry + + class Query(ObjectType): + articles = SQLAlchemyConnectionField(ArticleType) + + schema = Schema(query=Query) + result = schema.execute("query { articles { counter edges { node { headline }}}}", + context_value={'session': session}) + + assert not result.errors + assert result.data['articles']['counter'] == exp_counter + assert result.data['articles']['edges'][0]['node']['headline'] == 'Hi!' + + +def test_automatically_created_connection(): + expected = "ArticleTypeConnection" + + class ArticleType(SQLAlchemyObjectType): + class Meta: + model = Article + interfaces = (Node,) + registry = registry + + assert ArticleType._meta.connection.__name__ == expected + + +def test_passing_connection_instance(): + expected = "CnxHumanType" + + class HumanType(SQLAlchemyObjectType): + class Meta: + model = Reporter + interfaces = (Node,) + + class ArticleType(SQLAlchemyObjectType): + class Meta: + model = Article + interfaces = (Node,) + connection = Connection.create_type(expected, node=HumanType) + registry = registry + + assert ArticleType._meta.connection.__name__ == expected + + +def test_passing_incorrect_connection_instance(): + with pytest.raises(AssertionError) as excinfo: + class ArticleType(SQLAlchemyObjectType): + class Meta: + model = Article + interfaces = (Node,) + connection = 'spam' + registry = registry + + assert "The connection must be a Connection. Received" in str(excinfo.value) diff --git a/graphene_sqlalchemy/types.py b/graphene_sqlalchemy/types.py index 04d1a8a6..20f27578 100644 --- a/graphene_sqlalchemy/types.py +++ b/graphene_sqlalchemy/types.py @@ -1,4 +1,5 @@ from collections import OrderedDict +from inspect import isclass from sqlalchemy.inspection import inspect as sqlalchemyinspect from sqlalchemy.ext.hybrid import hybrid_property @@ -112,20 +113,31 @@ def __init_subclass_with_meta__(cls, model=None, registry=None, skip_registry=Fa if use_connection is None and interfaces: use_connection = any((issubclass(interface, Node) for interface in interfaces)) - if use_connection and not connection: - # We create the connection automatically - connection = Connection.create_type('{}Connection'.format(cls.__name__), node=cls) - - if connection is not None: - assert issubclass(connection, Connection), ( + cnx = None + if use_connection: + if connection and isclass(connection) and issubclass(connection, Connection) and \ + not hasattr(connection, '_meta'): + # Create connection type automatically using given class + cnx = connection.create_type('{}Connection'.format(cls.__name__), node=cls) + elif not connection: + # Create connection type automatically using graphene.relay.Connection + cnx = Connection.create_type('{}Connection'.format(cls.__name__), node=cls) + else: + cnx = connection + + if cnx is not None: + assert isclass(cnx), ( + "The connection must be a Connection. Received {}" + ).format(type(cnx)) + assert issubclass(cnx, Connection), ( "The connection must be a Connection. Received {}" - ).format(connection.__name__) + ).format(cnx.__name__) _meta = SQLAlchemyObjectTypeOptions(cls) _meta.model = model _meta.registry = registry _meta.fields = sqla_fields - _meta.connection = connection + _meta.connection = cnx _meta.id = id or 'id' super(SQLAlchemyObjectType, cls).__init_subclass_with_meta__(_meta=_meta, interfaces=interfaces, **options)