Skip to content

Support categories for events #137

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions app/database/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import Session, sessionmaker

from app import config

Expand All @@ -25,7 +25,7 @@ def create_env_engine(psql_environment, sqlalchemy_database_url):
Base = declarative_base()


def get_db():
def get_db() -> Session:
db = SessionLocal()
try:
yield db
Expand Down
51 changes: 45 additions & 6 deletions app/database/models.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
from __future__ import annotations

from datetime import datetime
from typing import Dict, Any

from sqlalchemy import (DDL, Boolean, Column, DateTime, ForeignKey, Index,
Integer, String, event)
Integer, String, event, UniqueConstraint)
from sqlalchemy.dialects.postgresql import TSVECTOR
from sqlalchemy.orm import relationship
from sqlalchemy.exc import IntegrityError, SQLAlchemyError
from sqlalchemy.orm import relationship, Session

from app.config import PSQL_ENVIRONMENT
from app.database.database import Base
from app.dependencies import logger


class UserEvent(Base):
Expand Down Expand Up @@ -51,11 +56,12 @@ class Event(Base):
end = Column(DateTime, nullable=False)
content = Column(String)
location = Column(String)
color = Column(String, nullable=True)

owner = relationship("User")
owner_id = Column(Integer, ForeignKey("users.id"))
color = Column(String, nullable=True)
category_id = Column(Integer, ForeignKey("categories.id"))

owner = relationship("User")
participants = relationship("UserEvent", back_populates="events")

# PostgreSQL
Expand All @@ -65,12 +71,45 @@ class Event(Base):
'events_tsv_idx',
'events_tsv',
postgresql_using='gin'),
)
)

def __repr__(self):
return f'<Event {self.id}>'


class Category(Base):
__tablename__ = "categories"

__table_args__ = (
UniqueConstraint('user_id', 'name', 'color'),
)
id = Column(Integer, primary_key=True, index=True)
name = Column(String, nullable=False)
color = Column(String, nullable=False)
user_id = Column(Integer, ForeignKey("users.id"), nullable=False)

@staticmethod
def create(db_session: Session, name: str, color: str,
user_id: int) -> Category:
try:
category = Category(name=name, color=color, user_id=user_id)
db_session.add(category)
db_session.flush()
db_session.commit()
db_session.refresh(category)
except (SQLAlchemyError, IntegrityError) as e:
logger.error(f"Failed to create category: {e}")
raise e
else:
return category

def to_dict(self) -> Dict[str, Any]:
return {c.name: getattr(self, c.name) for c in self.__table__.columns}

def __repr__(self) -> str:
return f'<Category {self.id} {self.name} {self.color}>'


class PSQLEnvironmentError(Exception):
pass

Expand All @@ -88,7 +127,7 @@ class PSQLEnvironmentError(Exception):
Event.__table__,
'after_create',
trigger_snippet.execute_if(dialect='postgresql')
)
)


class Invitation(Base):
Expand Down
30 changes: 18 additions & 12 deletions app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from app.dependencies import (logger, MEDIA_PATH, STATIC_PATH, templates)
from app.internal.quotes import daily_quotes, load_quotes
from app.routers import (
agenda, dayview, email, event, invitation, profile, search, telegram,
whatsapp
agenda, categories, dayview, email, event, invitation, profile, search,
telegram, whatsapp
)
from app.telegram.bot import telegram_bot

Expand All @@ -34,21 +34,27 @@ def create_tables(engine, psql_environment):

app.logger = logger

app.include_router(profile.router)
app.include_router(event.router)
app.include_router(agenda.router)
app.include_router(telegram.router)
app.include_router(dayview.router)
app.include_router(email.router)
app.include_router(invitation.router)
app.include_router(whatsapp.router)
app.include_router(search.router)
routers_to_include = [
agenda.router,
categories.router,
dayview.router,
email.router,
event.router,
invitation.router,
profile.router,
search.router,
telegram.router,
whatsapp.router,
]

for router in routers_to_include:
app.include_router(router)

telegram_bot.set_webhook()


# TODO: I add the quote day to the home page
# until the relavent calendar view will be developed.
# until the relevant calendar view will be developed.
@app.get("/")
@logger.catch()
async def home(request: Request, db: Session = Depends(get_db)):
Expand Down
80 changes: 80 additions & 0 deletions app/routers/categories.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from typing import Dict, List, Any

