2
2
3
3
import pytest
4
4
from sqlalchemy .exc import OperationalError
5
- from sqlalchemy .orm .scoping import ScopedSession
6
- from sqlmodel import Field , SQLModel
5
+ from sqlalchemy .orm import Session
7
6
8
- from database_setup_tools .session_manager import SessionManager
9
7
from database_setup_tools .setup import DatabaseSetup
10
8
from tests .integration .database_config import DATABASE_URIS
11
9
from tests .sample_model import Customer , model_metadata
12
10
13
11
14
12
@pytest .mark .parametrize ("database_uri" , DATABASE_URIS )
15
13
class TestDatabaseIntegration :
14
+ #
15
+ # Fixtures
16
+ #
17
+
16
18
@pytest .fixture
17
19
def database_setup (self , database_uri : str ) -> DatabaseSetup :
18
20
setup = DatabaseSetup (model_metadata = model_metadata , database_uri = database_uri )
@@ -22,31 +24,12 @@ def database_setup(self, database_uri: str) -> DatabaseSetup:
22
24
setup .drop_database ()
23
25
24
26
@pytest .fixture
25
- def database_session (self , database_setup : DatabaseSetup ) -> Iterator [ScopedSession ]:
27
+ def session (self , database_setup : DatabaseSetup ) -> Iterator [Session ]:
26
28
"""Get a database session"""
27
29
return next (database_setup .session_manager .get_session ())
28
30
29
- def test_create_database_and_tables (self , database_setup : DatabaseSetup , database_session : ScopedSession ):
30
- """Test that the tables are created correctly"""
31
- database_session .execute (f"SELECT * FROM { Customer .__tablename__ } " )
32
-
33
- def test_create_database_multiple_times (self , database_setup : DatabaseSetup , database_session : ScopedSession ):
34
- """Test that creating the database multiple times does not cause problems"""
35
- database_setup .create_database ()
36
- database_session .execute (f"SELECT * FROM { Customer .__tablename__ } " )
37
-
38
- def test_drop_database (self , database_setup : DatabaseSetup , database_session : ScopedSession ):
39
- """Test that the database is dropped correctly"""
40
- assert database_setup .drop_database () is True
41
-
42
- with pytest .raises (OperationalError ):
43
- database_session .execute (f"SELECT * FROM { Customer .__tablename__ } " )
44
-
45
- assert database_setup .drop_database () is False
46
-
47
- def test_truncate_all_tables (self , database_setup : DatabaseSetup , database_session : ScopedSession ):
48
- """Test that all tables are truncated correctly"""
49
-
31
+ @pytest .fixture
32
+ def delivery_table (self , session : Session ) -> str :
50
33
setup_statements = [
51
34
f"""CREATE TABLE delivery (
52
35
id INTEGER,
@@ -63,44 +46,68 @@ def test_truncate_all_tables(self, database_setup: DatabaseSetup, database_sessi
63
46
f"INSERT INTO \" { Customer .__tablename__ } \" VALUES (1, 'John Doe')" ,
64
47
"INSERT INTO \" delivery\" VALUES (1, 'Delivery 1', 1)" ,
65
48
]
49
+
66
50
for statement in setup_statements :
67
- database_session .execute (statement )
51
+ session .execute (statement )
68
52
69
- assert database_session .execute (f"SELECT * FROM { Customer .__tablename__ } " ).rowcount == 1
70
- assert database_session .execute (f'SELECT * FROM "delivery"' ).rowcount == 1
71
- database_session .commit ()
53
+ return "delivery"
72
54
73
- database_setup .truncate ()
55
+ @pytest .fixture
56
+ def standalone_table (self , session : Session ) -> str :
57
+ setup_statements = [
58
+ "CREATE TABLE standalone (id INTEGER, PRIMARY KEY(id))" ,
59
+ 'SELECT * FROM "standalone"' ,
60
+ 'INSERT INTO "standalone" VALUES (1)' ,
61
+ ]
74
62
75
- assert database_session . execute ( f"SELECT * FROM { Customer . __tablename__ } " ). rowcount == 0
76
- assert database_session . execute ( f'SELECT * FROM "delivery"' ). rowcount == 0
63
+ for statement in setup_statements :
64
+ session . execute ( statement )
77
65
78
- def test_truncate_custom_tables (self , database_uri : str ):
79
- """Test that only specified tables are truncated correctly"""
66
+ return "standalone"
80
67
81
- class TableToTruncate ( SQLModel , table = True ):
82
- id : int = Field ( index = True , primary_key = True )
83
- name : str
68
+ #
69
+ # Tests
70
+ #
84
71
85
- setup = DatabaseSetup (model_metadata = model_metadata , database_uri = database_uri )
86
- setup .drop_database ()
87
- setup .create_database ()
88
- database_session = next (setup .session_manager .get_session ())
72
+ def test_create_database_and_tables (self , database_setup : DatabaseSetup , session : Session ):
73
+ """Test that the tables are created correctly"""
74
+ session .execute (f"SELECT * FROM { Customer .__tablename__ } " )
89
75
90
- setup_statements = [
91
- f'SELECT * FROM "{ Customer .__tablename__ } "' ,
92
- f'SELECT * FROM "{ TableToTruncate .__tablename__ } "' ,
93
- f"INSERT INTO \" { Customer .__tablename__ } \" VALUES (1, 'John Doe')" ,
94
- f"INSERT INTO \" { TableToTruncate .__tablename__ } \" VALUES (1, 'Test')" ,
95
- ]
96
- for statement in setup_statements :
97
- database_session .execute (statement )
76
+ def test_create_database_multiple_times (self , database_setup : DatabaseSetup , session : Session ):
77
+ """Test that creating the database multiple times does not cause problems"""
78
+ database_setup .create_database ()
79
+ session .execute (f"SELECT * FROM { Customer .__tablename__ } " )
80
+
81
+ def test_drop_database (self , database_setup : DatabaseSetup , session : Session ):
82
+ """Test that the database is dropped correctly"""
83
+ assert database_setup .drop_database () is True
84
+
85
+ with pytest .raises (OperationalError ):
86
+ session .execute (f"SELECT * FROM { Customer .__tablename__ } " )
87
+
88
+ assert database_setup .drop_database () is False
89
+
90
+ def test_truncate_all_tables (self , database_setup : DatabaseSetup , session : Session , delivery_table : str ):
91
+ """Test that all tables are truncated correctly"""
92
+
93
+ assert session .execute (f"SELECT * FROM { Customer .__tablename__ } " ).rowcount == 1
94
+ assert session .execute (f'SELECT * FROM "{ delivery_table } "' ).rowcount == 1
95
+ session .commit ()
96
+
97
+ database_setup .truncate ()
98
+
99
+ assert session .execute (f"SELECT * FROM { Customer .__tablename__ } " ).rowcount == 0
100
+ assert session .execute (f'SELECT * FROM "{ delivery_table } "' ).rowcount == 0
101
+
102
+ def test_truncate_custom_tables (self , database_setup : DatabaseSetup , session : Session , delivery_table : str , standalone_table : str ):
103
+ """Test that only specified tables are truncated correctly"""
98
104
99
- assert database_session .execute (f"SELECT * FROM { Customer .__tablename__ } " ).rowcount == 1
100
- assert database_session .execute (f'SELECT * FROM "{ TableToTruncate . __tablename__ } "' ).rowcount == 1
101
- database_session .commit ()
105
+ assert session .execute (f"SELECT * FROM { Customer .__tablename__ } " ).rowcount == 1
106
+ assert session .execute (f'SELECT * FROM "{ delivery_table } "' ).rowcount == 1
107
+ session .commit ()
102
108
103
- setup .truncate (tables = [TableToTruncate ])
109
+ database_setup .truncate (tables = [Customer ])
104
110
105
- assert database_session .execute (f"SELECT * FROM { Customer .__tablename__ } " ).rowcount == 1
106
- assert database_session .execute (f'SELECT * FROM "{ TableToTruncate .__tablename__ } "' ).rowcount == 0
111
+ assert session .execute (f"SELECT * FROM { Customer .__tablename__ } " ).rowcount == 0
112
+ assert session .execute (f'SELECT * FROM "{ delivery_table } "' ).rowcount == 0
113
+ assert session .execute (f'SELECT * FROM "{ standalone_table } "' ).rowcount == 1
0 commit comments