Skip to content

Commit 0b75d51

Browse files
authored
Add select and sort constructors (#19)
1 parent 5e18162 commit 0b75d51

File tree

7 files changed

+157
-26
lines changed

7 files changed

+157
-26
lines changed

docs/advanced/filter.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ items = await item_crud.select_models(
5252

5353
## 算术运算符
5454

55-
此过滤器使用方法需查看:[算数](#_8)
55+
此过滤器使用方法需查看:[算数](#_7)
5656

5757
- `__add`: Python `+` 运算符
5858
- `__radd`: Python `+` 反向运算

docs/usage/select.md

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
```py
2+
async def select(
3+
self,
4+
**kwargs
5+
) -> Select:
6+
```
7+
8+
此方法用于构造 SQLAlchemy
9+
Select,在一些特定场景将会很有用,比如,配合 [fastapi-pagination](https://github.com/uriyyo/fastapi-pagination) 使用
10+
11+
## 示例
12+
13+
```py hl_lines="28"
14+
from typing import Any, Annotated
15+
16+
from fastapi import Depends, FastAPI, Query
17+
from pydantic import BaseModel
18+
19+
from sqlalchemy_crud_plus import CRUDPlus
20+
21+
from sqlalchemy import select, Select
22+
from sqlalchemy import DeclarativeBase as Base
23+
from sqlalchemy.ext.asyncio import AsyncSession
24+
25+
from fastapi_pagination import LimitOffsetPage, Page, add_pagination
26+
from fastapi_pagination.ext.sqlalchemy import paginate
27+
28+
29+
class ModelIns(Base):
30+
# your sqlalchemy model
31+
pass
32+
33+
34+
class UserOut(BaseModel):
35+
# your pydantic schema
36+
pass
37+
38+
39+
class CRUDIns(CRUDPlus[ModelIns]):
40+
async def get_list(self, name: str = None, method: str = None) -> Select:
41+
return await self.select(name__like=f'%{name}%', method=method)
42+
43+
44+
crud_ins = CRUDIns(ModelIns)
45+
46+
47+
app = FastAPI()
48+
add_pagination(app)
49+
50+
51+
@app.get("/users", response_model=Page[UserOut])
52+
async def get_users(
53+
db: AsyncSession = Depends(get_db),
54+
name: Annotated[str | None, Query()] = None,
55+
method: Annotated[str | None, Query()] = None,
56+
) -> Any:
57+
select = await crud_ins.get_list()
58+
return await paginate(db, select)
59+
```

docs/usage/select_order.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
```py
2+
async def select_order(
3+
self,
4+
sort_columns: str | list[str],
5+
sort_orders: str | list[str] | None = None,
6+
**kwargs,
7+
) -> Select:
8+
```
9+
10+
此方法与 [select](./select.md) 方法类似,但增加了 [排序](./select_models_order.md/#_1) 功能

mkdocs.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ nav:
1313
- 新增 - 多条: usage/create_models.md
1414
- 查询 - 主键 ID: usage/select_model.md
1515
- 查询 - 条件过滤: usage/select_model_by_column.md
16+
- Select: usage/select.md
17+
- Select - 排序: usage/select_order.md
1618
- 查询 - 列表: usage/select_models.md
1719
- 查询 - 列表排序: usage/select_models_order.md
1820
- 更新 - 主键 ID: usage/update_model.md

sqlalchemy_crud_plus/crud.py

Lines changed: 67 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,7 @@
22
# -*- coding: utf-8 -*-
33
from typing import Any, Generic, Iterable, Sequence, Type
44

5-
from sqlalchemy import Row, RowMapping, select
6-
from sqlalchemy import delete as sa_delete
7-
from sqlalchemy import update as sa_update
5+
from sqlalchemy import Row, RowMapping, Select, delete, select, update
86
from sqlalchemy.ext.asyncio import AsyncSession
97

108
from sqlalchemy_crud_plus.errors import MultipleResultsError
@@ -16,7 +14,13 @@ class CRUDPlus(Generic[Model]):
1614
def __init__(self, model: Type[Model]):
1715
self.model = model
1816

19-
async def create_model(self, session: AsyncSession, obj: CreateSchema, commit: bool = False, **kwargs) -> Model:
17+
async def create_model(
18+
self,
19+
session: AsyncSession,
20+
obj: CreateSchema,
21+
commit: bool = False,
22+
**kwargs,
23+
) -> Model:
2024
"""
2125
Create a new instance of a model
2226
@@ -36,7 +40,10 @@ async def create_model(self, session: AsyncSession, obj: CreateSchema, commit: b
3640
return ins
3741

3842
async def create_models(
39-
self, session: AsyncSession, obj: Iterable[CreateSchema], commit: bool = False
43+
self,
44+
session: AsyncSession,
45+
obj: Iterable[CreateSchema],
46+
commit: bool = False,
4047
) -> list[Model]:
4148
"""
4249
Create new instances of a model
@@ -79,6 +86,35 @@ async def select_model_by_column(self, session: AsyncSession, **kwargs) -> Model
7986
query = await session.execute(stmt)
8087
return query.scalars().first()
8188

89+
async def select(self, **kwargs) -> Select:
90+
"""
91+
Construct the SQLAlchemy selection
92+
93+
:param kwargs: Query expressions.
94+
:return:
95+
"""
96+
filters = parse_filters(self.model, **kwargs)
97+
stmt = select(self.model).where(*filters)
98+
return stmt
99+
100+
async def select_order(
101+
self,
102+
sort_columns: str | list[str],
103+
sort_orders: str | list[str] | None = None,
104+
**kwargs,
105+
) -> Select:
106+
"""
107+
Constructing SQLAlchemy selection with sorting
108+
109+
:param kwargs: Query expressions.
110+
:param sort_columns: more details see apply_sorting
111+
:param sort_orders: more details see apply_sorting
112+
:return:
113+
"""
114+
stmt = await self.select(**kwargs)
115+
sorted_stmt = apply_sorting(self.model, stmt, sort_columns, sort_orders)
116+
return sorted_stmt
117+
82118
async def select_models(self, session: AsyncSession, **kwargs) -> Sequence[Row[Any] | RowMapping | Any]:
83119
"""
84120
Query all rows
@@ -87,13 +123,16 @@ async def select_models(self, session: AsyncSession, **kwargs) -> Sequence[Row[A
87123
:param kwargs: Query expressions.
88124
:return:
89125
"""
90-
filters = parse_filters(self.model, **kwargs)
91-
stmt = select(self.model).where(*filters)
126+
stmt = await self.select(**kwargs)
92127
query = await session.execute(stmt)
93128
return query.scalars().all()
94129

95130
async def select_models_order(
96-
self, session: AsyncSession, sort_columns: str | list[str], sort_orders: str | list[str] | None = None, **kwargs
131+
self,
132+
session: AsyncSession,
133+
sort_columns: str | list[str],
134+
sort_orders: str | list[str] | None = None,
135+
**kwargs,
97136
) -> Sequence[Row | RowMapping | Any] | None:
98137
"""
99138
Query all rows and sort by columns
@@ -103,14 +142,16 @@ async def select_models_order(
103142
:param sort_orders: more details see apply_sorting
104143
:return:
105144
"""
106-
filters = parse_filters(self.model, **kwargs)
107-
stmt = select(self.model).where(*filters)
108-
stmt_sort = apply_sorting(self.model, stmt, sort_columns, sort_orders)
109-
query = await session.execute(stmt_sort)
145+
stmt = await self.select_order(sort_columns, sort_orders, **kwargs)
146+
query = await session.execute(stmt)
110147
return query.scalars().all()
111148

112149
async def update_model(
113-
self, session: AsyncSession, pk: int, obj: UpdateSchema | dict[str, Any], commit: bool = False
150+
self,
151+
session: AsyncSession,
152+
pk: int,
153+
obj: UpdateSchema | dict[str, Any],
154+
commit: bool = False,
114155
) -> int:
115156
"""
116157
Update an instance by model's primary key
@@ -125,7 +166,7 @@ async def update_model(
125166
instance_data = obj
126167
else:
127168
instance_data = obj.model_dump(exclude_unset=True)
128-
stmt = sa_update(self.model).where(self.model.id == pk).values(**instance_data)
169+
stmt = update(self.model).where(self.model.id == pk).values(**instance_data)
129170
result = await session.execute(stmt)
130171
if commit:
131172
await session.commit()
@@ -157,13 +198,18 @@ async def update_model_by_column(
157198
instance_data = obj
158199
else:
159200
instance_data = obj.model_dump(exclude_unset=True)
160-
stmt = sa_update(self.model).where(*filters).values(**instance_data) # type: ignore
201+
stmt = update(self.model).where(*filters).values(**instance_data) # type: ignore
161202
result = await session.execute(stmt)
162203
if commit:
163204
await session.commit()
164205
return result.rowcount # type: ignore
165206

166-
async def delete_model(self, session: AsyncSession, pk: int, commit: bool = False) -> int:
207+
async def delete_model(
208+
self,
209+
session: AsyncSession,
210+
pk: int,
211+
commit: bool = False,
212+
) -> int:
167213
"""
168214
Delete an instance by model's primary key
169215
@@ -172,7 +218,7 @@ async def delete_model(self, session: AsyncSession, pk: int, commit: bool = Fals
172218
:param commit: If `True`, commits the transaction immediately. Default is `False`.
173219
:return:
174220
"""
175-
stmt = sa_delete(self.model).where(self.model.id == pk)
221+
stmt = delete(self.model).where(self.model.id == pk)
176222
result = await session.execute(stmt)
177223
if commit:
178224
await session.commit()
@@ -204,10 +250,10 @@ async def delete_model_by_column(
204250
raise MultipleResultsError(f'Only one record is expected to be delete, found {total_count} records.')
205251
if logical_deletion:
206252
deleted_flag = {deleted_flag_column: True}
207-
stmt = sa_update(self.model).where(*filters).values(**deleted_flag)
253+
stmt = update(self.model).where(*filters).values(**deleted_flag)
208254
else:
209-
stmt = sa_delete(self.model).where(*filters)
210-
await session.execute(stmt)
255+
stmt = delete(self.model).where(*filters)
256+
result = await session.execute(stmt)
211257
if commit:
212258
await session.commit()
213-
return total_count
259+
return result.rowcount # type: ignore

sqlalchemy_crud_plus/utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,7 @@
5757
}
5858

5959

60-
def get_sqlalchemy_filter(
61-
operator: str, value: Any, allow_arithmetic: bool = True
62-
) -> Callable[[str], Callable] | None:
60+
def get_sqlalchemy_filter(operator: str, value: Any, allow_arithmetic: bool = True) -> Callable[[str], Callable] | None:
6361
if operator in ['in', 'not_in', 'between']:
6462
if not isinstance(value, (tuple, list, set)):
6563
raise SelectOperatorError(f'The value of the <{operator}> filter must be tuple, list or set')

tests/test_select.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
# -*- coding: utf-8 -*-
33
import pytest
44

5+
from sqlalchemy import Select
6+
57
from sqlalchemy_crud_plus import CRUDPlus
68
from tests.model import Ins
79

@@ -336,14 +338,28 @@ async def test_select_model_by_column_with_or(create_test_model, async_db_sessio
336338
assert result.id == 1
337339

338340

341+
@pytest.mark.asyncio
342+
async def test_select(create_test_model):
343+
crud = CRUDPlus(Ins)
344+
result = await crud.select()
345+
assert isinstance(result, Select)
346+
347+
339348
@pytest.mark.asyncio
340349
async def test_select_models(create_test_model, async_db_session):
341-
async with async_db_session.begin() as session:
350+
async with async_db_session() as session:
342351
crud = CRUDPlus(Ins)
343352
result = await crud.select_models(session)
344353
assert len(result) == 9
345354

346355

356+
@pytest.mark.asyncio
357+
async def test_select_order(create_test_model):
358+
crud = CRUDPlus(Ins)
359+
result = await crud.select_order(sort_columns='name')
360+
assert isinstance(result, Select)
361+
362+
347363
@pytest.mark.asyncio
348364
async def test_select_models_order_default_asc(create_test_model, async_db_session):
349365
async with async_db_session() as session:

0 commit comments

Comments
 (0)