From 494703061da07860ee50bb7f2128f1290650c3ff Mon Sep 17 00:00:00 2001 From: Erik Wrede Date: Fri, 27 Jan 2023 13:46:06 +0100 Subject: [PATCH 1/2] feat!: configurable key to ID conversion --- graphene_sqlalchemy/converter.py | 42 +++++++++++++++------ graphene_sqlalchemy/tests/test_converter.py | 26 +++++++++++-- 2 files changed, 53 insertions(+), 15 deletions(-) diff --git a/graphene_sqlalchemy/converter.py b/graphene_sqlalchemy/converter.py index 8c7cd7a..3d1c554 100644 --- a/graphene_sqlalchemy/converter.py +++ b/graphene_sqlalchemy/converter.py @@ -93,6 +93,14 @@ def set_non_null_many_relationships(non_null_flag): use_non_null_many_relationships = non_null_flag +use_id_type_for_keys = True + + +def set_id_for_keys(id_flag): + global use_id_type_for_keys + use_id_type_for_keys = id_flag + + def get_column_doc(column): return getattr(column, "doc", None) @@ -259,18 +267,34 @@ def inner(fn): convert_sqlalchemy_composite.register = _register_composite_class +def _is_primary_or_foreign_key(column): + return getattr(column, "primary_key", False) or ( + len(getattr(column, "foreign_keys", [])) > 0 + ) + + def convert_sqlalchemy_column(column_prop, registry, resolver, **field_kwargs): column = column_prop.columns[0] - # The converter expects a type to find the right conversion function. - # If we get an instance instead, we need to convert it to a type. - # The conversion function will still be able to access the instance via the column argument. + # We only use the converter if no type was specified using the ORMField if "type_" not in field_kwargs: - column_type = getattr(column, "type", None) - if not isinstance(column_type, type): - column_type = type(column_type) + # If the column is a primary key, we use the ID typ + if use_id_type_for_keys and _is_primary_or_foreign_key(column): + field_type = graphene.ID + else: + # The converter expects a type to find the right conversion function. + # If we get an instance instead, we need to convert it to a type. + # The conversion function will still be able to access the instance via the column argument. + column_type = getattr(column, "type", None) + if not isinstance(column_type, type): + column_type = type(column_type) + + field_type = convert_sqlalchemy_type( + column_type, column=column, registry=registry + ) + field_kwargs.setdefault( "type_", - convert_sqlalchemy_type(column_type, column=column, registry=registry), + field_type, ) field_kwargs.setdefault("required", not is_column_nullable(column)) field_kwargs.setdefault("description", get_column_doc(column)) @@ -385,10 +409,6 @@ def convert_column_to_int_or_id( registry: Registry = None, **kwargs, ): - # fixme drop the primary key processing from here in another pr - if column is not None: - if getattr(column, "primary_key", False) is True: - return graphene.ID return graphene.Int diff --git a/graphene_sqlalchemy/tests/test_converter.py b/graphene_sqlalchemy/tests/test_converter.py index f70a50f..ff26e9b 100644 --- a/graphene_sqlalchemy/tests/test_converter.py +++ b/graphene_sqlalchemy/tests/test_converter.py @@ -5,7 +5,7 @@ import pytest import sqlalchemy import sqlalchemy_utils as sqa_utils -from sqlalchemy import Column, func, select, types +from sqlalchemy import Column, ForeignKey, func, select, types from sqlalchemy.dialects import postgresql from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.hybrid import hybrid_property @@ -42,11 +42,13 @@ def mock_resolver(): pass -def get_field(sqlalchemy_type, **column_kwargs): +def get_field(sqlalchemy_type, *column_args, **column_kwargs): class Model(declarative_base()): __tablename__ = "model" id_ = Column(types.Integer, primary_key=True) - column = Column(sqlalchemy_type, doc="Custom Help Text", **column_kwargs) + column = Column( + sqlalchemy_type, *column_args, doc="Custom Help Text", **column_kwargs + ) column_prop = inspect(Model).column_attrs["column"] return convert_sqlalchemy_column(column_prop, get_global_registry(), mock_resolver) @@ -381,12 +383,28 @@ def test_should_integer_convert_int(): assert get_field(types.Integer()).type == graphene.Int -def test_should_primary_integer_convert_id(): +def test_should_key_integer_convert_id(): assert get_field(types.Integer(), primary_key=True).type == graphene.NonNull( graphene.ID ) +def test_should_primary_string_convert_id(): + assert get_field(types.String(), primary_key=True).type == graphene.NonNull( + graphene.ID + ) + + +def test_should_primary_uuid_convert_id(): + assert get_field(sqa_utils.UUIDType, primary_key=True).type == graphene.NonNull( + graphene.ID + ) + + +def test_should_foreign_key_convert_id(): + assert get_field(types.Integer(), ForeignKey("model.id_")).type == graphene.ID + + def test_should_boolean_convert_boolean(): assert get_field(types.Boolean()).type == graphene.Boolean From b0a9b8474c5c57635bd323652954f45dc100af69 Mon Sep 17 00:00:00 2001 From: Erik Wrede Date: Fri, 6 Oct 2023 23:28:08 +0200 Subject: [PATCH 2/2] merge master --- graphene_sqlalchemy/tests/models.py | 5 ++- graphene_sqlalchemy/types.py | 54 ++++++++++++++++------------- graphene_sqlalchemy/utils.py | 8 ++++- 3 files changed, 38 insertions(+), 29 deletions(-) diff --git a/graphene_sqlalchemy/tests/models.py b/graphene_sqlalchemy/tests/models.py index b638b5d..77f6aee 100644 --- a/graphene_sqlalchemy/tests/models.py +++ b/graphene_sqlalchemy/tests/models.py @@ -26,11 +26,10 @@ from graphene_sqlalchemy.utils import SQL_VERSION_HIGHER_EQUAL_THAN_1_4, SQL_VERSION_HIGHER_EQUAL_THAN_2 # fmt: off -import sqlalchemy if SQL_VERSION_HIGHER_EQUAL_THAN_2: - from sqlalchemy.sql.sqltypes import HasExpressionLookup # noqa # isort:skip + from sqlalchemy.sql.sqltypes import HasExpressionLookup # noqa # isort:skip else: - from sqlalchemy.sql.sqltypes import _LookupExpressionAdapter as HasExpressionLookup # noqa # isort:skip + from sqlalchemy.sql.sqltypes import _LookupExpressionAdapter as HasExpressionLookup # noqa # isort:skip # fmt: on PetKind = Enum("cat", "dog", name="pet_kind") diff --git a/graphene_sqlalchemy/types.py b/graphene_sqlalchemy/types.py index 226d1e8..66db1e6 100644 --- a/graphene_sqlalchemy/types.py +++ b/graphene_sqlalchemy/types.py @@ -408,13 +408,15 @@ class SQLAlchemyObjectType(SQLAlchemyBase, ObjectType): Usage: - class MyModel(Base): - id = Column(Integer(), primary_key=True) - name = Column(String()) + .. code-block:: python - class MyType(SQLAlchemyObjectType): - class Meta: - model = MyModel + class MyModel(Base): + id = Column(Integer(), primary_key=True) + name = Column(String()) + + class MyType(SQLAlchemyObjectType): + class Meta: + model = MyModel """ @classmethod @@ -450,30 +452,32 @@ class SQLAlchemyInterface(SQLAlchemyBase, Interface): Usage (using joined table inheritance): - class MyBaseModel(Base): - id = Column(Integer(), primary_key=True) - type = Column(String()) - name = Column(String()) + .. code-block:: python - __mapper_args__ = { - "polymorphic_on": type, - } + class MyBaseModel(Base): + id = Column(Integer(), primary_key=True) + type = Column(String()) + name = Column(String()) - class MyChildModel(Base): - date = Column(Date()) + __mapper_args__ = { + "polymorphic_on": type, + } - __mapper_args__ = { - "polymorphic_identity": "child", - } + class MyChildModel(Base): + date = Column(Date()) - class MyBaseType(SQLAlchemyInterface): - class Meta: - model = MyBaseModel + __mapper_args__ = { + "polymorphic_identity": "child", + } - class MyChildType(SQLAlchemyObjectType): - class Meta: - model = MyChildModel - interfaces = (MyBaseType,) + class MyBaseType(SQLAlchemyInterface): + class Meta: + model = MyBaseModel + + class MyChildType(SQLAlchemyObjectType): + class Meta: + model = MyChildModel + interfaces = (MyBaseType,) """ @classmethod diff --git a/graphene_sqlalchemy/utils.py b/graphene_sqlalchemy/utils.py index ac5be88..bb9386e 100644 --- a/graphene_sqlalchemy/utils.py +++ b/graphene_sqlalchemy/utils.py @@ -27,12 +27,18 @@ def is_graphene_version_less_than(version_string): # pragma: no cover SQL_VERSION_HIGHER_EQUAL_THAN_1_4 = False -if not is_sqlalchemy_version_less_than("1.4"): +if not is_sqlalchemy_version_less_than("1.4"): # pragma: no cover from sqlalchemy.ext.asyncio import AsyncSession SQL_VERSION_HIGHER_EQUAL_THAN_1_4 = True +SQL_VERSION_HIGHER_EQUAL_THAN_2 = False + +if not is_sqlalchemy_version_less_than("2.0.0b1"): # pragma: no cover + SQL_VERSION_HIGHER_EQUAL_THAN_2 = True + + def get_session(context): return context.get("session")