Skip to content

Commit 65e947f

Browse files
committed
优化表的主键支持,select_model,update_model,delete_model 方法支持非int类型主键和复合主键操作
1 parent 7d6b511 commit 65e947f

File tree

4 files changed

+141
-27
lines changed

4 files changed

+141
-27
lines changed

docs/usage/delete_model.md

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ class CRUDIns(CRUDPlus[ModelIns]):
2424
async def delete_model(
2525
self,
2626
session: AsyncSession,
27-
pk: int,
27+
pk: Union[Any, Dict[str, Any]],
2828
flush: bool = False,
2929
commit: bool = False,
3030
) -> int:
@@ -44,3 +44,34 @@ async def delete_model(
4444
| Type | Description |
4545
|------|-------------|
4646
| int | 删除数量 |
47+
48+
49+
## example
50+
51+
```python
52+
# Model with composite primary key
53+
class UserComposite(Base):
54+
__tablename__ = "users_composite"
55+
id = Column(String, primary_key=True)
56+
name = Column(String, primary_key=True)
57+
email = Column(String)
58+
59+
class UserCreate(BaseModel):
60+
id: str
61+
name: str | None
62+
email: str
63+
64+
async def example(session: AsyncSession):
65+
# Composite primary key model
66+
crud_composite = CRUDPlus(UserComposite)
67+
68+
# Create
69+
await crud_composite.create_model(
70+
session, UserCreate(id="123", name="John", email="composite@example.com"), commit=True
71+
)
72+
73+
74+
# Delete by composite primary key (dictionary)
75+
await crud_composite.delete_model(session, {"id": "123", "name": "John"}, commit=True)
76+
77+
```

docs/usage/select_model.md

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ class CRUDIns(CRUDPlus[ModelIns]):
2424
async def select_model(
2525
self,
2626
session: AsyncSession,
27-
pk: int,
27+
pk: Union[Any, Dict[str, Any]],
2828
*whereclause: ColumnExpressionArgument[bool],
2929
) -> Model | None:
3030
```
@@ -42,3 +42,35 @@ async def select_model(
4242
| Type | Description |
4343
|---------------------|-------------|
4444
| `TypeVar `\|` None` | 模型实例 |
45+
46+
47+
## example
48+
49+
```python
50+
# Model with composite primary key
51+
class UserComposite(Base):
52+
__tablename__ = "users_composite"
53+
id = Column(String, primary_key=True)
54+
name = Column(String, primary_key=True)
55+
email = Column(String)
56+
57+
class UserCreate(BaseModel):
58+
id: str
59+
name: str | None
60+
email: str
61+
62+
async def example(session: AsyncSession):
63+
# Composite primary key model
64+
crud_composite = CRUDPlus(UserComposite)
65+
66+
# Create
67+
await crud_composite.create_model(
68+
session, UserCreate(id="123", name="John", email="composite@example.com"), commit=True
69+
)
70+
71+
# Select by composite primary key (dictionary)
72+
user = await crud_composite.select_model(session, {"id": "123", "name": "John"})
73+
print(user.email) # composite@example.com
74+
75+
76+
```

docs/usage/update_model.md

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ class UpdateIns(BaseModel):
2121

2222

2323
class CRUDIns(CRUDPlus[ModelIns]):
24-
async def create(self, db: AsyncSession, pk: int, obj: UpdateIns) -> int:
24+
async def update(self, db: AsyncSession, pk: Union[Any, Dict[str, Any]], obj: UpdateIns) -> int:
2525
return await self.update_model(db, pk, obj)
2626
```
2727

@@ -31,7 +31,7 @@ class CRUDIns(CRUDPlus[ModelIns]):
3131
async def update_model(
3232
self,
3333
session: AsyncSession,
34-
pk: int,
34+
pk: Union[Any, Dict[str, Any]],
3535
obj: UpdateSchema | dict[str, Any],
3636
flush: bool = False,
3737
commit: bool = False,
@@ -70,3 +70,36 @@ async def update_model(
7070
| Type | Description |
7171
|------|-------------|
7272
| int | 更新数量 |
73+
74+
75+
## example
76+
77+
```python
78+
# Model with composite primary key
79+
class UserComposite(Base):
80+
__tablename__ = "users_composite"
81+
id = Column(String, primary_key=True)
82+
name = Column(String, primary_key=True)
83+
email = Column(String)
84+
85+
class UserCreate(BaseModel):
86+
id: str
87+
name: str | None
88+
email: str
89+
90+
async def example(session: AsyncSession):
91+
# Composite primary key model
92+
crud_composite = CRUDPlus(UserComposite)
93+
94+
# Create
95+
await crud_composite.create_model(
96+
session, UserCreate(id="123", name="John", email="composite@example.com"), commit=True
97+
)
98+
99+
# Update by composite primary key (dictionary)
100+
await crud_composite.update_model(
101+
session, {"id": "123", "name": "John"}, {"email": "updated_composite@example.com"}, commit=True
102+
)
103+
104+
105+
```

sqlalchemy_crud_plus/crud.py

Lines changed: 41 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#!/usr/bin/env python3
22
# -*- coding: utf-8 -*-
3-
from typing import Any, Generic, Iterable, Sequence, Type
3+
from typing import Any, Generic, Iterable, Sequence, Type, Union, Dict
44

55
from sqlalchemy import (
66
Column,
@@ -26,16 +26,36 @@ def __init__(self, model: Type[Model]):
2626
self.model = model
2727
self.primary_key = self._get_primary_key()
2828

29-
def _get_primary_key(self) -> Column:
29+
def _get_primary_keys(self) -> list[Column]:
3030
"""
31-
Dynamically retrieve the primary key column(s) for the model.
31+
Retrieve the primary key columns for the model.
3232
"""
3333
mapper = inspect(self.model)
34-
primary_key = mapper.primary_key
35-
if len(primary_key) == 1:
36-
return primary_key[0]
34+
return list(mapper.primary_key)
35+
36+
def _validate_pk_input(self, pk: Union[Any, Dict[str, Any]]) -> Dict[str, Any]:
37+
"""
38+
Validate and normalize primary key input to a dictionary mapping column names to values.
39+
40+
:param pk: A single value for single primary key, or a dictionary for composite primary keys.
41+
:return: Dictionary mapping primary key column names to their values.
42+
"""
43+
pk_columns = [pk_col.name for pk_col in self.primary_keys]
44+
if len(self.primary_keys) == 1:
45+
if isinstance(pk, dict):
46+
if pk_columns[0] not in pk:
47+
raise ValueError(f"Primary key column '{pk_columns[0]}' missing in dictionary")
48+
return {pk_columns[0]: pk[pk_columns[0]]}
49+
return {pk_columns[0]: pk}
3750
else:
38-
raise CompositePrimaryKeysError('Composite primary keys are not supported')
51+
if not isinstance(pk, dict):
52+
raise ValueError(
53+
f"Composite primary keys require a dictionary with keys {pk_columns}, got {type(pk)}"
54+
)
55+
missing = set(pk_columns) - set(pk.keys())
56+
if missing:
57+
raise ValueError(f"Missing primary key columns: {missing}")
58+
return {k: v for k, v in pk.items() if k in pk_columns}
3959

4060
async def create_model(
4161
self,
@@ -154,7 +174,7 @@ async def exists(
154174
async def select_model(
155175
self,
156176
session: AsyncSession,
157-
pk: int,
177+
pk: Union[Any, Dict[str, Any]],
158178
*whereclause: ColumnExpressionArgument[bool],
159179
) -> Model | None:
160180
"""
@@ -165,10 +185,9 @@ async def select_model(
165185
:param whereclause: The WHERE clauses to apply to the query.
166186
:return:
167187
"""
168-
filter_list = list(whereclause)
169-
_filters = [self.primary_key == pk]
170-
_filters.extend(filter_list)
171-
stmt = select(self.model).where(*_filters)
188+
pk_dict = self._validate_pk_input(pk)
189+
filters = [getattr(self.model, col) == val for col, val in pk_dict.items()] + list(whereclause)
190+
stmt = select(self.model).where(*filters)
172191
query = await session.execute(stmt)
173192
return query.scalars().first()
174193

@@ -270,7 +289,7 @@ async def select_models_order(
270289
async def update_model(
271290
self,
272291
session: AsyncSession,
273-
pk: int,
292+
pk: Union[Any, Dict[str, Any]],
274293
obj: UpdateSchema | dict[str, Any],
275294
flush: bool = False,
276295
commit: bool = False,
@@ -287,14 +306,11 @@ async def update_model(
287306
:param kwargs: Additional model data not included in the pydantic schema.
288307
:return:
289308
"""
290-
if isinstance(obj, dict):
291-
instance_data = obj
292-
else:
293-
instance_data = obj.model_dump(exclude_unset=True)
294-
if kwargs:
295-
instance_data.update(kwargs)
296-
297-
stmt = update(self.model).where(self.primary_key == pk).values(**instance_data)
309+
pk_dict = self._validate_pk_input(pk)
310+
instance_data = obj if isinstance(obj, dict) else obj.model_dump(exclude_unset=True)
311+
instance_data.update(kwargs)
312+
filters = [getattr(self.model, col) == val for col, val in pk_dict.items()]
313+
stmt = update(self.model).where(*filters).values(**instance_data)
298314
result = await session.execute(stmt)
299315

300316
if flush:
@@ -346,7 +362,7 @@ async def update_model_by_column(
346362
async def delete_model(
347363
self,
348364
session: AsyncSession,
349-
pk: int,
365+
pk: Union[Any, Dict[str, Any]],
350366
flush: bool = False,
351367
commit: bool = False,
352368
) -> int:
@@ -359,7 +375,9 @@ async def delete_model(
359375
:param commit: If `True`, commits the transaction immediately. Default is `False`.
360376
:return:
361377
"""
362-
stmt = delete(self.model).where(self.primary_key == pk)
378+
pk_dict = self._validate_pk_input(pk)
379+
filters = [getattr(self.model, col) == val for col, val in pk_dict.items()]
380+
stmt = delete(self.model).where(*filters)
363381
result = await session.execute(stmt)
364382

365383
if flush:

0 commit comments

Comments
 (0)