Skip to content

Commit d6746e3

Browse files
committed
add test for batch sorting with custom ormfield
1 parent 515060b commit d6746e3

File tree

1 file changed

+64
-0
lines changed

1 file changed

+64
-0
lines changed

graphene_sqlalchemy/tests/test_batching.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -565,6 +565,70 @@ def resolve_reporters(self, info):
565565
assert len(select_statements) == 2
566566

567567

568+
@pytest.mark.asyncio
569+
def test_batch_sorting_with_custom_ormfield(session_factory):
570+
session = session_factory()
571+
reporter_1 = Reporter(first_name='Reporter_1')
572+
session.add(reporter_1)
573+
reporter_2 = Reporter(first_name='Reporter_2')
574+
session.add(reporter_2)
575+
session.commit()
576+
session.close()
577+
578+
class ReporterType(SQLAlchemyObjectType):
579+
class Meta:
580+
model = Reporter
581+
name = "Reporter"
582+
interfaces = (relay.Node,)
583+
batching = True
584+
connection_class = Connection
585+
586+
firstname = ORMField(model_attr="first_name")
587+
588+
class Query(graphene.ObjectType):
589+
node = relay.Node.Field()
590+
reporters = BatchSQLAlchemyConnectionField(ReporterType.connection)
591+
592+
class ReporterType(SQLAlchemyObjectType):
593+
class Meta:
594+
model = Reporter
595+
interfaces = (relay.Node,)
596+
batching = True
597+
598+
schema = graphene.Schema(query=Query)
599+
600+
# Test one-to-one and many-to-one relationships
601+
with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler:
602+
# Starts new session to fully reset the engine / connection logging level
603+
session = session_factory()
604+
result = schema.execute("""
605+
query {
606+
reporters(sort: [FIRSTNAME_DESC]) {
607+
edges {
608+
node {
609+
firstname
610+
}
611+
}
612+
}
613+
}
614+
""", context_value={"session": session})
615+
messages = sqlalchemy_logging_handler.messages
616+
617+
result = to_std_dicts(result.data)
618+
assert result == {
619+
"reporters": {"edges": [
620+
{"node": {
621+
"firstname": "Reporter_2",
622+
}},
623+
{"node": {
624+
"firstname": "Reporter_1",
625+
}},
626+
]}
627+
}
628+
select_statements = [message for message in messages if 'SELECT' in message and 'FROM reporters' in message]
629+
assert len(select_statements) == 2
630+
631+
568632
@pytest.mark.asyncio
569633
async def test_connection_factory_field_overrides_batching_is_false(session_factory):
570634
session = session_factory()

0 commit comments

Comments
 (0)