Skip to content

Add select and sort constructors #19

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/advanced/filter.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ items = await item_crud.select_models(

## 算术运算符

此过滤器使用方法需查看:[算数](#_8)
此过滤器使用方法需查看:[算数](#_7)

- `__add`: Python `+` 运算符
- `__radd`: Python `+` 反向运算
Expand Down
59 changes: 59 additions & 0 deletions docs/usage/select.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
```py
async def select(
self,
**kwargs
) -> Select:
```

此方法用于构造 SQLAlchemy
Select,在一些特定场景将会很有用,比如,配合 [fastapi-pagination](https://github.com/uriyyo/fastapi-pagination) 使用

## 示例

```py hl_lines="28"
from typing import Any, Annotated

from fastapi import Depends, FastAPI, Query
from pydantic import BaseModel

from sqlalchemy_crud_plus import CRUDPlus

from sqlalchemy import select, Select
from sqlalchemy import DeclarativeBase as Base
from sqlalchemy.ext.asyncio import AsyncSession

from fastapi_pagination import LimitOffsetPage, Page, add_pagination
from fastapi_pagination.ext.sqlalchemy import paginate


class ModelIns(Base):
# your sqlalchemy model
pass


class UserOut(BaseModel):
# your pydantic schema
pass


class CRUDIns(CRUDPlus[ModelIns]):
async def get_list(self, name: str = None, method: str = None) -> Select:
return await self.select(name__like=f'%{name}%', method=method)


crud_ins = CRUDIns(ModelIns)


app = FastAPI()
add_pagination(app)


@app.get("/users", response_model=Page[UserOut])
async def get_users(
db: AsyncSession = Depends(get_db),
name: Annotated[str | None, Query()] = None,
method: Annotated[str | None, Query()] = None,
) -> Any:
select = await crud_ins.get_list()
return await paginate(db, select)
```
10 changes: 10 additions & 0 deletions docs/usage/select_order.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
```py
async def select_order(
self,
sort_columns: str | list[str],
sort_orders: str | list[str] | None = None,
**kwargs,
) -> Select:
```

此方法与 [select](./select.md) 方法类似,但增加了 [排序](./select_models_order.md/#_1) 功能
2 changes: 2 additions & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ nav:
- 新增 - 多条: usage/create_models.md
- 查询 - 主键 ID: usage/select_model.md
- 查询 - 条件过滤: usage/select_model_by_column.md
- Select: usage/select.md
- Select - 排序: usage/select_order.md
- 查询 - 列表: usage/select_models.md
- 查询 - 列表排序: usage/select_models_order.md
- 更新 - 主键 ID: usage/update_model.md
Expand Down
88 changes: 67 additions & 21 deletions sqlalchemy_crud_plus/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@
# -*- coding: utf-8 -*-
from typing import Any, Generic, Iterable, Sequence, Type

from sqlalchemy import Row, RowMapping, select
from sqlalchemy import delete as sa_delete
from sqlalchemy import update as sa_update
from sqlalchemy import Row, RowMapping, Select, delete, select, update
from sqlalchemy.ext.asyncio import AsyncSession

from sqlalchemy_crud_plus.errors import MultipleResultsError
Expand All @@ -16,7 +14,13 @@ class CRUDPlus(Generic[Model]):
def __init__(self, model: Type[Model]):
self.model = model

async def create_model(self, session: AsyncSession, obj: CreateSchema, commit: bool = False, **kwargs) -> Model:
async def create_model(
self,
session: AsyncSession,
obj: CreateSchema,
commit: bool = False,
**kwargs,
) -> Model:
"""
Create a new instance of a model

Expand All @@ -36,7 +40,10 @@ async def create_model(self, session: AsyncSession, obj: CreateSchema, commit: b
return ins

async def create_models(
self, session: AsyncSession, obj: Iterable[CreateSchema], commit: bool = False
self,
session: AsyncSession,
obj: Iterable[CreateSchema],
commit: bool = False,
) -> list[Model]:
"""
Create new instances of a model
Expand Down Expand Up @@ -79,6 +86,35 @@ async def select_model_by_column(self, session: AsyncSession, **kwargs) -> Model
query = await session.execute(stmt)
return query.scalars().first()

async def select(self, **kwargs) -> Select:
"""
Construct the SQLAlchemy selection

:param kwargs: Query expressions.
:return:
"""
filters = parse_filters(self.model, **kwargs)
stmt = select(self.model).where(*filters)
return stmt

async def select_order(
self,
sort_columns: str | list[str],
sort_orders: str | list[str] | None = None,
**kwargs,
) -> Select:
"""
Constructing SQLAlchemy selection with sorting

:param kwargs: Query expressions.
:param sort_columns: more details see apply_sorting
:param sort_orders: more details see apply_sorting
:return:
"""
stmt = await self.select(**kwargs)
sorted_stmt = apply_sorting(self.model, stmt, sort_columns, sort_orders)
return sorted_stmt

async def select_models(self, session: AsyncSession, **kwargs) -> Sequence[Row[Any] | RowMapping | Any]:
"""
Query all rows
Expand All @@ -87,13 +123,16 @@ async def select_models(self, session: AsyncSession, **kwargs) -> Sequence[Row[A
:param kwargs: Query expressions.
:return:
"""
filters = parse_filters(self.model, **kwargs)
stmt = select(self.model).where(*filters)
stmt = await self.select(**kwargs)
query = await session.execute(stmt)
return query.scalars().all()

async def select_models_order(
self, session: AsyncSession, sort_columns: str | list[str], sort_orders: str | list[str] | None = None, **kwargs
self,
session: AsyncSession,
sort_columns: str | list[str],
sort_orders: str | list[str] | None = None,
**kwargs,
) -> Sequence[Row | RowMapping | Any] | None:
"""
Query all rows and sort by columns
Expand All @@ -103,14 +142,16 @@ async def select_models_order(
:param sort_orders: more details see apply_sorting
:return:
"""
filters = parse_filters(self.model, **kwargs)
stmt = select(self.model).where(*filters)
stmt_sort = apply_sorting(self.model, stmt, sort_columns, sort_orders)
query = await session.execute(stmt_sort)
stmt = await self.select_order(sort_columns, sort_orders, **kwargs)
query = await session.execute(stmt)
return query.scalars().all()

async def update_model(
self, session: AsyncSession, pk: int, obj: UpdateSchema | dict[str, Any], commit: bool = False
self,
session: AsyncSession,
pk: int,
obj: UpdateSchema | dict[str, Any],
commit: bool = False,
) -> int:
"""
Update an instance by model's primary key
Expand All @@ -125,7 +166,7 @@ async def update_model(
instance_data = obj
else:
instance_data = obj.model_dump(exclude_unset=True)
stmt = sa_update(self.model).where(self.model.id == pk).values(**instance_data)
stmt = update(self.model).where(self.model.id == pk).values(**instance_data)
result = await session.execute(stmt)
if commit:
await session.commit()
Expand Down Expand Up @@ -157,13 +198,18 @@ async def update_model_by_column(
instance_data = obj
else:
instance_data = obj.model_dump(exclude_unset=True)
stmt = sa_update(self.model).where(*filters).values(**instance_data) # type: ignore
stmt = update(self.model).where(*filters).values(**instance_data) # type: ignore
result = await session.execute(stmt)
if commit:
await session.commit()
return result.rowcount # type: ignore

async def delete_model(self, session: AsyncSession, pk: int, commit: bool = False) -> int:
async def delete_model(
self,
session: AsyncSession,
pk: int,
commit: bool = False,
) -> int:
"""
Delete an instance by model's primary key

Expand All @@ -172,7 +218,7 @@ async def delete_model(self, session: AsyncSession, pk: int, commit: bool = Fals
:param commit: If `True`, commits the transaction immediately. Default is `False`.
:return:
"""
stmt = sa_delete(self.model).where(self.model.id == pk)
stmt = delete(self.model).where(self.model.id == pk)
result = await session.execute(stmt)
if commit:
await session.commit()
Expand Down Expand Up @@ -204,10 +250,10 @@ async def delete_model_by_column(
raise MultipleResultsError(f'Only one record is expected to be delete, found {total_count} records.')
if logical_deletion:
deleted_flag = {deleted_flag_column: True}
stmt = sa_update(self.model).where(*filters).values(**deleted_flag)
stmt = update(self.model).where(*filters).values(**deleted_flag)
else:
stmt = sa_delete(self.model).where(*filters)
await session.execute(stmt)
stmt = delete(self.model).where(*filters)
result = await session.execute(stmt)
if commit:
await session.commit()
return total_count
return result.rowcount # type: ignore
4 changes: 1 addition & 3 deletions sqlalchemy_crud_plus/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,7 @@
}


def get_sqlalchemy_filter(
operator: str, value: Any, allow_arithmetic: bool = True
) -> Callable[[str], Callable] | None:
def get_sqlalchemy_filter(operator: str, value: Any, allow_arithmetic: bool = True) -> Callable[[str], Callable] | None:
if operator in ['in', 'not_in', 'between']:
if not isinstance(value, (tuple, list, set)):
raise SelectOperatorError(f'The value of the <{operator}> filter must be tuple, list or set')
Expand Down
18 changes: 17 additions & 1 deletion tests/test_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
# -*- coding: utf-8 -*-
import pytest

from sqlalchemy import Select

from sqlalchemy_crud_plus import CRUDPlus
from tests.model import Ins

Expand Down Expand Up @@ -336,14 +338,28 @@ async def test_select_model_by_column_with_or(create_test_model, async_db_sessio
assert result.id == 1


@pytest.mark.asyncio
async def test_select(create_test_model):
crud = CRUDPlus(Ins)
result = await crud.select()
assert isinstance(result, Select)


@pytest.mark.asyncio
async def test_select_models(create_test_model, async_db_session):
async with async_db_session.begin() as session:
async with async_db_session() as session:
crud = CRUDPlus(Ins)
result = await crud.select_models(session)
assert len(result) == 9


@pytest.mark.asyncio
async def test_select_order(create_test_model):
crud = CRUDPlus(Ins)
result = await crud.select_order(sort_columns='name')
assert isinstance(result, Select)


@pytest.mark.asyncio
async def test_select_models_order_default_asc(create_test_model, async_db_session):
async with async_db_session() as session:
Expand Down