@@ -565,6 +565,70 @@ def resolve_reporters(self, info):
565
565
assert len (select_statements ) == 2
566
566
567
567
568
+ @pytest .mark .asyncio
569
+ def test_batch_sorting_with_custom_ormfield (session_factory ):
570
+ session = session_factory ()
571
+ reporter_1 = Reporter (first_name = 'Reporter_1' )
572
+ session .add (reporter_1 )
573
+ reporter_2 = Reporter (first_name = 'Reporter_2' )
574
+ session .add (reporter_2 )
575
+ session .commit ()
576
+ session .close ()
577
+
578
+ class ReporterType (SQLAlchemyObjectType ):
579
+ class Meta :
580
+ model = Reporter
581
+ name = "Reporter"
582
+ interfaces = (relay .Node ,)
583
+ batching = True
584
+ connection_class = Connection
585
+
586
+ firstname = ORMField (model_attr = "first_name" )
587
+
588
+ class Query (graphene .ObjectType ):
589
+ node = relay .Node .Field ()
590
+ reporters = BatchSQLAlchemyConnectionField (ReporterType .connection )
591
+
592
+ class ReporterType (SQLAlchemyObjectType ):
593
+ class Meta :
594
+ model = Reporter
595
+ interfaces = (relay .Node ,)
596
+ batching = True
597
+
598
+ schema = graphene .Schema (query = Query )
599
+
600
+ # Test one-to-one and many-to-one relationships
601
+ with mock_sqlalchemy_logging_handler () as sqlalchemy_logging_handler :
602
+ # Starts new session to fully reset the engine / connection logging level
603
+ session = session_factory ()
604
+ result = schema .execute ("""
605
+ query {
606
+ reporters(sort: [FIRSTNAME_DESC]) {
607
+ edges {
608
+ node {
609
+ firstname
610
+ }
611
+ }
612
+ }
613
+ }
614
+ """ , context_value = {"session" : session })
615
+ messages = sqlalchemy_logging_handler .messages
616
+
617
+ result = to_std_dicts (result .data )
618
+ assert result == {
619
+ "reporters" : {"edges" : [
620
+ {"node" : {
621
+ "firstname" : "Reporter_2" ,
622
+ }},
623
+ {"node" : {
624
+ "firstname" : "Reporter_1" ,
625
+ }},
626
+ ]}
627
+ }
628
+ select_statements = [message for message in messages if 'SELECT' in message and 'FROM reporters' in message ]
629
+ assert len (select_statements ) == 2
630
+
631
+
568
632
@pytest .mark .asyncio
569
633
async def test_connection_factory_field_overrides_batching_is_false (session_factory ):
570
634
session = session_factory ()
0 commit comments