Skip to content

Commit 81c0d2c

Browse files
authored
Fix pytest interface unit tests (#233)
1 parent 3bb28a3 commit 81c0d2c

File tree

8 files changed

+28
-26
lines changed

8 files changed

+28
-26
lines changed

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -174,14 +174,14 @@ Click [fastapi_best_architecture_ui](https://github.com/fastapi-practices/fastap
174174
Execute unittests via pytest
175175

176176
1. Create the test database `fba_test`, select utf8mb4 encoding
177-
2. Enter the app directory
177+
2. Using `backend/sql/create_tables.sql` file to create database tables
178+
3. Initialize the test data using the `backend/sql/init_pytest_data.sql` file
179+
4. Enter the app directory
178180

179181
```shell
180182
cd backend/app/
181183
```
182184

183-
3. Using `backend/sql/create_tables.sql` file to create database tables
184-
4. Initialize the test data using the `backend/sql/init_pytest_data.sql` file
185185
5. Execute the test command
186186

187187
```shell

README.zh-CN.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -168,14 +168,14 @@ TODO:
168168
通过 pytest 执行单元测试
169169

170170
1. 创建测试数据库 `fba_test`,选择 utf8mb4 编码
171-
2. 进入app目录
171+
2. 使用 `backend/sql/create_tables.sql` 文件创建数据库表
172+
3. 使用 `backend/sql/init_pytest_data.sql` 文件初始化测试数据
173+
4. 进入app目录
172174

173175
```shell
174176
cd backend/app/
175177
```
176178

177-
3. 使用 `backend/sql/create_tables.sql` 文件创建数据库表
178-
4. 使用 `backend/sql/init_pytest_data.sql` 文件初始化测试数据
179179
5. 执行测试命令
180180

181181
```shell

backend/app/api/v1/auth/auth.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
@router.post('/swagger_login', summary='swagger 表单登录', description='form 格式登录,仅用于 swagger 文档调试接口')
2020
async def swagger_user_login(form_data: OAuth2PasswordRequestForm = Depends()) -> GetSwaggerToken:
2121
token, user = await AuthService().swagger_login(form_data=form_data)
22-
return GetSwaggerToken(access_token=token, user=user)
22+
return GetSwaggerToken(access_token=token, user=user) # type: ignore
2323

2424

2525
@router.post(
@@ -37,7 +37,7 @@ async def user_login(request: Request, obj: AuthLogin, background_tasks: Backgro
3737
refresh_token=refresh_token,
3838
access_token_expire_time=access_expire,
3939
refresh_token_expire_time=refresh_expire,
40-
user=user,
40+
user=user, # type: ignore
4141
)
4242
return await response_base.success(data=data)
4343

backend/app/tests/api_v1/test_auth.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,17 @@
33
from starlette.testclient import TestClient
44

55
from backend.app.core.conf import settings
6+
from backend.app.tests.conftest import PYTEST_USERNAME, PYTEST_PASSWORD
67

78

89
def test_login(client: TestClient) -> None:
910
data = {
10-
'username': 'admin',
11-
'password': '123456',
11+
'username': PYTEST_USERNAME,
12+
'password': PYTEST_PASSWORD,
1213
}
13-
response = client.post(f'{settings.API_V1_STR}/auth/login', json=data)
14+
response = client.post(f'{settings.API_V1_STR}/auth/swagger_login', data=data)
1415
assert response.status_code == 200
15-
assert response.json()['data']['access_token_type'] == 'Bearer'
16+
assert response.json()['token_type'] == 'Bearer'
1617

1718

1819
def test_logout(client: TestClient, token_headers: dict[str, str]) -> None:

backend/app/tests/conftest.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,11 @@
1717
app.dependency_overrides[get_db] = override_get_db
1818

1919

20+
# Test user
21+
PYTEST_USERNAME = 'admin'
22+
PYTEST_PASSWORD = '123456'
23+
24+
2025
@pytest.fixture(scope='module')
2126
def client() -> Generator:
2227
with TestClient(app) as c:
@@ -25,4 +30,4 @@ def client() -> Generator:
2530

2631
@pytest.fixture(scope='module')
2732
def token_headers(client: TestClient) -> Dict[str, str]:
28-
return get_token_headers(client=client, username='admin', password='123456')
33+
return get_token_headers(client=client, username=PYTEST_USERNAME, password=PYTEST_PASSWORD)

backend/app/tests/utils/db_mysql.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,32 +3,25 @@
33
from sqlalchemy.ext.asyncio import AsyncSession
44

55
from backend.app.core.conf import settings
6-
from backend.app.models.base import MappedBase
76
from backend.app.database.db_mysql import create_engine_and_session
87

98
TEST_DB_DATABASE = settings.DB_DATABASE + '_test'
109

11-
SQLALCHEMY_DATABASE_URL = (
10+
TEST_SQLALCHEMY_DATABASE_URL = (
1211
f'mysql+asyncmy://{settings.DB_USER}:{settings.DB_PASSWORD}@{settings.DB_HOST}:'
1312
f'{settings.DB_PORT}/{TEST_DB_DATABASE}?charset={settings.DB_CHARSET}'
1413
)
1514

16-
async_engine, async_db_session = create_engine_and_session(SQLALCHEMY_DATABASE_URL)
15+
test_async_engine, test_async_db_session = create_engine_and_session(TEST_SQLALCHEMY_DATABASE_URL)
1716

1817

1918
async def override_get_db() -> AsyncSession:
2019
"""session 生成器"""
21-
session = async_db_session()
20+
session = test_async_db_session()
2221
try:
2322
yield session
2423
except Exception as se:
2524
await session.rollback()
2625
raise se
2726
finally:
2827
await session.close()
29-
30-
31-
async def create_table():
32-
"""创建数据库表"""
33-
async with async_engine.begin() as coon:
34-
await coon.run_sync(MappedBase.metadata.create_all)

backend/app/tests/utils/get_headers.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ def get_token_headers(client: TestClient, username: str, password: str) -> Dict[
1212
'username': username,
1313
'password': password,
1414
}
15-
response = client.post(f'{settings.API_V1_STR}/auth/login', json=data)
16-
token_type = response.json()['data']['access_token_type']
17-
access_token = response.json()['data']['access_token']
15+
response = client.post(f'{settings.API_V1_STR}/auth/swagger_login', data=data)
16+
token_type = response.json()['token_type']
17+
access_token = response.json()['access_token']
1818
headers = {'Authorization': f'{token_type} {access_token}'}
1919
return headers

backend/app/utils/request_parse.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ def get_request_ip(request: Request) -> str:
2424
ip = forwarded.split(',')[0]
2525
else:
2626
ip = request.client.host
27+
# 忽略 pytest
28+
if ip == 'testclient':
29+
ip = '127.0.0.1'
2730
return ip
2831

2932

0 commit comments

Comments
 (0)