Skip to content

Commit 84a324d

Browse files
authored
Merge 81590a2 into 53c1b08
2 parents 53c1b08 + 81590a2 commit 84a324d

File tree

8 files changed

+295
-29
lines changed

8 files changed

+295
-29
lines changed

app/database/database.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from sqlalchemy import create_engine
44
from sqlalchemy.ext.declarative import declarative_base
5-
from sqlalchemy.orm import sessionmaker
5+
from sqlalchemy.orm import Session, sessionmaker
66

77
from app import config
88

@@ -25,7 +25,7 @@ def create_env_engine(psql_environment, sqlalchemy_database_url):
2525
Base = declarative_base()
2626

2727

28-
def get_db():
28+
def get_db() -> Session:
2929
db = SessionLocal()
3030
try:
3131
yield db

app/database/models.py

Lines changed: 48 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,17 @@
1+
from __future__ import annotations
2+
13
from datetime import datetime
4+
from typing import Dict, Any
25

3-
from app.config import PSQL_ENVIRONMENT
4-
from app.database.database import Base
56
from sqlalchemy import (DDL, Boolean, Column, DateTime, ForeignKey, Index,
6-
Integer, String, event)
7+
Integer, String, event, UniqueConstraint)
78
from sqlalchemy.dialects.postgresql import TSVECTOR
8-
from sqlalchemy.orm import relationship
9+
from sqlalchemy.exc import IntegrityError, SQLAlchemyError
10+
from sqlalchemy.orm import relationship, Session
11+
12+
from app.config import PSQL_ENVIRONMENT
13+
from app.database.database import Base
14+
from app.dependencies import logger
915

1016

1117
class UserEvent(Base):
@@ -50,11 +56,12 @@ class Event(Base):
5056
end = Column(DateTime, nullable=False)
5157
content = Column(String)
5258
location = Column(String)
59+
color = Column(String, nullable=True)
5360

54-
owner = relationship("User")
5561
owner_id = Column(Integer, ForeignKey("users.id"))
56-
color = Column(String, nullable=True)
62+
category_id = Column(Integer, ForeignKey("categories.id"))
5763

64+
owner = relationship("User")
5865
participants = relationship("UserEvent", back_populates="events")
5966

6067
# PostgreSQL
@@ -64,12 +71,45 @@ class Event(Base):
6471
'events_tsv_idx',
6572
'events_tsv',
6673
postgresql_using='gin'),
67-
)
74+
)
6875

6976
def __repr__(self):
7077
return f'<Event {self.id}>'
7178

7279

80+
class Category(Base):
81+
__tablename__ = "categories"
82+
83+
__table_args__ = (
84+
UniqueConstraint('user_id', 'name', 'color'),
85+
)
86+
id = Column(Integer, primary_key=True, index=True)
87+
name = Column(String, nullable=False)
88+
color = Column(String, nullable=False)
89+
user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
90+
91+
@staticmethod
92+
def create(db_session: Session, name: str, color: str,
93+
user_id: int) -> Category:
94+
try:
95+
category = Category(name=name, color=color, user_id=user_id)
96+
db_session.add(category)
97+
db_session.flush()
98+
db_session.commit()
99+
db_session.refresh(category)
100+
except (SQLAlchemyError, IntegrityError) as e:
101+
logger.error(f"Failed to create category: {e}")
102+
raise e
103+
else:
104+
return category
105+
106+
def to_dict(self) -> Dict[str, Any]:
107+
return {c.name: getattr(self, c.name) for c in self.__table__.columns}
108+
109+
def __repr__(self) -> str:
110+
return f'<Category {self.id} {self.name} {self.color}>'
111+
112+
73113
class PSQLEnvironmentError(Exception):
74114
pass
75115

@@ -87,7 +127,7 @@ class PSQLEnvironmentError(Exception):
87127
Event.__table__,
88128
'after_create',
89129
trigger_snippet.execute_if(dialect='postgresql')
90-
)
130+
)
91131

92132

93133
class Invitation(Base):

app/main.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
from app.dependencies import (logger, MEDIA_PATH, STATIC_PATH, templates)
99
from app.internal.quotes import daily_quotes, load_quotes
1010
from app.routers import (
11-
agenda, dayview, email, event, invitation, profile, search, telegram,
12-
whatsapp
11+
agenda, categories, dayview, email, event, invitation, profile, search,
12+
telegram, whatsapp
1313
)
1414
from app.telegram.bot import telegram_bot
1515

@@ -34,21 +34,27 @@ def create_tables(engine, psql_environment):
3434

3535
app.logger = logger
3636

37-
app.include_router(profile.router)
38-
app.include_router(event.router)
39-
app.include_router(agenda.router)
40-
app.include_router(telegram.router)
41-
app.include_router(dayview.router)
42-
app.include_router(email.router)
43-
app.include_router(invitation.router)
44-
app.include_router(whatsapp.router)
45-
app.include_router(search.router)
37+
routers_to_include = [
38+
agenda.router,
39+
categories.router,
40+
dayview.router,
41+
email.router,
42+
event.router,
43+
invitation.router,
44+
profile.router,
45+
search.router,
46+
telegram.router,
47+
whatsapp.router,
48+
]
49+
50+
for router in routers_to_include:
51+
app.include_router(router)
4652

4753
telegram_bot.set_webhook()
4854

4955

