Skip to content

Commit 06c90cb

Browse files
committed
fix: create special enum filters. Code pending refactor.
1 parent 87bbd6f commit 06c90cb

File tree

5 files changed

+202
-112
lines changed

5 files changed

+202
-112
lines changed

graphene_sqlalchemy/filters.py

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@ def __init_subclass_with_meta__(
4949
logic_functions = _get_functions_by_regex(".+_logic$", "_logic$", cls)
5050

5151
new_filter_fields = {}
52-
print(f"Generating Filter for {cls.__name__} with model {model} ")
5352
# Generate Graphene Fields from the filter functions based on type hints
5453
for field_name, _annotations in logic_functions:
5554
assert (
@@ -70,9 +69,6 @@ def __init_subclass_with_meta__(
7069
_meta.fields = filter_fields
7170
_meta.fields.update(new_filter_fields)
7271

73-
for field in _meta.fields:
74-
print(f"Added field {field} of type {_meta.fields[field].type}")
75-
7672
_meta.model = model
7773

7874
super(BaseTypeFilter, cls).__init_subclass_with_meta__(_meta=_meta, **options)
@@ -289,6 +285,48 @@ def execute_filters(
289285

290286
return query, clauses
291287

288+
class SQLEnumFilter(FieldFilter):
289+
"""Basic Filter for Scalars in Graphene.
290+
We want this filter to use Dynamic fields so it provides the base
291+
filtering methods ("eq, nEq") for different types of scalars.
292+
The Dynamic fields will resolve to Meta.filtered_type"""
293+
class Meta:
294+
graphene_type = graphene.Enum
295+
296+
# Abstract methods can be marked using ScalarFilterInputType. See comment on the init method
297+
@classmethod
298+
def eq_filter(
299+
cls, query, field, val: ScalarFilterInputType
300+
) -> Union[Tuple[Query, Any], Any]:
301+
return field == val.value
302+
303+
@classmethod
304+
def n_eq_filter(
305+
cls, query, field, val: ScalarFilterInputType
306+
) -> Union[Tuple[Query, Any], Any]:
307+
return not_(field == val.value)
308+
309+
class PyEnumFilter(FieldFilter):
310+
"""Basic Filter for Scalars in Graphene.
311+
We want this filter to use Dynamic fields so it provides the base
312+
filtering methods ("eq, nEq") for different types of scalars.
313+
The Dynamic fields will resolve to Meta.filtered_type"""
314+
class Meta:
315+
graphene_type = graphene.Enum
316+
317+
# Abstract methods can be marked using ScalarFilterInputType. See comment on the init method
318+
@classmethod
319+
def eq_filter(
320+
cls, query, field, val: ScalarFilterInputType
321+
) -> Union[Tuple[Query, Any], Any]:
322+
return field == val
323+
324+
@classmethod
325+
def n_eq_filter(
326+
cls, query, field, val: ScalarFilterInputType
327+
) -> Union[Tuple[Query, Any], Any]:
328+
return not_(field == val)
329+
292330

293331
class StringFilter(FieldFilter):
294332
class Meta:

graphene_sqlalchemy/registry.py

Lines changed: 41 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,16 @@
11
from collections import defaultdict
22
from typing import TYPE_CHECKING, List, Type
33

4-
from sqlalchemy.types import Enum as SQLAlchemyEnumType
5-
64
import graphene
75
from graphene import Enum
86
from graphene.types.base import BaseType
7+
from sqlalchemy.types import Enum as SQLAlchemyEnumType
98

109
if TYPE_CHECKING: # pragma: no_cover
1110
from graphene_sqlalchemy.filters import (
1211
FieldFilter,
1312
BaseTypeFilter,
14-
RelationshipFilter,
15-
)
13+
RelationshipFilter, )
1614

1715

1816
class Registry(object):
@@ -81,7 +79,7 @@ def register_sort_enum(self, obj_type, sort_enum: Enum):
8179
from .types import SQLAlchemyObjectType
8280

8381
if not isinstance(obj_type, type) or not issubclass(
84-
obj_type, SQLAlchemyObjectType
82+
obj_type, SQLAlchemyObjectType
8583
):
8684
raise TypeError(
8785
"Expected SQLAlchemyObjectType, but got: {!r}".format(obj_type)
@@ -94,7 +92,7 @@ def get_sort_enum_for_object_type(self, obj_type: graphene.ObjectType):
9492
return self._registry_sort_enums.get(obj_type)
9593

9694
def register_union_type(
97-
self, union: Type[graphene.Union], obj_types: List[Type[graphene.ObjectType]]
95+
self, union: Type[graphene.Union], obj_types: List[Type[graphene.ObjectType]]
9896
):
9997
if not issubclass(union, graphene.Union):
10098
raise TypeError("Expected graphene.Union, but got: {!r}".format(union))
@@ -112,7 +110,7 @@ def get_union_for_object_types(self, obj_types: List[Type[graphene.ObjectType]])
112110

113111
# Filter Scalar Fields of Object Types
114112
def register_filter_for_scalar_type(
115-
self, scalar_type: Type[graphene.Scalar], filter_obj: Type["FieldFilter"]
113+
self, scalar_type: Type[graphene.Scalar], filter_obj: Type["FieldFilter"]
116114
):
117115
from .filters import FieldFilter
118116

@@ -123,21 +121,49 @@ def register_filter_for_scalar_type(
123121
raise TypeError("Expected ScalarFilter, but got: {!r}".format(filter_obj))
124122
self._registry_scalar_filters[scalar_type] = filter_obj
125123

124+
def get_filter_for_sql_enum_type(
125+
self, enum_type: Type[graphene.Enum]
126+
) -> Type["FieldFilter"]:
127+
from .filters import SQLEnumFilter
128+
129+
filter_type = self._registry_scalar_filters.get(enum_type)
130+
if not filter_type:
131+
filter_type = SQLEnumFilter.create_type(
132+
f"Default{enum_type.__name__}EnumFilter", graphene_type=enum_type
133+
)
134+
self._registry_scalar_filters[enum_type] = filter_type
135+
return filter_type
136+
137+
def get_filter_for_py_enum_type(
138+
self, enum_type: Type[graphene.Enum]
139+
) -> Type["FieldFilter"]:
140+
from .filters import PyEnumFilter
141+
142+
filter_type = self._registry_scalar_filters.get(enum_type)
143+
if not filter_type:
144+
filter_type = PyEnumFilter.create_type(
145+
f"Default{enum_type.__name__}EnumFilter", graphene_type=enum_type
146+
)
147+
self._registry_scalar_filters[enum_type] = filter_type
148+
return filter_type
149+
126150
def get_filter_for_scalar_type(
127-
self, scalar_type: Type[graphene.Scalar]
151+
self, scalar_type: Type[graphene.Scalar]
128152
) -> Type["FieldFilter"]:
129153
from .filters import FieldFilter
130154

131155
filter_type = self._registry_scalar_filters.get(scalar_type)
132156
if not filter_type:
133-
return FieldFilter.create_type(
157+
filter_type = FieldFilter.create_type(
134158
f"Default{scalar_type.__name__}ScalarFilter", graphene_type=scalar_type
135159
)
160+
self._registry_scalar_filters[scalar_type] = filter_type
161+
136162
return filter_type
137163

138164
# TODO register enums automatically
139165
def register_filter_for_enum_type(
140-
self, enum_type: Type[graphene.Enum], filter_obj: Type["FieldFilter"]
166+
self, enum_type: Type[graphene.Enum], filter_obj: Type["FieldFilter"]
141167
):
142168
from .filters import FieldFilter
143169

@@ -148,16 +174,11 @@ def register_filter_for_enum_type(
148174
raise TypeError("Expected FieldFilter, but got: {!r}".format(filter_obj))
149175
self._registry_scalar_filters[enum_type] = filter_obj
150176

151-
def get_filter_for_enum_type(
152-
self, enum_type: Type[graphene.Enum]
153-
) -> Type["FieldFilter"]:
154-
return self._registry_enum_type_filters.get(enum_type)
155-
156177
# Filter Base Types
157178
def register_filter_for_base_type(
158-
self,
159-
base_type: Type[BaseType],
160-
filter_obj: Type["BaseTypeFilter"],
179+
self,
180+
base_type: Type[BaseType],
181+
filter_obj: Type["BaseTypeFilter"],
161182
):
162183
from .filters import BaseTypeFilter
163184

@@ -175,7 +196,7 @@ def get_filter_for_base_type(self, base_type: Type[BaseType]):
175196

176197
# Filter Relationships between base types
177198
def register_relationship_filter_for_base_type(
178-
self, base_type: BaseType, filter_obj: Type["RelationshipFilter"]
199+
self, base_type: BaseType, filter_obj: Type["RelationshipFilter"]
179200
):
180201
from .filters import RelationshipFilter
181202

@@ -189,7 +210,7 @@ def register_relationship_filter_for_base_type(
189210
self._registry_relationship_filters[base_type] = filter_obj
190211

191212
def get_relationship_filter_for_base_type(
192-
self, base_type: Type[BaseType]
213+
self, base_type: Type[BaseType]
193214
) -> "RelationshipFilter":
194215
return self._registry_relationship_filters.get(base_type)
195216

graphene_sqlalchemy/tests/conftest.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
1+
from typing import Literal
2+
3+
import graphene
14
import pytest
25
import pytest_asyncio
36
from sqlalchemy import create_engine
47
from sqlalchemy.orm import sessionmaker
58

6-
import graphene
79
from graphene_sqlalchemy.utils import SQL_VERSION_HIGHER_EQUAL_THAN_1_4
8-
10+
from .models import Base, CompositeFullName
911
from ..converter import convert_sqlalchemy_composite
1012
from ..registry import reset_global_registry
11-
from .models import Base, CompositeFullName
1213

1314
if SQL_VERSION_HIGHER_EQUAL_THAN_1_4:
1415
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
@@ -25,23 +26,32 @@ def convert_composite_class(composite, registry):
2526
return graphene.Field(graphene.Int)
2627

2728

28-
@pytest.fixture(params=[False, True])
29-
def async_session(request):
29+
# make a typed literal for session one is sync and one is async
30+
SESSION_TYPE = Literal["sync", "session_factory"]
31+
32+
33+
@pytest.fixture(params=["sync", "async"])
34+
def session_type(request) -> SESSION_TYPE:
3035
return request.param
3136

3237

3338
@pytest.fixture
34-
def test_db_url(async_session: bool):
35-
if async_session:
39+
def async_session(session_type):
40+
return session_type == "async"
41+
42+
43+
@pytest.fixture
44+
def test_db_url(session_type: SESSION_TYPE):
45+
if session_type == "async":
3646
return "sqlite+aiosqlite://"
3747
else:
3848
return "sqlite://"
3949

4050

4151
@pytest.mark.asyncio
4252
@pytest_asyncio.fixture(scope="function")
43-
async def session_factory(async_session: bool, test_db_url: str):
44-
if async_session:
53+
async def session_factory(session_type: SESSION_TYPE, test_db_url: str):
54+
if session_type == "async":
4555
if not SQL_VERSION_HIGHER_EQUAL_THAN_1_4:
4656
pytest.skip("Async Sessions only work in sql alchemy 1.4 and above")
4757
engine = create_async_engine(test_db_url)

graphene_sqlalchemy/tests/test_filters.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -218,20 +218,21 @@ async def test_filter_enum(session):
218218
query = """
219219
query {
220220
reporters (filter: {
221-
favoritePetKind: {eq: "dog"}
221+
favoritePetKind: {eq: DOG}
222222
}
223223
) {
224224
edges {
225225
node {
226226
firstName
227227
lastName
228+
favoritePetKind
228229
}
229230
}
230231
}
231232
}
232233
"""
233234
expected = {
234-
"pets": {"edges": [{"node": {"firstName": "Jane", "lastName": "Roe"}}]},
235+
"reporters": {"edges": [{"node": {"firstName": "Jane", "lastName": "Roe", "favoritePetKind": "DOG"}}]},
235236
}
236237
schema = graphene.Schema(query=Query)
237238
result = await schema.execute_async(query, context_value={"session": session})
@@ -243,7 +244,7 @@ async def test_filter_enum(session):
243244
pets (filter: {
244245
and: [
245246
{ hairKind: {eq: LONG} },
246-
{ petKind: {eq: "dog"} }
247+
{ petKind: {eq: DOG} }
247248
]}) {
248249
edges {
249250
node {
@@ -842,7 +843,7 @@ async def test_filter_logic_and(session):
842843
reporters (filter: {
843844
and: [
844845
{ firstName: { eq: "John" } },
845-
{ favoritePetKind: { eq: "cat" } },
846+
{ favoritePetKind: { eq: CAT } },
846847
]
847848
}) {
848849
edges {
@@ -874,13 +875,14 @@ async def test_filter_logic_or(session):
874875
reporters (filter: {
875876
or: [
876877
{ lastName: { eq: "Woe" } },
877-
{ favoritePetKind: { eq: "dog" } },
878+
{ favoritePetKind: { eq: DOG } },
878879
]
879880
}) {
880881
edges {
881882
node {
882883
firstName
883884
lastName
885+
favoritePetKind
884886
}
885887
}
886888
}
@@ -889,8 +891,8 @@ async def test_filter_logic_or(session):
889891
expected = {
890892
"reporters": {
891893
"edges": [
892-
{"node": {"firstName": "John", "lastName": "Woe"}},
893-
# {"node": {"firstName": "Jane", "lastName": "Roe"}},
894+
{"node": {"firstName": "John", "lastName": "Woe", "favoritePetKind": "CAT"}},
895+
{"node": {"firstName": "Jane", "lastName": "Roe", "favoritePetKind": "DOG"}},
894896
]
895897
}
896898
}

0 commit comments

Comments
 (0)