diff --git a/README.md b/README.md index 2f8fa80..b3f9061 100644 --- a/README.md +++ b/README.md @@ -9,11 +9,11 @@ │   ├── deps.py │   └── v1 │   ├── __init__.py +│   ├── auth +│   │   ├── __init__.py +│   │   └── token.py │   └── users │   ├── __init__.py -│   ├── auth -│   │   ├── __init__.py -│   │   └── token.py │   ├── create.py │   └── retrieve.py ├── core @@ -23,27 +23,36 @@ │   └── settings.py ├── logic │   ├── __init__.py +│   ├── auth +│   │   ├── __init__.py +│   │   └── auth.py +│   ├── logic.py │   ├── security │   │   ├── __init__.py │   │   ├── jwt.py -│   │   └── pwd.py +│   │   ├── pwd.py +│   │   └── security.py │   └── users │   ├── __init__.py -│   ├── auth -│   │   ├── __init__.py -│   │   └── auth.py │   └── users.py ├── models │   ├── __init__.py +│   ├── auth +│   │   ├── __init__.py +│   │   └── token.py │   ├── base.py -│   ├── token.py -│   └── user.py +│   ├── types +│   │   ├── __init__.py +│   │   └── unix.py +│   └── users +│   ├── __init__.py +│   └── user.py └── repositories ├── __init__.py - ├── abstract.py + ├── base.py └── user.py -11 directories, 28 files +14 directories, 34 files ``` ## Create a `.env` file based on `.env.dist` and make all the necessary customizations diff --git a/app/api/deps.py b/app/api/deps.py index ed48ef5..e17b722 100644 --- a/app/api/deps.py +++ b/app/api/deps.py @@ -8,19 +8,20 @@ from fastapi.security import APIKeyHeader from app.logic import Logic as _Logic -from app.models.user import User as _User +from app.models.users.user import User as _User async def get_logic() -> _Logic: - return await _Logic.create() + async with Logic.create() as logic: + yield logic Logic = Annotated[_Logic, Depends(get_logic)] async def get_user( - token: Annotated[str, Depends(APIKeyHeader(name='access-token'))], - logic: Logic, + token: Annotated[str, Depends(APIKeyHeader(name='access-token'))], + logic: Logic, ) -> _User | None: return await logic.users.retrieve_by_token(token) diff --git a/app/api/v1/__init__.py b/app/api/v1/__init__.py index df18962..495ed65 100644 --- a/app/api/v1/__init__.py +++ b/app/api/v1/__init__.py @@ -6,11 +6,12 @@ from fastapi import APIRouter -from . import users +from . import auth, users FOLDER_NAME = f'{Path(__file__).parent.name}' router = APIRouter(prefix=f'/{FOLDER_NAME}', tags=[FOLDER_NAME]) +router.include_router(auth.router) router.include_router(users.router) __all__ = ['router'] diff --git a/app/api/v1/users/auth/__init__.py b/app/api/v1/auth/__init__.py similarity index 100% rename from app/api/v1/users/auth/__init__.py rename to app/api/v1/auth/__init__.py diff --git a/app/api/v1/users/auth/token.py b/app/api/v1/auth/token.py similarity index 63% rename from app/api/v1/users/auth/token.py rename to app/api/v1/auth/token.py index f45a003..e7074d2 100644 --- a/app/api/v1/users/auth/token.py +++ b/app/api/v1/auth/token.py @@ -1,8 +1,8 @@ from fastapi import APIRouter from app.api import deps -from app.models.token import AccessToken -from app.models.user import UserCreate +from app.models.auth import AccessToken +from app.models.users.user import UserCreate router = APIRouter(prefix='/token') @@ -12,7 +12,7 @@ async def token(data: UserCreate, logic: deps.Logic): """ Retrieve new access token """ - return await logic.users.auth.generate_token(**data.model_dump()) + return await logic.auth.generate_token(**data.model_dump()) __all__ = ['router'] diff --git a/app/api/v1/users/__init__.py b/app/api/v1/users/__init__.py index 81404e3..648044d 100644 --- a/app/api/v1/users/__init__.py +++ b/app/api/v1/users/__init__.py @@ -4,10 +4,9 @@ from fastapi import APIRouter -from . import auth, create, retrieve +from . import create, retrieve router = APIRouter(prefix='/users', tags=['users']) -router.include_router(auth.router) router.include_router(create.router) router.include_router(retrieve.router) diff --git a/app/api/v1/users/create.py b/app/api/v1/users/create.py index 306a238..b3f1c4d 100644 --- a/app/api/v1/users/create.py +++ b/app/api/v1/users/create.py @@ -1,7 +1,7 @@ from fastapi import APIRouter from app.api import deps -from app.models.user import UserCreate, UserRead +from app.models.users.user import UserCreate, UserRead router = APIRouter(prefix='/create') @@ -11,7 +11,7 @@ async def create(data: UserCreate, logic: deps.Logic): """ Create user """ - return await logic.users.create(**data.model_dump()) + return await logic.users.create(data) __all__ = ['router'] diff --git a/app/api/v1/users/retrieve.py b/app/api/v1/users/retrieve.py index b03c78e..832a4a7 100644 --- a/app/api/v1/users/retrieve.py +++ b/app/api/v1/users/retrieve.py @@ -5,7 +5,7 @@ from fastapi import APIRouter from app.api import deps -from app.models.user import UserRead +from app.models.users.user import UserRead router = APIRouter() diff --git a/app/core/db.py b/app/core/db.py index 8b48863..c316278 100644 --- a/app/core/db.py +++ b/app/core/db.py @@ -2,7 +2,7 @@ Database """ -from typing import NoReturn, Self +from typing import Self from sqlalchemy.ext.asyncio import (AsyncEngine, async_sessionmaker, create_async_engine) @@ -15,40 +15,40 @@ class Database: _instance = None - def __new__(cls, *args, **kwargs) -> Self: + def __new__(cls, *args, **kwargs) -> 'Database': if cls._instance is None: cls._instance = super(Database, cls).__new__(cls) return cls._instance def __init__( - self, - engine: AsyncEngine | None = None, - session: AsyncSession | None = None, + self, + engine: AsyncEngine | None = None, + session: AsyncSession | None = None, ) -> None: if not hasattr(self, 'initialized'): - self.engine = engine - self.session = session + self.__engine = engine + self.__session = session self.initialized = True - async def __set_async_engine(self) -> NoReturn: - if self.engine is None: - self.engine = create_async_engine( + async def __set_async_engine(self) -> None: + if self.__engine is None: + self.__engine = create_async_engine( settings.pg_dsn.unicode_string(), echo=False, future=True ) - async def __set_async_session(self) -> NoReturn: - if self.session is None: - self.session = async_sessionmaker( + async def __set_async_session(self) -> None: + if self.__session is None: + self.__session = async_sessionmaker( autocommit=False, autoflush=False, - bind=self.engine, + bind=self.__engine, class_=AsyncSession, expire_on_commit=False, )() - async def __set_repositories(self) -> NoReturn: - if self.session is not None: - self.user = repos.UserRepo(session=self.session) + async def __set_repositories(self) -> None: + if self.__session is not None: + self.user = repos.UserRepo(session=self.__session) async def __aenter__(self) -> Self: await self.__set_async_engine() @@ -56,6 +56,8 @@ async def __aenter__(self) -> Self: await self.__set_repositories() return self - async def __aexit__(self, exc_type, exc_value, traceback) -> NoReturn: - if self.session is not None: - await self.session.close() + async def __aexit__(self, exc_type, exc_value, traceback) -> None: + if self.__session is not None: + await self.__session.commit() + await self.__session.close() + self.__session = None diff --git a/app/logic/__init__.py b/app/logic/__init__.py index 270ee61..e543ca9 100644 --- a/app/logic/__init__.py +++ b/app/logic/__init__.py @@ -1,21 +1,3 @@ -from typing import Self - -from app.core.db import Database - -from .security import Security -from .users import Users - - -class Logic: - def __init__(self, db: Database): - self.db = db - self.security = Security() - self.users = Users(self) - - @classmethod - async def create(cls) -> Self: - async with Database() as db: - return cls(db) - +from .logic import Logic __all__ = ['Logic'] diff --git a/app/logic/users/auth/__init__.py b/app/logic/auth/__init__.py similarity index 100% rename from app/logic/users/auth/__init__.py rename to app/logic/auth/__init__.py diff --git a/app/logic/auth/auth.py b/app/logic/auth/auth.py new file mode 100644 index 0000000..aa45007 --- /dev/null +++ b/app/logic/auth/auth.py @@ -0,0 +1,24 @@ +from typing import TYPE_CHECKING + +from app.core import exps +from app.models.auth import AccessToken + +if TYPE_CHECKING: + from app.logic import Logic + + +class Auth: + def __init__(self, logic: 'Logic'): + self.logic = logic + + async def generate_token( + self, email: str, password: str + ) -> AccessToken | None: + if (user := await self.logic.db.user.retrieve_by_email(email)) is None: + raise exps.UserNotFoundException() + if not self.logic.security.pwd.checkpwd(password, user.password): + raise exps.UserIsCorrectException() + access_token = self.logic.security.jwt.encode_token( + {'id': user.id}, 1440 + ) + return AccessToken(token=access_token) diff --git a/app/logic/logic.py b/app/logic/logic.py new file mode 100644 index 0000000..fb87cd6 --- /dev/null +++ b/app/logic/logic.py @@ -0,0 +1,22 @@ +from typing import Self, AsyncGenerator +from contextlib import asynccontextmanager + +from app.core.db import Database + +from .security import Security +from .users import Users +from .auth import Auth + + +class Logic: + def __init__(self, db: Database): + self.db = db + self.security = Security() + self.users = Users(self) + self.auth = Auth(self) + + @classmethod + @asynccontextmanager + async def create(cls) -> AsyncGenerator[Self, None]: + async with Database() as db: + yield cls(db) diff --git a/app/logic/security/__init__.py b/app/logic/security/__init__.py index 16915f4..18a60f4 100644 --- a/app/logic/security/__init__.py +++ b/app/logic/security/__init__.py @@ -1,13 +1,3 @@ -from app.core.settings import settings - -from .jwt import JWT -from .pwd import PWD - - -class Security: - def __init__(self): - self.jwt = JWT(settings.APP_SECRET_KEY) - self.pwd = PWD() - +from .security import Security __all__ = ['Security'] diff --git a/app/logic/security/jwt.py b/app/logic/security/jwt.py index 1f34e05..e7a9c8f 100644 --- a/app/logic/security/jwt.py +++ b/app/logic/security/jwt.py @@ -9,7 +9,7 @@ class JWT: def __init__(self, secret_key: str): self.secret_key: str = secret_key - def decode_token(self, token: str) -> dict | None: + def decode_token(self, token: str) -> dict: try: payload = jwt.decode(token, self.secret_key, algorithms=['HS256']) except Exception: @@ -18,7 +18,9 @@ def decode_token(self, token: str) -> dict | None: exp = payload.get('exp') if exp and dt.datetime.now(dt.UTC).timestamp() > exp: raise exps.TokenExpiredException() - return payload.get('payload') + if (payload := payload.get('payload', None)) is None: + raise exps.TokenInvalidException() + return payload def encode_token(self, payload: dict, minutes: int) -> str: claims = { diff --git a/app/logic/security/pwd.py b/app/logic/security/pwd.py index 3e7762a..2b0665a 100644 --- a/app/logic/security/pwd.py +++ b/app/logic/security/pwd.py @@ -2,8 +2,10 @@ class PWD: - def hashpwd(self, password: str) -> str: + @staticmethod + def hashpwd(password: str) -> str: return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode() - def checkpwd(self, password: str, hashed_password: str) -> bool: + @staticmethod + def checkpwd(password: str, hashed_password: str) -> bool: return bcrypt.checkpw(password.encode(), hashed_password.encode()) diff --git a/app/logic/security/security.py b/app/logic/security/security.py new file mode 100644 index 0000000..33bc2bb --- /dev/null +++ b/app/logic/security/security.py @@ -0,0 +1,10 @@ +from app.core.settings import settings + +from .jwt import JWT +from .pwd import PWD + + +class Security: + def __init__(self): + self.jwt = JWT(settings.APP_SECRET_KEY) + self.pwd = PWD() diff --git a/app/logic/users/auth/auth.py b/app/logic/users/auth/auth.py deleted file mode 100644 index 264440b..0000000 --- a/app/logic/users/auth/auth.py +++ /dev/null @@ -1,24 +0,0 @@ -from typing import TYPE_CHECKING - -from app.core import exps -from app.models.token import AccessToken - -if TYPE_CHECKING: - from app.logic import Logic - - -class Auth: - def __init__(self, logic: 'Logic'): - self.logic = logic - - async def generate_token( - self, email: str, password: str - ) -> AccessToken | None: - if user := await self.logic.db.user.retrieve_by_email(email): - if not self.logic.security.pwd.checkpwd(password, user.password): - raise exps.UserIsCorrectException() - access_token = self.logic.security.jwt.encode_token( - {'id': user.id}, 1440 - ) - return AccessToken(token=access_token) - raise exps.UserNotFoundException() diff --git a/app/logic/users/users.py b/app/logic/users/users.py index bef1447..16073dc 100644 --- a/app/logic/users/users.py +++ b/app/logic/users/users.py @@ -1,9 +1,7 @@ from typing import TYPE_CHECKING from app.core import exps -from app.models.user import User - -from .auth import Auth +from app.models.users.user import User, UserCreate if TYPE_CHECKING: from app.logic import Logic @@ -12,24 +10,21 @@ class Users: def __init__(self, logic: 'Logic'): self.logic = logic - self.auth = Auth(self.logic) - async def create(self, email: str, password: str) -> User | None: - if await self.logic.db.user.retrieve_by_email(email): + async def create(self, model: UserCreate) -> User | None: + if await self.logic.db.user.retrieve_by_email(model.email): raise exps.UserExistsException() - password_hash = self.logic.security.pwd.hashpwd(password) - model = User(email=email, password=password_hash) + model.password = self.logic.security.pwd.hashpwd(model.password) user = await self.logic.db.user.create(model) return user async def retrieve_by_token(self, token: str) -> User | None: - if payload := self.logic.security.jwt.decode_token(token): - if not ( - user := await self.logic.db.user.retrieve_one( - ident=payload.get('id') - ) - ): - raise exps.UserNotFoundException() - else: - return user + payload = self.logic.security.jwt.decode_token(token) + if not ( + user := await self.logic.db.user.retrieve_one( + ident=payload.get('id') + ) + ): + raise exps.UserNotFoundException() + return user diff --git a/app/models/__init__.py b/app/models/__init__.py index fa250a3..b564529 100644 --- a/app/models/__init__.py +++ b/app/models/__init__.py @@ -2,6 +2,6 @@ Database Models Module """ -from . import token, user +from . import users -__all__ = ['user', 'token'] +__all__ = ['users',] diff --git a/app/models/auth/__init__.py b/app/models/auth/__init__.py new file mode 100644 index 0000000..aaa9306 --- /dev/null +++ b/app/models/auth/__init__.py @@ -0,0 +1,3 @@ +from .token import AccessToken + +__all__ = ['AccessToken'] diff --git a/app/models/token.py b/app/models/auth/token.py similarity index 100% rename from app/models/token.py rename to app/models/auth/token.py diff --git a/app/models/base.py b/app/models/base.py index e6bf683..66b7e52 100644 --- a/app/models/base.py +++ b/app/models/base.py @@ -2,34 +2,11 @@ import uuid from functools import partial -from sqlalchemy.types import BigInteger, TypeDecorator from sqlmodel import Field, SQLModel -datetime_utc_now = partial(dt.datetime.now, tz=dt.UTC) - - -class UnixType(TypeDecorator): - impl = BigInteger - cache_ok = True +from .types import UnixType - def load_dialect_impl(self, dialect): - return dialect.type_descriptor(BigInteger()) - - def process_bind_param( - self, value: dt.date | dt.datetime | str | None, dialect - ) -> int | None: - if isinstance(value, dt.datetime): - return int(value.timestamp()) - elif isinstance(value, dt.date): - return int(dt.datetime.combine(value, dt.time.min).timestamp()) - elif isinstance(value, str): - return int(dt.datetime.fromisoformat(value).timestamp()) - - def process_result_value( - self, value: int | None, dialect - ) -> dt.datetime | None: - if isinstance(value, int): - return dt.datetime.fromtimestamp(value, dt.UTC) +datetime_utc_now = partial(dt.datetime.now, tz=dt.UTC) class IDModel(SQLModel): diff --git a/app/models/types/__init__.py b/app/models/types/__init__.py new file mode 100644 index 0000000..2554b30 --- /dev/null +++ b/app/models/types/__init__.py @@ -0,0 +1,5 @@ +from .unix import UnixType + +__all__ = [ + 'UnixType' +] diff --git a/app/models/types/unix.py b/app/models/types/unix.py new file mode 100644 index 0000000..38ea912 --- /dev/null +++ b/app/models/types/unix.py @@ -0,0 +1,27 @@ +import datetime as dt + +from sqlalchemy.types import BigInteger, TypeDecorator + + +class UnixType(TypeDecorator): + impl = BigInteger + cache_ok = True + + def load_dialect_impl(self, dialect): + return dialect.type_descriptor(BigInteger()) + + def process_bind_param( + self, value: dt.date | dt.datetime | str | None, dialect + ) -> int | None: + if isinstance(value, dt.datetime): + return int(value.timestamp()) + elif isinstance(value, dt.date): + return int(dt.datetime.combine(value, dt.time.min).timestamp()) + elif isinstance(value, str): + return int(dt.datetime.fromisoformat(value).timestamp()) + + def process_result_value( + self, value: int | None, dialect + ) -> dt.datetime | None: + if isinstance(value, int): + return dt.datetime.fromtimestamp(value, dt.UTC) diff --git a/app/models/users/__init__.py b/app/models/users/__init__.py new file mode 100644 index 0000000..bd7ce04 --- /dev/null +++ b/app/models/users/__init__.py @@ -0,0 +1,4 @@ +from . import user +from .. import auth + +__all__ = ['auth', 'user'] diff --git a/app/models/user.py b/app/models/users/user.py similarity index 88% rename from app/models/user.py rename to app/models/users/user.py index cbcb7e7..b95a8bd 100644 --- a/app/models/user.py +++ b/app/models/users/user.py @@ -4,7 +4,7 @@ from sqlmodel import Field, SQLModel -from .base import IDModel +from app.models.base import IDModel class UserBase(SQLModel): diff --git a/app/repositories/abstract.py b/app/repositories/base.py similarity index 77% rename from app/repositories/abstract.py rename to app/repositories/base.py index 29c6ee9..597996c 100644 --- a/app/repositories/abstract.py +++ b/app/repositories/base.py @@ -1,7 +1,15 @@ import abc -from typing import (Any, Generic, List, NoReturn, Optional, Sequence, Type, - TypeAlias, TypeVar) - +from typing import ( + Any, + Generic, + List, + NoReturn, + Optional, + Sequence, + Type, + TypeAlias, + TypeVar +) import sqlmodel as sm from sqlmodel.ext.asyncio.session import AsyncSession @@ -17,15 +25,14 @@ def __init__(self, model: Type[AbstractModel], session: AsyncSession): async def create(self, model: AbstractModel) -> AbstractModel: model = self.model.model_validate(model) self.session.add(model) - await self.session.commit() - await self.session.refresh(model) + await self.session.flush() return model async def retrieve_one( - self, - *, - ident: Optional[int] = None, - where_clauses: WhereClauses = None, + self, + *, + ident: Optional[int] = None, + where_clauses: WhereClauses = None, ) -> Optional[AbstractModel]: if ident is not None: return await self.session.get(self.model, ident) @@ -36,10 +43,10 @@ async def retrieve_one( return entity.first() async def retrieve_many( - self, - where_clauses: WhereClauses = None, - limit: Optional[int] = None, - order_by: Optional[Any] = None, + self, + where_clauses: WhereClauses = None, + limit: Optional[int] = None, + order_by: Optional[Any] = None, ) -> Optional[Sequence[AbstractModel]]: stmt = sm.select(self.model) if where_clauses is not None: diff --git a/app/repositories/user.py b/app/repositories/user.py index e0cdc5e..7c183cf 100644 --- a/app/repositories/user.py +++ b/app/repositories/user.py @@ -5,9 +5,9 @@ from sqlmodel.ext.asyncio.session import AsyncSession -from app.models.user import User +from app.models.users.user import User -from .abstract import Repository +from .base import Repository class UserRepo(Repository[User]): diff --git a/poetry.lock b/poetry.lock index 2978ac9..c9f24ae 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1669,4 +1669,4 @@ files = [ [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "47064c924c901a9b719c463eb50a21c3df4e49112eb8b5c86559a7088fa14a0c" +content-hash = "5e7e53277bed76535a8fc9dbfa1bb129b159b40357f6380f87dc815dff535753" diff --git a/pyproject.toml b/pyproject.toml index 40a59cc..ac877c3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,6 +16,7 @@ bcrypt = "^4.1.3" fastapi = {extras = ["all"], version = "^0.111.0"} sqlmodel = "^0.0.18" pyjwt = "^2.8.0" +greenlet = "^3.0.3" [tool.poetry.group.dev] optional = true