Skip to content

Commit 78fd33e

Browse files
authored
Add Truncate Functionality (#14)
1 parent 8b9f344 commit 78fd33e

File tree

7 files changed

+112
-26
lines changed

7 files changed

+112
-26
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ __pycache__/
88
.DS_Store
99

1010
test.db
11+
test.db-journal
1112

1213
# C extensions
1314
*.so

database_setup_tools/setup.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import threading
2-
from typing import Optional
2+
from typing import List, Optional
33

44
import sqlalchemy_utils
5-
from sqlalchemy import MetaData
5+
from sqlalchemy import MetaData, Table
6+
from sqlmodel import SQLModel
67

78
from database_setup_tools.session_manager import SessionManager
89

@@ -47,6 +48,15 @@ def model_metadata(self) -> MetaData:
4748
"""
4849
return self._model_metadata
4950

51+
@property
52+
def session_manager(self) -> SessionManager:
53+
"""Getter for the session manager
54+
55+
Returns:
56+
SessionManager: The session manager
57+
"""
58+
return SessionManager(database_uri=self.database_uri)
59+
5060
@property
5161
def database_uri(self) -> str:
5262
"""Getter for the database URI
@@ -71,11 +81,26 @@ def create_database(self) -> bool:
7181
"""Create the database and the tables if not done yet"""
7282
if not sqlalchemy_utils.database_exists(self.database_uri):
7383
sqlalchemy_utils.create_database(self.database_uri)
74-
session_manager = SessionManager(self.database_uri)
75-
self.model_metadata.create_all(session_manager.engine)
84+
self.model_metadata.create_all(self.session_manager.engine)
7685
return True
7786
return False
7887

88+
def truncate(self, tables: Optional[List[SQLModel]] = None):
89+
"""Truncate all tables in the database"""
90+
tables_to_truncate: List[Table] = self.model_metadata.sorted_tables
91+
if tables is not None:
92+
table_names = [table.__tablename__ for table in tables]
93+
tables_to_truncate = filter(lambda table: table.name in table_names, tables_to_truncate)
94+
95+
session = next(self.session_manager.get_session())
96+
97+
try:
98+
tables_with_schema = [f"{table.schema or 'public'}.\"{table.name}\"" for table in tables_to_truncate]
99+
session.execute(f"TRUNCATE TABLE {', '.join(tables_with_schema)} CASCADE;")
100+
session.commit()
101+
finally:
102+
session.close()
103+
79104
@classmethod
80105
def _get_cached_instance(cls, args: tuple, kwargs: dict) -> Optional[object]:
81106
"""Provides a cached instance of the SessionManager class if existing"""

poetry.lock

Lines changed: 4 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/integration/database_config.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
1-
SQLITE_DATABASE_URI = "sqlite:///test.db"
21
POSTGRESQL_DATABASE_URI = "postgresql+psycopg2://postgres:postgres@localhost:5432/test"
32

43
DATABASE_URIS = [
5-
SQLITE_DATABASE_URI,
64
POSTGRESQL_DATABASE_URI,
75
]

tests/integration/test_database_integration.py

Lines changed: 70 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,45 +3,104 @@
33
import pytest
44
from sqlalchemy.exc import OperationalError
55
from sqlalchemy.orm.scoping import ScopedSession
6+
from sqlmodel import Field, SQLModel
67

78
from database_setup_tools.session_manager import SessionManager
89
from database_setup_tools.setup import DatabaseSetup
910
from tests.integration.database_config import DATABASE_URIS
10-
from tests.sample_model import User, model_metadata
11+
from tests.sample_model import Customer, model_metadata
1112

1213

1314
@pytest.mark.parametrize("database_uri", DATABASE_URIS)
1415
class TestDatabaseIntegration:
1516
@pytest.fixture
1617
def database_setup(self, database_uri: str) -> DatabaseSetup:
1718
setup = DatabaseSetup(model_metadata=model_metadata, database_uri=database_uri)
19+
setup.drop_database()
20+
setup.create_database()
1821
yield setup
1922
setup.drop_database()
2023

2124
@pytest.fixture
22-
def database_session(self, database_uri: str) -> Iterator[ScopedSession]:
25+
def database_session(self, database_setup: DatabaseSetup) -> Iterator[ScopedSession]:
2326
"""Get a database session"""
24-
session_manager = SessionManager(database_uri)
25-
return next(session_manager.get_session())
27+
return next(database_setup.session_manager.get_session())
2628

2729
def test_create_database_and_tables(self, database_setup: DatabaseSetup, database_session: ScopedSession):
2830
"""Test that the tables are created correctly"""
29-
# noinspection SqlInjection,SqlDialectInspection
30-
database_session.execute(f"SELECT * FROM {User.__tablename__}")
31+
database_session.execute(f"SELECT * FROM {Customer.__tablename__}")
3132

3233
def test_create_database_multiple_times(self, database_setup: DatabaseSetup, database_session: ScopedSession):
3334
"""Test that creating the database multiple times does not cause problems"""
3435
database_setup.create_database()
35-
# noinspection SqlInjection,SqlDialectInspection
36-
database_session.execute(f"SELECT * FROM {User.__tablename__}")
36+
database_session.execute(f"SELECT * FROM {Customer.__tablename__}")
3737

3838
def test_drop_database(self, database_setup: DatabaseSetup, database_session: ScopedSession):
3939
"""Test that the database is dropped correctly"""
40-
database_setup.create_database()
4140
assert database_setup.drop_database() is True
4241

4342
with pytest.raises(OperationalError):
44-
# noinspection SqlDialectInspection
45-
database_session.execute(f"SELECT * FROM {User.__tablename__}")
43+
database_session.execute(f"SELECT * FROM {Customer.__tablename__}")
4644

4745
assert database_setup.drop_database() is False
46+
47+
def test_truncate_all_tables(self, database_setup: DatabaseSetup, database_session: ScopedSession):
48+
"""Test that all tables are truncated correctly"""
49+
50+
setup_statements = [
51+
f"""CREATE TABLE delivery (
52+
id INTEGER,
53+
name TEXT NOT NULL,
54+
customer_id INTEGER,
55+
PRIMARY KEY(id),
56+
CONSTRAINT fk_user
57+
FOREIGN KEY(customer_id)
58+
REFERENCES "{Customer.__tablename__}"(id)
59+
)
60+
""",
61+
f'SELECT * FROM "{Customer.__tablename__}"',
62+
'SELECT * FROM "delivery"',
63+
f"INSERT INTO \"{Customer.__tablename__}\" VALUES (1, 'John Doe')",
64+
"INSERT INTO \"delivery\" VALUES (1, 'Delivery 1', 1)",
65+
]
66+
for statement in setup_statements:
67+
database_session.execute(statement)
68+
69+
assert database_session.execute(f"SELECT * FROM {Customer.__tablename__}").rowcount == 1
70+
assert database_session.execute(f'SELECT * FROM "delivery"').rowcount == 1
71+
database_session.commit()
72+
73+
database_setup.truncate()
74+
75+
assert database_session.execute(f"SELECT * FROM {Customer.__tablename__}").rowcount == 0
76+
assert database_session.execute(f'SELECT * FROM "delivery"').rowcount == 0
77+
78+
def test_truncate_custom_tables(self, database_uri: str):
79+
"""Test that only specified tables are truncated correctly"""
80+
81+
class TableToTruncate(SQLModel, table=True):
82+
id: int = Field(index=True, primary_key=True)
83+
name: str
84+
85+
setup = DatabaseSetup(model_metadata=model_metadata, database_uri=database_uri)
86+
setup.drop_database()
87+
setup.create_database()
88+
database_session = next(setup.session_manager.get_session())
89+
90+
setup_statements = [
91+
f'SELECT * FROM "{Customer.__tablename__}"',
92+
f'SELECT * FROM "{TableToTruncate.__tablename__}"',
93+
f"INSERT INTO \"{Customer.__tablename__}\" VALUES (1, 'John Doe')",
94+
f"INSERT INTO \"{TableToTruncate.__tablename__}\" VALUES (1, 'Test')",
95+
]
96+
for statement in setup_statements:
97+
database_session.execute(statement)
98+
99+
assert database_session.execute(f"SELECT * FROM {Customer.__tablename__}").rowcount == 1
100+
assert database_session.execute(f'SELECT * FROM "{TableToTruncate.__tablename__}"').rowcount == 1
101+
database_session.commit()
102+
103+
setup.truncate(tables=[TableToTruncate])
104+
105+
assert database_session.execute(f"SELECT * FROM {Customer.__tablename__}").rowcount == 1
106+
assert database_session.execute(f'SELECT * FROM "{TableToTruncate.__tablename__}"').rowcount == 0

tests/sample_model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
from sqlmodel import SQLModel, Field
1+
from sqlmodel import Field, SQLModel
22

33

4-
class User(SQLModel, table=True):
5-
"""User model"""
4+
class Customer(SQLModel, table=True):
5+
"""Customer model"""
66

77
id: int = Field(index=True, primary_key=True)
88
name: str

tests/unit/test_setup.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,15 @@ def test_create_database_setup_fail_database_uri_invalid_type(invalid_database_u
4343
with pytest.raises(TypeError):
4444
DatabaseSetup(model_metadata=model_metadata, database_uri=invalid_database_uri)
4545

46-
def test_database_uri(self, database_setup: DatabaseSetup, database_uri: str):
46+
def test_database_uri_property(self, database_setup: DatabaseSetup, database_uri: str):
4747
assert database_setup.database_uri == database_uri
4848

49-
def test_model_metadata(self, database_setup: DatabaseSetup):
49+
def test_model_metadata_property(self, database_setup: DatabaseSetup):
5050
assert database_setup.model_metadata == model_metadata
5151

52+
def test_session_manager_property(self, database_setup: DatabaseSetup, database_uri: str):
53+
assert database_setup.session_manager == SessionManager(database_uri=database_uri)
54+
5255
def test_create_database(self, database_setup: DatabaseSetup, database_uri: str, when: Callable, expect: Callable):
5356
unstub() # remove stub for create_database method
5457

0 commit comments

Comments
 (0)