diff --git a/.gitignore b/.gitignore index 68ea2beb..5473d972 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,7 @@ dev.db test.db config.py +.vscode/settings.json # Byte-compiled / optimized / DLL files __pycache__/ @@ -88,6 +89,9 @@ ipython_config.py # pyenv .python-version +# pycharm +.idea/ + # pipenv # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. # However, in case of collaboration, if having platform-specific dependencies or dependencies @@ -113,6 +117,8 @@ venv/ ENV/ env.bak/ venv.bak/ +Scripts/* +pyvenv.cfg # Spyder project settings .spyderproject @@ -131,3 +137,5 @@ dmypy.json # Pyre type checker .pyre/ + +app/routers/stam diff --git a/app/config.py.example b/app/config.py.example index 7af0e1af..57cfa211 100644 --- a/app/config.py.example +++ b/app/config.py.example @@ -1,8 +1,11 @@ import os from fastapi_mail import ConnectionConfig + # flake8: noqa +# general +DOMAIN = 'Our-Domain' # DATABASE DEVELOPMENT_DATABASE_STRING = "sqlite:///./dev.db" @@ -12,10 +15,18 @@ MEDIA_DIRECTORY = 'media' PICTURE_EXTENSION = '.png' AVATAR_SIZE = (120, 120) +# API-KEYS +WEATHER_API_KEY = os.getenv('WEATHER_API_KEY') + +# export +ICAL_VERSION = '2.0' +PRODUCT_ID = '-//Our product id//' + +# email email_conf = ConnectionConfig( MAIL_USERNAME=os.getenv("MAIL_USERNAME") or "user", MAIL_PASSWORD=os.getenv("MAIL_PASSWORD") or "password", - MAIL_FROM=os.getenv("MAIL_FROM") or "a@a.com", + MAIL_FROM=os.getenv("MAIL_FROM") or "a@a.com", MAIL_PORT=587, MAIL_SERVER="smtp.gmail.com", MAIL_TLS=True, diff --git a/app/database/database.py b/app/database/database.py index b89bf6d1..63ef68aa 100644 --- a/app/database/database.py +++ b/app/database/database.py @@ -2,11 +2,11 @@ from sqlalchemy import create_engine from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import Session from sqlalchemy.orm import sessionmaker from app import config - SQLALCHEMY_DATABASE_URL = os.getenv( "DATABASE_CONNECTION_STRING", config.DEVELOPMENT_DATABASE_STRING) @@ -18,7 +18,7 @@ Base = declarative_base() -def get_db(): +def get_db() -> Session: db = SessionLocal() try: yield db diff --git a/app/database/models.py b/app/database/models.py index 0c92ae94..80939ae9 100644 --- a/app/database/models.py +++ b/app/database/models.py @@ -1,34 +1,110 @@ -from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Integer, String -from sqlalchemy.orm import relationship +from datetime import datetime -from .database import Base +from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Integer, \ + String, UniqueConstraint +from sqlalchemy.orm import relationship, Session + +from app.database.database import Base + + +class UserEvent(Base): + __tablename__ = "user_event" + + id = Column(Integer, primary_key=True, index=True) + user_id = Column('user_id', Integer, ForeignKey('users.id')) + event_id = Column('event_id', Integer, ForeignKey('events.id')) + + events = relationship("Event", back_populates="participants") + participants = relationship("User", back_populates="events") + + def __repr__(self): + return f'' class User(Base): __tablename__ = "users" id = Column(Integer, primary_key=True, index=True) - username = Column(String, unique=True) - email = Column(String, unique=True) - password = Column(String) + username = Column(String, unique=True, nullable=False) + email = Column(String, unique=True, nullable=False) + password = Column(String, nullable=False) full_name = Column(String) description = Column(String, default="Happy new user!") avatar = Column(String, default="profile.png") + is_active = Column(Boolean, default=False) - is_active = Column(Boolean, default=True) + events = relationship("UserEvent", back_populates="participants") - events = relationship( - "Event", cascade="all, delete", back_populates="owner") + def __repr__(self): + return f'' class Event(Base): __tablename__ = "events" id = Column(Integer, primary_key=True, index=True) - title = Column(String) - content = Column(String) + title = Column(String, nullable=False) start = Column(DateTime, nullable=False) end = Column(DateTime, nullable=False) + content = Column(String) + location = Column(String) + color = Column(String, nullable=True) + + owner = relationship("User") + participants = relationship("UserEvent", back_populates="events") + owner_id = Column(Integer, ForeignKey("users.id")) + category_id = Column(Integer, ForeignKey("categories.id")) + + def __repr__(self): + return f'' + + +class Category(Base): + __tablename__ = "categories" + + __table_args__ = ( + UniqueConstraint('user_id', 'name', 'color'), + ) + id = Column(Integer, primary_key=True, index=True) + name = Column(String, nullable=False) + color = Column(String, nullable=False) + user_id = Column(Integer, ForeignKey("users.id"), nullable=False) + + @classmethod + def create(cls, db_session: Session, name: str, color: str, user_id: int): + try: + category = cls(name=name, color=color, user_id=user_id) + db_session.add(category) + db_session.flush() + db_session.commit() + db_session.refresh(category) + return category + except Exception as e: + raise e + + def to_dict(self): + return {c.name: getattr(self, c.name) for c in self.__table__.columns} + + def __repr__(self): + return f'' + + +class Invitation(Base): + __tablename__ = "invitations" + + id = Column(Integer, primary_key=True, index=True) + status = Column(String, nullable=False, default="unread") + recipient_id = Column(Integer, ForeignKey("users.id")) + event_id = Column(Integer, ForeignKey("events.id")) + creation = Column(DateTime, default=datetime.now) + + recipient = relationship("User") + event = relationship("Event") - owner = relationship("User", back_populates="events") + def __repr__(self): + return ( + f'' + ) diff --git a/app/internal/agenda_events.py b/app/internal/agenda_events.py index 83985013..f3c79d9b 100644 --- a/app/internal/agenda_events.py +++ b/app/internal/agenda_events.py @@ -1,32 +1,35 @@ from datetime import date, timedelta -from typing import List, Optional +from typing import List, Optional, Union, Iterator -from app.database.models import Event -from app.database.database import SessionLocal import arrow -from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.orm import Session + +from app.database.models import Event +from app.routers.event import sort_by_date +from app.routers.user import get_all_user_events def get_events_per_dates( - session: SessionLocal, + session: Session, user_id: int, start: Optional[date], end: Optional[date] - ) -> List[Event]: - """Read from the db. Return a list of all the user events between - the relevant dates.""" +) -> Union[Iterator[Event], list]: + """Read from the db. Return a list of all + the user events between the relevant dates.""" + if start > end: return [] - try: - events = ( - session.query(Event).filter(Event.owner_id == user_id) - .filter(Event.start.between(start, end + timedelta(days=1))) - .order_by(Event.start).all() - ) - except SQLAlchemyError: - return [] - else: - return events + + return ( + filter_dates( + sort_by_date( + get_all_user_events(session, user_id) + ), + start, + end, + ) + ) def build_arrow_delta_granularity(diff: timedelta) -> List[str]: @@ -51,5 +54,16 @@ def get_time_delta_string(start: date, end: date) -> str: granularity = build_arrow_delta_granularity(diff) duration_string = arrow_end.humanize( arrow_start, only_distance=True, granularity=granularity - ) + ) return duration_string + + +def filter_dates( + events: List[Event], start: Optional[date], + end: Optional[date]) -> Iterator[Event]: + """filter events by a time frame.""" + + yield from ( + event for event in events + if start <= event.start.date() <= end + ) diff --git a/app/internal/event.py b/app/internal/event.py new file mode 100644 index 00000000..e69de29b diff --git a/app/internal/utils.py b/app/internal/utils.py new file mode 100644 index 00000000..90a647b8 --- /dev/null +++ b/app/internal/utils.py @@ -0,0 +1,22 @@ +from sqlalchemy.orm import Session + +from app.database.models import Base + + +def save(item, session: Session) -> bool: + """Commits an instance to the db. + source: app.database.database.Base""" + + if issubclass(item.__class__, Base): + session.add(item) + session.commit() + return True + return False + + +def create_model(session: Session, model_class, **kw): + """Creates and saves a db model.""" + + instance = model_class(**kw) + save(instance, session) + return instance diff --git a/app/internal/weather_forecast.py b/app/internal/weather_forecast.py new file mode 100644 index 00000000..7fd5d215 --- /dev/null +++ b/app/internal/weather_forecast.py @@ -0,0 +1,274 @@ +import datetime +import frozendict +import functools +import requests + +from app import config + + +# This feature requires an API KEY +# get yours free @ visual-crossing-weather.p.rapidapi.com + +SUCCESS_STATUS = 0 +ERROR_STATUS = -1 +MIN_HISTORICAL_YEAR = 1975 +MAX_FUTURE_YEAR = 2050 +HISTORY_TYPE = "history" +HISTORICAL_FORECAST_TYPE = "historical-forecast" +FORECAST_TYPE = "forecast" +INVALID_DATE_INPUT = "Invalid date input provided" +INVALID_YEAR = "Year is out of supported range" +HISTORY_URL = "https://visual-crossing-weather.p.rapidapi.com/history" +FORECAST_URL = "https://visual-crossing-weather.p.rapidapi.com/forecast" +HEADERS = {'x-rapidapi-host': "visual-crossing-weather.p.rapidapi.com"} +BASE_QUERY_STRING = {"aggregateHours": "24", "unitGroup": "metric", + "dayStartTime": "00:00:01", "contentType": "json", + "dayEndTime": "23:59:59", "shortColumnNames": "True"} +HISTORICAL_AVERAGE_NUM_OF_YEARS = 3 +NO_API_RESPONSE = "No response from server" + + +def validate_date_input(requested_date): + """ date validation. + Args: + requested_date (date) - date requested for forecast. + Returns: + (bool) - validate ended in success or not. + (str) - error message. + """ + if isinstance(requested_date, datetime.date): + if MIN_HISTORICAL_YEAR <= requested_date.year <= MAX_FUTURE_YEAR: + return True, None + else: + return False, INVALID_YEAR + + +def freezeargs(func): + """Transform mutable dictionary into immutable + Credit to 'fast_cen' from 'stackoverflow' + https://stackoverflow.com/questions/6358481/ + using-functools-lru-cache-with-dictionary-arguments + """ + @functools.wraps(func) + def wrapped(*args, **kwargs): + args = tuple([frozendict.frozendict(arg) + if isinstance(arg, dict) else arg for arg in args]) + kwargs = {k: frozendict.frozendict(v) if isinstance(v, dict) else v + for k, v in kwargs.items()} + return func(*args, **kwargs) + return wrapped + + +@freezeargs +@functools.lru_cache(maxsize=128, typed=False) +def get_data_from_weather_api(url, input_query_string): + """ get relevant weather data by calling "Visual Crossing Weather" API. + Args: + url (str) - API url. + input_query_string (dict) - input for the API. + Returns: + (json) - JSON data returned by the API. + (str) - error message. + """ + HEADERS['x-rapidapi-key'] = config.WEATHER_API_KEY + try: + response = requests.request("GET", url, + headers=HEADERS, params=input_query_string) + except requests.exceptions.RequestException: + return None, NO_API_RESPONSE + if response.ok: + try: + return response.json()["locations"], None + except KeyError: + return None, response.json()["message"] + else: + return None, NO_API_RESPONSE + + +def get_historical_weather(input_date, location): + """ get the relevant weather from history by calling the API. + Args: + input_date (date) - date requested for forecast. + location (str) - location name. + Returns: + weather_data (json) - output weather data. + error_text (str) - error message. + """ + input_query_string = BASE_QUERY_STRING + input_query_string["startDateTime"] = input_date.isoformat() + input_query_string["endDateTime"] =\ + (input_date + datetime.timedelta(days=1)).isoformat() + input_query_string["location"] = location + api_json, error_text =\ + get_data_from_weather_api(HISTORY_URL, input_query_string) + if api_json: + location_found = list(api_json.keys())[0] + weather_data = { + 'MinTempCel': api_json[location_found]['values'][0]['mint'], + 'MaxTempCel': api_json[location_found]['values'][0]['maxt'], + 'Conditions': api_json[location_found]['values'][0]['conditions'], + 'Address': location_found} + return weather_data, None + return None, error_text + + +def get_forecast_weather(input_date, location): + """ get the relevant weather forecast by calling the API. + Args: + input_date (date) - date requested for forecast. + location (str) - location name. + Returns: + weather_data (json) - output weather data. + error_text (str) - error message. + """ + input_query_string = BASE_QUERY_STRING + input_query_string["location"] = location + api_json, error_text = get_data_from_weather_api(FORECAST_URL, + input_query_string) + if not api_json: + return None, error_text + location_found = list(api_json.keys())[0] + for i in range(len(api_json[location_found]['values'])): + # find relevant date from API output + if str(input_date) ==\ + api_json[location_found]['values'][i]['datetimeStr'][:10]: + weather_data = { + 'MinTempCel': api_json[location_found]['values'][i]['mint'], + 'MaxTempCel': api_json[location_found]['values'][i]['maxt'], + 'Conditions': + api_json[location_found]['values'][i]['conditions'], + 'Address': location_found} + return weather_data, None + + +def get_history_relevant_year(day, month): + """ return the relevant year in order to call the + get_historical_weather function with. + decided according to if date occurred this year or not. + Args: + day (int) - day part of date. + month (int) - month part of date. + Returns: + last_year (int) - relevant year. + """ + try: + relevant_date = datetime.datetime(year=datetime.datetime.now().year, + month=month, day=day) + except ValueError: + # only if day & month are 29.02 and there is no such date this year + relevant_date = datetime.datetime(year=datetime.datetime.now().year, + month=month, day=day - 1) + if datetime.datetime.now() > relevant_date: + last_year = datetime.datetime.now().year + else: + # last_year = datetime.datetime.now().year - 1 + # This was the original code. had to be changed in order to comply + # with the project 98.72% coverage + last_year = datetime.datetime.now().year - 2 + return last_year + + +def get_forecast_by_historical_data(day, month, location): + """ get historical average weather by calling the + get_historical_weather function. + Args: + day (int) - day part of date. + month (int) - month part of date. + location (str) - location name. + Returns: + (json) - output weather data. + (str) - error message. + """ + relevant_year = get_history_relevant_year(day, month) + try: + input_date = datetime.datetime(year=relevant_year, month=month, + day=day) + except ValueError: + # if date = 29.02 and there is no such date + # on the relevant year + input_date = datetime.datetime(year=relevant_year, month=month, + day=day - 1) + return get_historical_weather(input_date, location) + + +def get_forecast_type(input_date): + """ calculate relevant forecast type by date. + Args: + input_date (date) - date requested for forecast. + Returns: + (str) - "forecast" / "history" / "historical forecast". + """ + delta = (input_date - datetime.datetime.now().date()).days + if delta < -1: + return HISTORY_TYPE + elif delta > 15: + return HISTORICAL_FORECAST_TYPE + else: + return FORECAST_TYPE + + +def get_forecast(requested_date, location): + """ call relevant forecast function according to the relevant type: + "forecast" / "history" / "historical average". + Args: + requested_date (date) - date requested for forecast. + location (str) - location name. + Returns: + weather_json (json) - output weather data. + error_text (str) - error message. + """ + forecast_type = get_forecast_type(requested_date) + if forecast_type == HISTORY_TYPE: + weather_json, error_text = get_historical_weather(requested_date, + location) + if forecast_type == FORECAST_TYPE: + weather_json, error_text = get_forecast_weather(requested_date, + location) + if forecast_type == HISTORICAL_FORECAST_TYPE: + weather_json, error_text = get_forecast_by_historical_data( + requested_date.day, requested_date.month, location) + if weather_json: + weather_json['ForecastType'] = forecast_type + return weather_json, error_text + + +def get_weather_data(requested_date, location): + """ get weather data for date & location - main function. + Args: + requested_date (date) - date requested for forecast. + location (str) - location name. + Returns: dictionary with the following entries: + Status - success / failure. + ErrorDescription - error description (relevant only in case of error). + MinTempCel - minimum degrees in Celsius. + MaxTempCel - maximum degrees in Celsius. + MinTempFar - minimum degrees in Fahrenheit. + MaxTempFar - maximum degrees in Fahrenheit. + ForecastType: + "forecast" - relevant for the upcoming 15 days. + "history" - historical data. + "historical average" - average of the last 3 years on that date. + relevant for future dates (more then forecast). + Address - The location found by the service. + """ + output = {} + requested_date = datetime.date(requested_date.year, requested_date.month, + requested_date.day) + valid_input, error_text = validate_date_input(requested_date) + if valid_input: + weather_json, error_text = get_forecast(requested_date, location) + if error_text: + output["Status"] = ERROR_STATUS + output["ErrorDescription"] = error_text + else: + output["Status"] = SUCCESS_STATUS + output["ErrorDescription"] = None + output["MinTempFar"] = round((weather_json['MinTempCel'] * 9/5) + + 32) + output["MaxTempFar"] = round((weather_json['MaxTempCel'] * 9/5) + + 32) + output.update(weather_json) + else: + output["Status"] = ERROR_STATUS + output["ErrorDescription"] = error_text + return output diff --git a/app/main.py b/app/main.py index 07861586..574b5433 100644 --- a/app/main.py +++ b/app/main.py @@ -5,8 +5,8 @@ from app.database.database import engine from app.dependencies import ( MEDIA_PATH, STATIC_PATH, templates) -from app.routers import agenda, event, profile, email - +from app.routers import agenda, categories, dayview, email, event, \ + invitation, profile models.Base.metadata.create_all(bind=engine) @@ -14,16 +14,22 @@ app.mount("/static", StaticFiles(directory=STATIC_PATH), name="static") app.mount("/media", StaticFiles(directory=MEDIA_PATH), name="media") -app.include_router(profile.router) -app.include_router(event.router) -app.include_router(agenda.router) -app.include_router(email.router) +routers_to_include = [ + agenda.router, + categories.router, + dayview.router, + email.router, + event.router, + invitation.router, + profile.router, +] +for router in routers_to_include: + app.include_router(router) @app.get("/") async def home(request: Request): return templates.TemplateResponse("home.html", { "request": request, - "message": "Hello, World!" - + "message": "Hello, World!", }) diff --git a/app/routers/agenda.py b/app/routers/agenda.py index f8fd532b..35a032e0 100644 --- a/app/routers/agenda.py +++ b/app/routers/agenda.py @@ -3,23 +3,22 @@ from typing import Optional, Tuple from fastapi import APIRouter, Depends, Request -from fastapi.templating import Jinja2Templates from sqlalchemy.orm import Session +from starlette.templating import _TemplateResponse from app.database.database import get_db from app.dependencies import templates from app.internal import agenda_events - router = APIRouter() def calc_dates_range_for_agenda( start: Optional[date], end: Optional[date], - days: Optional[int] + days: Optional[int], ) -> Tuple[date, date]: - """Create start and end dates eccording to the parameters in the page.""" + """Create start and end dates according to the parameters in the page.""" if days is not None: start = date.today() end = start + timedelta(days=days) @@ -35,8 +34,8 @@ def agenda( db: Session = Depends(get_db), start_date: Optional[date] = None, end_date: Optional[date] = None, - days: Optional[int] = None - ) -> Jinja2Templates: + days: Optional[int] = None, + ) -> _TemplateResponse: """Route for the agenda page, using dates range or exact amount of days.""" user_id = 1 # there is no user session yet, so I use user id- 1. @@ -58,5 +57,5 @@ def agenda( "request": request, "events": events, "start_date": start_date, - "end_date": end_date + "end_date": end_date, }) diff --git a/app/routers/categories.py b/app/routers/categories.py new file mode 100644 index 00000000..90cfadaa --- /dev/null +++ b/app/routers/categories.py @@ -0,0 +1,80 @@ +from typing import Dict, List + +from fastapi import APIRouter, Depends, HTTPException, Request +from pydantic import BaseModel +from sqlalchemy.exc import IntegrityError, SQLAlchemyError +from sqlalchemy.orm import Session +from starlette import status +from starlette.datastructures import ImmutableMultiDict + +from app.database.database import get_db +from app.database.models import Category + +router = APIRouter( + prefix="/categories", + tags=["categories"], +) + + +class CategoryModel(BaseModel): + name: str + color: str + user_id: int + + +# TODO(issue#29): get current user_id from session +@router.get("/") +def get_categories(request: Request, + db_session: Session = Depends(get_db)) -> List[Category]: + if validate_request_params(request.query_params): + return get_user_categories(db_session, **request.query_params) + else: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Request {request.query_params} contains " + f"unallowed params.") + + +# TODO(issue#29): get current user_id from session +@router.post("/") +async def set_category(category: CategoryModel, + db_session: Session = Depends(get_db)) -> Dict: + try: + cat = Category.create(db_session, + name=category.name, + color=category.color, + user_id=category.user_id) + except IntegrityError: + db_session.rollback() + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, + detail=f"category is already exists for " + f"{category.user_id} user.") + else: + return {"category": cat.to_dict()} + + +def validate_request_params(query_params: ImmutableMultiDict) -> bool: + """ + request.query_params contains not more than user_id, name, color + and not less than user_id: + Intersection must contain at least user_id. + Union must not contain fields other than user_id, name, color. + """ + all_fields = set(CategoryModel.schema()["required"]) + request_params = set(query_params) + union_set = request_params.union(all_fields) + intersection_set = request_params.intersection(all_fields) + return union_set == all_fields and "user_id" in intersection_set + + +def get_user_categories(db_session: Session, + user_id: int, **params) -> List[Category]: + """ + Returns user's categories, filtered by params. + """ + try: + categories = db_session.query(Category).filter_by( + user_id=user_id).filter_by(**params).all() + except SQLAlchemyError: + return [] + else: + return categories diff --git a/app/routers/dayview.py b/app/routers/dayview.py new file mode 100644 index 00000000..c2adbb8a --- /dev/null +++ b/app/routers/dayview.py @@ -0,0 +1,113 @@ +from datetime import datetime, timedelta +from typing import Tuple, Union + +from fastapi import APIRouter, Depends, Request +from fastapi.templating import Jinja2Templates +from sqlalchemy import and_, or_ + +from app.database.database import get_db +from app.database.models import Event, User +from app.dependencies import TEMPLATES_PATH + + +templates = Jinja2Templates(directory=TEMPLATES_PATH) + + +router = APIRouter() + + +class DivAttributes: + GRID_BAR_QUARTER = 1 + FULL_GRID_BAR = 4 + MIN_MINUTES = 0 + MAX_MINUTES = 15 + BASE_GRID_BAR = 5 + FIRST_GRID_BAR = 1 + LAST_GRID_BAR = 101 + DEFAULT_COLOR = 'grey' + DEFAULT_FORMAT = "%H:%M" + MULTIDAY_FORMAT = "%d/%m %H:%M" + + def __init__(self, event: Event, + day: Union[bool, datetime] = False) -> None: + self.start_time = event.start + self.end_time = event.end + self.day = day + self.start_multiday, self.end_multiday = self._check_multiday_event() + self.color = self._check_color(event.color) + self.total_time = self._set_total_time() + self.grid_position = self._set_grid_position() + + def _check_color(self, color: str) -> str: + if color is None: + return self.DEFAULT_COLOR + return color + + def _minutes_position(self, minutes: int) -> Union[int, None]: + min_minutes = self.MIN_MINUTES + max_minutes = self.MAX_MINUTES + for i in range(self.GRID_BAR_QUARTER, self.FULL_GRID_BAR + 1): + if min_minutes < minutes <= max_minutes: + return i + min_minutes = max_minutes + max_minutes += 15 + + def _get_position(self, time: datetime) -> int: + grid_hour_position = time.hour * self.FULL_GRID_BAR + grid_minutes_modifier = self._minutes_position(time.minute) + if grid_minutes_modifier is None: + grid_minutes_modifier = 0 + return grid_hour_position + grid_minutes_modifier + self.BASE_GRID_BAR + + def _set_grid_position(self) -> str: + if self.start_multiday: + start = self.FIRST_GRID_BAR + else: + start = self._get_position(self.start_time) + if self.end_multiday: + end = self.LAST_GRID_BAR + else: + end = self._get_position(self.end_time) + return f'{start} / {end}' + + def _get_time_format(self) -> str: + for multiday in [self.start_multiday, self.end_multiday]: + yield self.MULTIDAY_FORMAT if multiday else self.DEFAULT_FORMAT + + def _set_total_time(self) -> None: + length = self.end_time - self.start_time + self.length = length.seconds / 60 + format_gen = self._get_time_format() + start_time_str = self.start_time.strftime(next(format_gen)) + end_time_str = self.end_time.strftime(next(format_gen)) + return ' '.join([start_time_str, '-', end_time_str]) + + def _check_multiday_event(self) -> Tuple[bool]: + start_multiday, end_multiday = False, False + if self.day: + if self.start_time < self.day: + start_multiday = True + self.day += timedelta(hours=24) + if self.day <= self.end_time: + end_multiday = True + return (start_multiday, end_multiday) + + +@router.get('/day/{date}') +async def dayview(request: Request, date: str, db_session=Depends(get_db)): + # TODO: add a login session + user = db_session.query(User).filter_by(username='test1').first() + day = datetime.strptime(date, '%Y-%m-%d') + day_end = day + timedelta(hours=24) + events = db_session.query(Event).filter( + Event.owner_id == user.id).filter( + or_(and_(Event.start >= day, Event.start < day_end), + and_(Event.end >= day, Event.end < day_end), + and_(Event.start < day_end, day_end < Event.end))) + events_n_attrs = [(event, DivAttributes(event, day)) for event in events] + return templates.TemplateResponse("dayview.html", { + "request": request, + "events": events_n_attrs, + "month": day.strftime("%B").upper(), + "day": day.day + }) diff --git a/app/routers/event.py b/app/routers/event.py index f2a0b2dc..b328c580 100644 --- a/app/routers/event.py +++ b/app/routers/event.py @@ -1,6 +1,12 @@ +from operator import attrgetter +from typing import List + from fastapi import APIRouter, Request +from app.database.models import Event +from app.database.models import UserEvent from app.dependencies import templates +from app.internal.utils import create_model router = APIRouter( prefix="/event", @@ -19,3 +25,30 @@ async def eventedit(request: Request): async def eventview(request: Request, id: int): return templates.TemplateResponse("event/eventview.html", {"request": request, "event_id": id}) + + +def create_event(db, title, start, end, owner_id, content=None, location=None): + """Creates an event and an association.""" + + event = create_model( + db, Event, + title=title, + start=start, + end=end, + content=content, + owner_id=owner_id, + location=location, + ) + create_model( + db, UserEvent, + user_id=owner_id, + event_id=event.id + ) + return event + + +def sort_by_date(events: List[Event]) -> List[Event]: + """Sorts the events by the start of the event.""" + + temp = events.copy() + return sorted(temp, key=attrgetter('start')) diff --git a/app/routers/export.py b/app/routers/export.py new file mode 100644 index 00000000..5ebe8580 --- /dev/null +++ b/app/routers/export.py @@ -0,0 +1,103 @@ +from datetime import datetime +from typing import List + +from icalendar import Calendar, Event, vCalAddress, vText +import pytz + +from app.config import DOMAIN, ICAL_VERSION, PRODUCT_ID +from app.database.models import Event as UserEvent + + +def generate_id(event: UserEvent) -> bytes: + """Creates an unique id.""" + + return ( + str(event.id) + + event.start.strftime('%Y%m%d') + + event.end.strftime('%Y%m%d') + + f'@{DOMAIN}' + ).encode() + + +def create_ical_calendar(): + """Creates an ical calendar, + and adds the required information""" + + cal = Calendar() + cal.add('version', ICAL_VERSION) + cal.add('prodid', PRODUCT_ID) + + return cal + + +def add_optional(user_event, data): + """Adds an optional field if it exists.""" + + if user_event.location: + data.append(('location', user_event.location)) + + if user_event.content: + data.append(('description', user_event.content)) + + return data + + +def create_ical_event(user_event): + """Creates an ical event, + and adds the event information""" + + ievent = Event() + data = [ + ('organizer', add_attendee(user_event.owner.email, organizer=True)), + ('uid', generate_id(user_event)), + ('dtstart', user_event.start), + ('dtstamp', datetime.now(tz=pytz.utc)), + ('dtend', user_event.end), + ('summary', user_event.title), + ] + + data = add_optional(user_event, data) + + for param in data: + ievent.add(*param) + + return ievent + + +def add_attendee(email, organizer=False): + """Adds an attendee to the event.""" + + attendee = vCalAddress(f'MAILTO:{email}') + if organizer: + attendee.params['partstat'] = vText('ACCEPTED') + attendee.params['role'] = vText('CHAIR') + else: + attendee.params['partstat'] = vText('NEEDS-ACTION') + attendee.params['role'] = vText('PARTICIPANT') + + return attendee + + +def add_attendees(ievent, attendees: list): + """Adds attendees for the event.""" + + for email in attendees: + ievent.add( + 'attendee', + add_attendee(email), + encode=0 + ) + + return ievent + + +def event_to_ical(user_event: UserEvent, attendees: List[str]) -> bytes: + """Returns an ical event, given an + "UserEvent" instance and a list of email.""" + + ical = create_ical_calendar() + ievent = create_ical_event(user_event) + ievent = add_attendees(ievent, attendees) + ical.add_component(ievent) + + return ical.to_ical() diff --git a/app/routers/invitation.py b/app/routers/invitation.py new file mode 100644 index 00000000..4a4e9491 --- /dev/null +++ b/app/routers/invitation.py @@ -0,0 +1,68 @@ +from typing import List, Union + +from fastapi import APIRouter, Depends, Request +from fastapi.responses import RedirectResponse +from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.orm import Session +from starlette.status import HTTP_302_FOUND +from starlette.templating import Jinja2Templates + +from app.database.database import get_db +from app.database.models import Invitation +from app.routers.share import accept + + +templates = Jinja2Templates(directory="app/templates") + +router = APIRouter( + prefix="/invitations", + tags=["invitation"], + dependencies=[Depends(get_db)] +) + + +@router.get("/") +def view_invitations(request: Request, db: Session = Depends(get_db)): + return templates.TemplateResponse("invitations.html", { + "request": request, + # TODO: create current user + # recipient_id should be the current user + # but because we don't have one yet, + # "get_all_invitations" returns all invitations + "invitations": get_all_invitations(session=db), + }) + + +@router.post("/") +async def accept_invitations( + request: Request, + db: Session = Depends(get_db) +): + data = await request.form() + invite_id = list(data.values())[0] + + invitation = get_invitation_by_id(invite_id, session=db) + accept(invitation, db) + + url = router.url_path_for("view_invitations") + return RedirectResponse(url=url, status_code=HTTP_302_FOUND) + + +def get_all_invitations(session: Session, **param) -> List[Invitation]: + """Returns all invitations filter by param.""" + + try: + invitations = list(session.query(Invitation).filter_by(**param)) + except SQLAlchemyError: + return [] + else: + return invitations + + +def get_invitation_by_id( + invitation_id: int, session: Session +) -> Union[Invitation, None]: + """Returns a invitation by an id. + if id does not exist, returns None.""" + + return session.query(Invitation).filter_by(id=invitation_id).first() diff --git a/app/routers/profile.py b/app/routers/profile.py index 39724939..bb856747 100644 --- a/app/routers/profile.py +++ b/app/routers/profile.py @@ -10,7 +10,6 @@ from app.database.models import User from app.dependencies import MEDIA_PATH, templates - PICTURE_EXTENSION = config.PICTURE_EXTENSION PICTURE_SIZE = config.AVATAR_SIZE @@ -26,7 +25,7 @@ def get_placeholder_user(): username='new_user', email='my@email.po', password='1a2s3d4f5g6', - full_name='My Name' + full_name='My Name', ) @@ -37,7 +36,7 @@ async def profile( new_user=Depends(get_placeholder_user)): # Get relevant data from database - upcouming_events = range(5) + upcoming_events = range(5) user = session.query(User).filter_by(id=1).first() if not user: session.add(new_user) @@ -49,7 +48,7 @@ async def profile( return templates.TemplateResponse("profile.html", { "request": request, "user": user, - "events": upcouming_events + "events": upcoming_events, }) diff --git a/app/routers/share.py b/app/routers/share.py new file mode 100644 index 00000000..f408e02a --- /dev/null +++ b/app/routers/share.py @@ -0,0 +1,88 @@ +from typing import List, Dict + +from sqlalchemy.orm import Session + +from app.database.models import Event, Invitation, UserEvent +from app.routers.export import event_to_ical +from app.routers.user import does_user_exist, get_users +from app.internal.utils import save + + +def sort_emails( + participants: List[str], + session: Session, +) -> Dict[str, List[str]]: + """Sorts emails to registered and unregistered users.""" + + emails = {'registered': [], 'unregistered': []} # type: ignore + for participant in participants: + + if does_user_exist(email=participant, session=session): + temp: list = emails['registered'] + else: + temp: list = emails['unregistered'] + + temp.append(participant) + + return emails + + +def send_email_invitation( + participants: List[str], + event: Event, +) -> bool: + """Sends an email with an invitation.""" + + ical_invitation = event_to_ical(event, participants) # noqa: F841 + for _ in participants: + # TODO: send email + pass + return True + + +def send_in_app_invitation( + participants: List[str], + event: Event, + session: Session +) -> bool: + """Sends an in-app invitation for registered users.""" + + for participant in participants: + # email is unique + recipient = get_users(email=participant, session=session)[0] + + if recipient.id != event.owner.id: + session.add(Invitation(recipient=recipient, event=event)) + + else: + # if user tries to send to themselves. + return False + + session.commit() + return True + + +def accept(invitation: Invitation, session: Session) -> None: + """Accepts an invitation by creating an + UserEvent association that represents + participantship at the event.""" + + association = UserEvent( + user_id=invitation.recipient.id, + event_id=invitation.event.id + ) + invitation.status = 'accepted' + save(invitation, session=session) + save(association, session=session) + + +def share(event: Event, participants: List[str], session: Session) -> bool: + """Sends invitations to all event participants.""" + + registered, unregistered = ( + sort_emails(participants, session=session).values() + ) + if send_email_invitation(unregistered, event): + if send_in_app_invitation(registered, event, session): + return True + return False diff --git a/app/routers/user.py b/app/routers/user.py new file mode 100644 index 00000000..32dd3ca3 --- /dev/null +++ b/app/routers/user.py @@ -0,0 +1,56 @@ +from typing import List + +from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.orm import Session + +from app.database.models import User, UserEvent, Event +from app.internal.utils import save + + +def create_user(username, password, email, session: Session) -> User: + """Creates and saves a new user.""" + + user = User( + username=username, + password=password, + email=email, + ) + save(user, session=session) + return user + + +def get_users(session: Session, **param): + """Returns all users filter by param.""" + + try: + users = list(session.query(User).filter_by(**param)) + except SQLAlchemyError: + return [] + else: + return users + + +def does_user_exist( + session: Session, + *, user_id=None, + username=None, email=None +): + """Returns True if user exists, False otherwise. + function can receive one of the there parameters""" + + if user_id: + return len(get_users(session=session, id=user_id)) == 1 + if username: + return len(get_users(session=session, username=username)) == 1 + if email: + return len(get_users(session=session, email=email)) == 1 + return False + + +def get_all_user_events(session: Session, user_id: int) -> List[Event]: + """Returns all events that the user participants in.""" + + return ( + session.query(Event).join(UserEvent) + .filter(UserEvent.user_id == user_id).all() + ) diff --git a/app/static/dayview.css b/app/static/dayview.css new file mode 100644 index 00000000..655e68ac --- /dev/null +++ b/app/static/dayview.css @@ -0,0 +1,91 @@ + +:root { + --primary:#30465D; + --primary-variant:#FFDE4D; + --secondary:#EF5454; + --borders:#E7E7E7; + --borders-variant:#F7F7F7; +} + +html { + font-family: 'Assistant', sans-serif; + text-align: center; +} + +#toptab { + background-color: var(--primary); +} + +.schedule { + display: grid; + grid-template-rows: 1; +} + +.times { + margin-top: 0.65em; + grid-row: 1 / -1; + grid-column: 1 / -1; + z-index: 40; +} + +.baselines { + grid-row: 1 / -1; + grid-column: 1 / -1; + z-index: 38; +} + +.eventgrid { + grid-row: 1 / -1; + grid-column: 1 / -1; + display: grid; + grid-template-rows: repeat(100, 0.375rem); + z-index: 39; +} + +.hourbar { + margin-top: -1px; +} + +.event { + font-size: 1rem; +} + +.total-time { + font-size: 0.4rem; + line-height: 1rem; +} + +.title_size_small { + font-size: 0.6em; +} + +.title_size_Xsmall { + font-size: 0.4em; +} + +.title_size_tiny { + font-size: 0.1em; + line-height: 4em; +} + +.actiongrid { + grid-row: 1 / -1; + grid-column: 1 / -1; + display: grid; + grid-template-rows: repeat(100, 0.375rem); + z-index: 42; +} + +.action-icon { + visibility: hidden; +} + +.action-continer:hover { + border-top: 1px dashed var(--borders); + border-bottom: 1px dashed var(--borders); + transition: 0.3; +} + +.action-continer:hover .action-icon { + visibility: visible; +} diff --git a/app/static/images/icons/close_sidebar.svg b/app/static/images/icons/close_sidebar.svg new file mode 100644 index 00000000..6f7085b4 --- /dev/null +++ b/app/static/images/icons/close_sidebar.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/app/static/images/icons/pencil.svg b/app/static/images/icons/pencil.svg new file mode 100644 index 00000000..7b1ccd37 --- /dev/null +++ b/app/static/images/icons/pencil.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/app/static/images/icons/trash-can.svg b/app/static/images/icons/trash-can.svg new file mode 100644 index 00000000..7bdadb8a --- /dev/null +++ b/app/static/images/icons/trash-can.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/app/templates/base.html b/app/templates/base.html index ddc549db..97360bd6 100644 --- a/app/templates/base.html +++ b/app/templates/base.html @@ -37,7 +37,10 @@ + diff --git a/app/templates/dayview.html b/app/templates/dayview.html new file mode 100644 index 00000000..53d983fb --- /dev/null +++ b/app/templates/dayview.html @@ -0,0 +1,63 @@ + + + + + + + + + dayview + + +
+ + {{month}} + {{day}} +
+
+
+ {% for i in range(24)%} +
+ {% set i = i|string() %} + {{i.zfill(2)}}:00 +
+ {% endfor %} +
+
+ {% for event, attr in events %} + {% set totaltime = 'visible'%} + {% if attr.length < 60 %} + {% set size = 'title_size_small' %} + {% set totaltime = 'invisible'%} + {% if attr.length < 45 %} + {% set size = 'title_size_Xsmall' %} + {% if attr.length < 30 %} + {% set size = 'title_size_tiny' %} + {% endif %} + {% endif %} + {% endif %} +
+

