Skip to content

Commit 064adc7

Browse files
committed
feat(filters): support filter aliasing (PR #378)
1 parent c38ebb3 commit 064adc7

File tree

3 files changed

+97
-39
lines changed

3 files changed

+97
-39
lines changed

graphene_sqlalchemy/filters.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import re
22
from typing import Any, Dict, List, Tuple, Type, TypeVar, Union
33

4+
from graphql import Undefined
45
from sqlalchemy import and_, not_, or_
56
from sqlalchemy.orm import Query, aliased # , selectinload
67

@@ -15,6 +16,31 @@
1516
"BaseTypeFilterSelf", Dict[str, Any], InputObjectTypeContainer
1617
)
1718

19+
class SQLAlchemyFilterInputField(graphene.InputField):
20+
def __init__(
21+
self,
22+
type_,
23+
model_attr,
24+
name=None,
25+
default_value=Undefined,
26+
deprecation_reason=None,
27+
description=None,
28+
required=False,
29+
_creation_counter=None,
30+
**extra_args,
31+
):
32+
super(SQLAlchemyFilterInputField, self).__init__(
33+
type_,
34+
name,
35+
default_value,
36+
deprecation_reason,
37+
description,
38+
required,
39+
_creation_counter,
40+
**extra_args,
41+
)
42+
43+
self.model_attr = model_attr
1844

