diff --git a/docs/usage/delete_model.md b/docs/usage/delete_model.md index 0cbc0cf..5977ad1 100644 --- a/docs/usage/delete_model.md +++ b/docs/usage/delete_model.md @@ -24,7 +24,7 @@ class CRUDIns(CRUDPlus[ModelIns]): async def delete_model( self, session: AsyncSession, - pk: int, + pk: Union[Any, Dict[str, Any]], flush: bool = False, commit: bool = False, ) -> int: @@ -44,3 +44,43 @@ async def delete_model( | Type | Description | |------|-------------| | int | 删除数量 | + + +## example + +```python +# Model with composite primary key +class UserComposite(Base): + __tablename__ = "users_composite" + id = Column(String, primary_key=True) + name = Column(String, primary_key=True) + email = Column(String) + +class UserCreate(BaseModel): + id: str + name: str | None + email: str + +async def example(session: AsyncSession): + # Composite primary key model + crud = CRUDPlus(UserComposite) + + # Create + await crud.create_model( + session, UserCreate(id="123", name="John", email="composite@example.com"), commit=True + ) + + + # Delete by composite primary key (dictionary) + await crud.delete_model(session, {"id": "123", "name": "John"}, commit=True) + + # Create + await crud.create_model( + session, UserCreate(id="456", name="Jack", email="Jack@example.com"), commit=True + ) + + # Delete by composite primary key (tuple) + await crud.delete_model(session, ("456", "Jack"), commit=True) + + +``` \ No newline at end of file diff --git a/docs/usage/select_model.md b/docs/usage/select_model.md index 68f5ebe..cceba02 100644 --- a/docs/usage/select_model.md +++ b/docs/usage/select_model.md @@ -24,7 +24,7 @@ class CRUDIns(CRUDPlus[ModelIns]): async def select_model( self, session: AsyncSession, - pk: int, + pk: Union[Any, Dict[str, Any]], *whereclause: ColumnExpressionArgument[bool], ) -> Model | None: ``` @@ -42,3 +42,39 @@ async def select_model( | Type | Description | |---------------------|-------------| | `TypeVar `\|` None` | 模型实例 | + + +## example + +```python +# Model with composite primary key +class UserComposite(Base): + __tablename__ = "users_composite" + id = Column(String, primary_key=True) + name = Column(String, primary_key=True) + email = Column(String) + +class UserCreate(BaseModel): + id: str + name: str | None + email: str + +async def example(session: AsyncSession): + # Composite primary key model + crud = CRUDPlus(UserComposite) + + # Create + await crud.create_model( + session, UserCreate(id="123", name="John", email="composite@example.com"), commit=True + ) + + # Select by composite primary key (dictionary) + user = await crud.select_model(session, {"id": "123", "name": "John"}) + print(user.email) # composite@example.com + + # Select by composite primary key (tuple) + user = await crud.select_model(session, ("123", "John")) + print(user.email) # composite@example.com + + +``` \ No newline at end of file diff --git a/docs/usage/update_model.md b/docs/usage/update_model.md index c5a34e6..1ec553f 100644 --- a/docs/usage/update_model.md +++ b/docs/usage/update_model.md @@ -21,7 +21,7 @@ class UpdateIns(BaseModel): class CRUDIns(CRUDPlus[ModelIns]): - async def create(self, db: AsyncSession, pk: int, obj: UpdateIns) -> int: + async def update(self, db: AsyncSession, pk: Union[Any, Dict[str, Any]], obj: UpdateIns) -> int: return await self.update_model(db, pk, obj) ``` @@ -31,7 +31,7 @@ class CRUDIns(CRUDPlus[ModelIns]): async def update_model( self, session: AsyncSession, - pk: int, + pk: Union[Any, Dict[str, Any]], obj: UpdateSchema | dict[str, Any], flush: bool = False, commit: bool = False, @@ -70,3 +70,40 @@ async def update_model( | Type | Description | |------|-------------| | int | 更新数量 | + + +## example + +```python +# Model with composite primary key +class UserComposite(Base): + __tablename__ = "users_composite" + id = Column(String, primary_key=True) + name = Column(String, primary_key=True) + email = Column(String) + +class UserCreate(BaseModel): + id: str + name: str | None + email: str + +async def example(session: AsyncSession): + # Composite primary key model + crud = CRUDPlus(UserComposite) + + # Create + await crud.create_model( + session, UserCreate(id="123", name="John", email="composite@example.com"), commit=True + ) + + # Update by composite primary key (dictionary) + await crud.update_model( + session, {"id": "123", "name": "John"}, {"email": "updated_composite@example.com"}, commit=True + ) + + # Update by composite primary key (tuple) + await crud.update_model( + session, ("123", "John"), {"email": "new_tuple@example.com"}, commit=True + ) + +``` \ No newline at end of file diff --git a/sqlalchemy_crud_plus/crud.py b/sqlalchemy_crud_plus/crud.py index cecf393..57f8af6 100644 --- a/sqlalchemy_crud_plus/crud.py +++ b/sqlalchemy_crud_plus/crud.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -from typing import Any, Generic, Iterable, Sequence, Type +from typing import Any, Generic, Iterable, Sequence, Type, Union, Dict, Tuple from sqlalchemy import ( Column, @@ -25,17 +25,56 @@ class CRUDPlus(Generic[Model]): def __init__(self, model: Type[Model]): self.model = model self.primary_key = self._get_primary_key() + self._pk_column_names = [pk_col.name for pk_col in self.primary_keys] # Cache column names - def _get_primary_key(self) -> Column: + def _get_primary_keys(self) -> list[Column]: """ - Dynamically retrieve the primary key column(s) for the model. + Retrieve the primary key columns for the model. """ mapper = inspect(self.model) - primary_key = mapper.primary_key - if len(primary_key) == 1: - return primary_key[0] + return list(mapper.primary_key) + + @property + def primary_key_columns(self) -> list[str]: + """ + Return the names of the primary key columns in order. + """ + return self._pk_column_names + + def _validate_pk_input(self, pk: Union[Any, Dict[str, Any], Tuple[Any, ...]]) -> Dict[str, Any]: + """ + Validate and normalize primary key input to a dictionary mapping column names to values. + + :param pk: A single value for single primary key, a dictionary, or a tuple for composite primary keys. + :return: Dictionary mapping primary key column names to their values. + :raises ValueError: If the input format is invalid or missing required primary key columns. + """ + if len(self.primary_keys) == 1: + pk_col = self._pk_column_names[0] + if isinstance(pk, dict): + if pk_col not in pk: + raise ValueError(f"Primary key column '{pk_col}' missing in dictionary") + return {pk_col: pk[pk_col]} + return {pk_col: pk} else: - raise CompositePrimaryKeysError('Composite primary keys are not supported') + if isinstance(pk, dict): + missing = set(self._pk_column_names) - set(pk.keys()) + if missing: + raise ValueError( + f"Missing primary key columns: {missing}. Expected keys: {self._pk_column_names}" + ) + return {k: v for k, v in pk.items() if k in self._pk_column_names} + elif isinstance(pk, tuple): + if len(pk) != len(self.primary_keys): + raise ValueError( + f"Expected {len(self.primary_keys)} primary key values, got {len(pk)}. " + f"Expected columns: {self._pk_column_names}" + ) + return dict(zip(self._pk_column_names, pk)) + raise ValueError( + f"Composite primary keys require a dictionary or tuple with keys/values for {self._pk_column_names}, " + f"got {type(pk)}" + ) async def create_model( self, @@ -154,21 +193,22 @@ async def exists( async def select_model( self, session: AsyncSession, - pk: int, + pk: Union[Any, Dict[str, Any], Tuple[Any, ...]], *whereclause: ColumnExpressionArgument[bool], ) -> Model | None: """ Query by ID :param session: The SQLAlchemy async session. - :param pk: The database primary key value. + :param pk: A single value for a single primary key (e.g., int, str), a dictionary + mapping column names to values, or a tuple of values (in column order) for + composite primary keys. :param whereclause: The WHERE clauses to apply to the query. :return: """ - filter_list = list(whereclause) - _filters = [self.primary_key == pk] - _filters.extend(filter_list) - stmt = select(self.model).where(*_filters) + pk_dict = self._validate_pk_input(pk) + filters = [getattr(self.model, col) == val for col, val in pk_dict.items()] + list(whereclause) + stmt = select(self.model).where(*filters) query = await session.execute(stmt) return query.scalars().first() @@ -270,7 +310,7 @@ async def select_models_order( async def update_model( self, session: AsyncSession, - pk: int, + pk: Union[Any, Dict[str, Any], Tuple[Any, ...]], obj: UpdateSchema | dict[str, Any], flush: bool = False, commit: bool = False, @@ -280,21 +320,20 @@ async def update_model( Update an instance by model's primary key :param session: The SQLAlchemy async session. - :param pk: The database primary key value. + :param pk: A single value for a single primary key (e.g., int, str), a dictionary + mapping column names to values, or a tuple of values (in column order) for + composite primary keys. :param obj: A pydantic schema or dictionary containing the update data :param flush: If `True`, flush all object changes to the database. Default is `False`. :param commit: If `True`, commits the transaction immediately. Default is `False`. :param kwargs: Additional model data not included in the pydantic schema. :return: """ - if isinstance(obj, dict): - instance_data = obj - else: - instance_data = obj.model_dump(exclude_unset=True) - if kwargs: - instance_data.update(kwargs) - - stmt = update(self.model).where(self.primary_key == pk).values(**instance_data) + pk_dict = self._validate_pk_input(pk) + instance_data = obj if isinstance(obj, dict) else obj.model_dump(exclude_unset=True) + instance_data.update(kwargs) + filters = [getattr(self.model, col) == val for col, val in pk_dict.items()] + stmt = update(self.model).where(*filters).values(**instance_data) result = await session.execute(stmt) if flush: @@ -346,7 +385,7 @@ async def update_model_by_column( async def delete_model( self, session: AsyncSession, - pk: int, + pk: Union[Any, Dict[str, Any], Tuple[Any, ...]], flush: bool = False, commit: bool = False, ) -> int: @@ -354,12 +393,16 @@ async def delete_model( Delete an instance by model's primary key :param session: The SQLAlchemy async session. - :param pk: The database primary key value. + :param pk: A single value for a single primary key (e.g., int, str), a dictionary + mapping column names to values, or a tuple of values (in column order) for + composite primary keys. :param flush: If `True`, flush all object changes to the database. Default is `False`. :param commit: If `True`, commits the transaction immediately. Default is `False`. :return: """ - stmt = delete(self.model).where(self.primary_key == pk) + pk_dict = self._validate_pk_input(pk) + filters = [getattr(self.model, col) == val for col, val in pk_dict.items()] + stmt = delete(self.model).where(*filters) result = await session.execute(stmt) if flush: