Skip to content

Update model primary key for dynamic retrieval #23

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Aug 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion docs/advanced/primary_key.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,22 @@
!!! note 主键参数命名

由于在 python 内部 id 的特殊性,我们设定 pk (参考 Django) 作为模型主键命名,所以在 crud 方法中,任何涉及到主键的地方,入参都为 `pk`

```py title="e.g." hl_lines="2"
async def delete(self, db: AsyncSession, primary_key: int) -> int:
return self.delete_model(db, pk=primary_key)
```

## 主键定义

!!! warning 自动主键

我们在 SQLAlchemy CRUD Plus 内部通过 [inspect()](https://docs.sqlalchemy.org/en/20/core/inspection.html) 自动搜索表主键,
而非强制绑定主键列必须命名为 id,感谢 [@DavidSche](https://github.com/DavidSche) 提供帮助

```py title="e.g." hl_lines="4"
class ModelIns(Base):
# your sqlalchemy model
# define your primary_key
custom_id: Mapped[int] = mapped_column(primary_key=True, index=True, autoincrement=True)
```
4 changes: 3 additions & 1 deletion docs/usage/delete_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@ from pydantic import BaseModel

from sqlalchemy_crud_plus import CRUDPlus

from sqlalchemy import Mapped, mapped_column
from sqlalchemy import DeclarativeBase as Base
from sqlalchemy.ext.asyncio import AsyncSession


class ModelIns(Base):
# your sqlalchemy model
pass
# define your primary_key
custom_id: Mapped[int] = mapped_column(primary_key=True, index=True, autoincrement=True)


class CreateIns(BaseModel):
Expand Down
6 changes: 4 additions & 2 deletions docs/usage/select_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@ from pydantic import BaseModel

from sqlalchemy_crud_plus import CRUDPlus

from sqlalchemy import Mapped, mapped_column
from sqlalchemy import DeclarativeBase as Base
from sqlalchemy.ext.asyncio import AsyncSession


class ModelIns(Base):
# your sqlalchemy model
pass
# define your primary_key
custom_id: Mapped[int] = mapped_column(primary_key=True, index=True, autoincrement=True)


class CreateIns(BaseModel):
Expand All @@ -30,6 +32,6 @@ class CreateIns(BaseModel):


class CRUDIns(CRUDPlus[ModelIns]):
async def create(self, db: AsyncSession, pk: int) -> ModelIns:
async def select(self, db: AsyncSession, pk: int) -> ModelIns:
return await self.select_model(db, pk)
```
4 changes: 3 additions & 1 deletion docs/usage/update_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,15 @@ from pydantic import BaseModel

from sqlalchemy_crud_plus import CRUDPlus

from sqlalchemy import Mapped, mapped_column
from sqlalchemy import DeclarativeBase as Base
from sqlalchemy.ext.asyncio import AsyncSession


class ModelIns(Base):
# your sqlalchemy model
pass
# define your primary_key
custom_id: Mapped[int] = mapped_column(primary_key=True, index=True, autoincrement=True)


class UpdateIns(BaseModel):
Expand Down
22 changes: 17 additions & 5 deletions sqlalchemy_crud_plus/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,29 @@
# -*- coding: utf-8 -*-
from typing import Any, Generic, Iterable, Sequence, Type

from sqlalchemy import Row, RowMapping, Select, delete, select, update
from sqlalchemy import Row, RowMapping, Select, delete, inspect, select, update
from sqlalchemy.ext.asyncio import AsyncSession

from sqlalchemy_crud_plus.errors import MultipleResultsError
from sqlalchemy_crud_plus.errors import CompositePrimaryKeysError, MultipleResultsError
from sqlalchemy_crud_plus.types import CreateSchema, Model, UpdateSchema
from sqlalchemy_crud_plus.utils import apply_sorting, count, parse_filters


class CRUDPlus(Generic[Model]):
def __init__(self, model: Type[Model]):
self.model = model
self.primary_key = self._get_primary_key()

def _get_primary_key(self):
"""
Dynamically retrieve the primary key column(s) for the model.
"""
mapper = inspect(self.model)
primary_key = mapper.primary_key
if len(primary_key) == 1:
return primary_key[0]
else:
raise CompositePrimaryKeysError('Composite primary keys are not supported')

async def create_model(
self,
Expand Down Expand Up @@ -69,7 +81,7 @@ async def select_model(self, session: AsyncSession, pk: int) -> Model | None:
:param pk: The database primary key value.
:return:
"""
stmt = select(self.model).where(self.model.id == pk)
stmt = select(self.model).where(self.primary_key == pk)
query = await session.execute(stmt)
return query.scalars().first()

Expand Down Expand Up @@ -166,7 +178,7 @@ async def update_model(
instance_data = obj
else:
instance_data = obj.model_dump(exclude_unset=True)
stmt = update(self.model).where(self.model.id == pk).values(**instance_data)
stmt = update(self.model).where(self.primary_key == pk).values(**instance_data)
result = await session.execute(stmt)
if commit:
await session.commit()
Expand Down Expand Up @@ -218,7 +230,7 @@ async def delete_model(
:param commit: If `True`, commits the transaction immediately. Default is `False`.
:return:
"""
stmt = delete(self.model).where(self.model.id == pk)
stmt = delete(self.model).where(self.primary_key == pk)
result = await session.execute(stmt)
if commit:
await session.commit()
Expand Down
7 changes: 7 additions & 0 deletions sqlalchemy_crud_plus/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,10 @@ class MultipleResultsError(SQLAlchemyCRUDPlusException):

def __init__(self, msg: str) -> None:
super().__init__(msg)


class CompositePrimaryKeysError(SQLAlchemyCRUDPlusException):
"""Error raised when a table have Composite primary keys."""

def __init__(self, msg: str) -> None:
super().__init__(msg)