1945
def _get_functions_by_regex(
2046
regex: str, subtract_regex: str, class_: Type
@@ -138,7 +164,8 @@ def execute_filters(
138164
# Check with a profiler is required to determine necessity
139165
input_field = cls._meta.fields[field]
140166
if isinstance(input_field, graphene.Dynamic):
141-
field_filter_type = input_field.get_type().type
167+
input_field = input_field.get_type()
168+
field_filter_type = input_field.type
142169
else:
143170
field_filter_type = cls._meta.fields[field].type
144171
# raise Exception
@@ -155,7 +182,8 @@ def execute_filters(
155182
)
156183
clauses.extend(_clauses)
157184
else:
158-
model_field = getattr(model, field)
185+
# Get the model attr from the inputfield in case the field is aliased in graphql
186+
model_field = getattr(model, input_field.model_attr or field)
159187
if issubclass(field_filter_type, BaseTypeFilter):
160188
# Get the model to join on the Filter Query
161189
joined_model = field_filter_type._meta.model

graphene_sqlalchemy/tests/test_filters.py

Lines changed: 45 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,8 @@
1-
import pytest
2-
from sqlalchemy.sql.operators import is_
3-
41
import graphene
2+
import pytest
53
from graphene import Connection, relay
4+
from sqlalchemy.sql.operators import is_
65

7-
from ..fields import SQLAlchemyConnectionField
8-
from ..filters import FloatFilter
9-
from ..types import ORMField, SQLAlchemyObjectType
106
from .models import (
117
Article,
128
Editor,
@@ -20,6 +16,10 @@
2016
Tag,
2117
)
2218
from .utils import eventually_await_session, to_std_dicts
19+
from ..fields import SQLAlchemyConnectionField
20+
from ..filters import FloatFilter
21+
from ..types import ORMField, SQLAlchemyObjectType
22+
2323

2424
# TODO test that generated schema is correct for all examples with:
2525
# with open('schema.gql', 'w') as fp:
@@ -110,26 +110,13 @@ class Meta:
110110

111111
class Query(graphene.ObjectType):
112112
node = relay.Node.Field()
113-
# # TODO how to create filterable singular field?
114-
# article = graphene.Field(ArticleType)
115113
articles = SQLAlchemyConnectionField(ArticleType.connection)
116-
# image = graphene.Field(ImageType)
117114
images = SQLAlchemyConnectionField(ImageType.connection)
118115
readers = SQLAlchemyConnectionField(ReaderType.connection)
119-
# reporter = graphene.Field(ReporterType)
120116
reporters = SQLAlchemyConnectionField(ReporterType.connection)
121117
pets = SQLAlchemyConnectionField(PetType.connection)
122118
tags = SQLAlchemyConnectionField(TagType.connection)
123119

124-
# def resolve_article(self, _info):
125-
# return session.query(Article).first()
126-
127-
# def resolve_image(self, _info):
128-
# return session.query(Image).first()
129-
130-
# def resolve_reporter(self, _info):
131-
# return session.query(Reporter).first()
132-
133120
return Query
134121

135122

@@ -159,6 +146,44 @@ async def test_filter_simple(session):
159146
assert_and_raise_result(result, expected)
160147

161148

149+
@pytest.mark.asyncio
150+
async def test_filter_alias(session):
151+
"""
152+
Test aliasing of column names in the type
153+
"""
154+
await add_test_data(session)
155+
156+
class ReporterType(SQLAlchemyObjectType):
157+
class Meta:
158+
model = Reporter
159+
name = "Reporter"
160+
interfaces = (relay.Node,)
161+
162+
lastNameAlias = ORMField(model_attr="last_name")
163+
164+
class Query(graphene.ObjectType):
165+
node = relay.Node.Field()
166+
reporters = SQLAlchemyConnectionField(ReporterType.connection)
167+
168+
query = """
169+
query {
170+
reporters (filter: {lastNameAlias: {eq: "Roe", like: "%oe"}}) {
171+
edges {
172+
node {
173+
firstName
174+
}
175+
}
176+
}
177+
}
178+
"""
179+
expected = {
180+
"reporters": {"edges": [{"node": {"firstName": "Jane"}}]},
181+
}
182+
schema = graphene.Schema(query=Query)
183+
result = await schema.execute_async(query, context_value={"session": session})
184+
assert_and_raise_result(result, expected)
185+
186+
162187
# Test a custom filter type
163188
@pytest.mark.asyncio
164189
async def test_filter_custom_type(session):
@@ -1084,7 +1109,7 @@ async def test_filter_hybrid_property(session):
10841109
result = to_std_dicts(result.data)
10851110
assert len(result["carts"]["edges"]) == 1
10861111
assert (
1087-
len(result["carts"]["edges"][0]["node"]["hybridPropShoppingCartItemList"]) == 2
1112+
len(result["carts"]["edges"][0]["node"]["hybridPropShoppingCartItemList"]) == 2
10881113
)
10891114

10901115

graphene_sqlalchemy/types.py

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
sort_argument_for_object_type,
3333
sort_enum_for_object_type,
3434
)
35-
from .filters import BaseTypeFilter, FieldFilter, RelationshipFilter
35+
from .filters import BaseTypeFilter, FieldFilter, RelationshipFilter, SQLAlchemyFilterInputField
3636
from .registry import Registry, get_global_registry
3737
from .resolvers import get_attr_resolver, get_custom_resolver
3838
from .utils import (
@@ -151,13 +151,13 @@ def filter_field_from_field(
151151
type_,
152152
registry: Registry,
153153
model_attr: Any,
154-
) -> Optional[Union[graphene.InputField, graphene.Dynamic]]:
154+
model_attr_name: str
155+
) -> Optional[graphene.InputField]:
155156
# Field might be a SQLAlchemyObjectType, due to hybrid properties
156157
if issubclass(type_, SQLAlchemyObjectType):
157158
filter_class = registry.get_filter_for_base_type(type_)
158-
return graphene.InputField(filter_class)
159159
# Enum Special Case
160-
if issubclass(type_, graphene.Enum) and isinstance(model_attr, ColumnProperty):
160+
elif issubclass(type_, graphene.Enum) and isinstance(model_attr, ColumnProperty):
161161
column = model_attr.columns[0]
162162
model_enum_type: Optional[sqlalchemy.types.Enum] = getattr(column, "type", None)
163163
if not getattr(model_enum_type, "enum_class", None):
@@ -168,16 +168,16 @@ def filter_field_from_field(
168168
filter_class = registry.get_filter_for_scalar_type(type_)
169169
if not filter_class:
170170
warnings.warn(
171-
f"No compatible filters found for {field.type}. Skipping field."
171+
f"No compatible filters found for {field.type} with db name {model_attr_name}. Skipping field."
172172
)
173173
return None
174-
return graphene.InputField(filter_class)
174+
return SQLAlchemyFilterInputField(filter_class, model_attr_name)
175175

176176

177177
def resolve_dynamic_relationship_filter(
178178
field: graphene.Dynamic,
179179
registry: Registry,
180-
model_attr: Any,
180+
model_attr_name: str
181181
) -> Optional[Union[graphene.InputField, graphene.Dynamic]]:
182182
# Resolve Dynamic Type
183183
type_ = get_nullable_type(field.get_type())
@@ -200,39 +200,44 @@ def resolve_dynamic_relationship_filter(
200200
reg_res = None
201201

202202
if not reg_res:
203+
warnings.warn(
204+
f"No compatible filters found for {field} with db name {model_attr_name}. Skipping field."
205+
)
203206
return None
204207

205-
return graphene.InputField(reg_res)
208+
return SQLAlchemyFilterInputField(reg_res, model_attr_name)
206209

207210

208211
def filter_field_from_type_field(
209212
field: Union[graphene.Field, graphene.Dynamic, Type[UnmountedType]],
210213
registry: Registry,
211214
filter_type: Optional[Type],
212215
model_attr: Any,
216+
model_attr_name: str
213217
) -> Optional[Union[graphene.InputField, graphene.Dynamic]]:
214218
# If a custom filter type was set for this field, use it here
215219
if filter_type:
216-
return graphene.InputField(filter_type)
220+
return SQLAlchemyFilterInputField(filter_type, model_attr_name)
217221
elif issubclass(type(field), graphene.Scalar):
218222
filter_class = registry.get_filter_for_scalar_type(type(field))
219-
return graphene.InputField(filter_class)
223+
return SQLAlchemyFilterInputField(filter_class, model_attr_name)
220224
# If the generated field is Dynamic, it is always a relationship
221225
# (due to graphene-sqlalchemy's conversion mechanism).
222226
elif isinstance(field, graphene.Dynamic):
223-
return Dynamic(partial(resolve_dynamic_relationship_filter, field, registry, model_attr))
224-
elif isinstance(field, graphene.Field):
225-
if inspect.isfunction(field._type) or isinstance(field._type, partial):
226-
return Dynamic(lambda: filter_field_from_field(field, get_nullable_type(field.type), registry, model_attr))
227-
else:
228-
return filter_field_from_field(field, get_nullable_type(field.type), registry, model_attr)
227+
return Dynamic(partial(resolve_dynamic_relationship_filter, field, registry, model_attr_name))
229228
# Unsupported but theoretically possible cases, please drop us an issue with reproduction if you need them
230229
elif isinstance(field, graphene.List) or isinstance(field._type, graphene.List):
231230
# Pure lists are not yet supported
232231
pass
233232
elif isinstance(field._type, graphene.Dynamic):
234233
# Fields with nested dynamic Dynamic are not yet supported
235234
pass
235+
# Order matters, this comes last as field._type == list also matches Field
236+
elif isinstance(field, graphene.Field):
237+
if inspect.isfunction(field._type) or isinstance(field._type, partial):
238+
return Dynamic(lambda: filter_field_from_field(field, get_nullable_type(field.type), registry, model_attr, model_attr_name))
239+
else:
240+
return filter_field_from_field(field, get_nullable_type(field.type), registry, model_attr, model_attr_name)
236241

237242

238243
def get_polymorphic_on(model):
@@ -372,7 +377,7 @@ def construct_fields_and_filters(
372377
fields[orm_field_name] = field
373378
if filtering_enabled_for_field:
374379
filters[orm_field_name] = filter_field_from_type_field(
375-
field, registry, filter_type, attr
380+
field, registry, filter_type, attr, attr_name
376381
)
377382

378383
return fields, filters

0 commit comments

Comments
 (0)