From d69c4800af4eb1f8cdf0264c0590bb57a5b5532a Mon Sep 17 00:00:00 2001 From: i Date: Wed, 27 Jan 2021 20:35:33 +0200 Subject: [PATCH 1/4] Support categories for events --- app/database/database.py | 4 +- app/database/models.py | 43 ++++++++++++--- app/main.py | 25 +++++---- app/routers/categories.py | 80 ++++++++++++++++++++++++++++ schema.md | 29 +++++++--- tests/category_fixture.py | 13 +++++ tests/conftest.py | 3 +- tests/test_categories.py | 108 ++++++++++++++++++++++++++++++++++++++ 8 files changed, 281 insertions(+), 24 deletions(-) create mode 100644 app/routers/categories.py create mode 100644 tests/category_fixture.py create mode 100644 tests/test_categories.py diff --git a/app/database/database.py b/app/database/database.py index c0544c0c..b312ce99 100644 --- a/app/database/database.py +++ b/app/database/database.py @@ -2,7 +2,7 @@ from sqlalchemy import create_engine from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm import Session, sessionmaker from app import config @@ -25,7 +25,7 @@ def create_env_engine(psql_environment, sqlalchemy_database_url): Base = declarative_base() -def get_db(): +def get_db() -> Session: db = SessionLocal() try: yield db diff --git a/app/database/models.py b/app/database/models.py index 4d243fef..741015f9 100644 --- a/app/database/models.py +++ b/app/database/models.py @@ -3,9 +3,9 @@ from app.config import PSQL_ENVIRONMENT from app.database.database import Base from sqlalchemy import (DDL, Boolean, Column, DateTime, ForeignKey, Index, - Integer, String, event) + Integer, String, event, UniqueConstraint) from sqlalchemy.dialects.postgresql import TSVECTOR -from sqlalchemy.orm import relationship +from sqlalchemy.orm import relationship, Session class UserEvent(Base): @@ -50,11 +50,12 @@ class Event(Base): end = Column(DateTime, nullable=False) content = Column(String) location = Column(String) + color = Column(String, nullable=True) - owner = relationship("User") owner_id = Column(Integer, ForeignKey("users.id")) - color = Column(String, nullable=True) + category_id = Column(Integer, ForeignKey("categories.id")) + owner = relationship("User") participants = relationship("UserEvent", back_populates="events") # PostgreSQL @@ -64,12 +65,42 @@ class Event(Base): 'events_tsv_idx', 'events_tsv', postgresql_using='gin'), - ) + ) def __repr__(self): return f'' +class Category(Base): + __tablename__ = "categories" + + __table_args__ = ( + UniqueConstraint('user_id', 'name', 'color'), + ) + id = Column(Integer, primary_key=True, index=True) + name = Column(String, nullable=False) + color = Column(String, nullable=False) + user_id = Column(Integer, ForeignKey("users.id"), nullable=False) + + @staticmethod + def create(db_session: Session, name: str, color: str, user_id: int): + try: + category = Category(name=name, color=color, user_id=user_id) + db_session.add(category) + db_session.flush() + db_session.commit() + db_session.refresh(category) + return category + except Exception as e: + raise e + + def to_dict(self): + return {c.name: getattr(self, c.name) for c in self.__table__.columns} + + def __repr__(self): + return f'' + + class PSQLEnvironmentError(Exception): pass @@ -87,7 +118,7 @@ class PSQLEnvironmentError(Exception): Event.__table__, 'after_create', trigger_snippet.execute_if(dialect='postgresql') - ) + ) class Invitation(Base): diff --git a/app/main.py b/app/main.py index 9f562d34..3f5197e6 100644 --- a/app/main.py +++ b/app/main.py @@ -7,7 +7,8 @@ from app.dependencies import ( MEDIA_PATH, STATIC_PATH, templates) from app.routers import ( - agenda, dayview, email, event, invitation, profile, search, telegram) + agenda, dayview, email, event, invitation, profile, search, telegram, + categories) from app.telegram.bot import telegram_bot @@ -27,14 +28,20 @@ def create_tables(engine, psql_environment): app.mount("/static", StaticFiles(directory=STATIC_PATH), name="static") app.mount("/media", StaticFiles(directory=MEDIA_PATH), name="media") -app.include_router(profile.router) -app.include_router(event.router) -app.include_router(agenda.router) -app.include_router(telegram.router) -app.include_router(dayview.router) -app.include_router(email.router) -app.include_router(invitation.router) -app.include_router(search.router) +routers_to_include = [ + agenda.router, + categories.router, + dayview.router, + email.router, + event.router, + invitation.router, + profile.router, + search.router, + telegram.router, +] + +for router in routers_to_include: + app.include_router(router) telegram_bot.set_webhook() diff --git a/app/routers/categories.py b/app/routers/categories.py new file mode 100644 index 00000000..90cfadaa --- /dev/null +++ b/app/routers/categories.py @@ -0,0 +1,80 @@ +from typing import Dict, List + +from fastapi import APIRouter, Depends, HTTPException, Request +from pydantic import BaseModel +from sqlalchemy.exc import IntegrityError, SQLAlchemyError +from sqlalchemy.orm import Session +from starlette import status +from starlette.datastructures import ImmutableMultiDict + +from app.database.database import get_db +from app.database.models import Category + +router = APIRouter( + prefix="/categories", + tags=["categories"], +) + + +class CategoryModel(BaseModel): + name: str + color: str + user_id: int + + +# TODO(issue#29): get current user_id from session +@router.get("/") +def get_categories(request: Request, + db_session: Session = Depends(get_db)) -> List[Category]: + if validate_request_params(request.query_params): + return get_user_categories(db_session, **request.query_params) + else: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Request {request.query_params} contains " + f"unallowed params.") + + +# TODO(issue#29): get current user_id from session +@router.post("/") +async def set_category(category: CategoryModel, + db_session: Session = Depends(get_db)) -> Dict: + try: + cat = Category.create(db_session, + name=category.name, + color=category.color, + user_id=category.user_id) + except IntegrityError: + db_session.rollback() + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, + detail=f"category is already exists for " + f"{category.user_id} user.") + else: + return {"category": cat.to_dict()} + + +def validate_request_params(query_params: ImmutableMultiDict) -> bool: + """ + request.query_params contains not more than user_id, name, color + and not less than user_id: + Intersection must contain at least user_id. + Union must not contain fields other than user_id, name, color. + """ + all_fields = set(CategoryModel.schema()["required"]) + request_params = set(query_params) + union_set = request_params.union(all_fields) + intersection_set = request_params.intersection(all_fields) + return union_set == all_fields and "user_id" in intersection_set + + +def get_user_categories(db_session: Session, + user_id: int, **params) -> List[Category]: + """ + Returns user's categories, filtered by params. + """ + try: + categories = db_session.query(Category).filter_by( + user_id=user_id).filter_by(**params).all() + except SQLAlchemyError: + return [] + else: + return categories diff --git a/schema.md b/schema.md index 58140f95..852870e6 100644 --- a/schema.md +++ b/schema.md @@ -11,13 +11,25 @@ │ ├── internal │ ├── __init__.py │ ├── admin.py -│ ├── routers -│ ├── __init__.py -│ ├── profile.py +│ ├── agenda_events.py +│ ├── email.py │ ├── media │ ├── example.png +│ ├── fake_user.png │ ├── profile.png +│ ├── routers +│ ├── __init__.py +│ ├── agenda.py +│ ├── categories.py +│ ├── email.py +│ ├── event.py +│ ├── profile.py │ ├── static +│ ├── event +│ ├── eventedit.css +│ ├── eventview.css +│ ├── agenda_style.css +│ ├── popover.js │ ├── style.css │ ├── popover.js │ ├── templates @@ -29,6 +41,11 @@ ├── schema.md └── tests ├── __init__.py - └── conftest.py - └── test_profile.py - └── test_app.py \ No newline at end of file + ├── conftest.py + ├── test_agenda_internal.py + ├── test_agenda_route.py + ├── test_app.py + ├── test_categories.py + ├── test_email.py + ├── test_event.py + └── test_profile.py \ No newline at end of file diff --git a/tests/category_fixture.py b/tests/category_fixture.py new file mode 100644 index 00000000..08cc6b97 --- /dev/null +++ b/tests/category_fixture.py @@ -0,0 +1,13 @@ +import pytest +from sqlalchemy.orm import Session + +from app.database.models import User, Category + + +@pytest.fixture +def category(session: Session, user: User) -> Category: + category = Category.create(session, name="Guitar Lesson", color="121212", + user_id=user.id) + yield category + session.delete(category) + session.commit() diff --git a/tests/conftest.py b/tests/conftest.py index 1100fc70..c7b27549 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,7 +4,6 @@ from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker - pytest_plugins = [ 'tests.user_fixture', 'tests.event_fixture', @@ -12,6 +11,7 @@ 'tests.association_fixture', 'tests.client_fixture', 'tests.asyncio_fixture', + 'tests.category_fixture', 'smtpdfix', ] @@ -47,6 +47,7 @@ def session(): Base.metadata.create_all(bind=test_engine) session = get_test_db() yield session + session.rollback() session.close() Base.metadata.drop_all(bind=test_engine) diff --git a/tests/test_categories.py b/tests/test_categories.py new file mode 100644 index 00000000..5f9a12e0 --- /dev/null +++ b/tests/test_categories.py @@ -0,0 +1,108 @@ +import pytest +from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.testing import mock +from starlette import status +from starlette.datastructures import ImmutableMultiDict + +from app.database.models import Event +from app.routers.categories import get_user_categories, validate_request_params + + +class TestCategories: + CATEGORY_ALREADY_EXISTS_MSG = "category is already exists for" + UNALLOWED_PARAMS = "contains unallowed params" + + @staticmethod + def test_get_categories_logic_succeeded(session, user, category): + assert get_user_categories(session, category.user_id) == [category] + + @staticmethod + def test_creating_new_category(client, user): + response = client.post("/categories/", + json={"user_id": user.id, "name": "Foo", + "color": "eecc11"}) + assert response.status_code == status.HTTP_200_OK + assert {"user_id": user.id, "name": "Foo", "color": "eecc11"}.items() \ + <= response.json()['category'].items() + + @staticmethod + def test_creating_not_unique_category_failed(client, user, category): + response = client.post("/categories/", json={"user_id": user.id, + "name": "Guitar Lesson", + "color": "121212"}) + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert TestCategories.CATEGORY_ALREADY_EXISTS_MSG in \ + response.json()["detail"] + + @staticmethod + def test_create_event_with_category(category): + event = Event(title="OOO", content="Guitar rocks!!", + owner_id=category.user_id, category_id=category.id) + assert event.category_id is not None + assert event.category_id == category.id + + @staticmethod + def test_get_user_categories(client, category): + response = client.get(f"/categories/?user_id={category.user_id}" + f"&name={category.name}&color={category.color}") + assert response.status_code == status.HTTP_200_OK + assert response.json() == [ + {"user_id": category.user_id, "color": "121212", + "name": "Guitar Lesson", "id": category.id}] + + @staticmethod + def test_get_category_by_name(client, user, category): + response = client.get(f"/categories/?user_id={category.user_id}" + f"&name={category.name}") + assert response.status_code == status.HTTP_200_OK + assert response.json() == [ + {"user_id": category.user_id, "color": "121212", + "name": "Guitar Lesson", "id": category.id}] + + @staticmethod + def test_get_category_by_color(client, user, category): + response = client.get(f"/categories/?user_id={category.user_id}&" + f"color={category.color}") + assert response.status_code == status.HTTP_200_OK + assert response.json() == [ + {"user_id": category.user_id, "color": "121212", + "name": "Guitar Lesson", "id": category.id}] + + @staticmethod + def test_get_category_bad_request(client): + response = client.get("/categories/") + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert TestCategories.UNALLOWED_PARAMS in response.json()["detail"] + + @staticmethod + def test_repr(category): + assert category.__repr__() == \ + f'' + + @staticmethod + def test_to_dict(category): + assert {c.name: getattr(category, c.name) for c in + category.__table__.columns} == category.to_dict() + + @staticmethod + @pytest.mark.parametrize('params, expected_result', [ + (ImmutableMultiDict([('user_id', ''), ('name', ''), + ('color', '')]), True), + (ImmutableMultiDict([('user_id', ''), ('name', '')]), True), + (ImmutableMultiDict([('user_id', ''), ('color', '')]), True), + (ImmutableMultiDict([('user_id', '')]), True), + (ImmutableMultiDict([('name', ''), ('color', '')]), False), + (ImmutableMultiDict([]), False), + (ImmutableMultiDict([('user_id', ''), ('name', ''), ('color', ''), + ('bad_param', '')]), False), + ]) + def test_validate_request_params(params, expected_result): + assert validate_request_params(params) == expected_result + + @staticmethod + def test_get_categories_failed(session): + def raise_error(param): + raise SQLAlchemyError() + + session.query = mock.Mock(side_effect=raise_error) + assert get_user_categories(session, 1) == [] From 176fbb58648623e8f058e9e0892881d7712c6f86 Mon Sep 17 00:00:00 2001 From: i Date: Thu, 28 Jan 2021 18:40:33 +0200 Subject: [PATCH 2/4] PR comments --- app/database/models.py | 18 ++++++++++++------ app/main.py | 4 ++-- app/routers/categories.py | 10 +++++----- tests/test_categories.py | 13 +++++++------ 4 files changed, 26 insertions(+), 19 deletions(-) diff --git a/app/database/models.py b/app/database/models.py index ce680616..fba04145 100644 --- a/app/database/models.py +++ b/app/database/models.py @@ -1,12 +1,16 @@ +from __future__ import annotations + from datetime import datetime +from typing import Dict, Any -from app.config import PSQL_ENVIRONMENT -from app.database.database import Base from sqlalchemy import (DDL, Boolean, Column, DateTime, ForeignKey, Index, Integer, String, event, UniqueConstraint) from sqlalchemy.dialects.postgresql import TSVECTOR from sqlalchemy.orm import relationship, Session +from app.config import PSQL_ENVIRONMENT +from app.database.database import Base + class UserEvent(Base): __tablename__ = "user_event" @@ -83,21 +87,23 @@ class Category(Base): user_id = Column(Integer, ForeignKey("users.id"), nullable=False) @staticmethod - def create(db_session: Session, name: str, color: str, user_id: int): + def create(db_session: Session, name: str, color: str, + user_id: int) -> Category: try: category = Category(name=name, color=color, user_id=user_id) db_session.add(category) db_session.flush() db_session.commit() db_session.refresh(category) - return category except Exception as e: raise e + else: + return category - def to_dict(self): + def to_dict(self) -> Dict[str, Any]: return {c.name: getattr(self, c.name) for c in self.__table__.columns} - def __repr__(self): + def __repr__(self) -> str: return f'' diff --git a/app/main.py b/app/main.py index 28ccbf21..c0d47573 100644 --- a/app/main.py +++ b/app/main.py @@ -2,19 +2,19 @@ from fastapi.staticfiles import StaticFiles from sqlalchemy.orm import Session +from app import config from app.config import PSQL_ENVIRONMENT from app.database import models from app.database.database import engine, get_db from app.dependencies import ( MEDIA_PATH, STATIC_PATH, templates) +from app.internal.logger_customizer import LoggerCustomizer from app.internal.quotes import load_quotes, daily_quotes from app.routers import ( agenda, categories, dayview, email, event, invitation, profile, search, telegram, whatsapp ) from app.telegram.bot import telegram_bot -from app.internal.logger_customizer import LoggerCustomizer -from app import config def create_tables(engine, psql_environment): diff --git a/app/routers/categories.py b/app/routers/categories.py index 90cfadaa..322b02ca 100644 --- a/app/routers/categories.py +++ b/app/routers/categories.py @@ -1,4 +1,4 @@ -from typing import Dict, List +from typing import Dict, List, Any from fastapi import APIRouter, Depends, HTTPException, Request from pydantic import BaseModel @@ -37,17 +37,17 @@ def get_categories(request: Request, # TODO(issue#29): get current user_id from session @router.post("/") async def set_category(category: CategoryModel, - db_session: Session = Depends(get_db)) -> Dict: + db_sess: Session = Depends(get_db)) -> Dict[str, Any]: try: - cat = Category.create(db_session, + cat = Category.create(db_sess, name=category.name, color=category.color, user_id=category.user_id) except IntegrityError: - db_session.rollback() + db_sess.rollback() raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"category is already exists for " - f"{category.user_id} user.") + f"user {category.user_id}.") else: return {"category": cat.to_dict()} diff --git a/tests/test_categories.py b/tests/test_categories.py index 5f9a12e0..16721983 100644 --- a/tests/test_categories.py +++ b/tests/test_categories.py @@ -21,9 +21,10 @@ def test_creating_new_category(client, user): response = client.post("/categories/", json={"user_id": user.id, "name": "Foo", "color": "eecc11"}) - assert response.status_code == status.HTTP_200_OK - assert {"user_id": user.id, "name": "Foo", "color": "eecc11"}.items() \ - <= response.json()['category'].items() + assert response.ok + assert {"user_id": user.id, "name": "Foo", + "color": "eecc11"}.items() <= response.json()[ + 'category'].items() @staticmethod def test_creating_not_unique_category_failed(client, user, category): @@ -45,7 +46,7 @@ def test_create_event_with_category(category): def test_get_user_categories(client, category): response = client.get(f"/categories/?user_id={category.user_id}" f"&name={category.name}&color={category.color}") - assert response.status_code == status.HTTP_200_OK + assert response.ok assert response.json() == [ {"user_id": category.user_id, "color": "121212", "name": "Guitar Lesson", "id": category.id}] @@ -54,7 +55,7 @@ def test_get_user_categories(client, category): def test_get_category_by_name(client, user, category): response = client.get(f"/categories/?user_id={category.user_id}" f"&name={category.name}") - assert response.status_code == status.HTTP_200_OK + assert response.ok assert response.json() == [ {"user_id": category.user_id, "color": "121212", "name": "Guitar Lesson", "id": category.id}] @@ -63,7 +64,7 @@ def test_get_category_by_name(client, user, category): def test_get_category_by_color(client, user, category): response = client.get(f"/categories/?user_id={category.user_id}&" f"color={category.color}") - assert response.status_code == status.HTTP_200_OK + assert response.ok assert response.json() == [ {"user_id": category.user_id, "color": "121212", "name": "Guitar Lesson", "id": category.id}] From 81590a2d46801a51c0977994cd957f4f7507d979 Mon Sep 17 00:00:00 2001 From: i Date: Mon, 1 Feb 2021 22:01:35 +0200 Subject: [PATCH 3/4] PR comments --- app/database/models.py | 5 ++++- tests/test_categories.py | 24 ++++++++++++------------ 2 files changed, 16 insertions(+), 13 deletions(-) diff --git a/app/database/models.py b/app/database/models.py index fba04145..4f3308ef 100644 --- a/app/database/models.py +++ b/app/database/models.py @@ -6,10 +6,12 @@ from sqlalchemy import (DDL, Boolean, Column, DateTime, ForeignKey, Index, Integer, String, event, UniqueConstraint) from sqlalchemy.dialects.postgresql import TSVECTOR +from sqlalchemy.exc import IntegrityError, SQLAlchemyError from sqlalchemy.orm import relationship, Session from app.config import PSQL_ENVIRONMENT from app.database.database import Base +from app.dependencies import logger class UserEvent(Base): @@ -95,7 +97,8 @@ def create(db_session: Session, name: str, color: str, db_session.flush() db_session.commit() db_session.refresh(category) - except Exception as e: + except (SQLAlchemyError, IntegrityError) as e: + logger.error(f"Failed to create category: {e}") raise e else: return category diff --git a/tests/test_categories.py b/tests/test_categories.py index 16721983..c3c7e3d9 100644 --- a/tests/test_categories.py +++ b/tests/test_categories.py @@ -22,9 +22,9 @@ def test_creating_new_category(client, user): json={"user_id": user.id, "name": "Foo", "color": "eecc11"}) assert response.ok - assert {"user_id": user.id, "name": "Foo", - "color": "eecc11"}.items() <= response.json()[ - 'category'].items() + assert {("user_id", user.id), ("name", "Foo"), + ("color", "eecc11")}.issubset( + set(response.json()['category'].items())) @staticmethod def test_creating_not_unique_category_failed(client, user, category): @@ -47,27 +47,27 @@ def test_get_user_categories(client, category): response = client.get(f"/categories/?user_id={category.user_id}" f"&name={category.name}&color={category.color}") assert response.ok - assert response.json() == [ - {"user_id": category.user_id, "color": "121212", - "name": "Guitar Lesson", "id": category.id}] + assert set(response.json()[0].items()) == { + ("user_id", category.user_id), ("color", "121212"), + ("name", "Guitar Lesson"), ("id", category.id)} @staticmethod def test_get_category_by_name(client, user, category): response = client.get(f"/categories/?user_id={category.user_id}" f"&name={category.name}") assert response.ok - assert response.json() == [ - {"user_id": category.user_id, "color": "121212", - "name": "Guitar Lesson", "id": category.id}] + assert set(response.json()[0].items()) == { + ("user_id", category.user_id), ("color", "121212"), + ("name", "Guitar Lesson"), ("id", category.id)} @staticmethod def test_get_category_by_color(client, user, category): response = client.get(f"/categories/?user_id={category.user_id}&" f"color={category.color}") assert response.ok - assert response.json() == [ - {"user_id": category.user_id, "color": "121212", - "name": "Guitar Lesson", "id": category.id}] + assert set(response.json()[0].items()) == { + ("user_id", category.user_id), ("color", "121212"), + ("name", "Guitar Lesson"), ("id", category.id)} @staticmethod def test_get_category_bad_request(client): From 269652fbaff97e3ca9e4b26072b0db845f2a0866 Mon Sep 17 00:00:00 2001 From: i Date: Tue, 2 Feb 2021 12:28:07 +0200 Subject: [PATCH 4/4] Remove redundant --- app/database/models.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/app/database/models.py b/app/database/models.py index 45ce9963..4f3308ef 100644 --- a/app/database/models.py +++ b/app/database/models.py @@ -13,9 +13,6 @@ from app.database.database import Base from app.dependencies import logger -from app.config import PSQL_ENVIRONMENT -from app.database.database import Base - class UserEvent(Base): __tablename__ = "user_event"