From 65e947fff254476a79393463dbc4c73361e1e23d Mon Sep 17 00:00:00 2001 From: davidche Date: Wed, 30 Apr 2025 13:49:44 +0800 Subject: [PATCH 1/2] =?UTF-8?q?=E4=BC=98=E5=8C=96=E8=A1=A8=E7=9A=84?= =?UTF-8?q?=E4=B8=BB=E9=94=AE=E6=94=AF=E6=8C=81=EF=BC=8Cselect=5Fmodel?= =?UTF-8?q?=EF=BC=8Cupdate=5Fmodel,delete=5Fmodel=20=E6=96=B9=E6=B3=95?= =?UTF-8?q?=E6=94=AF=E6=8C=81=E9=9D=9Eint=E7=B1=BB=E5=9E=8B=E4=B8=BB?= =?UTF-8?q?=E9=94=AE=E5=92=8C=E5=A4=8D=E5=90=88=E4=B8=BB=E9=94=AE=E6=93=8D?= =?UTF-8?q?=E4=BD=9C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/usage/delete_model.md | 33 ++++++++++++++++++- docs/usage/select_model.md | 34 ++++++++++++++++++- docs/usage/update_model.md | 37 +++++++++++++++++++-- sqlalchemy_crud_plus/crud.py | 64 +++++++++++++++++++++++------------- 4 files changed, 141 insertions(+), 27 deletions(-) diff --git a/docs/usage/delete_model.md b/docs/usage/delete_model.md index 0cbc0cf..ffbc2c4 100644 --- a/docs/usage/delete_model.md +++ b/docs/usage/delete_model.md @@ -24,7 +24,7 @@ class CRUDIns(CRUDPlus[ModelIns]): async def delete_model( self, session: AsyncSession, - pk: int, + pk: Union[Any, Dict[str, Any]], flush: bool = False, commit: bool = False, ) -> int: @@ -44,3 +44,34 @@ async def delete_model( | Type | Description | |------|-------------| | int | 删除数量 | + + +## example + +```python +# Model with composite primary key +class UserComposite(Base): + __tablename__ = "users_composite" + id = Column(String, primary_key=True) + name = Column(String, primary_key=True) + email = Column(String) + +class UserCreate(BaseModel): + id: str + name: str | None + email: str + +async def example(session: AsyncSession): + # Composite primary key model + crud_composite = CRUDPlus(UserComposite) + + # Create + await crud_composite.create_model( + session, UserCreate(id="123", name="John", email="composite@example.com"), commit=True + ) + + + # Delete by composite primary key (dictionary) + await crud_composite.delete_model(session, {"id": "123", "name": "John"}, commit=True) + +``` \ No newline at end of file diff --git a/docs/usage/select_model.md b/docs/usage/select_model.md index 68f5ebe..5aa87e3 100644 --- a/docs/usage/select_model.md +++ b/docs/usage/select_model.md @@ -24,7 +24,7 @@ class CRUDIns(CRUDPlus[ModelIns]): async def select_model( self, session: AsyncSession, - pk: int, + pk: Union[Any, Dict[str, Any]], *whereclause: ColumnExpressionArgument[bool], ) -> Model | None: ``` @@ -42,3 +42,35 @@ async def select_model( | Type | Description | |---------------------|-------------| | `TypeVar `\|` None` | 模型实例 | + + +## example + +```python +# Model with composite primary key +class UserComposite(Base): + __tablename__ = "users_composite" + id = Column(String, primary_key=True) + name = Column(String, primary_key=True) + email = Column(String) + +class UserCreate(BaseModel): + id: str + name: str | None + email: str + +async def example(session: AsyncSession): + # Composite primary key model + crud_composite = CRUDPlus(UserComposite) + + # Create + await crud_composite.create_model( + session, UserCreate(id="123", name="John", email="composite@example.com"), commit=True + ) + + # Select by composite primary key (dictionary) + user = await crud_composite.select_model(session, {"id": "123", "name": "John"}) + print(user.email) # composite@example.com + + +``` \ No newline at end of file diff --git a/docs/usage/update_model.md b/docs/usage/update_model.md index c5a34e6..6ac26e0 100644 --- a/docs/usage/update_model.md +++ b/docs/usage/update_model.md @@ -21,7 +21,7 @@ class UpdateIns(BaseModel): class CRUDIns(CRUDPlus[ModelIns]): - async def create(self, db: AsyncSession, pk: int, obj: UpdateIns) -> int: + async def update(self, db: AsyncSession, pk: Union[Any, Dict[str, Any]], obj: UpdateIns) -> int: return await self.update_model(db, pk, obj) ``` @@ -31,7 +31,7 @@ class CRUDIns(CRUDPlus[ModelIns]): async def update_model( self, session: AsyncSession, - pk: int, + pk: Union[Any, Dict[str, Any]], obj: UpdateSchema | dict[str, Any], flush: bool = False, commit: bool = False, @@ -70,3 +70,36 @@ async def update_model( | Type | Description | |------|-------------| | int | 更新数量 | + + +## example + +```python +# Model with composite primary key +class UserComposite(Base): + __tablename__ = "users_composite" + id = Column(String, primary_key=True) + name = Column(String, primary_key=True) + email = Column(String) + +class UserCreate(BaseModel): + id: str + name: str | None + email: str + +async def example(session: AsyncSession): + # Composite primary key model + crud_composite = CRUDPlus(UserComposite) + + # Create + await crud_composite.create_model( + session, UserCreate(id="123", name="John", email="composite@example.com"), commit=True + ) + + # Update by composite primary key (dictionary) + await crud_composite.update_model( + session, {"id": "123", "name": "John"}, {"email": "updated_composite@example.com"}, commit=True + ) + + +``` \ No newline at end of file diff --git a/sqlalchemy_crud_plus/crud.py b/sqlalchemy_crud_plus/crud.py index cecf393..602f0ad 100644 --- a/sqlalchemy_crud_plus/crud.py +++ b/sqlalchemy_crud_plus/crud.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -from typing import Any, Generic, Iterable, Sequence, Type +from typing import Any, Generic, Iterable, Sequence, Type, Union, Dict from sqlalchemy import ( Column, @@ -26,16 +26,36 @@ def __init__(self, model: Type[Model]): self.model = model self.primary_key = self._get_primary_key() - def _get_primary_key(self) -> Column: + def _get_primary_keys(self) -> list[Column]: """ - Dynamically retrieve the primary key column(s) for the model. + Retrieve the primary key columns for the model. """ mapper = inspect(self.model) - primary_key = mapper.primary_key - if len(primary_key) == 1: - return primary_key[0] + return list(mapper.primary_key) + + def _validate_pk_input(self, pk: Union[Any, Dict[str, Any]]) -> Dict[str, Any]: + """ + Validate and normalize primary key input to a dictionary mapping column names to values. + + :param pk: A single value for single primary key, or a dictionary for composite primary keys. + :return: Dictionary mapping primary key column names to their values. + """ + pk_columns = [pk_col.name for pk_col in self.primary_keys] + if len(self.primary_keys) == 1: + if isinstance(pk, dict): + if pk_columns[0] not in pk: + raise ValueError(f"Primary key column '{pk_columns[0]}' missing in dictionary") + return {pk_columns[0]: pk[pk_columns[0]]} + return {pk_columns[0]: pk} else: - raise CompositePrimaryKeysError('Composite primary keys are not supported') + if not isinstance(pk, dict): + raise ValueError( + f"Composite primary keys require a dictionary with keys {pk_columns}, got {type(pk)}" + ) + missing = set(pk_columns) - set(pk.keys()) + if missing: + raise ValueError(f"Missing primary key columns: {missing}") + return {k: v for k, v in pk.items() if k in pk_columns} async def create_model( self, @@ -154,7 +174,7 @@ async def exists( async def select_model( self, session: AsyncSession, - pk: int, + pk: Union[Any, Dict[str, Any]], *whereclause: ColumnExpressionArgument[bool], ) -> Model | None: """ @@ -165,10 +185,9 @@ async def select_model( :param whereclause: The WHERE clauses to apply to the query. :return: """ - filter_list = list(whereclause) - _filters = [self.primary_key == pk] - _filters.extend(filter_list) - stmt = select(self.model).where(*_filters) + pk_dict = self._validate_pk_input(pk) + filters = [getattr(self.model, col) == val for col, val in pk_dict.items()] + list(whereclause) + stmt = select(self.model).where(*filters) query = await session.execute(stmt) return query.scalars().first() @@ -270,7 +289,7 @@ async def select_models_order( async def update_model( self, session: AsyncSession, - pk: int, + pk: Union[Any, Dict[str, Any]], obj: UpdateSchema | dict[str, Any], flush: bool = False, commit: bool = False, @@ -287,14 +306,11 @@ async def update_model( :param kwargs: Additional model data not included in the pydantic schema. :return: """ - if isinstance(obj, dict): - instance_data = obj - else: - instance_data = obj.model_dump(exclude_unset=True) - if kwargs: - instance_data.update(kwargs) - - stmt = update(self.model).where(self.primary_key == pk).values(**instance_data) + pk_dict = self._validate_pk_input(pk) + instance_data = obj if isinstance(obj, dict) else obj.model_dump(exclude_unset=True) + instance_data.update(kwargs) + filters = [getattr(self.model, col) == val for col, val in pk_dict.items()] + stmt = update(self.model).where(*filters).values(**instance_data) result = await session.execute(stmt) if flush: @@ -346,7 +362,7 @@ async def update_model_by_column( async def delete_model( self, session: AsyncSession, - pk: int, + pk: Union[Any, Dict[str, Any]], flush: bool = False, commit: bool = False, ) -> int: @@ -359,7 +375,9 @@ async def delete_model( :param commit: If `True`, commits the transaction immediately. Default is `False`. :return: """ - stmt = delete(self.model).where(self.primary_key == pk) + pk_dict = self._validate_pk_input(pk) + filters = [getattr(self.model, col) == val for col, val in pk_dict.items()] + stmt = delete(self.model).where(*filters) result = await session.execute(stmt) if flush: From d9acb401d6ccc61ee667ef41d8c3c536d68747c5 Mon Sep 17 00:00:00 2001 From: davidche Date: Wed, 30 Apr 2025 14:44:28 +0800 Subject: [PATCH 2/2] update use pk in Tuple --- docs/usage/delete_model.md | 15 ++++++-- docs/usage/select_model.md | 10 ++++-- docs/usage/update_model.md | 10 ++++-- sqlalchemy_crud_plus/crud.py | 69 ++++++++++++++++++++++++------------ 4 files changed, 73 insertions(+), 31 deletions(-) diff --git a/docs/usage/delete_model.md b/docs/usage/delete_model.md index ffbc2c4..5977ad1 100644 --- a/docs/usage/delete_model.md +++ b/docs/usage/delete_model.md @@ -63,15 +63,24 @@ class UserCreate(BaseModel): async def example(session: AsyncSession): # Composite primary key model - crud_composite = CRUDPlus(UserComposite) + crud = CRUDPlus(UserComposite) # Create - await crud_composite.create_model( + await crud.create_model( session, UserCreate(id="123", name="John", email="composite@example.com"), commit=True ) # Delete by composite primary key (dictionary) - await crud_composite.delete_model(session, {"id": "123", "name": "John"}, commit=True) + await crud.delete_model(session, {"id": "123", "name": "John"}, commit=True) + + # Create + await crud.create_model( + session, UserCreate(id="456", name="Jack", email="Jack@example.com"), commit=True + ) + + # Delete by composite primary key (tuple) + await crud.delete_model(session, ("456", "Jack"), commit=True) + ``` \ No newline at end of file diff --git a/docs/usage/select_model.md b/docs/usage/select_model.md index 5aa87e3..cceba02 100644 --- a/docs/usage/select_model.md +++ b/docs/usage/select_model.md @@ -61,16 +61,20 @@ class UserCreate(BaseModel): async def example(session: AsyncSession): # Composite primary key model - crud_composite = CRUDPlus(UserComposite) + crud = CRUDPlus(UserComposite) # Create - await crud_composite.create_model( + await crud.create_model( session, UserCreate(id="123", name="John", email="composite@example.com"), commit=True ) # Select by composite primary key (dictionary) - user = await crud_composite.select_model(session, {"id": "123", "name": "John"}) + user = await crud.select_model(session, {"id": "123", "name": "John"}) print(user.email) # composite@example.com + # Select by composite primary key (tuple) + user = await crud.select_model(session, ("123", "John")) + print(user.email) # composite@example.com + ``` \ No newline at end of file diff --git a/docs/usage/update_model.md b/docs/usage/update_model.md index 6ac26e0..1ec553f 100644 --- a/docs/usage/update_model.md +++ b/docs/usage/update_model.md @@ -89,17 +89,21 @@ class UserCreate(BaseModel): async def example(session: AsyncSession): # Composite primary key model - crud_composite = CRUDPlus(UserComposite) + crud = CRUDPlus(UserComposite) # Create - await crud_composite.create_model( + await crud.create_model( session, UserCreate(id="123", name="John", email="composite@example.com"), commit=True ) # Update by composite primary key (dictionary) - await crud_composite.update_model( + await crud.update_model( session, {"id": "123", "name": "John"}, {"email": "updated_composite@example.com"}, commit=True ) + # Update by composite primary key (tuple) + await crud.update_model( + session, ("123", "John"), {"email": "new_tuple@example.com"}, commit=True + ) ``` \ No newline at end of file diff --git a/sqlalchemy_crud_plus/crud.py b/sqlalchemy_crud_plus/crud.py index 602f0ad..57f8af6 100644 --- a/sqlalchemy_crud_plus/crud.py +++ b/sqlalchemy_crud_plus/crud.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -from typing import Any, Generic, Iterable, Sequence, Type, Union, Dict +from typing import Any, Generic, Iterable, Sequence, Type, Union, Dict, Tuple from sqlalchemy import ( Column, @@ -25,6 +25,7 @@ class CRUDPlus(Generic[Model]): def __init__(self, model: Type[Model]): self.model = model self.primary_key = self._get_primary_key() + self._pk_column_names = [pk_col.name for pk_col in self.primary_keys] # Cache column names def _get_primary_keys(self) -> list[Column]: """ @@ -33,29 +34,47 @@ def _get_primary_keys(self) -> list[Column]: mapper = inspect(self.model) return list(mapper.primary_key) - def _validate_pk_input(self, pk: Union[Any, Dict[str, Any]]) -> Dict[str, Any]: + @property + def primary_key_columns(self) -> list[str]: + """ + Return the names of the primary key columns in order. + """ + return self._pk_column_names + + def _validate_pk_input(self, pk: Union[Any, Dict[str, Any], Tuple[Any, ...]]) -> Dict[str, Any]: """ Validate and normalize primary key input to a dictionary mapping column names to values. - :param pk: A single value for single primary key, or a dictionary for composite primary keys. + :param pk: A single value for single primary key, a dictionary, or a tuple for composite primary keys. :return: Dictionary mapping primary key column names to their values. + :raises ValueError: If the input format is invalid or missing required primary key columns. """ - pk_columns = [pk_col.name for pk_col in self.primary_keys] if len(self.primary_keys) == 1: + pk_col = self._pk_column_names[0] if isinstance(pk, dict): - if pk_columns[0] not in pk: - raise ValueError(f"Primary key column '{pk_columns[0]}' missing in dictionary") - return {pk_columns[0]: pk[pk_columns[0]]} - return {pk_columns[0]: pk} + if pk_col not in pk: + raise ValueError(f"Primary key column '{pk_col}' missing in dictionary") + return {pk_col: pk[pk_col]} + return {pk_col: pk} else: - if not isinstance(pk, dict): - raise ValueError( - f"Composite primary keys require a dictionary with keys {pk_columns}, got {type(pk)}" - ) - missing = set(pk_columns) - set(pk.keys()) - if missing: - raise ValueError(f"Missing primary key columns: {missing}") - return {k: v for k, v in pk.items() if k in pk_columns} + if isinstance(pk, dict): + missing = set(self._pk_column_names) - set(pk.keys()) + if missing: + raise ValueError( + f"Missing primary key columns: {missing}. Expected keys: {self._pk_column_names}" + ) + return {k: v for k, v in pk.items() if k in self._pk_column_names} + elif isinstance(pk, tuple): + if len(pk) != len(self.primary_keys): + raise ValueError( + f"Expected {len(self.primary_keys)} primary key values, got {len(pk)}. " + f"Expected columns: {self._pk_column_names}" + ) + return dict(zip(self._pk_column_names, pk)) + raise ValueError( + f"Composite primary keys require a dictionary or tuple with keys/values for {self._pk_column_names}, " + f"got {type(pk)}" + ) async def create_model( self, @@ -174,14 +193,16 @@ async def exists( async def select_model( self, session: AsyncSession, - pk: Union[Any, Dict[str, Any]], + pk: Union[Any, Dict[str, Any], Tuple[Any, ...]], *whereclause: ColumnExpressionArgument[bool], ) -> Model | None: """ Query by ID :param session: The SQLAlchemy async session. - :param pk: The database primary key value. + :param pk: A single value for a single primary key (e.g., int, str), a dictionary + mapping column names to values, or a tuple of values (in column order) for + composite primary keys. :param whereclause: The WHERE clauses to apply to the query. :return: """ @@ -289,7 +310,7 @@ async def select_models_order( async def update_model( self, session: AsyncSession, - pk: Union[Any, Dict[str, Any]], + pk: Union[Any, Dict[str, Any], Tuple[Any, ...]], obj: UpdateSchema | dict[str, Any], flush: bool = False, commit: bool = False, @@ -299,7 +320,9 @@ async def update_model( Update an instance by model's primary key :param session: The SQLAlchemy async session. - :param pk: The database primary key value. + :param pk: A single value for a single primary key (e.g., int, str), a dictionary + mapping column names to values, or a tuple of values (in column order) for + composite primary keys. :param obj: A pydantic schema or dictionary containing the update data :param flush: If `True`, flush all object changes to the database. Default is `False`. :param commit: If `True`, commits the transaction immediately. Default is `False`. @@ -362,7 +385,7 @@ async def update_model_by_column( async def delete_model( self, session: AsyncSession, - pk: Union[Any, Dict[str, Any]], + pk: Union[Any, Dict[str, Any], Tuple[Any, ...]], flush: bool = False, commit: bool = False, ) -> int: @@ -370,7 +393,9 @@ async def delete_model( Delete an instance by model's primary key :param session: The SQLAlchemy async session. - :param pk: The database primary key value. + :param pk: A single value for a single primary key (e.g., int, str), a dictionary + mapping column names to values, or a tuple of values (in column order) for + composite primary keys. :param flush: If `True`, flush all object changes to the database. Default is `False`. :param commit: If `True`, commits the transaction immediately. Default is `False`. :return: