1
1
#!/usr/bin/env python3
2
2
# -*- coding: utf-8 -*-
3
- from typing import Any , Generic , Iterable , Sequence , Type
3
+ from typing import Any , Generic , Iterable , Sequence , Type , Union , Dict
4
4
5
5
from sqlalchemy import (
6
6
Column ,
@@ -26,16 +26,36 @@ def __init__(self, model: Type[Model]):
26
26
self .model = model
27
27
self .primary_key = self ._get_primary_key ()
28
28
29
- def _get_primary_key (self ) -> Column :
29
+ def _get_primary_keys (self ) -> list [ Column ] :
30
30
"""
31
- Dynamically retrieve the primary key column(s) for the model.
31
+ Retrieve the primary key columns for the model.
32
32
"""
33
33
mapper = inspect (self .model )
34
- primary_key = mapper .primary_key
35
- if len (primary_key ) == 1 :
36
- return primary_key [0 ]
34
+ return list (mapper .primary_key )
35
+
36
+ def _validate_pk_input (self , pk : Union [Any , Dict [str , Any ]]) -> Dict [str , Any ]:
37
+ """
38
+ Validate and normalize primary key input to a dictionary mapping column names to values.
39
+
40
+ :param pk: A single value for single primary key, or a dictionary for composite primary keys.
41
+ :return: Dictionary mapping primary key column names to their values.
42
+ """
43
+ pk_columns = [pk_col .name for pk_col in self .primary_keys ]
44
+ if len (self .primary_keys ) == 1 :
45
+ if isinstance (pk , dict ):
46
+ if pk_columns [0 ] not in pk :
47
+ raise ValueError (f"Primary key column '{ pk_columns [0 ]} ' missing in dictionary" )
48
+ return {pk_columns [0 ]: pk [pk_columns [0 ]]}
49
+ return {pk_columns [0 ]: pk }
37
50
else :
38
- raise CompositePrimaryKeysError ('Composite primary keys are not supported' )
51
+ if not isinstance (pk , dict ):
52
+ raise ValueError (
53
+ f"Composite primary keys require a dictionary with keys { pk_columns } , got { type (pk )} "
54
+ )
55
+ missing = set (pk_columns ) - set (pk .keys ())
56
+ if missing :
57
+ raise ValueError (f"Missing primary key columns: { missing } " )
58
+ return {k : v for k , v in pk .items () if k in pk_columns }
39
59
40
60
async def create_model (
41
61
self ,
@@ -154,7 +174,7 @@ async def exists(
154
174
async def select_model (
155
175
self ,
156
176
session : AsyncSession ,
157
- pk : int ,
177
+ pk : Union [ Any , Dict [ str , Any ]] ,
158
178
* whereclause : ColumnExpressionArgument [bool ],
159
179
) -> Model | None :
160
180
"""
@@ -165,10 +185,9 @@ async def select_model(
165
185
:param whereclause: The WHERE clauses to apply to the query.
166
186
:return:
167
187
"""
168
- filter_list = list (whereclause )
169
- _filters = [self .primary_key == pk ]
170
- _filters .extend (filter_list )
171
- stmt = select (self .model ).where (* _filters )
188
+ pk_dict = self ._validate_pk_input (pk )
189
+ filters = [getattr (self .model , col ) == val for col , val in pk_dict .items ()] + list (whereclause )
190
+ stmt = select (self .model ).where (* filters )
172
191
query = await session .execute (stmt )
173
192
return query .scalars ().first ()
174
193
@@ -270,7 +289,7 @@ async def select_models_order(
270
289
async def update_model (
271
290
self ,
272
291
session : AsyncSession ,
273
- pk : int ,
292
+ pk : Union [ Any , Dict [ str , Any ]] ,
274
293
obj : UpdateSchema | dict [str , Any ],
275
294
flush : bool = False ,
276
295
commit : bool = False ,
@@ -287,14 +306,11 @@ async def update_model(
287
306
:param kwargs: Additional model data not included in the pydantic schema.
288
307
:return:
289
308
"""
290
- if isinstance (obj , dict ):
291
- instance_data = obj
292
- else :
293
- instance_data = obj .model_dump (exclude_unset = True )
294
- if kwargs :
295
- instance_data .update (kwargs )
296
-
297
- stmt = update (self .model ).where (self .primary_key == pk ).values (** instance_data )
309
+ pk_dict = self ._validate_pk_input (pk )
310
+ instance_data = obj if isinstance (obj , dict ) else obj .model_dump (exclude_unset = True )
311
+ instance_data .update (kwargs )
312
+ filters = [getattr (self .model , col ) == val for col , val in pk_dict .items ()]
313
+ stmt = update (self .model ).where (* filters ).values (** instance_data )
298
314
result = await session .execute (stmt )
299
315
300
316
if flush :
@@ -346,7 +362,7 @@ async def update_model_by_column(
346
362
async def delete_model (
347
363
self ,
348
364
session : AsyncSession ,
349
- pk : int ,
365
+ pk : Union [ Any , Dict [ str , Any ]] ,
350
366
flush : bool = False ,
351
367
commit : bool = False ,
352
368
) -> int :
@@ -359,7 +375,9 @@ async def delete_model(
359
375
:param commit: If `True`, commits the transaction immediately. Default is `False`.
360
376
:return:
361
377
"""
362
- stmt = delete (self .model ).where (self .primary_key == pk )
378
+ pk_dict = self ._validate_pk_input (pk )
379
+ filters = [getattr (self .model , col ) == val for col , val in pk_dict .items ()]
380
+ stmt = delete (self .model ).where (* filters )
363
381
result = await session .execute (stmt )
364
382
365
383
if flush :
0 commit comments