diff --git a/docs/advanced/filter.md b/docs/advanced/filter.md index ecf05ec..0c1aee6 100644 --- a/docs/advanced/filter.md +++ b/docs/advanced/filter.md @@ -52,7 +52,7 @@ items = await item_crud.select_models( ## 算术运算符 -此过滤器使用方法需查看:[算数](#_8) +此过滤器使用方法需查看:[算数](#_7) - `__add`: Python `+` 运算符 - `__radd`: Python `+` 反向运算 diff --git a/docs/usage/select.md b/docs/usage/select.md new file mode 100644 index 0000000..c85a57e --- /dev/null +++ b/docs/usage/select.md @@ -0,0 +1,59 @@ +```py +async def select( + self, + **kwargs +) -> Select: +``` + +此方法用于构造 SQLAlchemy +Select,在一些特定场景将会很有用,比如,配合 [fastapi-pagination](https://github.com/uriyyo/fastapi-pagination) 使用 + +## 示例 + +```py hl_lines="28" +from typing import Any, Annotated + +from fastapi import Depends, FastAPI, Query +from pydantic import BaseModel + +from sqlalchemy_crud_plus import CRUDPlus + +from sqlalchemy import select, Select +from sqlalchemy import DeclarativeBase as Base +from sqlalchemy.ext.asyncio import AsyncSession + +from fastapi_pagination import LimitOffsetPage, Page, add_pagination +from fastapi_pagination.ext.sqlalchemy import paginate + + +class ModelIns(Base): + # your sqlalchemy model + pass + + +class UserOut(BaseModel): + # your pydantic schema + pass + + +class CRUDIns(CRUDPlus[ModelIns]): + async def get_list(self, name: str = None, method: str = None) -> Select: + return await self.select(name__like=f'%{name}%', method=method) + + +crud_ins = CRUDIns(ModelIns) + + +app = FastAPI() +add_pagination(app) + + +@app.get("/users", response_model=Page[UserOut]) +async def get_users( + db: AsyncSession = Depends(get_db), + name: Annotated[str | None, Query()] = None, + method: Annotated[str | None, Query()] = None, +) -> Any: + select = await crud_ins.get_list() + return await paginate(db, select) +``` diff --git a/docs/usage/select_order.md b/docs/usage/select_order.md new file mode 100644 index 0000000..87dd66d --- /dev/null +++ b/docs/usage/select_order.md @@ -0,0 +1,10 @@ +```py +async def select_order( + self, + sort_columns: str | list[str], + sort_orders: str | list[str] | None = None, + **kwargs, +) -> Select: +``` + +此方法与 [select](./select.md) 方法类似,但增加了 [排序](./select_models_order.md/#_1) 功能 diff --git a/mkdocs.yml b/mkdocs.yml index dba20bb..29f2a81 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -13,6 +13,8 @@ nav: - 新增 - 多条: usage/create_models.md - 查询 - 主键 ID: usage/select_model.md - 查询 - 条件过滤: usage/select_model_by_column.md + - Select: usage/select.md + - Select - 排序: usage/select_order.md - 查询 - 列表: usage/select_models.md - 查询 - 列表排序: usage/select_models_order.md - 更新 - 主键 ID: usage/update_model.md diff --git a/sqlalchemy_crud_plus/crud.py b/sqlalchemy_crud_plus/crud.py index 6f541ad..10abc80 100644 --- a/sqlalchemy_crud_plus/crud.py +++ b/sqlalchemy_crud_plus/crud.py @@ -2,9 +2,7 @@ # -*- coding: utf-8 -*- from typing import Any, Generic, Iterable, Sequence, Type -from sqlalchemy import Row, RowMapping, select -from sqlalchemy import delete as sa_delete -from sqlalchemy import update as sa_update +from sqlalchemy import Row, RowMapping, Select, delete, select, update from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy_crud_plus.errors import MultipleResultsError @@ -16,7 +14,13 @@ class CRUDPlus(Generic[Model]): def __init__(self, model: Type[Model]): self.model = model - async def create_model(self, session: AsyncSession, obj: CreateSchema, commit: bool = False, **kwargs) -> Model: + async def create_model( + self, + session: AsyncSession, + obj: CreateSchema, + commit: bool = False, + **kwargs, + ) -> Model: """ Create a new instance of a model @@ -36,7 +40,10 @@ async def create_model(self, session: AsyncSession, obj: CreateSchema, commit: b return ins async def create_models( - self, session: AsyncSession, obj: Iterable[CreateSchema], commit: bool = False + self, + session: AsyncSession, + obj: Iterable[CreateSchema], + commit: bool = False, ) -> list[Model]: """ Create new instances of a model @@ -79,6 +86,35 @@ async def select_model_by_column(self, session: AsyncSession, **kwargs) -> Model query = await session.execute(stmt) return query.scalars().first() + async def select(self, **kwargs) -> Select: + """ + Construct the SQLAlchemy selection + + :param kwargs: Query expressions. + :return: + """ + filters = parse_filters(self.model, **kwargs) + stmt = select(self.model).where(*filters) + return stmt + + async def select_order( + self, + sort_columns: str | list[str], + sort_orders: str | list[str] | None = None, + **kwargs, + ) -> Select: + """ + Constructing SQLAlchemy selection with sorting + + :param kwargs: Query expressions. + :param sort_columns: more details see apply_sorting + :param sort_orders: more details see apply_sorting + :return: + """ + stmt = await self.select(**kwargs) + sorted_stmt = apply_sorting(self.model, stmt, sort_columns, sort_orders) + return sorted_stmt + async def select_models(self, session: AsyncSession, **kwargs) -> Sequence[Row[Any] | RowMapping | Any]: """ Query all rows @@ -87,13 +123,16 @@ async def select_models(self, session: AsyncSession, **kwargs) -> Sequence[Row[A :param kwargs: Query expressions. :return: """ - filters = parse_filters(self.model, **kwargs) - stmt = select(self.model).where(*filters) + stmt = await self.select(**kwargs) query = await session.execute(stmt) return query.scalars().all() async def select_models_order( - self, session: AsyncSession, sort_columns: str | list[str], sort_orders: str | list[str] | None = None, **kwargs + self, + session: AsyncSession, + sort_columns: str | list[str], + sort_orders: str | list[str] | None = None, + **kwargs, ) -> Sequence[Row | RowMapping | Any] | None: """ Query all rows and sort by columns @@ -103,14 +142,16 @@ async def select_models_order( :param sort_orders: more details see apply_sorting :return: """ - filters = parse_filters(self.model, **kwargs) - stmt = select(self.model).where(*filters) - stmt_sort = apply_sorting(self.model, stmt, sort_columns, sort_orders) - query = await session.execute(stmt_sort) + stmt = await self.select_order(sort_columns, sort_orders, **kwargs) + query = await session.execute(stmt) return query.scalars().all() async def update_model( - self, session: AsyncSession, pk: int, obj: UpdateSchema | dict[str, Any], commit: bool = False + self, + session: AsyncSession, + pk: int, + obj: UpdateSchema | dict[str, Any], + commit: bool = False, ) -> int: """ Update an instance by model's primary key @@ -125,7 +166,7 @@ async def update_model( instance_data = obj else: instance_data = obj.model_dump(exclude_unset=True) - stmt = sa_update(self.model).where(self.model.id == pk).values(**instance_data) + stmt = update(self.model).where(self.model.id == pk).values(**instance_data) result = await session.execute(stmt) if commit: await session.commit() @@ -157,13 +198,18 @@ async def update_model_by_column( instance_data = obj else: instance_data = obj.model_dump(exclude_unset=True) - stmt = sa_update(self.model).where(*filters).values(**instance_data) # type: ignore + stmt = update(self.model).where(*filters).values(**instance_data) # type: ignore result = await session.execute(stmt) if commit: await session.commit() return result.rowcount # type: ignore - async def delete_model(self, session: AsyncSession, pk: int, commit: bool = False) -> int: + async def delete_model( + self, + session: AsyncSession, + pk: int, + commit: bool = False, + ) -> int: """ Delete an instance by model's primary key @@ -172,7 +218,7 @@ async def delete_model(self, session: AsyncSession, pk: int, commit: bool = Fals :param commit: If `True`, commits the transaction immediately. Default is `False`. :return: """ - stmt = sa_delete(self.model).where(self.model.id == pk) + stmt = delete(self.model).where(self.model.id == pk) result = await session.execute(stmt) if commit: await session.commit() @@ -204,10 +250,10 @@ async def delete_model_by_column( raise MultipleResultsError(f'Only one record is expected to be delete, found {total_count} records.') if logical_deletion: deleted_flag = {deleted_flag_column: True} - stmt = sa_update(self.model).where(*filters).values(**deleted_flag) + stmt = update(self.model).where(*filters).values(**deleted_flag) else: - stmt = sa_delete(self.model).where(*filters) - await session.execute(stmt) + stmt = delete(self.model).where(*filters) + result = await session.execute(stmt) if commit: await session.commit() - return total_count + return result.rowcount # type: ignore diff --git a/sqlalchemy_crud_plus/utils.py b/sqlalchemy_crud_plus/utils.py index 4471ddd..b1302ec 100644 --- a/sqlalchemy_crud_plus/utils.py +++ b/sqlalchemy_crud_plus/utils.py @@ -57,9 +57,7 @@ } -def get_sqlalchemy_filter( - operator: str, value: Any, allow_arithmetic: bool = True -) -> Callable[[str], Callable] | None: +def get_sqlalchemy_filter(operator: str, value: Any, allow_arithmetic: bool = True) -> Callable[[str], Callable] | None: if operator in ['in', 'not_in', 'between']: if not isinstance(value, (tuple, list, set)): raise SelectOperatorError(f'The value of the <{operator}> filter must be tuple, list or set') diff --git a/tests/test_select.py b/tests/test_select.py index a83433a..a709cb9 100644 --- a/tests/test_select.py +++ b/tests/test_select.py @@ -2,6 +2,8 @@ # -*- coding: utf-8 -*- import pytest +from sqlalchemy import Select + from sqlalchemy_crud_plus import CRUDPlus from tests.model import Ins @@ -336,14 +338,28 @@ async def test_select_model_by_column_with_or(create_test_model, async_db_sessio assert result.id == 1 +@pytest.mark.asyncio +async def test_select(create_test_model): + crud = CRUDPlus(Ins) + result = await crud.select() + assert isinstance(result, Select) + + @pytest.mark.asyncio async def test_select_models(create_test_model, async_db_session): - async with async_db_session.begin() as session: + async with async_db_session() as session: crud = CRUDPlus(Ins) result = await crud.select_models(session) assert len(result) == 9 +@pytest.mark.asyncio +async def test_select_order(create_test_model): + crud = CRUDPlus(Ins) + result = await crud.select_order(sort_columns='name') + assert isinstance(result, Select) + + @pytest.mark.asyncio async def test_select_models_order_default_asc(create_test_model, async_db_session): async with async_db_session() as session: