Skip to content

Commit 0de4f8e

Browse files
committed
Fix N+1 problem for one-to-one and many-to-one relationships.
1 parent 0544f81 commit 0de4f8e

File tree

7 files changed

+375
-44
lines changed

7 files changed

+375
-44
lines changed

graphene_sqlalchemy/resolver.py

Whitespace-only changes.

graphene_sqlalchemy/tests/conftest.py

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import pytest
22
from sqlalchemy import create_engine
3-
from sqlalchemy.orm import scoped_session, sessionmaker
3+
from sqlalchemy.orm import sessionmaker
44

55
import graphene
66

@@ -23,19 +23,17 @@ def convert_composite_class(composite, registry):
2323

2424

2525
@pytest.yield_fixture(scope="function")
26-
def session():
27-
db = create_engine(test_db_url)
28-
connection = db.engine.connect()
29-
transaction = connection.begin()
30-
Base.metadata.create_all(connection)
31-
32-
# options = dict(bind=connection, binds={})
33-
session_factory = sessionmaker(bind=connection)
34-
session = scoped_session(session_factory)
35-
36-
yield session
37-
38-
# Finalize test here
39-
transaction.rollback()
40-
connection.close()
41-
session.remove()
26+
def session_factory():
27+
engine = create_engine(test_db_url)
28+
Base.metadata.create_all(engine)
29+
30+
yield sessionmaker(bind=engine)
31+
32+
# SQLite in-memory db is deleted when its connection is closed.
33+
# https://www.sqlite.org/inmemorydb.html
34+
engine.dispose()
35+
36+
37+
@pytest.fixture(scope="function")
38+
def session(session_factory):
39+
return session_factory()
Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
import contextlib
2+
import logging
3+
4+
import graphene
5+
6+
from ..types import SQLAlchemyObjectType
7+
from .models import Article, Reporter
8+
from .utils import to_std_dicts
9+
10+
11+
class MockLoggingHandler(logging.Handler):
12+
"""Intercept and store log messages in a list."""
13+
def __init__(self, *args, **kwargs):
14+
self.messages = []
15+
logging.Handler.__init__(self, *args, **kwargs)
16+
17+
def emit(self, record):
18+
self.messages.append(record.getMessage())
19+
20+
21+
@contextlib.contextmanager
22+
def mock_sqlalchemy_logging_handler():
23+
logging.basicConfig()
24+
sql_logger = logging.getLogger('sqlalchemy.engine')
25+
previous_level = sql_logger.level
26+
27+
sql_logger.setLevel(logging.INFO)
28+
mock_logging_handler = MockLoggingHandler()
29+
mock_logging_handler.setLevel(logging.INFO)
30+
sql_logger.addHandler(mock_logging_handler)
31+
32+
yield mock_logging_handler
33+
34+
sql_logger.setLevel(previous_level)
35+
36+
37+
def make_fixture(session):
38+
reporter_1 = Reporter(
39+
first_name='Reporter_1',
40+
)
41+
session.add(reporter_1)
42+
reporter_2 = Reporter(
43+
first_name='Reporter_2',
44+
)
45+
session.add(reporter_2)
46+
47+
article_1 = Article(headline='Article_1')
48+
article_1.reporter = reporter_1
49+
session.add(article_1)
50+
51+
article_2 = Article(headline='Article_2')
52+
article_2.reporter = reporter_2
53+
session.add(article_2)
54+
55+
session.commit()
56+
session.close()
57+
58+
59+
def get_schema(session):
60+
class ReporterType(SQLAlchemyObjectType):
61+
class Meta:
62+
model = Reporter
63+
64+
class ArticleType(SQLAlchemyObjectType):
65+
class Meta:
66+
model = Article
67+
68+
class Query(graphene.ObjectType):
69+
articles = graphene.Field(graphene.List(ArticleType))
70+
reporters = graphene.Field(graphene.List(ReporterType))
71+
72+
def resolve_articles(self, _info):
73+
return session.query(Article).all()
74+
75+
def resolve_reporters(self, _info):
76+
return session.query(Reporter).all()
77+
78+
return graphene.Schema(query=Query)
79+
80+
81+
def test_many_to_one(session_factory):
82+
session = session_factory()
83+
make_fixture(session)
84+
schema = get_schema(session)
85+
86+
with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler:
87+
# Starts new session to fully reset the engine / connection logging level
88+
session = session_factory()
89+
result = schema.execute("""
90+
query {
91+
articles {
92+
headline
93+
reporter {
94+
firstName
95+
}
96+
}
97+
}
98+
""", context_value={"session": session})
99+
messages = sqlalchemy_logging_handler.messages
100+
101+
assert len(messages) == 5
102+
assert messages == [
103+
'BEGIN (implicit)',
104+
105+
'SELECT articles.id AS articles_id, '
106+
'articles.headline AS articles_headline, '
107+
'articles.pub_date AS articles_pub_date, '
108+
'articles.reporter_id AS articles_reporter_id \n'
109+
'FROM articles',
110+
'()',
111+
112+
'SELECT reporters.id AS reporters_id, '
113+
'(SELECT CAST(count(reporters.id) AS INTEGER) AS anon_2 \nFROM reporters) AS anon_1, '
114+
'reporters.first_name AS reporters_first_name, '
115+
'reporters.last_name AS reporters_last_name, '
116+
'reporters.email AS reporters_email, '
117+
'reporters.favorite_pet_kind AS reporters_favorite_pet_kind \n'
118+
'FROM reporters \n'
119+
'WHERE reporters.id IN (?, ?)',
120+
'(1, 2)',
121+
]
122+
123+
assert not result.errors
124+
result = to_std_dicts(result.data)
125+
assert result == {
126+
"articles": [
127+
{
128+
"headline": "Article_1",
129+
"reporter": {
130+
"firstName": "Reporter_1",
131+
},
132+
},
133+
{
134+
"headline": "Article_2",
135+
"reporter": {
136+
"firstName": "Reporter_2",
137+
},
138+
},
139+
],
140+
}
141+
142+
143+
def test_one_to_one(session_factory):
144+
session = session_factory()
145+
make_fixture(session)
146+
schema = get_schema(session)
147+
148+
with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler:
149+
# Starts new session to fully reset the engine / connection logging level
150+
session = session_factory()
151+
result = schema.execute("""
152+
query {
153+
reporters {
154+
firstName
155+
favoriteArticle {
156+
headline
157+
}
158+
}
159+
}
160+
""", context_value={"session": session})
161+
messages = sqlalchemy_logging_handler.messages
162+
163+
assert len(messages) == 5
164+
assert messages == [
165+
'BEGIN (implicit)',
166+
167+
'SELECT (SELECT CAST(count(reporters.id) AS INTEGER) AS anon_2 \nFROM reporters) AS anon_1, '
168+
'reporters.id AS reporters_id, '
169+
'reporters.first_name AS reporters_first_name, '
170+
'reporters.last_name AS reporters_last_name, '
171+
'reporters.email AS reporters_email, '
172+
'reporters.favorite_pet_kind AS reporters_favorite_pet_kind \n'
173+
'FROM reporters',
174+
'()',
175+
176+
'SELECT articles.reporter_id AS articles_reporter_id, '
177+
'articles.id AS articles_id, '
178+
'articles.headline AS articles_headline, '
179+
'articles.pub_date AS articles_pub_date \n'
180+
'FROM articles \n'
181+
'WHERE articles.reporter_id IN (?, ?) '
182+
'ORDER BY articles.reporter_id',
183+
'(1, 2)'
184+
]
185+
186+
assert not result.errors
187+
result = to_std_dicts(result.data)
188+
assert result == {
189+
"reporters": [
190+
{
191+
"firstName": "Reporter_1",
192+
"favoriteArticle": {
193+
"headline": "Article_1",
194+
},
195+
},
196+
{
197+
"firstName": "Reporter_2",
198+
"favoriteArticle": {
199+
"headline": "Article_2",
200+
},
201+
},
202+
],
203+
}

graphene_sqlalchemy/tests/test_query.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,7 @@
55
from ..fields import SQLAlchemyConnectionField
66
from ..types import ORMField, SQLAlchemyObjectType
77
from .models import Article, CompositeFullName, Editor, HairKind, Pet, Reporter
8-
9-
10-
def to_std_dicts(value):
11-
"""Convert nested ordered dicts to normal dicts for better comparison."""
12-
if isinstance(value, dict):
13-
return {k: to_std_dicts(v) for k, v in value.items()}
14-
elif isinstance(value, list):
15-
return [to_std_dicts(v) for v in value]
16-
else:
17-
return value
8+
from .utils import to_std_dicts
189

1910

2011
def add_test_data(session):

graphene_sqlalchemy/tests/utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
def to_std_dicts(value):
2+
"""Convert nested ordered dicts to normal dicts for better comparison."""
3+
if isinstance(value, dict):
4+
return {k: to_std_dicts(v) for k, v in value.items()}
5+
elif isinstance(value, list):
6+
return [to_std_dicts(v) for v in value]
7+
else:
8+
return value

0 commit comments

Comments
 (0)