Skip to content

Commit 377e26a

Browse files
authored
Add bulk creation and column update (#4)
1 parent 90827fa commit 377e26a

File tree

6 files changed

+199
-173
lines changed

6 files changed

+199
-173
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,5 @@ __pycache__/
55
venv/
66
.idea/
77
dist/
8+
.pytest_cache/
9+
.ruff_cache/

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ repos:
77
- id: check-toml
88

99
- repo: https://github.com/charliermarsh/ruff-pre-commit
10-
rev: v0.4.0
10+
rev: v0.4.5
1111
hooks:
1212
- id: ruff
1313
args:

pdm.lock

Lines changed: 116 additions & 156 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

requirements.txt

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,15 @@ nodeenv==1.8.0
1515
packaging==24.0
1616
platformdirs==4.2.0
1717
pluggy==1.5.0
18-
pre-commit==3.5.0
19-
pydantic==2.7.0
20-
pydantic-core==2.18.1
21-
pytest==8.1.1
22-
pytest-asyncio==0.23.6
18+
pre-commit==3.7.1
19+
pydantic==2.7.1
20+
pydantic-core==2.18.2
21+
pytest==8.2.1
22+
pytest-asyncio==0.23.7
2323
pyyaml==6.0.1
24-
ruff==0.4.0
24+
ruff==0.4.5
2525
setuptools==69.5.1
26-
sqlalchemy==2.0.29
26+
sqlalchemy==2.0.30
2727
tomli==2.0.1; python_version < "3.11"
2828
typing-extensions==4.10.0
2929
virtualenv==20.25.3

sqlalchemy_crud_plus/crud.py

Lines changed: 46 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#!/usr/bin/env python3
22
# -*- coding: utf-8 -*-
3-
from typing import Any, Generic, Literal, Sequence, Type, TypeVar
3+
from typing import Any, Generic, Iterable, Literal, Sequence, Type, TypeVar
44

55
from pydantic import BaseModel
66
from sqlalchemy import Row, RowMapping, and_, asc, desc, or_, select
@@ -34,6 +34,19 @@ async def create_model(self, session: AsyncSession, obj: _CreateSchema, **kwargs
3434
instance = self.model(**obj.model_dump())
3535
session.add(instance)
3636

37+
async def create_models(self, session: AsyncSession, obj: Iterable[_CreateSchema]) -> None:
38+
"""
39+
Create new instances of a model
40+
41+
:param session:
42+
:param obj:
43+
:return:
44+
"""
45+
instance_list = []
46+
for i in obj:
47+
instance_list.append(self.model(**i.model_dump()))
48+
session.add_all(instance_list)
49+
3750
async def select_model_by_id(self, session: AsyncSession, pk: int) -> _Model | None:
3851
"""
3952
Query by ID
@@ -102,7 +115,7 @@ async def select_models_order(
102115
self,
103116
session: AsyncSession,
104117
*columns,
105-
model_sort: Literal['skip', 'asc', 'desc'] = 'skip',
118+
model_sort: Literal['default', 'asc', 'desc'] = 'default',
106119
) -> Sequence[Row | RowMapping | Any] | None:
107120
"""
108121
Query all rows asc or desc
@@ -112,9 +125,6 @@ async def select_models_order(
112125
:param model_sort:
113126
:return:
114127
"""
115-
if model_sort != 'skip':
116-
if len(columns) != 1:
117-
raise SelectExpressionError('ACS and DESC only allow you to specify one column for sorting')
118128
sort_list = []
119129
for column in columns:
120130
if hasattr(self.model, column):
@@ -123,7 +133,7 @@ async def select_models_order(
123133
else:
124134
raise ModelColumnError(f'Model column {column} is not found')
125135
match model_sort:
126-
case 'skip':
136+
case 'default':
127137
query = await session.execute(select(self.model).order_by(*sort_list))
128138
case 'asc':
129139
query = await session.execute(select(self.model).order_by(asc(*sort_list)))
@@ -135,7 +145,7 @@ async def select_models_order(
135145

136146
async def update_model(self, session: AsyncSession, pk: int, obj: _UpdateSchema | dict[str, Any], **kwargs) -> int:
137147
"""
138-
Update an instance of a model
148+
Update an instance of model's primary key
139149
140150
:param session:
141151
:param pk:
@@ -152,13 +162,41 @@ async def update_model(self, session: AsyncSession, pk: int, obj: _UpdateSchema
152162
result = await session.execute(sa_update(self.model).where(self.model.id == pk).values(**instance_data))
153163
return result.rowcount # type: ignore
154164

165+
async def update_model_by_column(
166+
self, session: AsyncSession, column: str, column_value: Any, obj: _UpdateSchema | dict[str, Any], **kwargs
167+
) -> int:
168+
"""
169+
Update an instance of model column
170+
171+
:param session:
172+
:param column:
173+
:param column_value:
174+
:param obj:
175+
:param kwargs:
176+
:return:
177+
"""
178+
if isinstance(obj, dict):
179+
instance_data = obj
180+
else:
181+
instance_data = obj.model_dump(exclude_unset=True)
182+
if kwargs:
183+
instance_data.update(kwargs)
184+
if hasattr(self.model, column):
185+
model_column = getattr(self.model, column)
186+
else:
187+
raise ModelColumnError(f'Model column {column} is not found')
188+
result = await session.execute(
189+
sa_update(self.model).where(model_column == column_value).values(**instance_data)
190+
)
191+
return result.rowcount # type: ignore
192+
155193
async def delete_model(self, session: AsyncSession, pk: int, **kwargs) -> int:
156194
"""
157195
Delete an instance of a model
158196
159197
:param session:
160198
:param pk:
161-
:param kwargs:
199+
:param kwargs: for soft deletion only
162200
:return:
163201
"""
164202
if not kwargs:

tests/test_crud.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,20 @@ async def test_create_model():
3434
assert query.name == f'test_name_{i}'
3535

3636

37+
@pytest.mark.asyncio
38+
async def test_create_models():
39+
async with async_db_session.begin() as session:
40+
crud = CRUDPlus(Ins)
41+
data = []
42+
for i in range(1, 10):
43+
data.append(ModelSchema(name=f'test_name_{i}'))
44+
await crud.create_models(session, data)
45+
async with async_db_session() as session:
46+
for i in range(1, 10):
47+
query = await session.scalar(select(Ins).where(Ins.id == i))
48+
assert query.name == f'test_name_{i}'
49+
50+
3751
@pytest.mark.asyncio
3852
async def test_select_model_by_id():
3953
await create_test_model()
@@ -96,7 +110,19 @@ async def test_update_model():
96110
data = ModelSchema(name='test_name_update_1')
97111
result = await crud.update_model(session, 1, data)
98112
assert result == 1
99-
result = await crud.select_model_by_id(session, 1)
113+
result = await session.get(Ins, 1)
114+
assert result.name == 'test_name_update_1'
115+
116+
117+
@pytest.mark.asyncio
118+
async def test_update_model_by_column():
119+
await create_test_model()
120+
async with async_db_session.begin() as session:
121+
crud = CRUDPlus(Ins)
122+
data = ModelSchema(name='test_name_update_1')
123+
result = await crud.update_model_by_column(session, 'name', 'test_name_1', data)
124+
assert result == 1
125+
result = await session.get(Ins, 1)
100126
assert result.name == 'test_name_update_1'
101127

102128

0 commit comments

Comments
 (0)