Skip to content

Commit 24528e6

Browse files
authored
Updated refresh token storage logic (#403)
1 parent 7853233 commit 24528e6

File tree

9 files changed

+120
-91
lines changed

9 files changed

+120
-91
lines changed

backend/app/admin/api/v1/auth/auth.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# -*- coding: utf-8 -*-
33
from typing import Annotated
44

5-
from fastapi import APIRouter, Depends, Query, Request
5+
from fastapi import APIRouter, Depends, Request, Response
66
from fastapi.security import HTTPBasicCredentials
77
from fastapi_limiter.depends import RateLimiter
88
from starlette.background import BackgroundTasks
@@ -28,18 +28,20 @@ async def swagger_login(obj: Annotated[HTTPBasicCredentials, Depends()]) -> GetS
2828
description='json 格式登录, 仅支持在第三方api工具调试, 例如: postman',
2929
dependencies=[Depends(RateLimiter(times=5, minutes=1))],
3030
)
31-
async def user_login(request: Request, obj: AuthLoginParam, background_tasks: BackgroundTasks) -> ResponseModel:
32-
data = await auth_service.login(request=request, obj=obj, background_tasks=background_tasks)
31+
async def user_login(
32+
request: Request, response: Response, obj: AuthLoginParam, background_tasks: BackgroundTasks
33+
) -> ResponseModel:
34+
data = await auth_service.login(request=request, response=response, obj=obj, background_tasks=background_tasks)
3335
return response_base.success(data=data)
3436

3537

3638
@router.post('/token/new', summary='创建新 token', dependencies=[DependsJwtAuth])
37-
async def create_new_token(request: Request, refresh_token: Annotated[str, Query(...)]) -> ResponseModel:
38-
data = await auth_service.new_token(request=request, refresh_token=refresh_token)
39+
async def create_new_token(request: Request, response: Response) -> ResponseModel:
40+
data = await auth_service.new_token(request=request, response=response)
3941
return response_base.success(data=data)
4042

4143

4244
@router.post('/logout', summary='用户登出', dependencies=[DependsJwtAuth])
43-
async def user_logout(request: Request) -> ResponseModel:
44-
await auth_service.logout(request=request)
45+
async def user_logout(request: Request, response: Response) -> ResponseModel:
46+
await auth_service.logout(request=request, response=response)
4547
return response_base.success()

backend/app/admin/schema/token.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,8 @@ class AccessTokenBase(SchemaBase):
1919

2020

2121
class GetNewToken(AccessTokenBase):
22-
refresh_token: str
23-
refresh_token_type: str = 'Bearer'
24-
refresh_token_expire_time: datetime
22+
pass
2523

2624

27-
class GetLoginToken(GetNewToken):
25+
class GetLoginToken(AccessTokenBase):
2826
user: GetUserInfoNoRelationDetail

backend/app/admin/service/auth_service.py

Lines changed: 38 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#!/usr/bin/env python3
22
# -*- coding: utf-8 -*-
3-
from fastapi import Request
3+
from fastapi import Request, Response
44
from fastapi.security import HTTPBasicCredentials
55
from starlette.background import BackgroundTask, BackgroundTasks
66

@@ -38,12 +38,14 @@ async def swagger_login(*, obj: HTTPBasicCredentials) -> tuple[str, User]:
3838
raise errors.AuthorizationError(msg='用户名或密码有误')
3939
elif not current_user.status:
4040
raise errors.AuthorizationError(msg='用户已被锁定, 请联系统管理员')
41-
access_token, _ = await create_access_token(str(current_user.id), multi_login=current_user.is_multi_login)
41+
access_token = await create_access_token(str(current_user.id), current_user.is_multi_login)
4242
await user_dao.update_login_time(db, obj.username)
43-
return access_token, current_user
43+
return access_token.access_token, current_user
4444

