Skip to content

Add commit option to CRUD operations #8

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 4 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,22 +1,17 @@
# sqlalchemy-crud-plus

基于 SQLAlChemy2 模型的异步 CRUD 操作
Asynchronous CRUD operations based on SQLAlChemy 2.0

## 下载
## Download

```shell
pip install sqlalchemy-crud-plus
```

## TODO

- [ ] ...

## Use

以下仅为简易示例

```python
# example:
from sqlalchemy.orm import declarative_base
from sqlalchemy_crud_plus import CRUDPlus

Expand All @@ -34,7 +29,7 @@ class CRUDIns(CRUDPlus[ModelIns]):


# singleton
ins_dao = CRUDIns(ModelIns)
ins_dao: CRUDIns = CRUDIns(ModelIns)
```

## 互动
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ select = [
"I"
]
preview = true
ignore-init-module-imports = true

[tool.ruff.lint.isort]
lines-between-types = 1
Expand Down
100 changes: 68 additions & 32 deletions sqlalchemy_crud_plus/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,33 +19,43 @@ class CRUDPlus(Generic[_Model]):
def __init__(self, model: Type[_Model]):
self.model = model

async def create_model(self, session: AsyncSession, obj: _CreateSchema, **kwargs) -> None:
async def create_model(self, session: AsyncSession, obj: _CreateSchema, commit: bool = False, **kwargs) -> _Model:
"""
Create a new instance of a model

:param session:
:param obj:
:param commit:
:param kwargs:
:return:
"""
if kwargs:
instance = self.model(**obj.model_dump(), **kwargs)
ins = self.model(**obj.model_dump(), **kwargs)
else:
instance = self.model(**obj.model_dump())
session.add(instance)
ins = self.model(**obj.model_dump())
session.add(ins)
if commit:
await session.commit()
return ins

async def create_models(self, session: AsyncSession, obj: Iterable[_CreateSchema]) -> None:
async def create_models(
self, session: AsyncSession, obj: Iterable[_CreateSchema], commit: bool = False
) -> list[_Model]:
"""
Create new instances of a model

:param session:
:param obj:
:param commit:
:return:
"""
instance_list = []
ins_list = []
for i in obj:
instance_list.append(self.model(**i.model_dump()))
session.add_all(instance_list)
ins_list.append(self.model(**i.model_dump()))
session.add_all(ins_list)
if commit:
await session.commit()
return ins_list

async def select_model_by_id(self, session: AsyncSession, pk: int) -> _Model | None:
"""
Expand All @@ -55,7 +65,8 @@ async def select_model_by_id(self, session: AsyncSession, pk: int) -> _Model | N
:param pk:
:return:
"""
query = await session.execute(select(self.model).where(self.model.id == pk))
stmt = select(self.model).where(self.model.id == pk)
query = await session.execute(stmt)
return query.scalars().first()

