2
2
# -*- coding: utf-8 -*-
3
3
from typing import Any , Generic , Iterable , Sequence , Type
4
4
5
- from sqlalchemy import Row , RowMapping , select
6
- from sqlalchemy import delete as sa_delete
7
- from sqlalchemy import update as sa_update
5
+ from sqlalchemy import Row , RowMapping , Select , delete , select , update
8
6
from sqlalchemy .ext .asyncio import AsyncSession
9
7
10
8
from sqlalchemy_crud_plus .errors import MultipleResultsError
@@ -16,7 +14,13 @@ class CRUDPlus(Generic[Model]):
16
14
def __init__ (self , model : Type [Model ]):
17
15
self .model = model
18
16
19
- async def create_model (self , session : AsyncSession , obj : CreateSchema , commit : bool = False , ** kwargs ) -> Model :
17
+ async def create_model (
18
+ self ,
19
+ session : AsyncSession ,
20
+ obj : CreateSchema ,
21
+ commit : bool = False ,
22
+ ** kwargs ,
23
+ ) -> Model :
20
24
"""
21
25
Create a new instance of a model
22
26
@@ -36,7 +40,10 @@ async def create_model(self, session: AsyncSession, obj: CreateSchema, commit: b
36
40
return ins
37
41
38
42
async def create_models (
39
- self , session : AsyncSession , obj : Iterable [CreateSchema ], commit : bool = False
43
+ self ,
44
+ session : AsyncSession ,
45
+ obj : Iterable [CreateSchema ],
46
+ commit : bool = False ,
40
47
) -> list [Model ]:
41
48
"""
42
49
Create new instances of a model
@@ -79,6 +86,35 @@ async def select_model_by_column(self, session: AsyncSession, **kwargs) -> Model
79
86
query = await session .execute (stmt )
80
87
return query .scalars ().first ()
81
88
89
+ async def select (self , ** kwargs ) -> Select :
90
+ """
91
+ Construct the SQLAlchemy selection
92
+
93
+ :param kwargs: Query expressions.
94
+ :return:
95
+ """
96
+ filters = parse_filters (self .model , ** kwargs )
97
+ stmt = select (self .model ).where (* filters )
98
+ return stmt
99
+
100
+ async def select_order (
101
+ self ,
102
+ sort_columns : str | list [str ],
103
+ sort_orders : str | list [str ] | None = None ,
104
+ ** kwargs ,
105
+ ) -> Select :
106
+ """
107
+ Constructing SQLAlchemy selection with sorting
108
+
109
+ :param kwargs: Query expressions.
110
+ :param sort_columns: more details see apply_sorting
111
+ :param sort_orders: more details see apply_sorting
112
+ :return:
113
+ """
114
+ stmt = await self .select (** kwargs )
115
+ sorted_stmt = apply_sorting (self .model , stmt , sort_columns , sort_orders )
116
+ return sorted_stmt
117
+
82
118
async def select_models (self , session : AsyncSession , ** kwargs ) -> Sequence [Row [Any ] | RowMapping | Any ]:
83
119
"""
84
120
Query all rows
@@ -87,13 +123,16 @@ async def select_models(self, session: AsyncSession, **kwargs) -> Sequence[Row[A
87
123
:param kwargs: Query expressions.
88
124
:return:
89
125
"""
90
- filters = parse_filters (self .model , ** kwargs )
91
- stmt = select (self .model ).where (* filters )
126
+ stmt = await self .select (** kwargs )
92
127
query = await session .execute (stmt )
93
128
return query .scalars ().all ()
94
129
95
130
async def select_models_order (
96
- self , session : AsyncSession , sort_columns : str | list [str ], sort_orders : str | list [str ] | None = None , ** kwargs
131
+ self ,
132
+ session : AsyncSession ,
133
+ sort_columns : str | list [str ],
134
+ sort_orders : str | list [str ] | None = None ,
135
+ ** kwargs ,
97
136
) -> Sequence [Row | RowMapping | Any ] | None :
98
137
"""
99
138
Query all rows and sort by columns
@@ -103,14 +142,16 @@ async def select_models_order(
103
142
:param sort_orders: more details see apply_sorting
104
143
:return:
105
144
"""
106
- filters = parse_filters (self .model , ** kwargs )
107
- stmt = select (self .model ).where (* filters )
108
- stmt_sort = apply_sorting (self .model , stmt , sort_columns , sort_orders )
109
- query = await session .execute (stmt_sort )
145
+ stmt = await self .select_order (sort_columns , sort_orders , ** kwargs )
146
+ query = await session .execute (stmt )
110
147
return query .scalars ().all ()
111
148
112
149
async def update_model (
113
- self , session : AsyncSession , pk : int , obj : UpdateSchema | dict [str , Any ], commit : bool = False
150
+ self ,
151
+ session : AsyncSession ,
152
+ pk : int ,
153
+ obj : UpdateSchema | dict [str , Any ],
154
+ commit : bool = False ,
114
155
) -> int :
115
156
"""
116
157
Update an instance by model's primary key
@@ -125,7 +166,7 @@ async def update_model(
125
166
instance_data = obj
126
167
else :
127
168
instance_data = obj .model_dump (exclude_unset = True )
128
- stmt = sa_update (self .model ).where (self .model .id == pk ).values (** instance_data )
169
+ stmt = update (self .model ).where (self .model .id == pk ).values (** instance_data )
129
170
result = await session .execute (stmt )
130
171
if commit :
131
172
await session .commit ()
@@ -157,13 +198,18 @@ async def update_model_by_column(
157
198
instance_data = obj
158
199
else :
159
200
instance_data = obj .model_dump (exclude_unset = True )
160
- stmt = sa_update (self .model ).where (* filters ).values (** instance_data ) # type: ignore
201
+ stmt = update (self .model ).where (* filters ).values (** instance_data ) # type: ignore
161
202
result = await session .execute (stmt )
162
203
if commit :
163
204
await session .commit ()
164
205
return result .rowcount # type: ignore
165
206
166
- async def delete_model (self , session : AsyncSession , pk : int , commit : bool = False ) -> int :
207
+ async def delete_model (
208
+ self ,
209
+ session : AsyncSession ,
210
+ pk : int ,
211
+ commit : bool = False ,
212
+ ) -> int :
167
213
"""
168
214
Delete an instance by model's primary key
169
215
@@ -172,7 +218,7 @@ async def delete_model(self, session: AsyncSession, pk: int, commit: bool = Fals
172
218
:param commit: If `True`, commits the transaction immediately. Default is `False`.
173
219
:return:
174
220
"""
175
- stmt = sa_delete (self .model ).where (self .model .id == pk )
221
+ stmt = delete (self .model ).where (self .model .id == pk )
176
222
result = await session .execute (stmt )
177
223
if commit :
178
224
await session .commit ()
@@ -204,10 +250,10 @@ async def delete_model_by_column(
204
250
raise MultipleResultsError (f'Only one record is expected to be delete, found { total_count } records.' )
205
251
if logical_deletion :
206
252
deleted_flag = {deleted_flag_column : True }
207
- stmt = sa_update (self .model ).where (* filters ).values (** deleted_flag )
253
+ stmt = update (self .model ).where (* filters ).values (** deleted_flag )
208
254
else :
209
- stmt = sa_delete (self .model ).where (* filters )
210
- await session .execute (stmt )
255
+ stmt = delete (self .model ).where (* filters )
256
+ result = await session .execute (stmt )
211
257
if commit :
212
258
await session .commit ()
213
- return total_count
259
+ return result . rowcount # type: ignore
0 commit comments