from fastapi import APIRouter, Depends, HTTPException, Request
from pydantic import BaseModel
from sqlalchemy.exc import IntegrityError, SQLAlchemyError
from sqlalchemy.orm import Session
from starlette import status
from starlette.datastructures import ImmutableMultiDict

from app.database.database import get_db
from app.database.models import Category

router = APIRouter(
prefix="/categories",
tags=["categories"],
)


class CategoryModel(BaseModel):
name: str
color: str
user_id: int


# TODO(issue#29): get current user_id from session
@router.get("/")
def get_categories(request: Request,
db_session: Session = Depends(get_db)) -> List[Category]:
if validate_request_params(request.query_params):
return get_user_categories(db_session, **request.query_params)
else:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Request {request.query_params} contains "
f"unallowed params.")


# TODO(issue#29): get current user_id from session
@router.post("/")
async def set_category(category: CategoryModel,
db_sess: Session = Depends(get_db)) -> Dict[str, Any]:
try:
cat = Category.create(db_sess,
name=category.name,
color=category.color,
user_id=category.user_id)
except IntegrityError:
db_sess.rollback()
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST,
detail=f"category is already exists for "
f"user {category.user_id}.")
else:
return {"category": cat.to_dict()}


def validate_request_params(query_params: ImmutableMultiDict) -> bool:
"""
request.query_params contains not more than user_id, name, color
and not less than user_id:
Intersection must contain at least user_id.
Union must not contain fields other than user_id, name, color.
"""
all_fields = set(CategoryModel.schema()["required"])
request_params = set(query_params)
union_set = request_params.union(all_fields)
intersection_set = request_params.intersection(all_fields)
return union_set == all_fields and "user_id" in intersection_set


def get_user_categories(db_session: Session,
user_id: int, **params) -> List[Category]:
"""
Returns user's categories, filtered by params.
"""
try:
categories = db_session.query(Category).filter_by(
user_id=user_id).filter_by(**params).all()
except SQLAlchemyError:
return []
else:
return categories
29 changes: 23 additions & 6 deletions schema.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,25 @@
│ ├── internal
│ ├── __init__.py
│ ├── admin.py
│ ├── routers
│ ├── __init__.py
│ ├── profile.py
│ ├── agenda_events.py
│ ├── email.py
│ ├── media
│ ├── example.png
│ ├── fake_user.png
│ ├── profile.png
│ ├── routers
│ ├── __init__.py
│ ├── agenda.py
│ ├── categories.py
│ ├── email.py
│ ├── event.py
│ ├── profile.py
│ ├── static
│ ├── event
│ ├── eventedit.css
│ ├── eventview.css
│ ├── agenda_style.css
│ ├── popover.js
│ ├── style.css
│ ├── popover.js
│ ├── templates
Expand All @@ -29,6 +41,11 @@
├── schema.md
└── tests
├── __init__.py
└── conftest.py
└── test_profile.py
└── test_app.py
├── conftest.py
├── test_agenda_internal.py
├── test_agenda_route.py
├── test_app.py
├── test_categories.py
├── test_email.py
├── test_event.py
└── test_profile.py
13 changes: 13 additions & 0 deletions tests/category_fixture.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import pytest
from sqlalchemy.orm import Session

from app.database.models import User, Category


@pytest.fixture
def category(session: Session, user: User) -> Category:
category = Category.create(session, name="Guitar Lesson", color="121212",
user_id=user.id)
yield category
session.delete(category)
session.commit()
4 changes: 4 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker


from app.config import PSQL_ENVIRONMENT
from app.database.database import Base


pytest_plugins = [
'tests.user_fixture',
'tests.event_fixture',
Expand All @@ -13,6 +15,7 @@
'tests.client_fixture',
'tests.asyncio_fixture',
'tests.logger_fixture',
'tests.category_fixture',
'smtpdfix',
'tests.quotes_fixture'
]
Expand Down Expand Up @@ -49,6 +52,7 @@ def session():
Base.metadata.create_all(bind=test_engine)
session = get_test_db()
yield session
session.rollback()
session.close()
Base.metadata.drop_all(bind=test_engine)

Expand Down
Loading