Skip to content

Commit 340b29d

Browse files
committed
Fix syntax error in Batching from signature changes
1 parent 63888d3 commit 340b29d

File tree

6 files changed

+36
-20
lines changed

6 files changed

+36
-20
lines changed

graphene_sqlalchemy/batching.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
from sqlalchemy.orm import Session, strategies
44
from sqlalchemy.orm.query import QueryContext
55

6+
from .utils import is_sqlalchemy_version_less_than
7+
68

79
def get_batch_resolver(relationship_prop):
810

@@ -52,15 +54,30 @@ async def batch_load_fn(self, parents):
5254
states = [(sqlalchemy.inspect(parent), True) for parent in parents]
5355

5456
# For our purposes, the query_context will only used to get the session
55-
query_context = QueryContext(session.query(parent_mapper.entity))
56-
57-
selectin_loader._load_for_path(
58-
query_context,
59-
parent_mapper._path_registry,
60-
states,
61-
None,
62-
child_mapper,
63-
)
57+
query_context = None
58+
if is_sqlalchemy_version_less_than('1.4'):
59+
query_context = QueryContext(session.query(parent_mapper.entity))
60+
else:
61+
parent_mapper_query = session.query(parent_mapper.entity)
62+
query_context = parent_mapper_query._compile_context()
63+
64+
if is_sqlalchemy_version_less_than('1.4'):
65+
selectin_loader._load_for_path(
66+
query_context,
67+
parent_mapper._path_registry,
68+
states,
69+
None,
70+
child_mapper
71+
)
72+
else:
73+
selectin_loader._load_for_path(
74+
query_context,
75+
parent_mapper._path_registry,
76+
states,
77+
None,
78+
child_mapper,
79+
None
80+
)
6481

6582
return [getattr(parent, relationship_prop.key) for parent in parents]
6683

graphene_sqlalchemy/tests/test_batching.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
default_connection_field_factory)
1111
from ..types import ORMField, SQLAlchemyObjectType
1212
from .models import Article, HairKind, Pet, Reporter
13-
from .utils import is_sqlalchemy_version_less_than, to_std_dicts
13+
from ..utils import is_sqlalchemy_version_less_than
14+
from .utils import to_std_dicts
1415

1516

1617
class MockLoggingHandler(logging.Handler):

graphene_sqlalchemy/tests/test_benchmark.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
from graphene import relay
55

66
from ..types import SQLAlchemyObjectType
7+
from ..utils import is_sqlalchemy_version_less_than
78
from .models import Article, HairKind, Pet, Reporter
8-
from .utils import is_sqlalchemy_version_less_than
99

1010
if is_sqlalchemy_version_less_than('1.2'):
1111
pytest.skip('SQL batching only works for SQLAlchemy 1.2+', allow_module_level=True)

graphene_sqlalchemy/tests/test_converter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ class Model(declarative_base()):
4747
return convert_sqlalchemy_column(column_prop, get_global_registry(), mock_resolver)
4848

4949

50-
def _test_should_unknown_sqlalchemy_field_raise_exception():
50+
def test_should_unknown_sqlalchemy_field_raise_exception():
5151
# TODO: SQLALchemy does not export types.Binary, remove or update this test
5252
re_err = "Don't know how to convert the SQLAlchemy field"
5353
with pytest.raises(Exception, match=re_err):

graphene_sqlalchemy/tests/utils.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
import pkg_resources
2-
3-
41
def to_std_dicts(value):
52
"""Convert nested ordered dicts to normal dicts for better comparison."""
63
if isinstance(value, dict):
@@ -9,8 +6,3 @@ def to_std_dicts(value):
96
return [to_std_dicts(v) for v in value]
107
else:
118
return value
12-
13-
14-
def is_sqlalchemy_version_less_than(version_string):
15-
"""Check the installed SQLAlchemy version"""
16-
return pkg_resources.get_distribution('SQLAlchemy').parsed_version < pkg_resources.parse_version(version_string)

graphene_sqlalchemy/utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import re
22
import warnings
33

4+
import pkg_resources
45
from sqlalchemy.exc import ArgumentError
56
from sqlalchemy.orm import class_mapper, object_mapper
67
from sqlalchemy.orm.exc import UnmappedClassError, UnmappedInstanceError
@@ -140,3 +141,8 @@ def sort_argument_for_model(cls, has_default=True):
140141
enum.default = None
141142

142143
return Argument(List(enum), default_value=enum.default)
144+
145+
146+
def is_sqlalchemy_version_less_than(version_string):
147+
"""Check the installed SQLAlchemy version"""
148+
return pkg_resources.get_distribution('SQLAlchemy').parsed_version < pkg_resources.parse_version(version_string)

0 commit comments

Comments
 (0)