Skip to content

Commit f00b080

Browse files
authored
Merge 73def22 into 3c22c8c
2 parents 3c22c8c + 73def22 commit f00b080

File tree

6 files changed

+336
-0
lines changed

6 files changed

+336
-0
lines changed

app/database/models.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -499,6 +499,19 @@ def __repr__(self):
499499
)
500500

501501

502+
class UserMenstrualPeriodLength(Base):
503+
__tablename__ = "user_menstrual_period_length"
504+
505+
id = Column(Integer, primary_key=True, index=True)
506+
user_id = Column(
507+
Integer,
508+
ForeignKey("users.id"),
509+
nullable=False,
510+
unique=True,
511+
)
512+
period_length = Column(Integer, nullable=False)
513+
514+
502515
class SharedListItem(Base):
503516
__tablename__ = "shared_list_item"
504517

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
import datetime
2+
from datetime import timedelta
3+
from typing import List, Union
4+
5+
from fastapi import Depends
6+
from loguru import logger
7+
from sqlalchemy import asc
8+
from sqlalchemy.exc import SQLAlchemyError
9+
from sqlalchemy.orm import Session
10+
11+
from app.database.models import Event, UserMenstrualPeriodLength
12+
from app.dependencies import get_db
13+
from app.internal.security.dependencies import current_user
14+
from app.internal.security.schema import CurrentUser
15+
from app.routers.event import create_event
16+
17+
MENSTRUAL_PERIOD_CATEGORY_ID = 111
18+
19+
20+
def get_avg_period_gap(db: Session, user_id: int) -> int:
21+
GAP_IN_CASE_NO_PERIODS = 30
22+
23+
period_days = get_all_period_days(db, user_id)
24+
gaps_list = []
25+
26+
if len(period_days) <= 1:
27+
return GAP_IN_CASE_NO_PERIODS
28+
29+
for i in range(len(period_days) - 1):
30+
gap = get_date_diff(period_days[i].start, period_days[i + 1].start)
31+
gaps_list.append(gap.days)
32+
return get_list_avg(gaps_list)
33+
34+
35+
def get_date_diff(date_1: datetime, date_2: datetime) -> timedelta:
36+
return date_2 - date_1
37+
38+
39+
def get_list_avg(received_list: List) -> int:
40+
return sum(received_list) // len(received_list)
41+
42+
43+
def remove_existing_period_dates(db: Session, user_id: int) -> None:
44+
(
45+
db.query(Event)
46+
.filter(Event.owner_id == user_id)
47+
.filter(Event.category_id == MENSTRUAL_PERIOD_CATEGORY_ID)
48+
.filter(Event.start > datetime.datetime.now())
49+
.delete()
50+
)
51+
db.commit()
52+
logger.info("Removed all period predictions to create new ones")
53+
54+
55+
def generate_predicted_period_dates(
56+
db: Session,
57+
period_length: str,
58+
period_start_date: datetime,
59+
user_id: int,
60+
) -> Event:
61+
delta = datetime.timedelta(int(period_length))
62+
period_end_date = period_start_date + delta
63+
period_event = create_event(
64+
db,
65+
"period",
66+
period_start_date,
67+
period_end_date,
68+
user_id,
69+
category_id=MENSTRUAL_PERIOD_CATEGORY_ID,
70+
)
71+
return period_event
72+
73+
74+
def add_3_month_predictions(
75+
db: Session,
76+
period_length: str,
77+
period_start_date: datetime,
78+
user_id: int,
79+
) -> List[Event]:
80+
avg_gap = get_avg_period_gap(db, user_id)
81+
avg_gap_delta = datetime.timedelta(avg_gap)
82+
generated_3_months = []
83+
for _ in range(4):
84+
generated_period = generate_predicted_period_dates(
85+
db,
86+
period_length,
87+
period_start_date,
88+
user_id,
89+
)
90+
generated_3_months.append(generated_period)
91+
period_start_date += avg_gap_delta
92+
logger.info(f"Generated predictions: {generated_3_months}")
93+
return generated_3_months
94+
95+
96+
def add_prediction_events_if_valid(
97+
period_start_date: datetime,
98+
db: Session = Depends(get_db),
99+
user: CurrentUser = Depends(current_user),
100+
) -> None:
101+
current_user_id = user.user_id
102+
user_period_length = is_user_signed_up_to_menstrual_predictor(
103+
db,
104+
current_user_id,
105+
)
106+
107+
remove_existing_period_dates(db, current_user_id)
108+
if user_period_length:
109+
add_3_month_predictions(
110+
db,
111+
user_period_length,
112+
period_start_date,
113+
current_user_id,
114+
)
115+
116+
117+
def get_all_period_days(session: Session, user_id: int) -> List[Event]:
118+
"""Returns all period days filtered by user id."""
119+
120+
try:
121+
period_days = (
122+
session.query(Event)
123+
.filter(Event.owner_id == user_id)
124+
.filter(Event.category_id == MENSTRUAL_PERIOD_CATEGORY_ID)
125+
.order_by(asc(Event.start))
126+
.all()
127+
)
128+
129+
except SQLAlchemyError as err:
130+
logger.exception(err)
131+
return []
132+
133+
return period_days
134+
135+
136+
def is_user_signed_up_to_menstrual_predictor(
137+
session: Session,
138+
user_id: int,
139+
) -> Union[bool, int]:
140+
user_menstrual_period_length = (
141+
session.query(UserMenstrualPeriodLength)
142+
.filter(user_id == user_id)
143+
.first()
144+
)
145+
if user_menstrual_period_length:
146+
return user_menstrual_period_length.period_length
147+
return False

