Skip to content

Commit cae608c

Browse files
SSL Support (#17)
* enable SSL for test postgres instance * test SSL support * Update docker-compose.yaml * certificates from image * some fixes * fix wrong type import --------- Co-authored-by: Yannic Schröer <yannicschroeer@outlook.de> Co-authored-by: Yannic Schröer <yannic@schroeer.tech>
1 parent 869309e commit cae608c

File tree

6 files changed

+80
-66
lines changed

6 files changed

+80
-66
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ __pycache__/
66
# IDE/OS
77
.idea
88
.DS_Store
9+
.vscode/PythonImportHelper-v2-Completion.json
910

1011
test.db
1112
test.db-journal

database_setup_tools/session_manager.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import threading
22
from functools import cached_property
3-
from typing import Generator, Optional
3+
from typing import Any, Generator, Optional
44

55
from sqlalchemy import create_engine
66
from sqlalchemy.engine import Engine
@@ -21,14 +21,14 @@ def __new__(cls, *args, **kwargs):
2121
cls._instances.append((super(cls, cls).__new__(cls), (args, kwargs)))
2222
return cls._get_cached_instance(args, kwargs)
2323

24-
def __init__(self, database_uri: str, **kwargs):
24+
def __init__(self, database_uri: str, **engine_options: dict[str, Any]):
2525
"""Session Manager constructor
2626
2727
Args:
2828
database_uri (str): The URI of the database to manage sessions for
2929
3030
Keyword Args:
31-
**kwargs: Keyword arguments to pass to the engine
31+
**engine_options: Keyword arguments to pass to the engine
3232
3333
postgresql:
3434
pool_size (int): The maximum number of connections to the database
@@ -39,7 +39,7 @@ def __init__(self, database_uri: str, **kwargs):
3939
raise TypeError("database_uri must be a string")
4040

4141
self._database_uri = database_uri
42-
self._engine = self._get_engine(**kwargs)
42+
self._engine = self._get_engine(**engine_options)
4343
self._session_factory = sessionmaker(self.engine)
4444
self._Session = scoped_session(self._session_factory)
4545

database_setup_tools/setup.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import threading
2-
from typing import List, Optional
2+
from typing import Any, List, Optional, Union
33

44
import sqlalchemy_utils
55
from sqlalchemy import MetaData, Table
@@ -21,13 +21,15 @@ def __new__(cls, *args, **kwargs):
2121
cls._instances.append((super(cls, cls).__new__(cls), (args, kwargs)))
2222
return cls._get_cached_instance(args, kwargs)
2323

24-
def __init__(self, model_metadata: MetaData, database_uri: str):
24+
def __init__(self, model_metadata: MetaData, database_uri: str, **engine_options: dict[str, Any]):
2525
"""Set up a database based on its URI and metadata. Will not overwrite existing data.
2626
2727
Args:
2828
model_metadata (Metadata): The metadata of the models to create the tables for
2929
database_uri (str): The URI of the database to create the tables for
3030
31+
Keyword Args:
32+
**engine_options: Keyword arguments to pass to the engine
3133
"""
3234
if not isinstance(model_metadata, MetaData):
3335
raise TypeError("model_metadata must be a MetaData")
@@ -37,6 +39,7 @@ def __init__(self, model_metadata: MetaData, database_uri: str):
3739

3840
self._model_metadata = model_metadata
3941
self._database_uri = database_uri
42+
self._engine_options = engine_options
4043
self.create_database()
4144

4245
@property
@@ -55,7 +58,7 @@ def session_manager(self) -> SessionManager:
5558
Returns:
5659
SessionManager: The session manager
5760
"""
58-
return SessionManager(database_uri=self.database_uri)
61+
return SessionManager(database_uri=self.database_uri, **self._engine_options)
5962

6063
@property
6164
def database_uri(self) -> str:
@@ -85,7 +88,7 @@ def create_database(self) -> bool:
8588
return True
8689
return False
8790

88-
def truncate(self, tables: Optional[List[SQLModel | SQLModelMetaclass]] = None):
91+
def truncate(self, tables: Optional[List[Union[SQLModel, SQLModelMetaclass]]] = None):
8992
"""Truncate all tables in the database"""
9093
tables_to_truncate: List[Table] = self.model_metadata.sorted_tables
9194
if tables is not None:

tests/docker-compose.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@ services:
44
postgresql-database:
55
image: postgres:15
66
container_name: "database-setup-tools-test-postgres-database"
7+
command: >
8+
-c ssl=on
9+
-c ssl_cert_file=/etc/ssl/certs/ssl-cert-snakeoil.pem
10+
-c ssl_key_file=/etc/ssl/private/ssl-cert-snakeoil.key
711
ports:
812
- "5432:5432"
913
environment:

tests/integration/database_config.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
POSTGRESQL_DATABASE_URI = "postgresql+psycopg2://postgres:postgres@localhost:5432/test"
2-
31
DATABASE_URIS = [
4-
POSTGRESQL_DATABASE_URI,
2+
"postgresql+psycopg2://postgres:postgres@localhost:5432/test", # PostgreSQL
3+
"postgresql+psycopg2://postgres:postgres@localhost:5432/test?sslmode=require", # PostgreSQL with SSL
54
]

tests/integration/test_database_integration.py

Lines changed: 62 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,19 @@
22

33
import pytest
44
from sqlalchemy.exc import OperationalError
5-
from sqlalchemy.orm.scoping import ScopedSession
6-
from sqlmodel import Field, SQLModel
5+
from sqlalchemy.orm import Session
76

8-
from database_setup_tools.session_manager import SessionManager
97
from database_setup_tools.setup import DatabaseSetup
108
from tests.integration.database_config import DATABASE_URIS
119
from tests.sample_model import Customer, model_metadata
1210

1311

1412
@pytest.mark.parametrize("database_uri", DATABASE_URIS)
1513
class TestDatabaseIntegration:
14+
#
15+
# Fixtures
16+
#
17+
1618
@pytest.fixture
1719
def database_setup(self, database_uri: str) -> DatabaseSetup:
1820
setup = DatabaseSetup(model_metadata=model_metadata, database_uri=database_uri)
@@ -22,31 +24,12 @@ def database_setup(self, database_uri: str) -> DatabaseSetup:
2224
setup.drop_database()
2325

2426
@pytest.fixture
25-
def database_session(self, database_setup: DatabaseSetup) -> Iterator[ScopedSession]:
27+
def session(self, database_setup: DatabaseSetup) -> Iterator[Session]:
2628
"""Get a database session"""
2729
return next(database_setup.session_manager.get_session())
2830

29-
def test_create_database_and_tables(self, database_setup: DatabaseSetup, database_session: ScopedSession):
30-
"""Test that the tables are created correctly"""
31-
database_session.execute(f"SELECT * FROM {Customer.__tablename__}")
32-
33-
def test_create_database_multiple_times(self, database_setup: DatabaseSetup, database_session: ScopedSession):
34-
"""Test that creating the database multiple times does not cause problems"""
35-
database_setup.create_database()
36-
database_session.execute(f"SELECT * FROM {Customer.__tablename__}")
37-
38-
def test_drop_database(self, database_setup: DatabaseSetup, database_session: ScopedSession):
39-
"""Test that the database is dropped correctly"""
40-
assert database_setup.drop_database() is True
41-
42-
with pytest.raises(OperationalError):
43-
database_session.execute(f"SELECT * FROM {Customer.__tablename__}")
44-
45-
assert database_setup.drop_database() is False
46-
47-
def test_truncate_all_tables(self, database_setup: DatabaseSetup, database_session: ScopedSession):
48-
"""Test that all tables are truncated correctly"""
49-
31+
@pytest.fixture
32+
def delivery_table(self, session: Session) -> str:
5033
setup_statements = [
5134
f"""CREATE TABLE delivery (
5235
id INTEGER,
@@ -63,44 +46,68 @@ def test_truncate_all_tables(self, database_setup: DatabaseSetup, database_sessi
6346
f"INSERT INTO \"{Customer.__tablename__}\" VALUES (1, 'John Doe')",
6447
"INSERT INTO \"delivery\" VALUES (1, 'Delivery 1', 1)",
6548
]
49+
6650
for statement in setup_statements:
67-
database_session.execute(statement)
51+
session.execute(statement)
6852

69-
assert database_session.execute(f"SELECT * FROM {Customer.__tablename__}").rowcount == 1
70-
assert database_session.execute(f'SELECT * FROM "delivery"').rowcount == 1
71-
database_session.commit()
53+
return "delivery"
7254

73-
database_setup.truncate()
55+
@pytest.fixture
56+
def standalone_table(self, session: Session) -> str:
57+
setup_statements = [
58+
"CREATE TABLE standalone (id INTEGER, PRIMARY KEY(id))",
59+
'SELECT * FROM "standalone"',
60+
'INSERT INTO "standalone" VALUES (1)',
61+
]
7462

75-
assert database_session.execute(f"SELECT * FROM {Customer.__tablename__}").rowcount == 0
76-
assert database_session.execute(f'SELECT * FROM "delivery"').rowcount == 0
63+
for statement in setup_statements:
64+
session.execute(statement)
7765

78-
def test_truncate_custom_tables(self, database_uri: str):
79-
"""Test that only specified tables are truncated correctly"""
66+
return "standalone"
8067

81-
class TableToTruncate(SQLModel, table=True):
82-
id: int = Field(index=True, primary_key=True)
83-
name: str
68+
#
69+
# Tests
70+
#
8471

85-
setup = DatabaseSetup(model_metadata=model_metadata, database_uri=database_uri)
86-
setup.drop_database()
87-
setup.create_database()
88-
database_session = next(setup.session_manager.get_session())
72+
def test_create_database_and_tables(self, database_setup: DatabaseSetup, session: Session):
73+
"""Test that the tables are created correctly"""
74+
session.execute(f"SELECT * FROM {Customer.__tablename__}")
8975

90-
setup_statements = [
91-
f'SELECT * FROM "{Customer.__tablename__}"',
92-
f'SELECT * FROM "{TableToTruncate.__tablename__}"',
93-
f"INSERT INTO \"{Customer.__tablename__}\" VALUES (1, 'John Doe')",
94-
f"INSERT INTO \"{TableToTruncate.__tablename__}\" VALUES (1, 'Test')",
95-
]
96-
for statement in setup_statements:
97-
database_session.execute(statement)
76+
def test_create_database_multiple_times(self, database_setup: DatabaseSetup, session: Session):
77+
"""Test that creating the database multiple times does not cause problems"""
78+
database_setup.create_database()
79+
session.execute(f"SELECT * FROM {Customer.__tablename__}")
80+
81+
def test_drop_database(self, database_setup: DatabaseSetup, session: Session):
82+
"""Test that the database is dropped correctly"""
83+
assert database_setup.drop_database() is True
84+
85+
with pytest.raises(OperationalError):
86+
session.execute(f"SELECT * FROM {Customer.__tablename__}")
87+
88+
assert database_setup.drop_database() is False
89+
90+
def test_truncate_all_tables(self, database_setup: DatabaseSetup, session: Session, delivery_table: str):
91+
"""Test that all tables are truncated correctly"""
92+
93+
assert session.execute(f"SELECT * FROM {Customer.__tablename__}").rowcount == 1
94+
assert session.execute(f'SELECT * FROM "{delivery_table}"').rowcount == 1
95+
session.commit()
96+
97+
database_setup.truncate()
98+
99+
assert session.execute(f"SELECT * FROM {Customer.__tablename__}").rowcount == 0
100+
assert session.execute(f'SELECT * FROM "{delivery_table}"').rowcount == 0
101+
102+
def test_truncate_custom_tables(self, database_setup: DatabaseSetup, session: Session, delivery_table: str, standalone_table: str):
103+
"""Test that only specified tables are truncated correctly"""
98104

99-
assert database_session.execute(f"SELECT * FROM {Customer.__tablename__}").rowcount == 1
100-
assert database_session.execute(f'SELECT * FROM "{TableToTruncate.__tablename__}"').rowcount == 1
101-
database_session.commit()
105+
assert session.execute(f"SELECT * FROM {Customer.__tablename__}").rowcount == 1
106+
assert session.execute(f'SELECT * FROM "{delivery_table}"').rowcount == 1
107+
session.commit()
102108

103-
setup.truncate(tables=[TableToTruncate])
109+
database_setup.truncate(tables=[Customer])
104110

105-
assert database_session.execute(f"SELECT * FROM {Customer.__tablename__}").rowcount == 1
106-
assert database_session.execute(f'SELECT * FROM "{TableToTruncate.__tablename__}"').rowcount == 0
111+
assert session.execute(f"SELECT * FROM {Customer.__tablename__}").rowcount == 0
112+
assert session.execute(f'SELECT * FROM "{delivery_table}"').rowcount == 0
113+
assert session.execute(f'SELECT * FROM "{standalone_table}"').rowcount == 1

0 commit comments

Comments
 (0)