diff --git a/graphene_sqlalchemy/batching.py b/graphene_sqlalchemy/batching.py index 85cc8855..e56b1e4c 100644 --- a/graphene_sqlalchemy/batching.py +++ b/graphene_sqlalchemy/batching.py @@ -1,13 +1,30 @@ +"""The dataloader uses "select in loading" strategy to load related entities.""" +from typing import Any + import aiodataloader import sqlalchemy from sqlalchemy.orm import Session, strategies from sqlalchemy.orm.query import QueryContext -from .utils import is_sqlalchemy_version_less_than +from .utils import (is_graphene_version_less_than, + is_sqlalchemy_version_less_than) -def get_batch_resolver(relationship_prop): +def get_data_loader_impl() -> Any: # pragma: no cover + """Graphene >= 3.1.1 ships a copy of aiodataloader with minor fixes. To preserve backward-compatibility, + aiodataloader is used in conjunction with older versions of graphene""" + if is_graphene_version_less_than("3.1.1"): + from aiodataloader import DataLoader + else: + from graphene.utils.dataloader import DataLoader + + return DataLoader + +DataLoader = get_data_loader_impl() + + +def get_batch_resolver(relationship_prop): # Cache this across `batch_load_fn` calls # This is so SQL string generation is cached under-the-hood via `bakery` selectin_loader = strategies.SelectInLoader(relationship_prop, (('lazy', 'selectin'),)) diff --git a/graphene_sqlalchemy/utils.py b/graphene_sqlalchemy/utils.py index f6ee9b62..27117c0c 100644 --- a/graphene_sqlalchemy/utils.py +++ b/graphene_sqlalchemy/utils.py @@ -151,11 +151,16 @@ def sort_argument_for_model(cls, has_default=True): return Argument(List(enum), default_value=enum.default) -def is_sqlalchemy_version_less_than(version_string): +def is_sqlalchemy_version_less_than(version_string): # pragma: no cover """Check the installed SQLAlchemy version""" return pkg_resources.get_distribution('SQLAlchemy').parsed_version < pkg_resources.parse_version(version_string) +def is_graphene_version_less_than(version_string): # pragma: no cover + """Check the installed graphene version""" + return pkg_resources.get_distribution('graphene').parsed_version < pkg_resources.parse_version(version_string) + + class singledispatchbymatchfunction: """ Inspired by @singledispatch, this is a variant that works using a matcher function @@ -197,6 +202,7 @@ def safe_isinstance_checker(arg): return isinstance(arg, cls) except TypeError: pass + return safe_isinstance_checker @@ -210,5 +216,6 @@ def registry_sqlalchemy_model_from_str(model_name: str) -> Optional[Any]: class DummyImport: """The dummy module returns 'object' for a query for any member""" + def __getattr__(self, name): return object