Skip to content

Commit d812c9b

Browse files
committed
Added sort support to SQLAlchemyConnectionField
1 parent 4827ce2 commit d812c9b

File tree

3 files changed

+66
-6
lines changed

3 files changed

+66
-6
lines changed

examples/flask_sqlalchemy/schema.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import graphene
22
from graphene import relay
3-
from graphene_sqlalchemy import SQLAlchemyConnectionField, SQLAlchemyObjectType
3+
from graphene_sqlalchemy import SQLAlchemyConnectionField, SQLAlchemyObjectType, utils
44
from models import Department as DepartmentModel
55
from models import Employee as EmployeeModel
66
from models import Role as RoleModel
@@ -27,11 +27,17 @@ class Meta:
2727
interfaces = (relay.Node, )
2828

2929

30+
SortEnumEmployee = utils.sort_enum_for_model(
31+
EmployeeModel, 'SortEnumEmployee',
32+
lambda c, d: c.upper() + ('_ASC' if d else '_DESC'))
33+
34+
3035
class Query(graphene.ObjectType):
3136
node = relay.Node.Field()
32-
all_employees = SQLAlchemyConnectionField(Employee)
33-
all_roles = SQLAlchemyConnectionField(Role)
34-
role = graphene.Field(Role)
37+
all_employees = SQLAlchemyConnectionField(
38+
Employee, sort=graphene.Argument(SortEnumEmployee, default_value=EmployeeModel.id))
39+
all_roles = SQLAlchemyConnectionField(Role, sort=utils.sort_argument_for_model(RoleModel))
40+
all_departments = SQLAlchemyConnectionField(Department)
3541

3642

3743
schema = graphene.Schema(query=Query, types=[Department, Employee, Role])

graphene_sqlalchemy/fields.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from collections import Iterable
12
from functools import partial
23

34
from sqlalchemy.orm.query import Query
@@ -16,8 +17,11 @@ def model(self):
1617
return self.type._meta.node._meta.model
1718

1819
@classmethod
19-
def get_query(cls, model, info, **args):
20-
return get_query(model, info.context)
20+
def get_query(cls, model, info, sort=None, **args):
21+
query = get_query(model, info.context)
22+
if sort is not None:
23+
query = query.order_by(*sort) if isinstance(sort, Iterable) else query.order_by(sort)
24+
return query
2125

2226
@property
2327
def type(self):

graphene_sqlalchemy/utils.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
from graphene import Argument, Enum, List
12
from sqlalchemy.exc import ArgumentError
3+
from sqlalchemy.inspection import inspect
24
from sqlalchemy.orm import class_mapper, object_mapper
35
from sqlalchemy.orm.exc import UnmappedClassError, UnmappedInstanceError
46

@@ -34,3 +36,51 @@ def is_mapped_instance(cls):
3436
return False
3537
else:
3638
return True
39+
40+
41+
def _symbol_name(column_name, is_asc):
42+
return column_name + ('_asc' if is_asc else '_desc')
43+
44+
45+
def _sort_enum_for_model(cls, name=None, symbol_name=_symbol_name):
46+
name = name or cls.__name__ + 'SortEnum'
47+
items = []
48+
default = []
49+
for column in inspect(cls).columns.values():
50+
asc = symbol_name(column.name, True), column.asc()
51+
desc = symbol_name(column.name, False), column.desc()
52+
if column.primary_key:
53+
default.append(asc[1])
54+
items.extend((asc, desc))
55+
return Enum(name, items), default
56+
57+
58+
def sort_enum_for_model(cls, name=None, symbol_name=_symbol_name):
59+
'''Create Graphene Enum for sorting a SQLAlchemy class query
60+
61+
Parameters
62+
- cls : Sqlalchemy model class
63+
Model used to create the sort enumerator
64+
- name : str, optional, default None
65+
Name to use for the enumerator. If not provided it will be set to `cls.__name__ + 'SortEnum'`
66+
- symbol_name : function, optional, default `_symbol_name`
67+
Function which takes the column name and a boolean indicating if the sort direction is ascending,
68+
and returns the symbol name for the current column and sort direction.
69+
The default function will create, for a column named 'foo', the symbols 'foo_asc' and 'foo_desc'
70+
71+
Returns
72+
- Enum
73+
The Graphene enumerator
74+
'''
75+
enum, _ = _sort_enum_for_model(cls, name, symbol_name)
76+
return enum
77+
78+
79+
def sort_argument_for_model(cls, has_default=True):
80+
'''Returns an Graphene argument for the sort field that accepts a list of sorting directions for a model.
81+
If `has_default` is True (the default) it will sort the result by the primary key(s)
82+
'''
83+
enum, default = _sort_enum_for_model(cls)
84+
if not has_default:
85+
default = None
86+
return Argument(List(enum), default_value=default)

0 commit comments

Comments
 (0)