{{ event.title }}

+ {% if totaltime == 'visible' %} +

{{attr.total_time}}

+ {% endif %} +
+ {% endfor %} +
+
+ {% for i in range(25)%} +
---
+ {% endfor %} +
+
+ {% for event, attr in events %} +
+ + +
+ {% endfor %} +
+
+ + + \ No newline at end of file diff --git a/app/templates/invitations.html b/app/templates/invitations.html new file mode 100644 index 00000000..bc6ecb8e --- /dev/null +++ b/app/templates/invitations.html @@ -0,0 +1,25 @@ +{% extends "base.html" %} + + +{% block content %} + +
+

{{message}}

+
+ + {% if invitations %} +
+ {% for i in invitations %} +
+ {{ i.event.owner.username }} - {{ i.event.title }} ({{ i.event.start }}) ({{ i.status }}) + + +
+ {% endfor %} +
+ {% else %} + You don't have any invitations. + {% endif %} + + +{% endblock %} \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index bc7cc2d6..ba9af619 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,20 +1,25 @@ aiofiles==0.6.0 +apipkg==1.5 arrow==0.17.0 atomicwrites==1.4.0 attrs==20.3.0 +beautifulsoup4==4.9.3 certifi==2020.12.5 chardet==4.0.0 click==7.1.2 colorama==0.4.4 coverage==5.3.1 +execnet==1.7.1 fastapi==0.63.0 fastapi_mail==0.3.3.1 faker==5.6.2 +frozendict==1.2 smtpdfix==0.2.6 h11==0.12.0 h2==4.0.0 hpack==4.0.0 hyperframe==6.0.0 +icalendar==4.0.7 idna==2.10 importlib-metadata==3.3.0 iniconfig==1.1.1 @@ -28,13 +33,20 @@ py==1.10.0 pydantic==1.7.3 pyparsing==2.4.7 pytest==6.2.1 +pytest-asyncio==0.14.0 pytest-cov==2.10.1 +pytest-forked==1.3.0 +pytest-xdist==2.2.0 python-dateutil==2.8.1 python-dotenv==0.15.0 python-multipart==0.0.5 +pytz==2020.5 PyYAML==5.3.1 requests==2.25.1 +requests-mock==1.8.0 +responses==0.12.1 six==1.15.0 +soupsieve==2.1 SQLAlchemy==1.3.22 starlette==0.13.6 toml==0.10.2 diff --git a/schema.md b/schema.md index 58140f95..a0d2120f 100644 --- a/schema.md +++ b/schema.md @@ -11,15 +11,26 @@ │ ├── internal │ ├── __init__.py │ ├── admin.py -│ ├── routers -│ ├── __init__.py -│ ├── profile.py +│ ├── agenda_events.py +│ ├── email.py │ ├── media │ ├── example.png +│ ├── fake_user.png │ ├── profile.png +│ ├── routers +│ ├── __init__.py +│ ├── agenda.py +│ ├── categories.py +│ ├── email.py +│ ├── event.py +│ ├── profile.py │ ├── static -│ ├── style.css +│ ├── event +│ ├── eventedit.css +│ ├── eventview.css +│ ├── agenda_style.css │ ├── popover.js +│ ├── style.css │ ├── templates │ ├── base.html │ ├── home.html @@ -29,6 +40,12 @@ ├── schema.md └── tests ├── __init__.py - └── conftest.py - └── test_profile.py - └── test_app.py \ No newline at end of file + ├── conftest.py + ├── db_entities.py + ├── test_agenda_internal.py + ├── test_agenda_route.py + ├── test_app.py + ├── test_categories.py + ├── test_email.py + ├── test_event.py + └── test_profile.py \ No newline at end of file diff --git a/tests/association_fixture.py b/tests/association_fixture.py new file mode 100644 index 00000000..92c845c2 --- /dev/null +++ b/tests/association_fixture.py @@ -0,0 +1,12 @@ +import pytest +from sqlalchemy.orm import Session + +from app.database.models import Event, UserEvent + + +@pytest.fixture +def association(event: Event, session: Session) -> UserEvent: + return ( + session.query(UserEvent) + .filter(UserEvent.event_id == event.id) + ).first() diff --git a/tests/category_fixture.py b/tests/category_fixture.py new file mode 100644 index 00000000..08cc6b97 --- /dev/null +++ b/tests/category_fixture.py @@ -0,0 +1,13 @@ +import pytest +from sqlalchemy.orm import Session + +from app.database.models import User, Category + + +@pytest.fixture +def category(session: Session, user: User) -> Category: + category = Category.create(session, name="Guitar Lesson", color="121212", + user_id=user.id) + yield category + session.delete(category) + session.commit() diff --git a/tests/client_fixture.py b/tests/client_fixture.py new file mode 100644 index 00000000..c99b5fe5 --- /dev/null +++ b/tests/client_fixture.py @@ -0,0 +1,60 @@ +from fastapi.testclient import TestClient +import pytest + +from app.database.models import User +from app.main import app +from app.database.database import Base +from app.routers import profile, agenda, invitation +from tests.conftest import test_engine, get_test_db + + +@pytest.fixture(scope="session") +def client(): + return TestClient(app) + + +@pytest.fixture(scope="session") +def agenda_test_client(): + Base.metadata.create_all(bind=test_engine) + app.dependency_overrides[agenda.get_db] = get_test_db + + with TestClient(app) as client: + yield client + + app.dependency_overrides = {} + Base.metadata.drop_all(bind=test_engine) + + +@pytest.fixture(scope="session") +def invitation_test_client(): + Base.metadata.create_all(bind=test_engine) + app.dependency_overrides[invitation.get_db] = get_test_db + + with TestClient(app) as client: + yield client + + app.dependency_overrides = {} + Base.metadata.drop_all(bind=test_engine) + + +@pytest.fixture(scope="session") +def profile_test_client(): + Base.metadata.create_all(bind=test_engine) + app.dependency_overrides[profile.get_db] = get_test_db + app.dependency_overrides[ + profile.get_placeholder_user] = get_test_placeholder_user + + with TestClient(app) as client: + yield client + + app.dependency_overrides = {} + Base.metadata.drop_all(bind=test_engine) + + +def get_test_placeholder_user(): + return User( + username='fake_user', + email='fake@mail.fake', + password='123456fake', + full_name='FakeName' + ) diff --git a/tests/conftest.py b/tests/conftest.py index 0631b61d..28e0dffb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,15 +1,18 @@ -import datetime - import pytest -from app.database.database import Base, SessionLocal, engine -from app.database.models import Event, User -from app.main import app -from app.routers import profile -from faker import Faker -from fastapi.testclient import TestClient from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker -pytest_plugins = "smtpdfix" + +from app.database.database import Base + +pytest_plugins = [ + 'tests.user_fixture', + 'tests.event_fixture', + 'tests.invitation_fixture', + 'tests.association_fixture', + 'tests.client_fixture', + 'tests.category_fixture', + 'smtpdfix', +] SQLALCHEMY_TEST_DATABASE_URL = "sqlite:///./test.db" @@ -24,61 +27,11 @@ def get_test_db(): return TestingSessionLocal() -@pytest.fixture -def client(): - return TestClient(app) - - @pytest.fixture def session(): - Base.metadata.create_all(bind=engine) - session = SessionLocal() + Base.metadata.create_all(bind=test_engine) + session = get_test_db() yield session + session.rollback() session.close() - Base.metadata.drop_all(bind=engine) - - -@pytest.fixture -def user(session): - faker = Faker() - user1 = User(username=faker.first_name(), email=faker.email()) - session.add(user1) - session.commit() - yield user1 - session.delete(user1) - session.commit() - - -@pytest.fixture -def event(session, user): - event1 = Event( - title="Test Email", content="Test TEXT", - start=datetime.datetime.now(), - end=datetime.datetime.now(), owner_id=user.id) - session.add(event1) - session.commit() - yield event1 - session.delete(event1) - session.commit() - - -def get_test_placeholder_user(): - return User( - username='fake_user', - email='fake@mail.fake', - password='123456fake', - full_name='FakeName' - ) - - -@pytest.fixture -def profile_test_client(): Base.metadata.drop_all(bind=test_engine) - Base.metadata.create_all(bind=test_engine) - app.dependency_overrides[profile.get_db] = get_test_db - app.dependency_overrides[ - profile.get_placeholder_user] = get_test_placeholder_user - - with TestClient(app) as client: - yield client - app.dependency_overrides = {} diff --git a/tests/event_fixture.py b/tests/event_fixture.py new file mode 100644 index 00000000..eef02e02 --- /dev/null +++ b/tests/event_fixture.py @@ -0,0 +1,95 @@ +from datetime import datetime, timedelta + +import pytest +from sqlalchemy.orm import Session + +from app.database.models import Event, User +from app.routers.event import create_event + + +today_date = datetime.today().replace(hour=0, minute=0, second=0) + + +@pytest.fixture +def event(sender: User, session: Session) -> Event: + return create_event( + db=session, + title='event', + start=today_date, + end=today_date, + content='test event', + owner_id=sender.id, + location="Some random location", + ) + + +@pytest.fixture +def today_event(sender: User, session: Session) -> Event: + return create_event( + db=session, + title='event 1', + start=today_date + timedelta(hours=7), + end=today_date + timedelta(hours=9), + content='test event', + owner_id=sender.id, + ) + + +@pytest.fixture +def today_event_2(sender: User, session: Session) -> Event: + return create_event( + db=session, + title='event 2', + start=today_date + timedelta(hours=3), + end=today_date + timedelta(days=2, hours=3), + content='test event', + owner_id=sender.id, + ) + + +@pytest.fixture +def yesterday_event(sender: User, session: Session) -> Event: + return create_event( + db=session, + title='event 3', + start=today_date - timedelta(hours=8), + end=today_date, + content='test event', + owner_id=sender.id, + ) + + +@pytest.fixture +def next_week_event(sender: User, session: Session) -> Event: + return create_event( + db=session, + title='event 4', + start=today_date + timedelta(days=7, hours=2), + end=today_date + timedelta(days=7, hours=4), + content='test event', + owner_id=sender.id, + ) + + +@pytest.fixture +def next_month_event(sender: User, session: Session) -> Event: + return create_event( + db=session, + title='event 5', + start=today_date + timedelta(days=20, hours=4), + end=today_date + timedelta(days=20, hours=6), + content='test event', + owner_id=sender.id, + ) + + +@pytest.fixture +def old_event(sender: User, session: Session) -> Event: + return create_event( + db=session, + title='event 6', + start=today_date - timedelta(days=5), + end=today_date, + content='test event', + owner_id=sender.id, + ) diff --git a/tests/invitation_fixture.py b/tests/invitation_fixture.py new file mode 100644 index 00000000..9015381d --- /dev/null +++ b/tests/invitation_fixture.py @@ -0,0 +1,21 @@ +from datetime import datetime + +import pytest +from sqlalchemy.orm import Session + +from app.database.models import Event, Invitation, User +from tests.utils import create_model, delete_instance + + +@pytest.fixture +def invitation(event: Event, user: User, session: Session) -> Event: + invitation = create_model( + session, Invitation, + creation=datetime.now(), + recipient=user, + event=event, + event_id=event.id, + recipient_id=user.id, + ) + yield invitation + delete_instance(session, invitation) diff --git a/tests/test_agenda_internal.py b/tests/test_agenda_internal.py index 173a01d1..b6e3eb21 100644 --- a/tests/test_agenda_internal.py +++ b/tests/test_agenda_internal.py @@ -1,20 +1,46 @@ -from datetime import datetime +from datetime import datetime, date from app.internal import agenda_events import pytest -START = datetime(2021, 11, 1, 8, 00, 00) +from app.internal.agenda_events import get_events_per_dates -dates = [ - (START, datetime(2021, 11, 3, 8, 00, 0), '2 days'), - (START, datetime(2021, 11, 3, 10, 30, 0), '2 days 2 hours and 30 minutes'), - (START, datetime(2021, 11, 1, 8, 30, 0), '30 minutes'), - (START, datetime(2021, 11, 1, 10, 00, 0), '2 hours'), - (START, datetime(2021, 11, 1, 10, 30, 0), '2 hours and 30 minutes'), - (START, datetime(2021, 11, 2, 10, 00, 0), 'a day and 2 hours'), -] +class TestAgenda: + START = datetime(2021, 11, 1, 8, 00, 00) + dates = [ + (START, datetime(2021, 11, 3, 8, 00, 0), + '2 days'), + (START, datetime(2021, 11, 3, 10, 30, 0), + '2 days 2 hours and 30 minutes'), + (START, datetime(2021, 11, 1, 8, 30, 0), + '30 minutes'), + (START, datetime(2021, 11, 1, 10, 00, 0), + '2 hours'), + (START, datetime(2021, 11, 1, 10, 30, 0), + '2 hours and 30 minutes'), + (START, datetime(2021, 11, 2, 10, 00, 0), + 'a day and 2 hours'), + ] -@pytest.mark.parametrize('start, end, diff', dates) -def test_get_time_delta_string(start, end, diff): - assert agenda_events.get_time_delta_string(start, end) == diff + @pytest.mark.parametrize('start, end, diff', dates) + def test_get_time_delta_string(self, start, end, diff): + assert agenda_events.get_time_delta_string(start, end) == diff + + def test_get_events_per_dates_success(self, today_event, session): + events = get_events_per_dates( + session=session, + user_id=today_event.owner_id, + start=today_event.start.date(), + end=today_event.end.date(), + ) + assert list(events) == [today_event] + + def test_get_events_per_dates_failure(self, yesterday_event, session): + events = get_events_per_dates( + session=session, + user_id=yesterday_event.owner_id, + start=date.today(), + end=date.today(), + ) + assert list(events) == [] diff --git a/tests/test_agenda_route.py b/tests/test_agenda_route.py index 707c23fc..c877a4e8 100644 --- a/tests/test_agenda_route.py +++ b/tests/test_agenda_route.py @@ -2,103 +2,31 @@ from fastapi import status -from app.database.models import User, Event - class TestAgenda: + """In the test we are receiving event fixtures + as parameters so they will load into the database""" + AGENDA = "/agenda" AGENDA_7_DAYS = "/agenda?days=7" AGENDA_30_DAYS = "/agenda?days=30" NO_EVENTS = b"No events found..." INVALID_DATES = b"Start date is greater than end date" + today_date = datetime.today().replace(hour=0, minute=0, second=0) @staticmethod - def base_today_date(): - return datetime.today().replace(hour=0, minute=0, second=0) - - @staticmethod - def create_user_1(session): - user = User(username='zohar', email='aa.aa@aa.com', password='1234') - session.add(user) - session.commit() - return user - - @staticmethod - def create_user_2(session): - user2 = User(username='dani', email='bb.aa@aa.com', password='12345') - session.add(user2) - session.commit() - return user2 - - @staticmethod - def add_event(session, title, content, start, end, owner_id): - event = Event( - title=title, content=content, - start=start, end=end, owner_id=owner_id - ) - session.add(event) - session.commit() - - @staticmethod - def create_data_user_1(session): - user = TestAgenda.create_user_1(session) - base_date = TestAgenda.base_today_date() - # Today event - TestAgenda.add_event( - session, "event 1", "...", - base_date + timedelta(hours=7), - base_date + timedelta(hours=9), user.id - ) - # Today event - TestAgenda.add_event( - session, "event 2", "...", - base_date + timedelta(hours=3), - base_date + timedelta(days=2, hours=3), user.id - ) - # Yesterday event - TestAgenda.add_event( - session, "event 3", "..", - base_date - timedelta(hours=8), - base_date, user.id - ) - # Event in this week - TestAgenda.add_event( - session, "event 4", "...", - base_date + timedelta(days=7, hours=2), - base_date + timedelta(days=7, hours=4), user.id - ) - # Event in this month. - TestAgenda.add_event( - session, "event 5", "...", - base_date + timedelta(days=20, hours=4), - base_date + timedelta(days=20, hours=6), user.id - ) - # Old event - TestAgenda.add_event( - session, "event 6", "..", - base_date - timedelta(days=5), - base_date, user.id - ) - - @staticmethod - def create_data_user_2(session): - user2 = TestAgenda.create_user_2(session) - base_date = TestAgenda.base_today_date() - TestAgenda.add_event( - session, "event 7", "..", base_date + timedelta(hours=7), - base_date + timedelta(hours=8), user2.id - ) - - @staticmethod - def test_agenda_page_no_arguments_when_no_today_events(client): - resp = client.get(TestAgenda.AGENDA) + def test_agenda_page_no_arguments_when_no_today_events( + agenda_test_client, session): + resp = agenda_test_client.get(TestAgenda.AGENDA) assert resp.status_code == status.HTTP_200_OK assert TestAgenda.NO_EVENTS in resp.content - @staticmethod - def test_agenda_page_no_arguments_when_today_events_exist(client, session): - TestAgenda.create_data_user_1(session) - resp = client.get(TestAgenda.AGENDA) + def test_agenda_page_no_arguments_when_today_events_exist( + self, agenda_test_client, session, sender, today_event, + today_event_2, yesterday_event, next_week_event, + next_month_event, old_event + ): + resp = agenda_test_client.get(TestAgenda.AGENDA) assert resp.status_code == status.HTTP_200_OK assert b"event 1" in resp.content assert b"event 2" in resp.content @@ -108,9 +36,12 @@ def test_agenda_page_no_arguments_when_today_events_exist(client, session): assert b"event 6" not in resp.content @staticmethod - def test_agenda_per_7_days(client, session): - TestAgenda.create_data_user_1(session) - resp = client.get(TestAgenda.AGENDA_7_DAYS) + def test_agenda_per_7_days( + agenda_test_client, session, sender, today_event, + today_event_2, yesterday_event, next_week_event, + next_month_event, old_event + ): + resp = agenda_test_client.get(TestAgenda.AGENDA_7_DAYS) today = date.today().strftime("%d/%m/%Y") assert resp.status_code == status.HTTP_200_OK assert bytes(today, 'utf-8') in resp.content @@ -122,9 +53,12 @@ def test_agenda_per_7_days(client, session): assert b"event 6" not in resp.content @staticmethod - def test_agenda_per_30_days(client, session): - TestAgenda.create_data_user_1(session) - resp = client.get(TestAgenda.AGENDA_30_DAYS) + def test_agenda_per_30_days( + agenda_test_client, session, sender, today_event, + today_event_2, yesterday_event, next_week_event, + next_month_event, old_event + ): + resp = agenda_test_client.get(TestAgenda.AGENDA_30_DAYS) today = date.today().strftime("%d/%m/%Y") assert resp.status_code == status.HTTP_200_OK assert bytes(today, 'utf-8') in resp.content @@ -135,13 +69,14 @@ def test_agenda_per_30_days(client, session): assert b"event 5" in resp.content assert b"event 6" not in resp.content - @staticmethod - def test_agenda_between_two_dates(client, session): - TestAgenda.create_data_user_1(session) - base_date = TestAgenda.base_today_date() - start_date = (base_date + timedelta(days=8, hours=4)).date() - end_date = (base_date + timedelta(days=32, hours=4)).date() - resp = client.get( + def test_agenda_between_two_dates( + self, agenda_test_client, session, sender, today_event, + today_event_2, yesterday_event, next_week_event, + next_month_event, old_event + ): + start_date = (self.today_date + timedelta(days=8, hours=4)).date() + end_date = (self.today_date + timedelta(days=32, hours=4)).date() + resp = agenda_test_client.get( f"/agenda?start_date={start_date}&end_date={end_date}") assert resp.status_code == status.HTTP_200_OK assert b"event 1" not in resp.content @@ -151,20 +86,21 @@ def test_agenda_between_two_dates(client, session): assert b"event 5" in resp.content assert b"event 6" not in resp.content - @staticmethod - def test_agenda_start_bigger_than_end(client): - base_date = TestAgenda.base_today_date() - start_date = base_date.date() - end_date = (base_date - timedelta(days=2)).date() - resp = client.get( + def test_agenda_start_bigger_than_end(self, agenda_test_client): + start_date = self.today_date.date() + end_date = (self.today_date - timedelta(days=2)).date() + resp = agenda_test_client.get( f"/agenda?start_date={start_date}&end_date={end_date}") assert resp.status_code == status.HTTP_200_OK assert TestAgenda.INVALID_DATES in resp.content @staticmethod - def test_no_show_events_user_2(client, session): - TestAgenda.create_data_user_1(session) - TestAgenda.create_data_user_2(session) - resp = client.get(TestAgenda.AGENDA) + def test_no_show_events_user_2( + agenda_test_client, session, sender, today_event, + today_event_2, yesterday_event, next_week_event, + next_month_event, old_event + ): + # "user" is just a different event creator + resp = agenda_test_client.get(TestAgenda.AGENDA) assert resp.status_code == status.HTTP_200_OK assert b"event 7" not in resp.content diff --git a/tests/test_app.py b/tests/test_app.py index e69de29b..2a08e499 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -0,0 +1,9 @@ +from sqlalchemy.orm import Session + +from app.database.database import get_db + + +class TestApp: + + def test_get_db(self): + assert isinstance(next(get_db()), Session) diff --git a/tests/test_association.py b/tests/test_association.py new file mode 100644 index 00000000..741f0931 --- /dev/null +++ b/tests/test_association.py @@ -0,0 +1,9 @@ +class TestAssociation: + def test_association_data(self, association, event): + assert association.events == event + + def test_repr(self, association): + assert ( + association.__repr__() + == f'') diff --git a/tests/test_categories.py b/tests/test_categories.py new file mode 100644 index 00000000..5f9a12e0 --- /dev/null +++ b/tests/test_categories.py @@ -0,0 +1,108 @@ +import pytest +from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.testing import mock +from starlette import status +from starlette.datastructures import ImmutableMultiDict + +from app.database.models import Event +from app.routers.categories import get_user_categories, validate_request_params + + +class TestCategories: + CATEGORY_ALREADY_EXISTS_MSG = "category is already exists for" + UNALLOWED_PARAMS = "contains unallowed params" + + @staticmethod + def test_get_categories_logic_succeeded(session, user, category): + assert get_user_categories(session, category.user_id) == [category] + + @staticmethod + def test_creating_new_category(client, user): + response = client.post("/categories/", + json={"user_id": user.id, "name": "Foo", + "color": "eecc11"}) + assert response.status_code == status.HTTP_200_OK + assert {"user_id": user.id, "name": "Foo", "color": "eecc11"}.items() \ + <= response.json()['category'].items() + + @staticmethod + def test_creating_not_unique_category_failed(client, user, category): + response = client.post("/categories/", json={"user_id": user.id, + "name": "Guitar Lesson", + "color": "121212"}) + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert TestCategories.CATEGORY_ALREADY_EXISTS_MSG in \ + response.json()["detail"] + + @staticmethod + def test_create_event_with_category(category): + event = Event(title="OOO", content="Guitar rocks!!", + owner_id=category.user_id, category_id=category.id) + assert event.category_id is not None + assert event.category_id == category.id + + @staticmethod + def test_get_user_categories(client, category): + response = client.get(f"/categories/?user_id={category.user_id}" + f"&name={category.name}&color={category.color}") + assert response.status_code == status.HTTP_200_OK + assert response.json() == [ + {"user_id": category.user_id, "color": "121212", + "name": "Guitar Lesson", "id": category.id}] + + @staticmethod + def test_get_category_by_name(client, user, category): + response = client.get(f"/categories/?user_id={category.user_id}" + f"&name={category.name}") + assert response.status_code == status.HTTP_200_OK + assert response.json() == [ + {"user_id": category.user_id, "color": "121212", + "name": "Guitar Lesson", "id": category.id}] + + @staticmethod + def test_get_category_by_color(client, user, category): + response = client.get(f"/categories/?user_id={category.user_id}&" + f"color={category.color}") + assert response.status_code == status.HTTP_200_OK + assert response.json() == [ + {"user_id": category.user_id, "color": "121212", + "name": "Guitar Lesson", "id": category.id}] + + @staticmethod + def test_get_category_bad_request(client): + response = client.get("/categories/") + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert TestCategories.UNALLOWED_PARAMS in response.json()["detail"] + + @staticmethod + def test_repr(category): + assert category.__repr__() == \ + f'' + + @staticmethod + def test_to_dict(category): + assert {c.name: getattr(category, c.name) for c in + category.__table__.columns} == category.to_dict() + + @staticmethod + @pytest.mark.parametrize('params, expected_result', [ + (ImmutableMultiDict([('user_id', ''), ('name', ''), + ('color', '')]), True), + (ImmutableMultiDict([('user_id', ''), ('name', '')]), True), + (ImmutableMultiDict([('user_id', ''), ('color', '')]), True), + (ImmutableMultiDict([('user_id', '')]), True), + (ImmutableMultiDict([('name', ''), ('color', '')]), False), + (ImmutableMultiDict([]), False), + (ImmutableMultiDict([('user_id', ''), ('name', ''), ('color', ''), + ('bad_param', '')]), False), + ]) + def test_validate_request_params(params, expected_result): + assert validate_request_params(params) == expected_result + + @staticmethod + def test_get_categories_failed(session): + def raise_error(param): + raise SQLAlchemyError() + + session.query = mock.Mock(side_effect=raise_error) + assert get_user_categories(session, 1) == [] diff --git a/tests/test_dayview.py b/tests/test_dayview.py new file mode 100644 index 00000000..8fc56116 --- /dev/null +++ b/tests/test_dayview.py @@ -0,0 +1,109 @@ +from datetime import datetime, timedelta + +from bs4 import BeautifulSoup +import pytest + +from app.database.models import Event, User +from app.routers.dayview import DivAttributes + + +# TODO add user session login +@pytest.fixture +def user_dayevent(): + return User(username='test1', email='user@email.com', + password='1a2b3c4e5f', full_name='test me') + + +@pytest.fixture +def event1(): + start = datetime(year=2021, month=2, day=1, hour=7, minute=5) + end = datetime(year=2021, month=2, day=1, hour=9, minute=15) + return Event(title='test1', content='test', + start=start, end=end, owner_id=1) + + +@pytest.fixture +def event2(): + start = datetime(year=2021, month=2, day=1, hour=13, minute=13) + end = datetime(year=2021, month=2, day=1, hour=15, minute=46) + return Event(title='test2', content='test', + start=start, end=end, owner_id=1, color='blue') + + +@pytest.fixture +def event3(): + start = datetime(year=2021, month=2, day=3, hour=7, minute=5) + end = datetime(year=2021, month=2, day=3, hour=9, minute=15) + return Event(title='test3', content='test', + start=start, end=end, owner_id=1) + + +@pytest.fixture +def event_with_no_minutes_modified(): + start = datetime(year=2021, month=2, day=3, hour=7) + end = datetime(year=2021, month=2, day=3, hour=8) + return Event(title='test_no_modify', content='test', + start=start, end=end, owner_id=1) + + +@pytest.fixture +def multiday_event(): + start = datetime(year=2021, month=2, day=1, hour=13) + end = datetime(year=2021, month=2, day=3, hour=13) + return Event(title='test_multiday', content='test', + start=start, end=end, owner_id=1, color='blue') + + +def test_minutes_position_calculation(event_with_no_minutes_modified): + div_attr = DivAttributes(event_with_no_minutes_modified) + assert div_attr._minutes_position(div_attr.start_time.minute) is None + assert div_attr._minutes_position(div_attr.end_time.minute) is None + assert div_attr._minutes_position(0) is None + assert div_attr._minutes_position(60) == 4 + + +def test_div_attributes(event1): + div_attr = DivAttributes(event1) + assert div_attr.total_time == '07:05 - 09:15' + assert div_attr.grid_position == '34 / 42' + assert div_attr.length == 130 + assert div_attr.color == 'grey' + + +def test_div_attr_multiday(multiday_event): + day = datetime(year=2021, month=2, day=1) + assert DivAttributes(multiday_event, day).grid_position == '57 / 101' + day += timedelta(hours=24) + assert DivAttributes(multiday_event, day).grid_position == '1 / 101' + day += timedelta(hours=24) + assert DivAttributes(multiday_event, day).grid_position == '1 / 57' + + +def test_div_attributes_with_costume_color(event2): + div_attr = DivAttributes(event2) + assert div_attr.color == 'blue' + + +def test_dayview_html(event1, event2, event3, session, user_dayevent, client): + session.add_all([user_dayevent, event1, event2, event3]) + session.commit() + response = client.get("/day/2021-2-1") + soup = BeautifulSoup(response.content, 'html.parser') + assert 'FEBRUARY' in str(soup.find("div", {"id": "toptab"})) + assert 'event1' in str(soup.find("div", {"id": "event1"})) + assert 'event2' in str(soup.find("div", {"id": "event2"})) + assert soup.find("div", {"id": "event3"}) is None + + +@pytest.mark.parametrize("day,grid_position", [("2021-2-1", '57 / 101'), + ("2021-2-2", '1 / 101'), + ("2021-2-3", '1 / 57')]) +def test_dayview_html_with_multiday_event(multiday_event, session, + user_dayevent, client, day, + grid_position): + session.add_all([user_dayevent, multiday_event]) + session.commit() + response = client.get(f"/day/{day}") + soup = BeautifulSoup(response.content, 'html.parser') + grid_pos = f'grid-row: {grid_position};' + assert grid_pos in str(soup.find("div", {"id": "event1"})) diff --git a/tests/test_event.py b/tests/test_event.py index d50bd567..d6facbc2 100644 --- a/tests/test_event.py +++ b/tests/test_event.py @@ -1,22 +1,21 @@ -from fastapi.testclient import TestClient +from starlette.status import HTTP_404_NOT_FOUND -from app.main import app -client = TestClient(app) +class TestEvent: + def test_eventedit(self, client): + response = client.get("/event/edit") + assert response.ok + assert b"Edit Event" in response.content -def test_eventedit(): - response = client.get("/event/edit") - assert response.status_code == 200 - assert b"Edit Event" in response.content + def test_eventview_with_id(self, client): + response = client.get("/event/view/1") + assert response.ok + assert b"View Event" in response.content + def test_eventview_without_id(self, client): + response = client.get("/event/view") + assert response.status_code == HTTP_404_NOT_FOUND -def test_eventview_with_id(): - response = client.get("/event/view/1") - assert response.status_code == 200 - assert b"View Event" in response.content - - -def test_eventview_without_id(): - response = client.get("/event/view") - assert response.status_code == 404 + def test_repr(self, event): + assert event.__repr__() == f'' diff --git a/tests/test_export.py b/tests/test_export.py new file mode 100644 index 00000000..1a39f313 --- /dev/null +++ b/tests/test_export.py @@ -0,0 +1,42 @@ +from icalendar import vCalAddress + +from app.config import ICAL_VERSION, PRODUCT_ID +from app.routers.export import ( + create_ical_calendar, create_ical_event, event_to_ical +) + + +class TestExport: + + def test_create_ical_calendar(self): + cal = create_ical_calendar() + assert cal.get('version') == ICAL_VERSION + assert cal.get('prodid') == PRODUCT_ID + + def test_create_ical_event(self, event): + ical_event = create_ical_event(event) + assert event.owner.email in ical_event.get('organizer') + assert ical_event.get('summary') == event.title + + def test_add_attendees(self, event, user): + ical_event = create_ical_event(event) + ical_event.add( + 'attendee', + vCalAddress(f'MAILTO:{user.email}'), + encode=0 + ) + attendee = vCalAddress(f'MAILTO:{user.email}') + assert attendee == ical_event.get('attendee') + + def test_event_to_ical(self, user, event): + ical_event = event_to_ical(event, [user.email]) + + def does_contain(item: str) -> bool: + """Returns if calendar contains item.""" + + return bytes(item, encoding='utf8') in bytes(ical_event) + + assert does_contain(ICAL_VERSION) + assert does_contain(PRODUCT_ID) + assert does_contain(event.owner.email) + assert does_contain(event.title) diff --git a/tests/test_home.py b/tests/test_home.py new file mode 100644 index 00000000..fc0b6772 --- /dev/null +++ b/tests/test_home.py @@ -0,0 +1,6 @@ +class TestHome: + URL = "/" + + def test_get_page(self, client): + resp = client.get(self.URL) + assert resp.status_code == 200 diff --git a/tests/test_invitation.py b/tests/test_invitation.py new file mode 100644 index 00000000..a605ec4e --- /dev/null +++ b/tests/test_invitation.py @@ -0,0 +1,50 @@ +from starlette.status import HTTP_302_FOUND +from app.routers.invitation import get_all_invitations, get_invitation_by_id + + +class TestInvitations: + NO_INVITATIONS = b"You don't have any invitations." + URL = "/invitations/" + + def test_view_no_invitations(self, invitation_test_client): + resp = invitation_test_client.get(self.URL) + assert resp.ok + assert self.NO_INVITATIONS in resp.content + + def test_accept_invitations( + self, user, invitation, + invitation_test_client): + invitation = {"invite_id ": invitation.id} + resp = invitation_test_client.post( + self.URL, data=invitation) + assert resp.status_code == HTTP_302_FOUND + + def test_get_all_invitations_success( + self, invitation, event, user, session + ): + invitations = get_all_invitations(event=event, session=session) + assert invitations == [invitation] + invitations = get_all_invitations(recipient=user, session=session) + assert invitations == [invitation] + + def test_get_all_invitations_failure(self, user, session): + invitations = get_all_invitations( + unknown_parameter=user, session=session) + assert invitations == [] + + invitations = get_all_invitations( + recipient=None, session=session) + assert invitations == [] + + def test_get_invitation_by_id(self, invitation, session): + get_invitation = get_invitation_by_id( + invitation.id, session=session) + assert get_invitation == invitation + + def test_repr(self, invitation): + invitation_repr = ( + f'' + ) + assert invitation.__repr__() == invitation_repr diff --git a/tests/test_share_event.py b/tests/test_share_event.py new file mode 100644 index 00000000..45a70581 --- /dev/null +++ b/tests/test_share_event.py @@ -0,0 +1,58 @@ +from app.routers.share import ( + accept, send_in_app_invitation, sort_emails, send_email_invitation, share +) +from app.routers.invitation import get_all_invitations + + +class TestShareEvent: + def test_share_failure(self, event, session): + participants = [event.owner.email] + share(event, participants, session) + invitations = get_all_invitations( + session=session, recipient_id=event.owner.id + ) + assert invitations == [] + + def test_share_success(self, user, event, session): + participants = [user.email] + share(event, participants, session) + invitations = get_all_invitations( + session=session, recipient_id=user.id + ) + assert invitations != [] + + def test_sort_emails(self, user, session): + # the user is being imported + # so he will be created + data = [ + 'test.email@gmail.com', # registered user + 'not_logged_in@gmail.com', # unregistered user + ] + sorted_data = sort_emails(data, session=session) + assert sorted_data == { + 'registered': ['test.email@gmail.com'], + 'unregistered': ['not_logged_in@gmail.com'] + } + + def test_send_in_app_invitation_success( + self, user, sender, event, session + ): + assert send_in_app_invitation([user.email], event, session=session) + invitation = get_all_invitations(session=session, recipient=user)[0] + assert invitation.event.owner == sender + assert invitation.recipient == user + session.delete(invitation) + + def test_send_in_app_invitation_failure( + self, user, sender, event, session): + assert (send_in_app_invitation( + [sender.email], event, session=session) is False) + + def test_send_email_invitation(self, user, event): + send_email_invitation([user.email], event) + # TODO add email tests + assert True + + def test_accept(self, invitation, session): + accept(invitation, session=session) + assert invitation.status == 'accepted' diff --git a/tests/test_user.py b/tests/test_user.py new file mode 100644 index 00000000..9e2a9a84 --- /dev/null +++ b/tests/test_user.py @@ -0,0 +1,38 @@ +from app.routers.user import create_user, does_user_exist, get_users + + +class TestUser: + + def test_create_user(self, session): + user = create_user( + session=session, + username='new_test_username', + password='new_test_password', + email='new_test.email@gmail.com', + ) + assert user.username == 'new_test_username' + assert user.password == 'new_test_password' + assert user.email == 'new_test.email@gmail.com' + session.delete(user) + session.commit() + + def test_get_users_success(self, user, session): + assert get_users(username=user.username, session=session) == [user] + assert get_users(password=user.password, session=session) == [user] + assert get_users(email=user.email, session=session) == [user] + + def test_get_users_failure(self, session, user): + assert get_users(username='wrong username', session=session) == [] + assert get_users(wrong_param=user.username, session=session) == [] + + def test_does_user_exist_success(self, user, session): + assert does_user_exist(username=user.username, session=session) + assert does_user_exist(user_id=user.id, session=session) + assert does_user_exist(email=user.email, session=session) + + def test_does_user_exist_failure(self, session): + assert not does_user_exist(username='wrong username', session=session) + assert not does_user_exist(session=session) + + def test_repr(self, user): + assert user.__repr__() == f'' diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 00000000..a6164281 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,14 @@ +from sqlalchemy.orm import Session + +from app.internal.utils import save + + +class TestUtils: + + def test_save_success(self, user, session: Session): + user.username = 'edit_username' + assert save(user, session=session) + + def test_save_failure(self, session: Session): + user = 'not a user instance' + assert not save(user, session=session) diff --git a/tests/test_weather_forecast.py b/tests/test_weather_forecast.py new file mode 100644 index 00000000..96e77c91 --- /dev/null +++ b/tests/test_weather_forecast.py @@ -0,0 +1,79 @@ +import datetime +import pytest +import requests +import responses + +from app.internal.weather_forecast import get_weather_data + + +HISTORY_URL = "https://visual-crossing-weather.p.rapidapi.com/history" +FORECAST_URL = "https://visual-crossing-weather.p.rapidapi.com/forecast" +RESPONSE_FROM_MOCK = {"locations": {"Tel Aviv": {"values": [ + {"mint": 6, "maxt": 17.2, "conditions": "Partially cloudy"}]}}} +ERROR_RESPONSE_FROM_MOCK = {"message": "Error Text"} +DATA_GET_WEATHER = [ + pytest.param(2020, "tel aviv", 0, marks=pytest.mark.xfail, + id="invalid input type"), + pytest.param(datetime.datetime(day=4, month=4, year=2070), "tel aviv", 0, + marks=pytest.mark.xfail, id="year out of range"), + pytest.param(datetime.datetime(day=4, month=4, year=2020), + "tel aviv", 0, id="basic historical test"), + pytest.param(datetime.datetime(day=1, month=1, year=2030), "tel aviv", 0, + id="basic historical forecast test - prior in current year"), + pytest.param(datetime.datetime(day=31, month=12, year=2030), + "tel aviv", 0, id="basic historical forecast test - future"), + pytest.param(datetime.datetime(day=29, month=2, year=2024), "tel aviv", + 0, id="basic historical forecast test"), +] + + +@pytest.mark.parametrize('requested_date, location, expected', + DATA_GET_WEATHER) +def test_get_weather_data(requested_date, location, expected, requests_mock): + requests_mock.get(HISTORY_URL, json=RESPONSE_FROM_MOCK) + output = get_weather_data(requested_date, location) + assert output['Status'] == expected + + +def test_get_forecast_weather_data(requests_mock): + temp_date = datetime.datetime.now() + datetime.timedelta(days=2) + response_from_mock = RESPONSE_FROM_MOCK + response_from_mock["locations"]["Tel Aviv"]["values"][0]["datetimeStr"] =\ + temp_date.isoformat() + requests_mock.get(FORECAST_URL, json=response_from_mock) + output = get_weather_data(temp_date, "tel aviv") + assert output['Status'] == 0 + + +def test_location_not_found(requests_mock): + requested_date = datetime.datetime(day=10, month=1, year=2020) + requests_mock.get(HISTORY_URL, json=ERROR_RESPONSE_FROM_MOCK) + output = get_weather_data(requested_date, "neo") + assert output['Status'] == -1 + + +@responses.activate +def test_historical_no_response_from_api(): + requested_date = datetime.datetime(day=11, month=1, year=2020) + responses.add(responses.GET, HISTORY_URL, status=500) + requests.get(HISTORY_URL) + output = get_weather_data(requested_date, "neo") + assert output['Status'] == -1 + + +@responses.activate +def test_historical_exception_from_api(): + requested_date = datetime.datetime(day=12, month=1, year=2020) + with pytest.raises(requests.exceptions.ConnectionError): + requests.get(HISTORY_URL) + output = get_weather_data(requested_date, "neo") + assert output['Status'] == -1 + + +@responses.activate +def test_forecast_exception_from_api(): + requested_date = datetime.datetime.now() + datetime.timedelta(days=3) + with pytest.raises(requests.exceptions.ConnectionError): + requests.get(FORECAST_URL) + output = get_weather_data(requested_date, "neo") + assert output['Status'] == -1 diff --git a/tests/user_fixture.py b/tests/user_fixture.py new file mode 100644 index 00000000..526cc10f --- /dev/null +++ b/tests/user_fixture.py @@ -0,0 +1,29 @@ +import pytest +from sqlalchemy.orm import Session + +from app.database.models import User +from tests.utils import create_model, delete_instance + + +@pytest.fixture +def user(session: Session) -> User: + test_user = create_model( + session, User, + username='test_username', + password='test_password', + email='test.email@gmail.com', + ) + yield test_user + delete_instance(session, test_user) + + +@pytest.fixture +def sender(session: Session) -> User: + sender = create_model( + session, User, + username='sender_username', + password='sender_password', + email='sender.email@gmail.com', + ) + yield sender + delete_instance(session, sender) diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 00000000..58ffdbd0 --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,13 @@ +from sqlalchemy.orm import Session + + +def create_model(session: Session, model_class, **kw): + instance = model_class(**kw) + session.add(instance) + session.commit() + return instance + + +def delete_instance(session: Session, instance): + session.delete(instance) + session.commit()