diff --git a/lms/lmsdb/bootstrap.py b/lms/lmsdb/bootstrap.py index 740c1f20..9331ba52 100644 --- a/lms/lmsdb/bootstrap.py +++ b/lms/lmsdb/bootstrap.py @@ -287,6 +287,15 @@ def _add_exercise_course_id_and_number_columns_constraint() -> bool: db_config.database.commit() +def _add_user_course_constaint() -> bool: + migrator = db_config.get_migrator_instance() + with db_config.database.transaction(): + migrate( + migrator.add_index('usercourse', ('user_id', 'course_id'), True), + ) + db_config.database.commit() + + def _last_status_view_migration() -> bool: Solution = models.Solution _migrate_column_in_table_if_needed(Solution, Solution.last_status_view) @@ -312,6 +321,9 @@ def main(): _last_course_viewed_migration() _uuid_migration() + if models.database.table_exists(models.UserCourse.__name__.lower()): + _add_user_course_constaint() + models.database.create_tables(models.ALL_MODELS, safe=True) if models.Role.select().count() == 0: diff --git a/lms/lmsdb/models.py b/lms/lmsdb/models.py index cf3c8bc1..f7efa8b2 100644 --- a/lms/lmsdb/models.py +++ b/lms/lmsdb/models.py @@ -16,7 +16,9 @@ BooleanField, Case, CharField, Check, DateTimeField, ForeignKeyField, IntegerField, JOIN, ManyToManyField, TextField, UUIDField, fn, ) -from playhouse.signals import Model, post_save, pre_save # type: ignore +from playhouse.signals import ( # type: ignore + Model, post_delete, post_save, pre_save, +) from werkzeug.security import ( check_password_hash, generate_password_hash, ) @@ -252,6 +254,11 @@ class UserCourse(BaseModel): course = ForeignKeyField(Course, backref='usercourses') date = DateTimeField(default=datetime.now) + class Meta: + indexes = ( + (('user_id', 'course_id'), True), + ) + @classmethod def is_user_registered(cls, user_id: int, course_id: int) -> bool: return ( @@ -265,6 +272,24 @@ def is_user_registered(cls, user_id: int, course_id: int) -> bool: ) +@post_save(sender=UserCourse) +def on_save_user_course(model_class, instance, created): + """Changes user's last course viewed.""" + if instance.user.last_course_viewed is None: + instance.user.last_course_viewed = instance.course + instance.user.save() + + +@post_delete(sender=UserCourse) +def on_delete_user_course(model_class, instance): + """Changes user's last course viewed.""" + if instance.user.last_course_viewed == instance.course: + instance.user.last_course_viewed = ( + Course.fetch(instance.user).limit(1).scalar() + ) + instance.user.save() + + class Notification(BaseModel): ID_FIELD_NAME = 'id' MAX_PER_USER = 10 diff --git a/tests/test_users.py b/tests/test_users.py index 18ae9c1c..63823ee5 100644 --- a/tests/test_users.py +++ b/tests/test_users.py @@ -185,3 +185,22 @@ def test_user_registered_to_course(student_user: User, course: Course): course2 = conftest.create_course(index=1) assert not course2.has_user(student_user) + @staticmethod + def test_usercourse_on_delete(student_user: User, course: Course): + usercourse = conftest.create_usercourse(student_user, course) + assert student_user.last_course_viewed == course + + usercourse.delete_instance() + assert student_user.last_course_viewed is None + + @staticmethod + def test_usercourse_on_save(student_user: User, course: Course): + course2 = conftest.create_course(index=1) + usercourse = conftest.create_usercourse(student_user, course) + assert student_user.last_course_viewed == course + + conftest.create_usercourse(student_user, course2) + assert student_user.last_course_viewed == course + + usercourse.delete_instance() + assert student_user.last_course_viewed == course2