app/main.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ def create_tables(engine, psql_environment):
7373
login,
7474
logout,
7575
meds,
76+
menstrual_predictor,
7677
notification,
7778
profile,
7879
register,
@@ -125,6 +126,7 @@ async def swagger_ui_redirect():
125126
login.router,
126127
logout.router,
127128
meds.router,
129+
menstrual_predictor.router,
128130
notification.router,
129131
profile.router,
130132
register.router,

app/routers/menstrual_predictor.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
import datetime
2+
3+
from fastapi import APIRouter, Depends, HTTPException, Request
4+
from fastapi.responses import RedirectResponse, Response
5+
from loguru import logger
6+
from sqlalchemy.exc import SQLAlchemyError
7+
from sqlalchemy.orm import Session
8+
from starlette.status import HTTP_302_FOUND, HTTP_400_BAD_REQUEST
9+
10+
from app.database.models import UserMenstrualPeriodLength
11+
from app.dependencies import get_db, templates
12+
from app.internal.menstrual_predictor_utils import (
13+
add_prediction_events_if_valid,
14+
generate_predicted_period_dates,
15+
is_user_signed_up_to_menstrual_predictor,
16+
)
17+
from app.internal.security.dependencies import current_user
18+
from app.internal.security.schema import CurrentUser
19+
from app.internal.utils import create_model
20+
21+
router = APIRouter(
22+
prefix="/menstrual_predictor",
23+
tags=["menstrual_predictor"],
24+
dependencies=[Depends(get_db)],
25+
)
26+
27+
MENSTRUAL_PERIOD_CATEGORY_ID = 111
28+
29+
30+
@router.get("/")
31+
def join_menstrual_predictor(
32+
request: Request,
33+
db: Session = Depends(get_db),
34+
user: CurrentUser = Depends(current_user),
35+
) -> Response:
36+
current_user_id = user.user_id
37+
38+
if not is_user_signed_up_to_menstrual_predictor(db, current_user_id):
39+
return templates.TemplateResponse(
40+
"join_menstrual_predictor.html",
41+
{
42+
"request": request,
43+
},
44+
)
45+
return RedirectResponse(url="/", status_code=HTTP_302_FOUND)
46+
47+
48+
@router.get("/add-period-start/{start_date}")
49+
def add_period_start(
50+
request: Request,
51+
start_date: str,
52+
db: Session = Depends(get_db),
53+
user: CurrentUser = Depends(current_user),
54+
) -> RedirectResponse:
55+
try:
56+
period_start_date = datetime.datetime.strptime(start_date, "%Y-%m-%d")
57+
except ValueError as err:
58+
logger.exception(err)
59+
raise HTTPException(
60+
status_code=HTTP_400_BAD_REQUEST,
61+
detail="The given date doesn't match a date format YYYY-MM-DD",
62+
)
63+
else:
64+
add_prediction_events_if_valid(period_start_date, db, user)
65+
logger.info("Adding menstrual start date")
66+
return RedirectResponse("/", status_code=HTTP_302_FOUND)
67+
68+
69+
@router.post("/")
70+
async def submit_join_form(
71+
request: Request,
72+
db: Session = Depends(get_db),
73+
user: CurrentUser = Depends(current_user),
74+
) -> RedirectResponse:
75+
76+
data = await request.form()
77+
78+
user_menstrual_period_length = {
79+
"user_id": user.user_id,
80+
"period_length": data["avg-period-length"],
81+
}
82+
last_period_date = datetime.datetime.strptime(
83+
data["last-period-date"],
84+
"%Y-%m-%d",
85+
)
86+
try:
87+
create_model(
88+
session=db,
89+
model_class=UserMenstrualPeriodLength,
90+
**user_menstrual_period_length,
91+
)
92+
except SQLAlchemyError:
93+
logger.info("Current user already signed up to the service, hurray")
94+
db.rollback()
95+
url = "/"
96+
generate_predicted_period_dates(
97+
db,
98+
data["avg-period-length"],
99+
last_period_date,
100+
user.user_id,
101+
)
102+
103+
return RedirectResponse(url=url, status_code=HTTP_302_FOUND)
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
{% extends "base.html" %}
2+
{% block content %}
3+
<div class="container">
4+
<h3>Please fill in your details</h3>
5+
<form method="post">
6+
<div class="row">
7+
<label for="avg-period-length" class="col-sm-2 col-form-label">Average period length</label>
8+
<div class="col-auto">
9+
<input type="number" class="form-control" name="avg-period-length" id="avg-period-length" min=1 placeholder="Required" autofocus required><br>
10+
</div>
11+
<div class="col-auto">
12+
<span id="PeriodLengthHelpInline" class="form-text">
13+
Must be above 1 day.
14+
</span>
15+
</div>
16+
</div>
17+
<div class="row">
18+
<label for="last-period-date" class="col-sm-2 col-form-label">Last period date</label>
19+
<div class="col-auto">
20+
<input type="date" class="form-control" name="last-period-date" id="last-period-date" placeholder="Required" required><br>
21+
</div>
22+
</div>
23+
24+
<input type="submit" class="btn-sm btn btn-outline-primary" value="Sign Up">
25+
</form>
26+
</div>
27+
<script>
28+
function change_max_to_today_date(el){
29+
today = new Date();
30+
today_str = today.toISOString().substring(0,10);
31+
el.max = today_str;
32+
}
33+
function validate_date_older_than_today(received_date){
34+
return received_date < new Date();
35+
}
36+
let last_period_date_element = document.getElementById("last-period-date");
37+
change_max_to_today_date(last_period_date_element);
38+
</script>
39+
{% endblock %}

