diff --git a/README.md b/README.md index d7137ae..67041b0 100644 --- a/README.md +++ b/README.md @@ -1,22 +1,17 @@ # sqlalchemy-crud-plus -基于 SQLAlChemy2 模型的异步 CRUD 操作 +Asynchronous CRUD operations based on SQLAlChemy 2.0 -## 下载 +## Download ```shell pip install sqlalchemy-crud-plus ``` -## TODO - -- [ ] ... - ## Use -以下仅为简易示例 - ```python +# example: from sqlalchemy.orm import declarative_base from sqlalchemy_crud_plus import CRUDPlus @@ -34,7 +29,7 @@ class CRUDIns(CRUDPlus[ModelIns]): # singleton -ins_dao = CRUDIns(ModelIns) +ins_dao: CRUDIns = CRUDIns(ModelIns) ``` ## 互动 diff --git a/pyproject.toml b/pyproject.toml index 4f92917..49f93c3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,7 +49,6 @@ select = [ "I" ] preview = true -ignore-init-module-imports = true [tool.ruff.lint.isort] lines-between-types = 1 diff --git a/sqlalchemy_crud_plus/crud.py b/sqlalchemy_crud_plus/crud.py index 3a24e97..11b6c1f 100644 --- a/sqlalchemy_crud_plus/crud.py +++ b/sqlalchemy_crud_plus/crud.py @@ -19,33 +19,43 @@ class CRUDPlus(Generic[_Model]): def __init__(self, model: Type[_Model]): self.model = model - async def create_model(self, session: AsyncSession, obj: _CreateSchema, **kwargs) -> None: + async def create_model(self, session: AsyncSession, obj: _CreateSchema, commit: bool = False, **kwargs) -> _Model: """ Create a new instance of a model :param session: :param obj: + :param commit: :param kwargs: :return: """ if kwargs: - instance = self.model(**obj.model_dump(), **kwargs) + ins = self.model(**obj.model_dump(), **kwargs) else: - instance = self.model(**obj.model_dump()) - session.add(instance) + ins = self.model(**obj.model_dump()) + session.add(ins) + if commit: + await session.commit() + return ins - async def create_models(self, session: AsyncSession, obj: Iterable[_CreateSchema]) -> None: + async def create_models( + self, session: AsyncSession, obj: Iterable[_CreateSchema], commit: bool = False + ) -> list[_Model]: """ Create new instances of a model :param session: :param obj: + :param commit: :return: """ - instance_list = [] + ins_list = [] for i in obj: - instance_list.append(self.model(**i.model_dump())) - session.add_all(instance_list) + ins_list.append(self.model(**i.model_dump())) + session.add_all(ins_list) + if commit: + await session.commit() + return ins_list async def select_model_by_id(self, session: AsyncSession, pk: int) -> _Model | None: """ @@ -55,7 +65,8 @@ async def select_model_by_id(self, session: AsyncSession, pk: int) -> _Model | N :param pk: :return: """ - query = await session.execute(select(self.model).where(self.model.id == pk)) + stmt = select(self.model).where(self.model.id == pk) + query = await session.execute(stmt) return query.scalars().first() async def select_model_by_column(self, session: AsyncSession, column: str, column_value: Any) -> _Model | None: @@ -69,10 +80,11 @@ async def select_model_by_column(self, session: AsyncSession, column: str, colum """ if hasattr(self.model, column): model_column = getattr(self.model, column) - query = await session.execute(select(self.model).where(model_column == column_value)) # type: ignore + stmt = select(self.model).where(model_column == column_value) # type: ignore + query = await session.execute(stmt) return query.scalars().first() else: - raise ModelColumnError(f'Model column {column} is not found') + raise ModelColumnError(f'Column {column} is not found in {self.model}') async def select_model_by_columns( self, session: AsyncSession, expression: Literal['and', 'or'] = 'and', **conditions @@ -91,31 +103,36 @@ async def select_model_by_columns( model_column = getattr(self.model, column) where_list.append(model_column == value) else: - raise ModelColumnError(f'Model column {column} is not found') + raise ModelColumnError(f'Column {column} is not found in {self.model}') match expression: case 'and': - query = await session.execute(select(self.model).where(and_(*where_list))) + stmt = select(self.model).where(and_(*where_list)) + query = await session.execute(stmt) case 'or': - query = await session.execute(select(self.model).where(or_(*where_list))) + stmt = select(self.model).where(or_(*where_list)) + query = await session.execute(stmt) case _: - raise SelectExpressionError(f'select expression {expression} is not supported') + raise SelectExpressionError( + f'Select expression {expression} is not supported, only supports `and`, `or`' + ) return query.scalars().first() - async def select_models(self, session: AsyncSession) -> Sequence[Row | RowMapping | Any] | None: + async def select_models(self, session: AsyncSession) -> Sequence[Row[Any] | RowMapping | Any]: """ Query all rows :param session: :return: """ - query = await session.execute(select(self.model)) + stmt = select(self.model) + query = await session.execute(stmt) return query.scalars().all() async def select_models_order( self, session: AsyncSession, *columns, - model_sort: Literal['default', 'asc', 'desc'] = 'default', + model_sort: Literal['asc', 'desc'] = 'desc', ) -> Sequence[Row | RowMapping | Any] | None: """ Query all rows asc or desc @@ -131,25 +148,28 @@ async def select_models_order( model_column = getattr(self.model, column) sort_list.append(model_column) else: - raise ModelColumnError(f'Model column {column} is not found') + raise ModelColumnError(f'Column {column} is not found in {self.model}') match model_sort: - case 'default': - query = await session.execute(select(self.model).order_by(*sort_list)) case 'asc': query = await session.execute(select(self.model).order_by(asc(*sort_list))) case 'desc': query = await session.execute(select(self.model).order_by(desc(*sort_list))) case _: - raise SelectExpressionError(f'select sort expression {model_sort} is not supported') + raise SelectExpressionError( + f'Select sort expression {model_sort} is not supported, only supports `asc`, `desc`' + ) return query.scalars().all() - async def update_model(self, session: AsyncSession, pk: int, obj: _UpdateSchema | dict[str, Any], **kwargs) -> int: + async def update_model( + self, session: AsyncSession, pk: int, obj: _UpdateSchema | dict[str, Any], commit: bool = False, **kwargs + ) -> int: """ Update an instance of model's primary key :param session: :param pk: :param obj: + :param commit: :param kwargs: :return: """ @@ -159,11 +179,20 @@ async def update_model(self, session: AsyncSession, pk: int, obj: _UpdateSchema instance_data = obj.model_dump(exclude_unset=True) if kwargs: instance_data.update(kwargs) - result = await session.execute(sa_update(self.model).where(self.model.id == pk).values(**instance_data)) + stmt = sa_update(self.model).where(self.model.id == pk).values(**instance_data) + result = await session.execute(stmt) + if commit: + await session.commit() return result.rowcount # type: ignore async def update_model_by_column( - self, session: AsyncSession, column: str, column_value: Any, obj: _UpdateSchema | dict[str, Any], **kwargs + self, + session: AsyncSession, + column: str, + column_value: Any, + obj: _UpdateSchema | dict[str, Any], + commit: bool = False, + **kwargs, ) -> int: """ Update an instance of model column @@ -172,6 +201,7 @@ async def update_model_by_column( :param column: :param column_value: :param obj: + :param commit: :param kwargs: :return: """ @@ -184,23 +214,29 @@ async def update_model_by_column( if hasattr(self.model, column): model_column = getattr(self.model, column) else: - raise ModelColumnError(f'Model column {column} is not found') - result = await session.execute( - sa_update(self.model).where(model_column == column_value).values(**instance_data) - ) + raise ModelColumnError(f'Column {column} is not found in {self.model}') + stmt = sa_update(self.model).where(model_column == column_value).values(**instance_data) # type: ignore + result = await session.execute(stmt) + if commit: + await session.commit() return result.rowcount # type: ignore - async def delete_model(self, session: AsyncSession, pk: int, **kwargs) -> int: + async def delete_model(self, session: AsyncSession, pk: int, commit: bool = False, **kwargs) -> int: """ Delete an instance of a model :param session: :param pk: + :param commit: :param kwargs: for soft deletion only :return: """ if not kwargs: - result = await session.execute(sa_delete(self.model).where(self.model.id == pk)) + stmt = sa_delete(self.model).where(self.model.id == pk) + result = await session.execute(stmt) else: - result = await session.execute(sa_update(self.model).where(self.model.id == pk).values(**kwargs)) + stmt = sa_update(self.model).where(self.model.id == pk).values(**kwargs) + result = await session.execute(stmt) + if commit: + await session.commit() return result.rowcount # type: ignore