5056
# TODO: I add the quote day to the home page
51-
# until the relavent calendar view will be developed.
57+
# until the relevant calendar view will be developed.
5258
@app.get("/")
5359
@logger.catch()
5460
async def home(request: Request, db: Session = Depends(get_db)):

app/routers/categories.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
from typing import Dict, List, Any
2+
3+
from fastapi import APIRouter, Depends, HTTPException, Request
4+
from pydantic import BaseModel
5+
from sqlalchemy.exc import IntegrityError, SQLAlchemyError
6+
from sqlalchemy.orm import Session
7+
from starlette import status
8+
from starlette.datastructures import ImmutableMultiDict
9+
10+
from app.database.database import get_db
11+
from app.database.models import Category
12+
13+
router = APIRouter(
14+
prefix="/categories",
15+
tags=["categories"],
16+
)
17+
18+
19+
class CategoryModel(BaseModel):
20+
name: str
21+
color: str
22+
user_id: int
23+
24+
25+
# TODO(issue#29): get current user_id from session
26+
@router.get("/")
27+
def get_categories(request: Request,
28+
db_session: Session = Depends(get_db)) -> List[Category]:
29+
if validate_request_params(request.query_params):
30+
return get_user_categories(db_session, **request.query_params)
31+
else:
32+
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST,
33+
detail=f"Request {request.query_params} contains "
34+
f"unallowed params.")
35+
36+
37+
# TODO(issue#29): get current user_id from session
38+
@router.post("/")
39+
async def set_category(category: CategoryModel,
40+
db_sess: Session = Depends(get_db)) -> Dict[str, Any]:
41+
try:
42+
cat = Category.create(db_sess,
43+
name=category.name,
44+
color=category.color,
45+
user_id=category.user_id)
46+
except IntegrityError:
47+
db_sess.rollback()
48+
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST,
49+
detail=f"category is already exists for "
50+
f"user {category.user_id}.")
51+
else:
52+
return {"category": cat.to_dict()}
53+
54+
55+
def validate_request_params(query_params: ImmutableMultiDict) -> bool:
56+
"""
57+
request.query_params contains not more than user_id, name, color
58+
and not less than user_id:
59+
Intersection must contain at least user_id.
60+
Union must not contain fields other than user_id, name, color.
61+
"""
62+
all_fields = set(CategoryModel.schema()["required"])
63+
request_params = set(query_params)
64+
union_set = request_params.union(all_fields)
65+
intersection_set = request_params.intersection(all_fields)
66+
return union_set == all_fields and "user_id" in intersection_set
67+
68+
69+
def get_user_categories(db_session: Session,
70+
user_id: int, **params) -> List[Category]:
71+
"""
72+
Returns user's categories, filtered by params.
73+
"""
74+
try:
75+
categories = db_session.query(Category).filter_by(
76+
user_id=user_id).filter_by(**params).all()
77+
except SQLAlchemyError:
78+
return []
79+
else:
80+
return categories

schema.md

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,25 @@
1111
│ ├── internal
1212
│ ├── __init__.py
1313
│ ├── admin.py
14-
│ ├── routers
15-
│ ├── __init__.py
16-
│ ├── profile.py
14+
│ ├── agenda_events.py
15+
│ ├── email.py
1716
│ ├── media
1817
│ ├── example.png
18+
│ ├── fake_user.png
1919
│ ├── profile.png
20+
│ ├── routers
21+
│ ├── __init__.py
22+
│ ├── agenda.py
23+
│ ├── categories.py
24+
│ ├── email.py
25+
│ ├── event.py
26+
│ ├── profile.py
2027
│ ├── static
28+
│ ├── event
29+
│ ├── eventedit.css
30+
│ ├── eventview.css
31+
│ ├── agenda_style.css
32+
│ ├── popover.js
2133
│ ├── style.css
2234
│ ├── popover.js
2335
│ ├── templates
@@ -29,6 +41,11 @@
2941
├── schema.md
3042
└── tests
3143
├── __init__.py
32-
└── conftest.py
33-
└── test_profile.py
34-
└── test_app.py
44+
├── conftest.py
45+
├── test_agenda_internal.py
46+
├── test_agenda_route.py
47+
├── test_app.py
48+
├── test_categories.py
49+
├── test_email.py
50+
├── test_event.py
51+
└── test_profile.py

tests/category_fixture.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import pytest
2+
from sqlalchemy.orm import Session
3+
4+
from app.database.models import User, Category
5+
6+
7+
@pytest.fixture
8+
def category(session: Session, user: User) -> Category:
9+
category = Category.create(session, name="Guitar Lesson", color="121212",
10+
user_id=user.id)
11+
yield category
12+
session.delete(category)
13+
session.commit()

tests/conftest.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from sqlalchemy import create_engine
55
from sqlalchemy.orm import sessionmaker
66

7-
87
pytest_plugins = [
98
'tests.user_fixture',
109
'tests.event_fixture',
@@ -13,6 +12,7 @@
1312
'tests.client_fixture',
1413
'tests.asyncio_fixture',
1514
'tests.logger_fixture',
15+
'tests.category_fixture',
1616
'smtpdfix',
1717
'tests.quotes_fixture'
1818
]
@@ -49,6 +49,7 @@ def session():
4949
Base.metadata.create_all(bind=test_engine)
5050
session = get_test_db()
5151
yield session
52+
session.rollback()
5253
session.close()
5354
Base.metadata.drop_all(bind=test_engine)
5455

0 commit comments

Comments
 (0)