Skip to content

Commit ac57fd4

Browse files
committed
Enable sorting when batching is enabled
1 parent dfee3e9 commit ac57fd4

File tree

4 files changed

+336
-153
lines changed

4 files changed

+336
-153
lines changed

graphene_sqlalchemy/batching.py

Lines changed: 36 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,28 @@
1+
from asyncio import get_event_loop
2+
13
import aiodataloader
24
import sqlalchemy
35
from sqlalchemy.orm import Session, strategies
46
from sqlalchemy.orm.query import QueryContext
57

68
from .utils import is_sqlalchemy_version_less_than
79

10+
# Cache this across `batch_load_fn` calls
11+
# This is so SQL string generation is cached under-the-hood via `bakery`
12+
# Caching the relationship loader for each relationship prop.
13+
RELATIONSHIP_LOADERS_CACHE = {}
814

9-
def get_batch_resolver(relationship_prop):
1015

11-
# Cache this across `batch_load_fn` calls
12-
# This is so SQL string generation is cached under-the-hood via `bakery`
13-
selectin_loader = strategies.SelectInLoader(relationship_prop, (('lazy', 'selectin'),))
16+
def get_batch_resolver(relationship_prop):
1417

1518
class RelationshipLoader(aiodataloader.DataLoader):
1619
cache = False
1720

21+
def __init__(self, relationship_prop, selectin_loader):
22+
super().__init__()
23+
self.relationship_prop = relationship_prop
24+
self.selectin_loader = selectin_loader
25+
1826
async def batch_load_fn(self, parents):
1927
"""
2028
Batch loads the relationships of all the parents as one SQL statement.
@@ -38,8 +46,8 @@ async def batch_load_fn(self, parents):
3846
SQLAlchemy's main maitainer suggestion.
3947
See https://git.io/JewQ7
4048
"""
41-
child_mapper = relationship_prop.mapper
42-
parent_mapper = relationship_prop.parent
49+
child_mapper = self.relationship_prop.mapper
50+
parent_mapper = self.relationship_prop.parent
4351
session = Session.object_session(parents[0])
4452

4553
# These issues are very unlikely to happen in practice...
@@ -62,26 +70,42 @@ async def batch_load_fn(self, parents):
6270
query_context = parent_mapper_query._compile_context()
6371

6472
if is_sqlalchemy_version_less_than('1.4'):
65-
selectin_loader._load_for_path(
73+
self.selectin_loader._load_for_path(
6674
query_context,
6775
parent_mapper._path_registry,
6876
states,
6977
None,
7078
child_mapper
7179
)
7280
else:
73-
selectin_loader._load_for_path(
81+
self.selectin_loader._load_for_path(
7482
query_context,
7583
parent_mapper._path_registry,
7684
states,
7785
None,
7886
child_mapper,
7987
None
8088
)
81-
82-
return [getattr(parent, relationship_prop.key) for parent in parents]
83-
84-
loader = RelationshipLoader()
89+
return [getattr(parent, self.relationship_prop.key) for parent in parents]
90+
91+
def _get_loader(relationship_prop):
92+
"""Retrieve the cached loader of the given relationship."""
93+
loader = RELATIONSHIP_LOADERS_CACHE.get(relationship_prop, None)
94+
if loader is None:
95+
selectin_loader = strategies.SelectInLoader(
96+
relationship_prop,
97+
(('lazy', 'selectin'),)
98+
)
99+
loader = RelationshipLoader(
100+
relationship_prop=relationship_prop,
101+
selectin_loader=selectin_loader
102+
)
103+
RELATIONSHIP_LOADERS_CACHE[relationship_prop] = loader
104+
else:
105+
loader.loop = get_event_loop()
106+
return loader
107+
108+
loader = _get_loader(relationship_prop)
85109

86110
async def resolve(root, info, **args):
87111
return await loader.load(root)

graphene_sqlalchemy/fields.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -129,20 +129,31 @@ def get_query(cls, model, info, sort=None, **args):
129129
return query
130130

131131

132-
class BatchSQLAlchemyConnectionField(UnsortedSQLAlchemyConnectionField):
132+
class BatchSQLAlchemyConnectionField(SQLAlchemyConnectionField):
133133
"""
134134
This is currently experimental.
135135
The API and behavior may change in future versions.
136136
Use at your own risk.
137137
"""
138138

139-
def wrap_resolve(self, parent_resolver):
140-
return partial(
141-
self.connection_resolver,
142-
self.resolver,
143-
get_nullable_type(self.type),
144-
self.model,
145-
)
139+
@classmethod
140+
def connection_resolver(cls, resolver, connection_type, model, root, info, **args):
141+
if root is None:
142+
resolved = resolver(root, info, **args)
143+
on_resolve = partial(cls.resolve_connection, connection_type, model, info, args)
144+
else:
145+
relationship_prop = None
146+
for relationship in root.__class__.__mapper__.relationships:
147+
if relationship.mapper.class_ == model:
148+
relationship_prop = relationship
149+
break
150+
resolved = get_batch_resolver(relationship_prop)(root, info, **args)
151+
on_resolve = partial(cls.resolve_connection, connection_type, root, info, args)
152+
153+
if is_thenable(resolved):
154+
return Promise.resolve(resolved).then(on_resolve)
155+
156+
return on_resolve(resolved)
146157

147158
@classmethod
148159
def from_relationship(cls, relationship, registry, **field_kwargs):

graphene_sqlalchemy/tests/models.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,24 @@ class Article(Base):
110110
headline = Column(String(100))
111111
pub_date = Column(Date())
112112
reporter_id = Column(Integer(), ForeignKey("reporters.id"))
113+
readers = relationship(
114+
"Reader", secondary="articles_readers", back_populates="articles"
115+
)
116+
117+
118+
class Reader(Base):
119+
__tablename__ = "readers"
120+
id = Column(Integer(), primary_key=True)
121+
name = Column(String(100))
122+
articles = relationship(
123+
"Article", secondary="articles_readers", back_populates="readers"
124+
)
125+
126+
127+
class ArticleReader(Base):
128+
__tablename__ = "articles_readers"
129+
article_id = Column(Integer(), ForeignKey("articles.id"), primary_key=True)
130+
reader_id = Column(Integer(), ForeignKey("readers.id"), primary_key=True)
113131

114132

115133
class ReflectedEditor(type):

0 commit comments

Comments
 (0)