4545
@staticmethod
46-
async def login(*, request: Request, obj: AuthLoginParam, background_tasks: BackgroundTasks) -> GetLoginToken:
46+
async def login(
47+
*, request: Request, response: Response, obj: AuthLoginParam, background_tasks: BackgroundTasks
48+
) -> GetLoginToken:
4749
async with async_db_session.begin() as db:
4850
try:
4951
current_user = await user_dao.get_by_username(db, obj.username)
@@ -61,14 +63,8 @@ async def login(*, request: Request, obj: AuthLoginParam, background_tasks: Back
6163
if captcha_code.lower() != obj.captcha.lower():
6264
raise errors.CustomError(error=CustomErrorCode.CAPTCHA_ERROR)
6365
current_user_id = current_user.id
64-
access_token, access_token_expire_time = await create_access_token(
65-
str(current_user_id), multi_login=current_user.is_multi_login
66-
)
67-
refresh_token, refresh_token_expire_time = await create_refresh_token(
68-
sub=str(current_user_id),
69-
expire_time=access_token_expire_time,
70-
multi_login=current_user.is_multi_login,
71-
)
66+
access_token = await create_access_token(str(current_user_id), current_user.is_multi_login)
67+
refresh_token = await create_refresh_token(str(current_user_id), current_user.is_multi_login)
7268
except errors.NotFoundError as e:
7369
raise errors.NotFoundError(msg=e.msg)
7470
except (errors.AuthorizationError, errors.CustomError) as e:
@@ -102,19 +98,29 @@ async def login(*, request: Request, obj: AuthLoginParam, background_tasks: Back
10298
)
10399
await redis_client.delete(f'{admin_settings.CAPTCHA_LOGIN_REDIS_PREFIX}:{request.state.ip}')
104100
await user_dao.update_login_time(db, obj.username)
101+
response.set_cookie(
102+
settings.COOKIE_REFRESH_TOKEN_KEY,
103+
refresh_token.refresh_token,
104+
settings.COOKIE_REFRESH_TOKEN_EXPIRE_SECONDS,
105+
refresh_token.refresh_token_expire_time,
106+
)
105107
await db.refresh(current_user)
106108
data = GetLoginToken(
107-
access_token=access_token,
108-
refresh_token=refresh_token,
109-
access_token_expire_time=access_token_expire_time,
110-
refresh_token_expire_time=refresh_token_expire_time,
109+
access_token=access_token.access_token,
110+
access_token_expire_time=access_token.access_token_expire_time,
111111
user=current_user, # type: ignore
112112
)
113113
return data
114114

115115
@staticmethod
116-
async def new_token(*, request: Request, refresh_token: str) -> GetNewToken:
117-
user_id = jwt_decode(refresh_token)
116+
async def new_token(*, request: Request, response: Response) -> GetNewToken:
117+
refresh_token = request.cookies.get(settings.COOKIE_REFRESH_TOKEN_KEY)
118+
if not refresh_token:
119+
raise errors.TokenError(msg='Refresh Token 丢失,请重新登录')
120+
try:
121+
user_id = jwt_decode(refresh_token)
122+
except Exception:
123+
raise errors.TokenError(msg='Refresh Token 无效')
118124
if request.user.id != user_id:
119125
raise errors.TokenError(msg='Refresh Token 无效')
120126
async with async_db_session() as db:
@@ -130,23 +136,34 @@ async def new_token(*, request: Request, refresh_token: str) -> GetNewToken:
130136
refresh_token=refresh_token,
131137
multi_login=current_user.is_multi_login,
132138
)
139+
response.set_cookie(
140+
settings.COOKIE_REFRESH_TOKEN_KEY,
141+
new_token.new_refresh_token,
142+
settings.COOKIE_REFRESH_TOKEN_EXPIRE_SECONDS,
143+
new_token.new_refresh_token_expire_time,
144+
)
133145
data = GetNewToken(
134146
access_token=new_token.new_access_token,
135147
access_token_expire_time=new_token.new_access_token_expire_time,
136-
refresh_token=new_token.new_refresh_token,
137-
refresh_token_expire_time=new_token.new_refresh_token_expire_time,
138148
)
139149
return data
140150

141151
@staticmethod
142-
async def logout(*, request: Request) -> None:
152+
async def logout(*, request: Request, response: Response) -> None:
143153
token = await get_token(request)
154+
refresh_token = request.cookies.get(settings.COOKIE_REFRESH_TOKEN_KEY)
155+
response.delete_cookie(settings.COOKIE_REFRESH_TOKEN_KEY)
144156
if request.user.is_multi_login:
145157
key = f'{settings.TOKEN_REDIS_PREFIX}:{request.user.id}:{token}'
146158
await redis_client.delete(key)
159+
if refresh_token:
160+
key = f'{settings.TOKEN_REFRESH_REDIS_PREFIX}:{request.user.id}:{refresh_token}'
161+
await redis_client.delete(key)
147162
else:
148163
key_prefix = f'{settings.TOKEN_REDIS_PREFIX}:{request.user.id}:'
149164
await redis_client.delete_prefix(key_prefix)
165+
key_prefix = f'{settings.TOKEN_REFRESH_REDIS_PREFIX}:{request.user.id}:'
166+
await redis_client.delete_prefix(key_prefix)
150167