tests/test_menstrual_predictor.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
class TestMenstrualPredictor:
2+
PREDICTOR_PREFIX = "/menstrual_predictor"
3+
ADD_PERIOD_START = "/add-period-start"
4+
5+
@staticmethod
6+
def test_menstrual_predictor_page_not_signed_up(client, session):
7+
resp = client.get(TestMenstrualPredictor.PREDICTOR_PREFIX)
8+
assert resp.ok
9+
10+
@staticmethod
11+
def test_menstrual_predictor_sign_up(client, session):
12+
resp = client.post(
13+
TestMenstrualPredictor.PREDICTOR_PREFIX,
14+
json={"avg-period-length": 8, "last-period-date": "2020-11-07"},
15+
)
16+
assert resp.ok
17+
18+
resp = client.get(
19+
TestMenstrualPredictor.PREDICTOR_PREFIX
20+
+ TestMenstrualPredictor.ADD_PERIOD_START
21+
+ "/2020-12-11",
22+
)
23+
assert resp.ok
24+
25+
@staticmethod
26+
def test_add_period_date(client, session):
27+
resp = client.get(
28+
TestMenstrualPredictor.PREDICTOR_PREFIX
29+
+ TestMenstrualPredictor.ADD_PERIOD_START
30+
+ "/2020-12-11",
31+
)
32+
assert resp.ok

0 commit comments

Comments
 (0)