diff --git a/app/database/database.py b/app/database/database.py index c0544c0c..b312ce99 100644 --- a/app/database/database.py +++ b/app/database/database.py @@ -2,7 +2,7 @@ from sqlalchemy import create_engine from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm import Session, sessionmaker from app import config @@ -25,7 +25,7 @@ def create_env_engine(psql_environment, sqlalchemy_database_url): Base = declarative_base() -def get_db(): +def get_db() -> Session: db = SessionLocal() try: yield db diff --git a/app/database/models.py b/app/database/models.py index c4bb8253..4f3308ef 100644 --- a/app/database/models.py +++ b/app/database/models.py @@ -1,12 +1,17 @@ +from __future__ import annotations + from datetime import datetime +from typing import Dict, Any from sqlalchemy import (DDL, Boolean, Column, DateTime, ForeignKey, Index, - Integer, String, event) + Integer, String, event, UniqueConstraint) from sqlalchemy.dialects.postgresql import TSVECTOR -from sqlalchemy.orm import relationship +from sqlalchemy.exc import IntegrityError, SQLAlchemyError +from sqlalchemy.orm import relationship, Session from app.config import PSQL_ENVIRONMENT from app.database.database import Base +from app.dependencies import logger class UserEvent(Base): @@ -51,11 +56,12 @@ class Event(Base): end = Column(DateTime, nullable=False) content = Column(String) location = Column(String) + color = Column(String, nullable=True) - owner = relationship("User") owner_id = Column(Integer, ForeignKey("users.id")) - color = Column(String, nullable=True) + category_id = Column(Integer, ForeignKey("categories.id")) + owner = relationship("User") participants = relationship("UserEvent", back_populates="events") # PostgreSQL @@ -65,12 +71,45 @@ class Event(Base): 'events_tsv_idx', 'events_tsv', postgresql_using='gin'), - ) + ) def __repr__(self): return f'' +class Category(Base): + __tablename__ = "categories" + + __table_args__ = ( + UniqueConstraint('user_id', 'name', 'color'), + ) + id = Column(Integer, primary_key=True, index=True) + name = Column(String, nullable=False) + color = Column(String, nullable=False) + user_id = Column(Integer, ForeignKey("users.id"), nullable=False) + + @staticmethod + def create(db_session: Session, name: str, color: str, + user_id: int) -> Category: + try: + category = Category(name=name, color=color, user_id=user_id) + db_session.add(category) + db_session.flush() + db_session.commit() + db_session.refresh(category) + except (SQLAlchemyError, IntegrityError) as e: + logger.error(f"Failed to create category: {e}") + raise e + else: + return category + + def to_dict(self) -> Dict[str, Any]: + return {c.name: getattr(self, c.name) for c in self.__table__.columns} + + def __repr__(self) -> str: + return f'' + + class PSQLEnvironmentError(Exception): pass @@ -88,7 +127,7 @@ class PSQLEnvironmentError(Exception): Event.__table__, 'after_create', trigger_snippet.execute_if(dialect='postgresql') - ) + ) class Invitation(Base): diff --git a/app/main.py b/app/main.py index aa48c487..fa778725 100644 --- a/app/main.py +++ b/app/main.py @@ -8,8 +8,8 @@ from app.dependencies import (logger, MEDIA_PATH, STATIC_PATH, templates) from app.internal.quotes import daily_quotes, load_quotes from app.routers import ( - agenda, dayview, email, event, invitation, profile, search, telegram, - whatsapp + agenda, categories, dayview, email, event, invitation, profile, search, + telegram, whatsapp ) from app.telegram.bot import telegram_bot @@ -34,21 +34,27 @@ def create_tables(engine, psql_environment): app.logger = logger -app.include_router(profile.router) -app.include_router(event.router) -app.include_router(agenda.router) -app.include_router(telegram.router) -app.include_router(dayview.router) -app.include_router(email.router) -app.include_router(invitation.router) -app.include_router(whatsapp.router) -app.include_router(search.router) +routers_to_include = [ + agenda.router, + categories.router, + dayview.router, + email.router, + event.router, + invitation.router, + profile.router, + search.router, + telegram.router, + whatsapp.router, +] + +for router in routers_to_include: + app.include_router(router) telegram_bot.set_webhook() # TODO: I add the quote day to the home page -# until the relavent calendar view will be developed. +# until the relevant calendar view will be developed. @app.get("/") @logger.catch() async def home(request: Request, db: Session = Depends(get_db)): diff --git a/app/routers/categories.py b/app/routers/categories.py new file mode 100644 index 00000000..322b02ca --- /dev/null +++ b/app/routers/categories.py @@ -0,0 +1,80 @@ +from typing import Dict, List, Any + +from fastapi import APIRouter, Depends, HTTPException, Request +from pydantic import BaseModel +from sqlalchemy.exc import IntegrityError, SQLAlchemyError +from sqlalchemy.orm import Session +from starlette import status +from starlette.datastructures import ImmutableMultiDict + +from app.database.database import get_db +from app.database.models import Category + +router = APIRouter( + prefix="/categories", + tags=["categories"], +) + + +class CategoryModel(BaseModel): + name: str + color: str + user_id: int + + +# TODO(issue#29): get current user_id from session +@router.get("/") +def get_categories(request: Request, + db_session: Session = Depends(get_db)) -> List[Category]: + if validate_request_params(request.query_params): + return get_user_categories(db_session, **request.query_params) + else: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Request {request.query_params} contains " + f"unallowed params.") + + +# TODO(issue#29): get current user_id from session +@router.post("/") +async def set_category(category: CategoryModel, + db_sess: Session = Depends(get_db)) -> Dict[str, Any]: + try: + cat = Category.create(db_sess, + name=category.name, + color=category.color, + user_id=category.user_id) + except IntegrityError: + db_sess.rollback() + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, + detail=f"category is already exists for " + f"user {category.user_id}.") + else: + return {"category": cat.to_dict()} + + +def validate_request_params(query_params: ImmutableMultiDict) -> bool: + """ + request.query_params contains not more than user_id, name, color + and not less than user_id: + Intersection must contain at least user_id. + Union must not contain fields other than user_id, name, color. + """ + all_fields = set(CategoryModel.schema()["required"]) + request_params = set(query_params) + union_set = request_params.union(all_fields) + intersection_set = request_params.intersection(all_fields) + return union_set == all_fields and "user_id" in intersection_set + + +def get_user_categories(db_session: Session, + user_id: int, **params) -> List[Category]: + """ + Returns user's categories, filtered by params. + """ + try: + categories = db_session.query(Category).filter_by( + user_id=user_id).filter_by(**params).all() + except SQLAlchemyError: + return [] + else: + return categories diff --git a/schema.md b/schema.md index 58140f95..852870e6 100644 --- a/schema.md +++ b/schema.md @@ -11,13 +11,25 @@ │ ├── internal │ ├── __init__.py │ ├── admin.py -│ ├── routers -│ ├── __init__.py -│ ├── profile.py +│ ├── agenda_events.py +│ ├── email.py │ ├── media │ ├── example.png +│ ├── fake_user.png │ ├── profile.png +│ ├── routers +│ ├── __init__.py +│ ├── agenda.py +│ ├── categories.py +│ ├── email.py +│ ├── event.py +│ ├── profile.py │ ├── static +│ ├── event +│ ├── eventedit.css +│ ├── eventview.css +│ ├── agenda_style.css +│ ├── popover.js │ ├── style.css │ ├── popover.js │ ├── templates @@ -29,6 +41,11 @@ ├── schema.md └── tests ├── __init__.py - └── conftest.py - └── test_profile.py - └── test_app.py \ No newline at end of file + ├── conftest.py + ├── test_agenda_internal.py + ├── test_agenda_route.py + ├── test_app.py + ├── test_categories.py + ├── test_email.py + ├── test_event.py + └── test_profile.py \ No newline at end of file diff --git a/tests/category_fixture.py b/tests/category_fixture.py new file mode 100644 index 00000000..08cc6b97 --- /dev/null +++ b/tests/category_fixture.py @@ -0,0 +1,13 @@ +import pytest +from sqlalchemy.orm import Session + +from app.database.models import User, Category + + +@pytest.fixture +def category(session: Session, user: User) -> Category: + category = Category.create(session, name="Guitar Lesson", color="121212", + user_id=user.id) + yield category + session.delete(category) + session.commit() diff --git a/tests/conftest.py b/tests/conftest.py index c73a3439..9c51495d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,9 +2,11 @@ from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker + from app.config import PSQL_ENVIRONMENT from app.database.database import Base + pytest_plugins = [ 'tests.user_fixture', 'tests.event_fixture', @@ -13,6 +15,7 @@ 'tests.client_fixture', 'tests.asyncio_fixture', 'tests.logger_fixture', + 'tests.category_fixture', 'smtpdfix', 'tests.quotes_fixture' ] @@ -49,6 +52,7 @@ def session(): Base.metadata.create_all(bind=test_engine) session = get_test_db() yield session + session.rollback() session.close() Base.metadata.drop_all(bind=test_engine) diff --git a/tests/test_categories.py b/tests/test_categories.py new file mode 100644 index 00000000..c3c7e3d9 --- /dev/null +++ b/tests/test_categories.py @@ -0,0 +1,109 @@ +import pytest +from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.testing import mock +from starlette import status +from starlette.datastructures import ImmutableMultiDict + +from app.database.models import Event +from app.routers.categories import get_user_categories, validate_request_params + + +class TestCategories: + CATEGORY_ALREADY_EXISTS_MSG = "category is already exists for" + UNALLOWED_PARAMS = "contains unallowed params" + + @staticmethod + def test_get_categories_logic_succeeded(session, user, category): + assert get_user_categories(session, category.user_id) == [category] + + @staticmethod + def test_creating_new_category(client, user): + response = client.post("/categories/", + json={"user_id": user.id, "name": "Foo", + "color": "eecc11"}) + assert response.ok + assert {("user_id", user.id), ("name", "Foo"), + ("color", "eecc11")}.issubset( + set(response.json()['category'].items())) + + @staticmethod + def test_creating_not_unique_category_failed(client, user, category): + response = client.post("/categories/", json={"user_id": user.id, + "name": "Guitar Lesson", + "color": "121212"}) + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert TestCategories.CATEGORY_ALREADY_EXISTS_MSG in \ + response.json()["detail"] + + @staticmethod + def test_create_event_with_category(category): + event = Event(title="OOO", content="Guitar rocks!!", + owner_id=category.user_id, category_id=category.id) + assert event.category_id is not None + assert event.category_id == category.id + + @staticmethod + def test_get_user_categories(client, category): + response = client.get(f"/categories/?user_id={category.user_id}" + f"&name={category.name}&color={category.color}") + assert response.ok + assert set(response.json()[0].items()) == { + ("user_id", category.user_id), ("color", "121212"), + ("name", "Guitar Lesson"), ("id", category.id)} + + @staticmethod + def test_get_category_by_name(client, user, category): + response = client.get(f"/categories/?user_id={category.user_id}" + f"&name={category.name}") + assert response.ok + assert set(response.json()[0].items()) == { + ("user_id", category.user_id), ("color", "121212"), + ("name", "Guitar Lesson"), ("id", category.id)} + + @staticmethod + def test_get_category_by_color(client, user, category): + response = client.get(f"/categories/?user_id={category.user_id}&" + f"color={category.color}") + assert response.ok + assert set(response.json()[0].items()) == { + ("user_id", category.user_id), ("color", "121212"), + ("name", "Guitar Lesson"), ("id", category.id)} + + @staticmethod + def test_get_category_bad_request(client): + response = client.get("/categories/") + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert TestCategories.UNALLOWED_PARAMS in response.json()["detail"] + + @staticmethod + def test_repr(category): + assert category.__repr__() == \ + f'' + + @staticmethod + def test_to_dict(category): + assert {c.name: getattr(category, c.name) for c in + category.__table__.columns} == category.to_dict() + + @staticmethod + @pytest.mark.parametrize('params, expected_result', [ + (ImmutableMultiDict([('user_id', ''), ('name', ''), + ('color', '')]), True), + (ImmutableMultiDict([('user_id', ''), ('name', '')]), True), + (ImmutableMultiDict([('user_id', ''), ('color', '')]), True), + (ImmutableMultiDict([('user_id', '')]), True), + (ImmutableMultiDict([('name', ''), ('color', '')]), False), + (ImmutableMultiDict([]), False), + (ImmutableMultiDict([('user_id', ''), ('name', ''), ('color', ''), + ('bad_param', '')]), False), + ]) + def test_validate_request_params(params, expected_result): + assert validate_request_params(params) == expected_result + + @staticmethod + def test_get_categories_failed(session): + def raise_error(param): + raise SQLAlchemyError() + + session.query = mock.Mock(side_effect=raise_error) + assert get_user_categories(session, 1) == []