Skip to content

Add bulk creation and column update #4

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,5 @@ __pycache__/
venv/
.idea/
dist/
.pytest_cache/
.ruff_cache/
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ repos:
- id: check-toml

- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: v0.4.0
rev: v0.4.5
hooks:
- id: ruff
args:
Expand Down
272 changes: 116 additions & 156 deletions pdm.lock

Large diffs are not rendered by default.

14 changes: 7 additions & 7 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@ nodeenv==1.8.0
packaging==24.0
platformdirs==4.2.0
pluggy==1.5.0
pre-commit==3.5.0
pydantic==2.7.0
pydantic-core==2.18.1
pytest==8.1.1
pytest-asyncio==0.23.6
pre-commit==3.7.1
pydantic==2.7.1
pydantic-core==2.18.2
pytest==8.2.1
pytest-asyncio==0.23.7
pyyaml==6.0.1
ruff==0.4.0
ruff==0.4.5
setuptools==69.5.1
sqlalchemy==2.0.29
sqlalchemy==2.0.30
tomli==2.0.1; python_version < "3.11"
typing-extensions==4.10.0
virtualenv==20.25.3
54 changes: 46 additions & 8 deletions sqlalchemy_crud_plus/crud.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from typing import Any, Generic, Literal, Sequence, Type, TypeVar
from typing import Any, Generic, Iterable, Literal, Sequence, Type, TypeVar

from pydantic import BaseModel
from sqlalchemy import Row, RowMapping, and_, asc, desc, or_, select
Expand Down Expand Up @@ -34,6 +34,19 @@ async def create_model(self, session: AsyncSession, obj: _CreateSchema, **kwargs
instance = self.model(**obj.model_dump())
session.add(instance)

async def create_models(self, session: AsyncSession, obj: Iterable[_CreateSchema]) -> None:
"""
Create new instances of a model

:param session:
:param obj:
:return:
"""
instance_list = []
for i in obj:
instance_list.append(self.model(**i.model_dump()))
session.add_all(instance_list)

async def select_model_by_id(self, session: AsyncSession, pk: int) -> _Model | None:
"""
Query by ID
Expand Down Expand Up @@ -102,7 +115,7 @@ async def select_models_order(
self,
session: AsyncSession,
*columns,
model_sort: Literal['skip', 'asc', 'desc'] = 'skip',
model_sort: Literal['default', 'asc', 'desc'] = 'default',
) -> Sequence[Row | RowMapping | Any] | None:
"""
Query all rows asc or desc
Expand All @@ -112,9 +125,6 @@ async def select_models_order(
:param model_sort:
:return:
"""
if model_sort != 'skip':
if len(columns) != 1:
raise SelectExpressionError('ACS and DESC only allow you to specify one column for sorting')
sort_list = []
for column in columns:
if hasattr(self.model, column):
Expand All @@ -123,7 +133,7 @@ async def select_models_order(
else:
raise ModelColumnError(f'Model column {column} is not found')
match model_sort:
case 'skip':
case 'default':
query = await session.execute(select(self.model).order_by(*sort_list))
case 'asc':
query = await session.execute(select(self.model).order_by(asc(*sort_list)))
Expand All @@ -135,7 +145,7 @@ async def select_models_order(

async def update_model(self, session: AsyncSession, pk: int, obj: _UpdateSchema | dict[str, Any], **kwargs) -> int:
"""
Update an instance of a model
Update an instance of model's primary key

:param session:
:param pk:
Expand All @@ -152,13 +162,41 @@ async def update_model(self, session: AsyncSession, pk: int, obj: _UpdateSchema
result = await session.execute(sa_update(self.model).where(self.model.id == pk).values(**instance_data))
return result.rowcount # type: ignore

async def update_model_by_column(
self, session: AsyncSession, column: str, column_value: Any, obj: _UpdateSchema | dict[str, Any], **kwargs
) -> int:
"""
Update an instance of model column

:param session:
:param column:
:param column_value:
:param obj:
:param kwargs:
:return:
"""
if isinstance(obj, dict):
instance_data = obj
else:
instance_data = obj.model_dump(exclude_unset=True)
if kwargs:
instance_data.update(kwargs)
if hasattr(self.model, column):
model_column = getattr(self.model, column)
else:
raise ModelColumnError(f'Model column {column} is not found')
result = await session.execute(
sa_update(self.model).where(model_column == column_value).values(**instance_data)
)
return result.rowcount # type: ignore

async def delete_model(self, session: AsyncSession, pk: int, **kwargs) -> int:
"""
Delete an instance of a model

:param session:
:param pk:
:param kwargs:
:param kwargs: for soft deletion only
:return:
"""
if not kwargs:
Expand Down
28 changes: 27 additions & 1 deletion tests/test_crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,20 @@ async def test_create_model():
assert query.name == f'test_name_{i}'


@pytest.mark.asyncio
async def test_create_models():
async with async_db_session.begin() as session:
crud = CRUDPlus(Ins)
data = []
for i in range(1, 10):
data.append(ModelSchema(name=f'test_name_{i}'))
await crud.create_models(session, data)
async with async_db_session() as session:
for i in range(1, 10):
query = await session.scalar(select(Ins).where(Ins.id == i))
assert query.name == f'test_name_{i}'


@pytest.mark.asyncio
async def test_select_model_by_id():
await create_test_model()
Expand Down Expand Up @@ -96,7 +110,19 @@ async def test_update_model():
data = ModelSchema(name='test_name_update_1')
result = await crud.update_model(session, 1, data)
assert result == 1
result = await crud.select_model_by_id(session, 1)
result = await session.get(Ins, 1)
assert result.name == 'test_name_update_1'


@pytest.mark.asyncio
async def test_update_model_by_column():
await create_test_model()
async with async_db_session.begin() as session:
crud = CRUDPlus(Ins)
data = ModelSchema(name='test_name_update_1')
result = await crud.update_model_by_column(session, 'name', 'test_name_1', data)
assert result == 1
result = await session.get(Ins, 1)
assert result.name == 'test_name_update_1'


Expand Down