Skip to content

Commit eb54da9

Browse files
authored
Add mor and __gor__ filters (#21)
* Adding mor and gor for or clause * Fix and optimize filters * Add usage documents
1 parent 43800c1 commit eb54da9

File tree

3 files changed

+139
-36
lines changed

3 files changed

+139
-36
lines changed

docs/advanced/filter.md

Lines changed: 41 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -73,11 +73,12 @@ items = await item_crud.select_models(
7373

7474
运算符需要多个值,且仅允许元组,列表,集合
7575

76-
```python
77-
# 获取年龄在 30 - 40 岁之间的员工
76+
```python title="__between"
77+
# 获取年龄在 30 - 40 岁之间且名字在目标列表的员工
7878
items = await item_crud.select_models(
7979
session=db,
8080
age__between=[30, 40],
81+
name__in=['bob', 'lucy'],
8182
)
8283
```
8384

@@ -86,7 +87,7 @@ items = await item_crud.select_models(
8687
可以通过将多个过滤器链接在一起来实现 AND 子句
8788

8889
```python
89-
# 获取年龄在 30 以上,薪资大于 2w 的员工
90+
# 获取年龄在 30 以上,薪资大于 20k 的员工
9091
items = await item_crud.select_models(
9192
session=db,
9293
age__gt=30,
@@ -100,14 +101,48 @@ items = await item_crud.select_models(
100101

101102
每个键都应是库已支持的过滤器,仅允许字典
102103

103-
```python
104+
```python title="__or"
104105
# 获取年龄在 40 岁以上或 30 岁以下的员工
105106
items = await item_crud.select_models(
106107
session=db,
107108
age__or={'gt': 40, 'lt': 30},
108109
)
109110
```
110111

112+
## MOR
113+
114+
!!! note
115+
116+
`or` 过滤器的高级用法,每个键都应是库已支持的过滤器,仅允许字典
117+
118+
```python title="__mor"
119+
# 获取年龄等于 30 岁和 40 岁的员工
120+
items = await item_crud.select_models(
121+
session=db,
122+
age__mor={'eq': [30, 40]}, # (1)
123+
)
124+
```
125+
126+
1. 原因:在 python 字典中,不允许存在相同的键值;<br/>
127+
场景:我有一个列,需要多个相同条件但不同条件值的查询,此时,你应该使用 `mor` 过滤器,正如此示例一样使用它
128+
129+
## GOR
130+
131+
!!! note
132+
133+
`or` 过滤器的更高级用法,每个值都应是一个已受支持的条件过滤器,它应该是一个数组
134+
135+
```python title="__gor__"
136+
# 获取年龄在 30 - 40 岁之间且薪资大于 20k 的员工
137+
items = await item_crud.select_models(
138+
session=db,
139+
__gor__=[
140+
{'age__between': [30, 40]},
141+
{'payroll__gt': 20000}
142+
]
143+
)
144+
```
145+
111146
## 算数
112147

113148
!!! note
@@ -119,9 +154,9 @@ items = await item_crud.select_models(
119154
`condition`:此值将作为运算后的比较值,比较条件取决于使用的过滤器
120155

121156
```python
122-
# 获取薪资打八折以后仍高于 15000 的员工
157+
# 获取薪资打八折以后仍高于 20k 的员工
123158
items = await item_crud.select_models(
124159
session=db,
125-
payroll__mul={'value': 0.8, 'condition': {'gt': 15000}},
160+
payroll__mul={'value': 0.8, 'condition': {'gt': 20000}},
126161
)
127162
```

sqlalchemy_crud_plus/utils.py

Lines changed: 74 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from sqlalchemy import ColumnElement, Select, and_, asc, desc, func, or_, select
88
from sqlalchemy.ext.asyncio import AsyncSession
9+
from sqlalchemy.orm import InstrumentedAttribute
910
from sqlalchemy.orm.util import AliasedClass
1011

1112
from sqlalchemy_crud_plus.errors import ColumnSortError, ModelColumnError, SelectOperatorError
@@ -70,7 +71,7 @@ def get_sqlalchemy_filter(operator: str, value: Any, allow_arithmetic: bool = Tr
7071
raise SelectOperatorError(f'Nested arithmetic operations are not allowed: {operator}')
7172

7273
sqlalchemy_filter = _SUPPORTED_FILTERS.get(operator)
73-
if sqlalchemy_filter is None:
74+
if sqlalchemy_filter is None and operator not in ['or', 'mor', '__gor']:
7475
warnings.warn(
7576
f'The operator <{operator}> is not yet supported, only {", ".join(_SUPPORTED_FILTERS.keys())}.',
7677
SyntaxWarning,
@@ -80,48 +81,92 @@ def get_sqlalchemy_filter(operator: str, value: Any, allow_arithmetic: bool = Tr
8081
return sqlalchemy_filter
8182

8283

83-
def get_column(model: Type[Model] | AliasedClass, field_name: str):
84+
def get_column(model: Type[Model] | AliasedClass, field_name: str) -> InstrumentedAttribute | None:
8485
column = getattr(model, field_name, None)
8586
if column is None:
8687
raise ModelColumnError(f'Column {field_name} is not found in {model}')
8788
return column
8889

8990

91+
def _create_or_filters(column: str, op: str, value: Any) -> list[ColumnElement | None]:
92+
or_filters = []
93+
if op == 'or':
94+
for or_op, or_value in value.items():
95+
sqlalchemy_filter = get_sqlalchemy_filter(or_op, or_value)
96+
if sqlalchemy_filter is not None:
97+
or_filters.append(sqlalchemy_filter(column)(or_value))
98+
elif op == 'mor':
99+
for or_op, or_values in value.items():
100+
for or_value in or_values:
101+
sqlalchemy_filter = get_sqlalchemy_filter(or_op, or_value)
102+
if sqlalchemy_filter is not None:
103+
or_filters.append(sqlalchemy_filter(column)(or_value))
104+
return or_filters
105+
106+
107+
def _create_arithmetic_filters(column: str, op: str, value: Any) -> list[ColumnElement | None]:
108+
arithmetic_filters = []
109+
if isinstance(value, dict) and {'value', 'condition'}.issubset(value):
110+
arithmetic_value = value['value']
111+
condition = value['condition']
112+
sqlalchemy_filter = get_sqlalchemy_filter(op, arithmetic_value)
113+
if sqlalchemy_filter is not None:
114+
for cond_op, cond_value in condition.items():
115+
arithmetic_filter = get_sqlalchemy_filter(cond_op, cond_value, allow_arithmetic=False)
116+
arithmetic_filters.append(
117+
arithmetic_filter(sqlalchemy_filter(column)(arithmetic_value))(cond_value)
118+
if cond_op != 'between'
119+
else arithmetic_filter(sqlalchemy_filter(column)(arithmetic_value))(*cond_value)
120+
)
121+
return arithmetic_filters
122+
123+
124+
def _create_and_filters(column: str, op: str, value: Any) -> list[ColumnElement | None]:
125+
and_filters = []
126+
sqlalchemy_filter = get_sqlalchemy_filter(op, value)
127+
if sqlalchemy_filter is not None:
128+
and_filters.append(sqlalchemy_filter(column)(value) if op != 'between' else sqlalchemy_filter(column)(*value))
129+
return and_filters
130+
131+
90132
def parse_filters(model: Type[Model] | AliasedClass, **kwargs) -> list[ColumnElement]:
91133
filters = []
92134

135+
def process_filters(target_column: str, target_op: str, target_value: Any):
136+
# OR / MOR
137+
or_filters = _create_or_filters(target_column, target_op, target_value)
138+
if or_filters:
139+
filters.append(or_(*or_filters))
140+
141+
# ARITHMETIC
142+
arithmetic_filters = _create_arithmetic_filters(target_column, target_op, target_value)
143+
if arithmetic_filters:
144+
filters.append(and_(*arithmetic_filters))
145+
else:
146+
# AND
147+
and_filters = _create_and_filters(target_column, target_op, target_value)
148+
if and_filters:
149+
filters.append(*and_filters)
150+
93151
for key, value in kwargs.items():
94152
if '__' in key:
95153
field_name, op = key.rsplit('__', 1)
96-
column = get_column(model, field_name)
97-
if op == 'or':
98-
or_filters = [
99-
sqlalchemy_filter(column)(or_value)
100-
for or_op, or_value in value.items()
101-
if (sqlalchemy_filter := get_sqlalchemy_filter(or_op, or_value)) is not None
102-
]
103-
filters.append(or_(*or_filters))
104-
elif isinstance(value, dict) and {'value', 'condition'}.issubset(value):
105-
advanced_value = value['value']
106-
condition = value['condition']
107-
sqlalchemy_filter = get_sqlalchemy_filter(op, advanced_value)
108-
if sqlalchemy_filter is not None:
109-
condition_filters = []
110-
for cond_op, cond_value in condition.items():
111-
condition_filter = get_sqlalchemy_filter(cond_op, cond_value, allow_arithmetic=False)
112-
condition_filters.append(
113-
condition_filter(sqlalchemy_filter(column)(advanced_value))(cond_value)
114-
if cond_op != 'between'
115-
else condition_filter(sqlalchemy_filter(column)(advanced_value))(*cond_value)
116-
)
117-
filters.append(and_(*condition_filters))
154+
155+
# OR GROUP
156+
if field_name == '__gor' and op == '':
157+
_or_filters = []
158+
for field_or in value:
159+
for _key, _value in field_or.items():
160+
_field_name, _op = _key.rsplit('__', 1)
161+
_column = get_column(model, _field_name)
162+
process_filters(_column, _op, _value)
163+
if _or_filters:
164+
filters.append(or_(*_or_filters))
118165
else:
119-
sqlalchemy_filter = get_sqlalchemy_filter(op, value)
120-
if sqlalchemy_filter is not None:
121-
filters.append(
122-
sqlalchemy_filter(column)(value) if op != 'between' else sqlalchemy_filter(column)(*value)
123-
)
166+
column = get_column(model, field_name)
167+
process_filters(column, op, value)
124168
else:
169+
# NON FILTER
125170
column = get_column(model, key)
126171
filters.append(column == value)
127172

tests/test_select.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ async def test_select_model_by_column_with_ne(create_test_model, async_db_sessio
8787
async def test_select_model_by_column_with_between(create_test_model, async_db_session):
8888
async with async_db_session() as session:
8989
crud = CRUDPlus(Ins)
90-
result = await crud.select_model_by_column(session, id__between=(0, 11))
90+
result = await crud.select_model_by_column(session, id__between=(0, 10))
9191
assert result.id == 1
9292

9393

@@ -338,6 +338,29 @@ async def test_select_model_by_column_with_or(create_test_model, async_db_sessio
338338
assert result.id == 1
339339

340340

341+
@pytest.mark.asyncio
342+
async def test_select_model_by_column_with_mor(create_test_model, async_db_session):
343+
async with async_db_session() as session:
344+
crud = CRUDPlus(Ins)
345+
result = await crud.select_model_by_column(session, id__mor={'eq': [1, 2, 3, 4, 5, 6, 7, 8, 9]})
346+
assert result.id == 1
347+
348+
349+
@pytest.mark.asyncio
350+
async def test_select_model_by_column_with___gor__(create_test_model, async_db_session):
351+
async with async_db_session() as session:
352+
crud = CRUDPlus(Ins)
353+
result = await crud.select_model_by_column(
354+
session,
355+
__gor__=[
356+
{'id__eq': 1},
357+
{'name__mor': {'endswith': ['1', '2']}},
358+
{'id__mul': {'value': 1, 'condition': {'eq': 1}}},
359+
],
360+
)
361+
assert result.id == 1
362+
363+
341364
@pytest.mark.asyncio
342365
async def test_select(create_test_model):
343366
crud = CRUDPlus(Ins)

0 commit comments

Comments
 (0)