151168

152169
auth_service = AuthService()

backend/app/admin/service/oauth2_service.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#!/usr/bin/env python3
22
# -*- coding: utf-8 -*-
33
from fast_captcha import text_captcha
4-
from fastapi import BackgroundTasks, Request
4+
from fastapi import BackgroundTasks, Request, Response
55

66
from backend.app.admin.conf import admin_settings
77
from backend.app.admin.crud.crud_user import user_dao
@@ -13,6 +13,7 @@
1313
from backend.common.enums import LoginLogStatusType, UserSocialType
1414
from backend.common.exception.errors import AuthorizationError
1515
from backend.common.security import jwt
16+
from backend.core.conf import settings
1617
from backend.database.db_mysql import async_db_session
1718
from backend.database.db_redis import redis_client
1819
from backend.utils.timezone import timezone
@@ -21,7 +22,7 @@
2122
class OAuth2Service:
2223
@staticmethod
2324
async def create_with_login(
24-
*, request: Request, background_tasks: BackgroundTasks, user: dict, social: UserSocialType
25+
*, request: Request, response: Response, background_tasks: BackgroundTasks, user: dict, social: UserSocialType
2526
) -> GetLoginToken | None:
2627
async with async_db_session.begin() as db:
2728
# 获取 OAuth2 平台用户信息
@@ -54,12 +55,8 @@ async def create_with_login(
5455
new_user_social = CreateUserSocialParam(source=social.value, uid=str(_id), user_id=sys_user.id)
5556
await user_social_dao.create(db, new_user_social)
5657
# 创建 token
57-
access_token, access_token_expire_time = await jwt.create_access_token(
58-
str(sys_user.id), multi_login=sys_user.is_multi_login
59-
)
60-
refresh_token, refresh_token_expire_time = await jwt.create_refresh_token(
61-
str(sys_user.id), access_token_expire_time, multi_login=sys_user.is_multi_login
62-
)
58+
access_token = await jwt.create_access_token(str(sys_user.id), sys_user.is_multi_login)
59+
refresh_token = await jwt.create_refresh_token(str(sys_user.id), multi_login=sys_user.is_multi_login)
6360
await user_dao.update_login_time(db, sys_user.username)
6461
await db.refresh(sys_user)
6562
login_log = dict(
@@ -72,11 +69,15 @@ async def create_with_login(
7269
)
7370
background_tasks.add_task(LoginLogService.create, **login_log)
7471
await redis_client.delete(f'{admin_settings.CAPTCHA_LOGIN_REDIS_PREFIX}:{request.state.ip}')
72+
response.set_cookie(
73+
settings.COOKIE_REFRESH_TOKEN_KEY,
74+
refresh_token.refresh_token,
75+
settings.COOKIE_REFRESH_TOKEN_EXPIRE_SECONDS,
76+
refresh_token.refresh_token_expire_time,
77+
)
7578
data = GetLoginToken(
76-
access_token=access_token,
77-
refresh_token=refresh_token,
78-
access_token_expire_time=access_token_expire_time,
79-
refresh_token_expire_time=refresh_token_expire_time,
79+
access_token=access_token.access_token,
80+
access_token_expire_time=access_token.access_token_expire_time,
8081
user=sys_user, # type: ignore
8182
)
8283
return data

backend/common/dataclasses.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ class UserAgentInfo:
2626

2727

2828
@dataclasses.dataclass
29-
class RequestCallNextReturn:
29+
class RequestCallNext:
3030
code: str
3131
msg: str
3232
status: StatusType
@@ -35,8 +35,20 @@ class RequestCallNextReturn:
3535

3636

3737
@dataclasses.dataclass
38-
class NewTokenReturn:
38+
class NewToken:
3939
new_access_token: str
40-
new_refresh_token: str
4140
new_access_token_expire_time: datetime
41+
new_refresh_token: str
4242
new_refresh_token_expire_time: datetime
43+
44+
45+
@dataclasses.dataclass
46+
class AccessToken:
47+
access_token: str
48+
access_token_expire_time: datetime
49+
50+
51+
@dataclasses.dataclass
52+
class RefreshToken:
53+
refresh_token: str
54+
refresh_token_expire_time: datetime

backend/common/response/response_schema.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ class ResponseBase:
5757
5858
@router.get('/test')
5959
def test() -> ResponseModel:
60-
return await response_base.success(data={'test': 'test'})
60+
return response_base.success(data={'test': 'test'})
6161
"""
6262

6363
@staticmethod

0 commit comments

Comments
 (0)