async def select_model_by_column(self, session: AsyncSession, column: str, column_value: Any) -> _Model | None:
Expand All @@ -69,10 +80,11 @@ async def select_model_by_column(self, session: AsyncSession, column: str, colum
"""
if hasattr(self.model, column):
model_column = getattr(self.model, column)
query = await session.execute(select(self.model).where(model_column == column_value)) # type: ignore
stmt = select(self.model).where(model_column == column_value) # type: ignore
query = await session.execute(stmt)
return query.scalars().first()
else:
raise ModelColumnError(f'Model column {column} is not found')
raise ModelColumnError(f'Column {column} is not found in {self.model}')

async def select_model_by_columns(
self, session: AsyncSession, expression: Literal['and', 'or'] = 'and', **conditions
Expand All @@ -91,31 +103,36 @@ async def select_model_by_columns(
model_column = getattr(self.model, column)
where_list.append(model_column == value)
else:
raise ModelColumnError(f'Model column {column} is not found')
raise ModelColumnError(f'Column {column} is not found in {self.model}')
match expression:
case 'and':
query = await session.execute(select(self.model).where(and_(*where_list)))
stmt = select(self.model).where(and_(*where_list))
query = await session.execute(stmt)
case 'or':
query = await session.execute(select(self.model).where(or_(*where_list)))
stmt = select(self.model).where(or_(*where_list))
query = await session.execute(stmt)
case _:
raise SelectExpressionError(f'select expression {expression} is not supported')
raise SelectExpressionError(
f'Select expression {expression} is not supported, only supports `and`, `or`'
)
return query.scalars().first()

async def select_models(self, session: AsyncSession) -> Sequence[Row | RowMapping | Any] | None:
async def select_models(self, session: AsyncSession) -> Sequence[Row[Any] | RowMapping | Any]:
"""
Query all rows

:param session:
:return:
"""
query = await session.execute(select(self.model))
stmt = select(self.model)
query = await session.execute(stmt)
return query.scalars().all()

async def select_models_order(
self,
session: AsyncSession,
*columns,
model_sort: Literal['default', 'asc', 'desc'] = 'default',
model_sort: Literal['asc', 'desc'] = 'desc',
) -> Sequence[Row | RowMapping | Any] | None:
"""
Query all rows asc or desc
Expand All @@ -131,25 +148,28 @@ async def select_models_order(
model_column = getattr(self.model, column)
sort_list.append(model_column)
else:
raise ModelColumnError(f'Model column {column} is not found')
raise ModelColumnError(f'Column {column} is not found in {self.model}')
match model_sort:
case 'default':
query = await session.execute(select(self.model).order_by(*sort_list))
case 'asc':
query = await session.execute(select(self.model).order_by(asc(*sort_list)))
case 'desc':
query = await session.execute(select(self.model).order_by(desc(*sort_list)))
case _:
raise SelectExpressionError(f'select sort expression {model_sort} is not supported')
raise SelectExpressionError(
f'Select sort expression {model_sort} is not supported, only supports `asc`, `desc`'
)
return query.scalars().all()

async def update_model(self, session: AsyncSession, pk: int, obj: _UpdateSchema | dict[str, Any], **kwargs) -> int:
async def update_model(
self, session: AsyncSession, pk: int, obj: _UpdateSchema | dict[str, Any], commit: bool = False, **kwargs
) -> int:
"""
Update an instance of model's primary key

:param session:
:param pk:
:param obj:
:param commit:
:param kwargs:
:return:
"""
Expand All @@ -159,11 +179,20 @@ async def update_model(self, session: AsyncSession, pk: int, obj: _UpdateSchema
instance_data = obj.model_dump(exclude_unset=True)
if kwargs:
instance_data.update(kwargs)
result = await session.execute(sa_update(self.model).where(self.model.id == pk).values(**instance_data))
stmt = sa_update(self.model).where(self.model.id == pk).values(**instance_data)
result = await session.execute(stmt)
if commit:
await session.commit()
return result.rowcount # type: ignore

async def update_model_by_column(
self, session: AsyncSession, column: str, column_value: Any, obj: _UpdateSchema | dict[str, Any], **kwargs
self,
session: AsyncSession,
column: str,
column_value: Any,
obj: _UpdateSchema | dict[str, Any],
commit: bool = False,
**kwargs,
) -> int:
"""
Update an instance of model column
Expand All @@ -172,6 +201,7 @@ async def update_model_by_column(
:param column:
:param column_value:
:param obj:
:param commit:
:param kwargs:
:return:
"""
Expand All @@ -184,23 +214,29 @@ async def update_model_by_column(
if hasattr(self.model, column):
model_column = getattr(self.model, column)
else:
raise ModelColumnError(f'Model column {column} is not found')
result = await session.execute(
sa_update(self.model).where(model_column == column_value).values(**instance_data)
)
raise ModelColumnError(f'Column {column} is not found in {self.model}')
stmt = sa_update(self.model).where(model_column == column_value).values(**instance_data) # type: ignore
result = await session.execute(stmt)
if commit:
await session.commit()
return result.rowcount # type: ignore

async def delete_model(self, session: AsyncSession, pk: int, **kwargs) -> int:
async def delete_model(self, session: AsyncSession, pk: int, commit: bool = False, **kwargs) -> int:
"""
Delete an instance of a model

:param session:
:param pk:
:param commit:
:param kwargs: for soft deletion only
:return:
"""
if not kwargs:
result = await session.execute(sa_delete(self.model).where(self.model.id == pk))
stmt = sa_delete(self.model).where(self.model.id == pk)
result = await session.execute(stmt)
else:
result = await session.execute(sa_update(self.model).where(self.model.id == pk).values(**kwargs))
stmt = sa_update(self.model).where(self.model.id == pk).values(**kwargs)
result = await session.execute(stmt)
if commit:
await session.commit()
return result.rowcount # type: ignore