From e7bccb2793c9de724b3c60f600ec33994f454fac Mon Sep 17 00:00:00 2001 From: Wu Clan Date: Mon, 28 Apr 2025 00:17:43 +0800 Subject: [PATCH] Add where clause support to select --- docs/index.md | 2 +- docs/usage/count.md | 14 ++-- docs/usage/exists.md | 14 ++-- docs/usage/select.md | 10 ++- docs/usage/select_model.md | 16 +++-- docs/usage/select_model_by_column.md | 19 ++--- docs/usage/select_models.md | 19 ++--- docs/usage/select_models_order.md | 37 +++++----- docs/usage/select_order.md | 15 ++-- sqlalchemy_crud_plus/crud.py | 102 ++++++++++++++++++--------- tests/test_select.py | 15 +++- 11 files changed, 157 insertions(+), 106 deletions(-) diff --git a/docs/index.md b/docs/index.md index bbea818..aa547f5 100644 --- a/docs/index.md +++ b/docs/index.md @@ -97,7 +97,7 @@ ## 互动 -[TG / Discord](https://wu-clan.github.io/homepage/) +[Discord](https://wu-clan.github.io/homepage/) ## 赞助 diff --git a/docs/usage/count.md b/docs/usage/count.md index c95b53a..33ff50e 100644 --- a/docs/usage/count.md +++ b/docs/usage/count.md @@ -23,21 +23,19 @@ class CRUDIns(CRUDPlus[ModelIns]): async def count( self, session: AsyncSession, - filters: ColumnElement | list[ColumnElement] | None = None, + *whereclause: ColumnExpressionArgument[bool], **kwargs, ) -> int: ``` **Parameters:** -| Name | Type | Description | Default | -|---------|----------------------------------------------------|------------------|---------| -| session | AsyncSession | 数据库会话 | 必填 | -| filters | `ColumnElement `\|` list[ColumnElement] `\|` None` | 要应用于查询的 WHERE 子句 | `None` | +| Name | Type | Description | Default | +|--------------|----------------------------------|------------------------------------------------------------------------------------------------------|---------| +| session | AsyncSession | 数据库会话 | 必填 | +| *whereclause | `ColumnExpressionArgument[bool]` | 等同于 [SQLAlchemy where](https://docs.sqlalchemy.org/en/20/tutorial/data_select.html#the-where-clause) | | +| **kwargs | | [条件过滤](../advanced/filter.md),将创建条件查询 SQL | | -!!! note "**kwargs" - - [条件过滤](../advanced/filter.md),将创建条件查询 SQL **Returns:** diff --git a/docs/usage/exists.md b/docs/usage/exists.md index 276aea8..91bdb11 100644 --- a/docs/usage/exists.md +++ b/docs/usage/exists.md @@ -23,21 +23,19 @@ class CRUDIns(CRUDPlus[ModelIns]): async def exists( self, session: AsyncSession, - filters: ColumnElement | list[ColumnElement] | None = None, + *whereclause: ColumnExpressionArgument[bool], **kwargs, ) -> bool: ``` **Parameters:** -| Name | Type | Description | Default | -|---------|----------------------------------------------------|------------------|---------| -| session | AsyncSession | 数据库会话 | 必填 | -| filters | `ColumnElement `\|` list[ColumnElement] `\|` None` | 要应用于查询的 WHERE 子句 | `None` | +| Name | Type | Description | Default | +|--------------|----------------------------------|------------------------------------------------------------------------------------------------------|---------| +| session | AsyncSession | 数据库会话 | 必填 | +| *whereclause | `ColumnExpressionArgument[bool]` | 等同于 [SQLAlchemy where](https://docs.sqlalchemy.org/en/20/tutorial/data_select.html#the-where-clause) | | +| **kwargs | | [条件过滤](../advanced/filter.md),将创建条件查询 SQL | | -!!! note "**kwargs" - - [条件过滤](../advanced/filter.md),将创建条件查询 SQL **Returns:** diff --git a/docs/usage/select.md b/docs/usage/select.md index cf3d9ff..968cc84 100644 --- a/docs/usage/select.md +++ b/docs/usage/select.md @@ -50,16 +50,20 @@ async def get_users( ## API ```py -async def select(self, **kwargs) -> Select: +async def select(self, *whereclause: ColumnExpressionArgument[bool], **kwargs) -> Select: ... ``` 此方法用于构造 SQLAlchemy Select,在一些特定场景将会很有用,例如,配合 [fastapi-pagination](https://github.com/uriyyo/fastapi-pagination) 使用 -!!! note "**kwargs" +**Parameters:** - [条件过滤](../advanced/filter.md),将创建条件查询 SQL +| Name | Type | Description | Default | +|--------------|----------------------------------|------------------------------------------------------------------------------------------------------|---------| +| session | AsyncSession | 数据库会话 | 必填 | +| *whereclause | `ColumnExpressionArgument[bool]` | 等同于 [SQLAlchemy where](https://docs.sqlalchemy.org/en/20/tutorial/data_select.html#the-where-clause) | | +| **kwargs | | [条件过滤](../advanced/filter.md),将创建条件查询 SQL | | **Returns:** diff --git a/docs/usage/select_model.md b/docs/usage/select_model.md index 45118a1..68f5ebe 100644 --- a/docs/usage/select_model.md +++ b/docs/usage/select_model.md @@ -21,15 +21,21 @@ class CRUDIns(CRUDPlus[ModelIns]): ## API ```py -async def select_model(self, session: AsyncSession, pk: int) -> Model | None: +async def select_model( + self, + session: AsyncSession, + pk: int, + *whereclause: ColumnExpressionArgument[bool], +) -> Model | None: ``` **Parameters:** -| Name | Type | Description | Default | -|---------|--------------|----------------------------------|---------| -| session | AsyncSession | 数据库会话 | 必填 | -| pk | int | [主键](../advanced/primary_key.md) | 必填 | +| Name | Type | Description | Default | +|--------------|----------------------------------|------------------------------------------------------------------------------------------------------|---------| +| session | AsyncSession | 数据库会话 | 必填 | +| pk | int | [主键](../advanced/primary_key.md) | 必填 | +| *whereclause | `ColumnExpressionArgument[bool]` | 等同于 [SQLAlchemy where](https://docs.sqlalchemy.org/en/20/tutorial/data_select.html#the-where-clause) | | **Returns:** diff --git a/docs/usage/select_model_by_column.md b/docs/usage/select_model_by_column.md index bee0c4e..9e7694f 100644 --- a/docs/usage/select_model_by_column.md +++ b/docs/usage/select_model_by_column.md @@ -20,18 +20,21 @@ class CRUDIns(CRUDPlus[ModelIns]): ## API ```py -async def select_model_by_column(self, session: AsyncSession, **kwargs) -> Model | None: +async def select_model_by_column( + self, + session: AsyncSession, + *whereclause: ColumnExpressionArgument[bool], + **kwargs, +) -> Model | None: ``` **Parameters:** -| Name | Type | Description | Default | -|---------|--------------|----------------------------------|---------| -| session | AsyncSession | 数据库会话 | 必填 | - -!!! note "**kwargs" - - [条件过滤](../advanced/filter.md),将创建条件查询 SQL +| Name | Type | Description | Default | +|--------------|----------------------------------|------------------------------------------------------------------------------------------------------|---------| +| session | AsyncSession | 数据库会话 | 必填 | +| *whereclause | `ColumnExpressionArgument[bool]` | 等同于 [SQLAlchemy where](https://docs.sqlalchemy.org/en/20/tutorial/data_select.html#the-where-clause) | | +| **kwargs | | [条件过滤](../advanced/filter.md),将创建条件查询 SQL | | **Returns:** diff --git a/docs/usage/select_models.md b/docs/usage/select_models.md index 38b2b42..f9cc344 100644 --- a/docs/usage/select_models.md +++ b/docs/usage/select_models.md @@ -22,18 +22,21 @@ class CRUDIns(CRUDPlus[ModelIns]): ## API ```py -async def select_models(self, session: AsyncSession, **kwargs) -> Sequence[Row[Any] | RowMapping | Any]: +async def select_models( + self, + session: AsyncSession, + *whereclause: ColumnExpressionArgument[bool], + **kwargs, +) -> Sequence[Row[Any] | RowMapping | Any]: ``` **Parameters:** -| Name | Type | Description | Default | -|---------|--------------|----------------------------------|---------| -| session | AsyncSession | 数据库会话 | 必填 | - -!!! note "**kwargs" - - [条件过滤](../advanced/filter.md),将创建条件查询 SQL +| Name | Type | Description | Default | +|--------------|----------------------------------|------------------------------------------------------------------------------------------------------|---------| +| session | AsyncSession | 数据库会话 | 必填 | +| *whereclause | `ColumnExpressionArgument[bool]` | 等同于 [SQLAlchemy where](https://docs.sqlalchemy.org/en/20/tutorial/data_select.html#the-where-clause) | | +| **kwargs | | [条件过滤](../advanced/filter.md),将创建条件查询 SQL | | **Returns:** diff --git a/docs/usage/select_models_order.md b/docs/usage/select_models_order.md index ba3c08c..2f51ccd 100644 --- a/docs/usage/select_models_order.md +++ b/docs/usage/select_models_order.md @@ -10,38 +10,37 @@ from sqlalchemy.ext.asyncio import AsyncSession class ModelIns(Base): - # your sqlalchemy model - pass + # your sqlalchemy model + pass class CRUDIns(CRUDPlus[ModelIns]): - async def create(self, db: AsyncSession) -> Sequence[ModelIns]: - return await self.select_models_order(db, sort_columns=['name', 'age'], sort_orders=['asc', 'desc']) + async def create(self, db: AsyncSession) -> Sequence[ModelIns]: + return await self.select_models_order(db, sort_columns=['name', 'age'], sort_orders=['asc', 'desc']) ``` ## API ```py - async def select_models_order( - self, - session: AsyncSession, - sort_columns: str | list[str], - sort_orders: str | list[str] | None = None, - **kwargs, +async def select_models_order( + self, + session: AsyncSession, + sort_columns: str | list[str], + sort_orders: str | list[str] | None = None, + *whereclause: ColumnExpressionArgument[bool], + **kwargs, ) -> Sequence[Row | RowMapping | Any] | None: ``` **Parameters:** -| Name | Type | Description | Default | -|--------------|--------------------------------|------------------------------------------------------------------------|---------| -| session | AsyncSession | 数据库会话 | 必填 | -| sort_columns | `str `\|` list[str]` | 应用排序的单个列名或列名列表 | 必填 | -| sort_orders | `str `\|` list[str] `\|` None` | 单个排序顺序(asc 或 desc)或与 sort_columns 中的列相对应的排序顺序列表。 如果未提供,则默认每列的排序顺序为 asc | `None` | - -!!! note "**kwargs" - - [条件过滤](../advanced/filter.md),将创建条件查询 SQL +| Name | Type | Description | Default | +|--------------|----------------------------------|------------------------------------------------------------------------------------------------------|---------| +| session | AsyncSession | 数据库会话 | 必填 | +| sort_columns | `str `\|` list[str]` | 应用排序的单个列名或列名列表 | 必填 | +| sort_orders | `str `\|` list[str] `\|` None` | 单个排序顺序(asc 或 desc)或与 sort_columns 中的列相对应的排序顺序列表。 如果未提供,则默认每列的排序顺序为 asc | `None` | +| *whereclause | `ColumnExpressionArgument[bool]` | 等同于 [SQLAlchemy where](https://docs.sqlalchemy.org/en/20/tutorial/data_select.html#the-where-clause) | | +| **kwargs | | [条件过滤](../advanced/filter.md),将创建条件查询 SQL | | **Returns:** diff --git a/docs/usage/select_order.md b/docs/usage/select_order.md index bc91d6b..d3fdbb3 100644 --- a/docs/usage/select_order.md +++ b/docs/usage/select_order.md @@ -5,20 +5,19 @@ async def select_order( self, sort_columns: str | list[str], sort_orders: str | list[str] | None = None, + *whereclause: ColumnExpressionArgument[bool], **kwargs, ) -> Select: ``` **Parameters:** -| Name | Type | Description | Default | -|--------------|--------------------------------|------------------------------------------------------------------------|---------| -| sort_columns | `str `\|` list[str]` | 应用排序的单个列名或列名列表 | 必填 | -| sort_orders | `str `\|` list[str] `\|` None` | 单个排序顺序(asc 或 desc)或与 sort_columns 中的列相对应的排序顺序列表。 如果未提供,则默认每列的排序顺序为 asc | `None` | - -!!! note "**kwargs" - - [条件过滤](../advanced/filter.md),将创建条件查询 SQL +| Name | Type | Description | Default | +|--------------|----------------------------------|------------------------------------------------------------------------------------------------------|---------| +| sort_columns | `str `\|` list[str]` | 应用排序的单个列名或列名列表 | 必填 | +| sort_orders | `str `\|` list[str] `\|` None` | 单个排序顺序(asc 或 desc)或与 sort_columns 中的列相对应的排序顺序列表。 如果未提供,则默认每列的排序顺序为 asc | `None` | +| *whereclause | `ColumnExpressionArgument[bool]` | 等同于 [SQLAlchemy where](https://docs.sqlalchemy.org/en/20/tutorial/data_select.html#the-where-clause) | | +| **kwargs | | [条件过滤](../advanced/filter.md),将创建条件查询 SQL | | **Returns:** diff --git a/sqlalchemy_crud_plus/crud.py b/sqlalchemy_crud_plus/crud.py index f30314f..cecf393 100644 --- a/sqlalchemy_crud_plus/crud.py +++ b/sqlalchemy_crud_plus/crud.py @@ -2,7 +2,18 @@ # -*- coding: utf-8 -*- from typing import Any, Generic, Iterable, Sequence, Type -from sqlalchemy import ColumnElement, Row, RowMapping, Select, delete, func, inspect, select, update +from sqlalchemy import ( + Column, + ColumnExpressionArgument, + Row, + RowMapping, + Select, + delete, + func, + inspect, + select, + update, +) from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy_crud_plus.errors import CompositePrimaryKeysError, MultipleResultsError @@ -15,7 +26,7 @@ def __init__(self, model: Type[Model]): self.model = model self.primary_key = self._get_primary_key() - def _get_primary_key(self): + def _get_primary_key(self) -> Column: """ Dynamically retrieve the primary key column(s) for the model. """ @@ -96,28 +107,23 @@ async def create_models( async def count( self, session: AsyncSession, - filters: ColumnElement | list[ColumnElement] | None = None, + *whereclause: ColumnExpressionArgument[bool], **kwargs, ) -> int: """ Counts records that match specified filters. :param session: The sqlalchemy session to use for the operation. - :param filters: The WHERE clauses to apply to the query. + :param whereclause: The WHERE clauses to apply to the query. :param kwargs: Query expressions. :return: """ - if filters is None: - filters = [] - - if not isinstance(filters, list): - filters = [filters] + filter_list = list(whereclause) if kwargs: - filters.extend(parse_filters(self.model, **kwargs)) + filter_list.extend(parse_filters(self.model, **kwargs)) - stmt = select(func.count()).select_from(self.model) - stmt = stmt.where(*filters) + stmt = select(func.count()).select_from(self.model).where(*filter_list) query = await session.execute(stmt) total_count = query.scalar() return total_count if total_count is not None else 0 @@ -125,70 +131,87 @@ async def count( async def exists( self, session: AsyncSession, - filters: ColumnElement | list[ColumnElement] | None = None, + *whereclause: ColumnExpressionArgument[bool], **kwargs, ) -> bool: """ Whether the records that match the specified filter exist. :param session: The sqlalchemy session to use for the operation. - :param filters: The WHERE clauses to apply to the query. + :param whereclause: The WHERE clauses to apply to the query. :param kwargs: Query expressions. :return: """ - if filters is None: - filters = [] - - if not isinstance(filters, list): - filters = [filters] + filter_list = list(whereclause) if kwargs: - filters.extend(parse_filters(self.model, **kwargs)) + filter_list.extend(parse_filters(self.model, **kwargs)) - stmt = select(self.model).where(*filters).limit(1) + stmt = select(self.model).where(*filter_list).limit(1) query = await session.execute(stmt) return query.scalars().first() is not None - async def select_model(self, session: AsyncSession, pk: int) -> Model | None: + async def select_model( + self, + session: AsyncSession, + pk: int, + *whereclause: ColumnExpressionArgument[bool], + ) -> Model | None: """ Query by ID :param session: The SQLAlchemy async session. :param pk: The database primary key value. + :param whereclause: The WHERE clauses to apply to the query. :return: """ - stmt = select(self.model).where(self.primary_key == pk) + filter_list = list(whereclause) + _filters = [self.primary_key == pk] + _filters.extend(filter_list) + stmt = select(self.model).where(*_filters) query = await session.execute(stmt) return query.scalars().first() - async def select_model_by_column(self, session: AsyncSession, **kwargs) -> Model | None: + async def select_model_by_column( + self, + session: AsyncSession, + *whereclause: ColumnExpressionArgument[bool], + **kwargs, + ) -> Model | None: """ Query by column :param session: The SQLAlchemy async session. + :param whereclause: The WHERE clauses to apply to the query. :param kwargs: Query expressions. :return: """ - filters = parse_filters(self.model, **kwargs) - stmt = select(self.model).where(*filters) + filter_list = list(whereclause) + _filters = parse_filters(self.model, **kwargs) + _filters.extend(filter_list) + stmt = select(self.model).where(*_filters) query = await session.execute(stmt) return query.scalars().first() - async def select(self, **kwargs) -> Select: + async def select(self, *whereclause: ColumnExpressionArgument[bool], **kwargs) -> Select: """ Construct the SQLAlchemy selection + :param whereclause: The WHERE clauses to apply to the query. :param kwargs: Query expressions. :return: """ - filters = parse_filters(self.model, **kwargs) - stmt = select(self.model).where(*filters) + filter_list = list(whereclause) + _filters = parse_filters(self.model, **kwargs) + _filters.extend(filter_list) + stmt = select(self.model).where(*_filters) return stmt async def select_order( self, sort_columns: str | list[str], sort_orders: str | list[str] | None = None, + *whereclause: ColumnExpressionArgument[bool], **kwargs, ) -> Select: """ @@ -196,22 +219,29 @@ async def select_order( :param sort_columns: more details see apply_sorting :param sort_orders: more details see apply_sorting + :param whereclause: The WHERE clauses to apply to the query. :param kwargs: Query expressions. :return: """ - stmt = await self.select(**kwargs) + stmt = await self.select(*whereclause, **kwargs) sorted_stmt = apply_sorting(self.model, stmt, sort_columns, sort_orders) return sorted_stmt - async def select_models(self, session: AsyncSession, **kwargs) -> Sequence[Row[Any] | RowMapping | Any]: + async def select_models( + self, + session: AsyncSession, + *whereclause: ColumnExpressionArgument[bool], + **kwargs, + ) -> Sequence[Row[Any] | RowMapping | Any]: """ Query all rows :param session: The SQLAlchemy async session. + :param whereclause: The WHERE clauses to apply to the query. :param kwargs: Query expressions. :return: """ - stmt = await self.select(**kwargs) + stmt = await self.select(*whereclause, **kwargs) query = await session.execute(stmt) return query.scalars().all() @@ -220,6 +250,7 @@ async def select_models_order( session: AsyncSession, sort_columns: str | list[str], sort_orders: str | list[str] | None = None, + *whereclause: ColumnExpressionArgument[bool], **kwargs, ) -> Sequence[Row | RowMapping | Any] | None: """ @@ -228,10 +259,11 @@ async def select_models_order( :param session: The SQLAlchemy async session. :param sort_columns: more details see apply_sorting :param sort_orders: more details see apply_sorting + :param whereclause: The WHERE clauses to apply to the query. :param kwargs: Query expressions. :return: """ - stmt = await self.select_order(sort_columns, sort_orders, **kwargs) + stmt = await self.select_order(sort_columns, sort_orders, *whereclause, **kwargs) query = await session.execute(stmt) return query.scalars().all() @@ -293,7 +325,7 @@ async def update_model_by_column( :return: """ filters = parse_filters(self.model, **kwargs) - total_count = await self.count(session, filters) + total_count = await self.count(session, *filters) if not allow_multiple and total_count > 1: raise MultipleResultsError(f'Only one record is expected to be update, found {total_count} records.') if isinstance(obj, dict): @@ -360,7 +392,7 @@ async def delete_model_by_column( :return: """ filters = parse_filters(self.model, **kwargs) - total_count = await self.count(session, filters) + total_count = await self.count(session, *filters) if not allow_multiple and total_count > 1: raise MultipleResultsError(f'Only one record is expected to be delete, found {total_count} records.') if logical_deletion: diff --git a/tests/test_select.py b/tests/test_select.py index f2d51cd..9f7dc9c 100644 --- a/tests/test_select.py +++ b/tests/test_select.py @@ -18,11 +18,20 @@ async def test_count(create_test_model, async_db_session): @pytest.mark.asyncio -async def test_count_filters(create_test_model, async_db_session): +async def test_count_filters_one(create_test_model, async_db_session): async with async_db_session() as session: crud = CRUDPlus(Ins) for i in range(1, 10): - result = await crud.count(session, [Ins.id == i]) + result = await crud.count(session, Ins.id == i) + assert result == 1 + + +@pytest.mark.asyncio +async def test_count_filters_list(create_test_model, async_db_session): + async with async_db_session() as session: + crud = CRUDPlus(Ins) + for i in range(1, 10): + result = await crud.count(session, Ins.id == i) assert result == 1 @@ -40,7 +49,7 @@ async def test_exists_filters(create_test_model, async_db_session): async with async_db_session() as session: crud = CRUDPlus(Ins) for i in range(1, 10): - result = await crud.exists(session, [Ins.id == i]) + result = await crud.exists(session, Ins.id == i) assert result is True