Skip to content

SSL Support #17

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Apr 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ __pycache__/
# IDE/OS
.idea
.DS_Store
.vscode/PythonImportHelper-v2-Completion.json

test.db
test.db-journal
Expand Down
8 changes: 4 additions & 4 deletions database_setup_tools/session_manager.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import threading
from functools import cached_property
from typing import Generator, Optional
from typing import Any, Generator, Optional

from sqlalchemy import create_engine
from sqlalchemy.engine import Engine
Expand All @@ -21,14 +21,14 @@ def __new__(cls, *args, **kwargs):
cls._instances.append((super(cls, cls).__new__(cls), (args, kwargs)))
return cls._get_cached_instance(args, kwargs)

def __init__(self, database_uri: str, **kwargs):
def __init__(self, database_uri: str, **engine_options: dict[str, Any]):
"""Session Manager constructor

Args:
database_uri (str): The URI of the database to manage sessions for

Keyword Args:
**kwargs: Keyword arguments to pass to the engine
**engine_options: Keyword arguments to pass to the engine

postgresql:
pool_size (int): The maximum number of connections to the database
Expand All @@ -39,7 +39,7 @@ def __init__(self, database_uri: str, **kwargs):
raise TypeError("database_uri must be a string")

self._database_uri = database_uri
self._engine = self._get_engine(**kwargs)
self._engine = self._get_engine(**engine_options)
self._session_factory = sessionmaker(self.engine)
self._Session = scoped_session(self._session_factory)

Expand Down
11 changes: 7 additions & 4 deletions database_setup_tools/setup.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import threading
from typing import List, Optional
from typing import Any, List, Optional, Union

import sqlalchemy_utils
from sqlalchemy import MetaData, Table
Expand All @@ -21,13 +21,15 @@ def __new__(cls, *args, **kwargs):
cls._instances.append((super(cls, cls).__new__(cls), (args, kwargs)))
return cls._get_cached_instance(args, kwargs)

def __init__(self, model_metadata: MetaData, database_uri: str):
def __init__(self, model_metadata: MetaData, database_uri: str, **engine_options: dict[str, Any]):
"""Set up a database based on its URI and metadata. Will not overwrite existing data.

Args:
model_metadata (Metadata): The metadata of the models to create the tables for
database_uri (str): The URI of the database to create the tables for

Keyword Args:
**engine_options: Keyword arguments to pass to the engine
"""
if not isinstance(model_metadata, MetaData):
raise TypeError("model_metadata must be a MetaData")
Expand All @@ -37,6 +39,7 @@ def __init__(self, model_metadata: MetaData, database_uri: str):

self._model_metadata = model_metadata
self._database_uri = database_uri
self._engine_options = engine_options
self.create_database()

@property
Expand All @@ -55,7 +58,7 @@ def session_manager(self) -> SessionManager:
Returns:
SessionManager: The session manager
"""
return SessionManager(database_uri=self.database_uri)
return SessionManager(database_uri=self.database_uri, **self._engine_options)

@property
def database_uri(self) -> str:
Expand Down Expand Up @@ -85,7 +88,7 @@ def create_database(self) -> bool:
return True
return False

def truncate(self, tables: Optional[List[SQLModel | SQLModelMetaclass]] = None):
def truncate(self, tables: Optional[List[Union[SQLModel, SQLModelMetaclass]]] = None):
"""Truncate all tables in the database"""
tables_to_truncate: List[Table] = self.model_metadata.sorted_tables
if tables is not None:
Expand Down
4 changes: 4 additions & 0 deletions tests/docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ services:
postgresql-database:
image: postgres:15
container_name: "database-setup-tools-test-postgres-database"
command: >
-c ssl=on
-c ssl_cert_file=/etc/ssl/certs/ssl-cert-snakeoil.pem
-c ssl_key_file=/etc/ssl/private/ssl-cert-snakeoil.key
ports:
- "5432:5432"
environment:
Expand Down
5 changes: 2 additions & 3 deletions tests/integration/database_config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
POSTGRESQL_DATABASE_URI = "postgresql+psycopg2://postgres:postgres@localhost:5432/test"

DATABASE_URIS = [
POSTGRESQL_DATABASE_URI,
"postgresql+psycopg2://postgres:postgres@localhost:5432/test", # PostgreSQL
"postgresql+psycopg2://postgres:postgres@localhost:5432/test?sslmode=require", # PostgreSQL with SSL
]
117 changes: 62 additions & 55 deletions tests/integration/test_database_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,19 @@

import pytest
from sqlalchemy.exc import OperationalError
from sqlalchemy.orm.scoping import ScopedSession
from sqlmodel import Field, SQLModel
from sqlalchemy.orm import Session

from database_setup_tools.session_manager import SessionManager
from database_setup_tools.setup import DatabaseSetup
from tests.integration.database_config import DATABASE_URIS
from tests.sample_model import Customer, model_metadata


@pytest.mark.parametrize("database_uri", DATABASE_URIS)
class TestDatabaseIntegration:
#
# Fixtures
#

@pytest.fixture
def database_setup(self, database_uri: str) -> DatabaseSetup:
setup = DatabaseSetup(model_metadata=model_metadata, database_uri=database_uri)
Expand All @@ -22,31 +24,12 @@ def database_setup(self, database_uri: str) -> DatabaseSetup:
setup.drop_database()

@pytest.fixture
def database_session(self, database_setup: DatabaseSetup) -> Iterator[ScopedSession]:
def session(self, database_setup: DatabaseSetup) -> Iterator[Session]:
"""Get a database session"""
return next(database_setup.session_manager.get_session())

def test_create_database_and_tables(self, database_setup: DatabaseSetup, database_session: ScopedSession):
"""Test that the tables are created correctly"""
database_session.execute(f"SELECT * FROM {Customer.__tablename__}")

def test_create_database_multiple_times(self, database_setup: DatabaseSetup, database_session: ScopedSession):
"""Test that creating the database multiple times does not cause problems"""
database_setup.create_database()
database_session.execute(f"SELECT * FROM {Customer.__tablename__}")

def test_drop_database(self, database_setup: DatabaseSetup, database_session: ScopedSession):
"""Test that the database is dropped correctly"""
assert database_setup.drop_database() is True

with pytest.raises(OperationalError):
database_session.execute(f"SELECT * FROM {Customer.__tablename__}")

assert database_setup.drop_database() is False

def test_truncate_all_tables(self, database_setup: DatabaseSetup, database_session: ScopedSession):
"""Test that all tables are truncated correctly"""

@pytest.fixture
def delivery_table(self, session: Session) -> str:
setup_statements = [
f"""CREATE TABLE delivery (
id INTEGER,
Expand All @@ -63,44 +46,68 @@ def test_truncate_all_tables(self, database_setup: DatabaseSetup, database_sessi
f"INSERT INTO \"{Customer.__tablename__}\" VALUES (1, 'John Doe')",
"INSERT INTO \"delivery\" VALUES (1, 'Delivery 1', 1)",
]

for statement in setup_statements:
database_session.execute(statement)
session.execute(statement)

assert database_session.execute(f"SELECT * FROM {Customer.__tablename__}").rowcount == 1
assert database_session.execute(f'SELECT * FROM "delivery"').rowcount == 1
database_session.commit()
return "delivery"

database_setup.truncate()
@pytest.fixture
def standalone_table(self, session: Session) -> str:
setup_statements = [
"CREATE TABLE standalone (id INTEGER, PRIMARY KEY(id))",
'SELECT * FROM "standalone"',
'INSERT INTO "standalone" VALUES (1)',
]

assert database_session.execute(f"SELECT * FROM {Customer.__tablename__}").rowcount == 0
assert database_session.execute(f'SELECT * FROM "delivery"').rowcount == 0
for statement in setup_statements:
session.execute(statement)

def test_truncate_custom_tables(self, database_uri: str):
"""Test that only specified tables are truncated correctly"""
return "standalone"

class TableToTruncate(SQLModel, table=True):
id: int = Field(index=True, primary_key=True)
name: str
#
# Tests
#

setup = DatabaseSetup(model_metadata=model_metadata, database_uri=database_uri)
setup.drop_database()
setup.create_database()
database_session = next(setup.session_manager.get_session())
def test_create_database_and_tables(self, database_setup: DatabaseSetup, session: Session):
"""Test that the tables are created correctly"""
session.execute(f"SELECT * FROM {Customer.__tablename__}")

setup_statements = [
f'SELECT * FROM "{Customer.__tablename__}"',
f'SELECT * FROM "{TableToTruncate.__tablename__}"',
f"INSERT INTO \"{Customer.__tablename__}\" VALUES (1, 'John Doe')",
f"INSERT INTO \"{TableToTruncate.__tablename__}\" VALUES (1, 'Test')",
]
for statement in setup_statements:
database_session.execute(statement)
def test_create_database_multiple_times(self, database_setup: DatabaseSetup, session: Session):
"""Test that creating the database multiple times does not cause problems"""
database_setup.create_database()
session.execute(f"SELECT * FROM {Customer.__tablename__}")

def test_drop_database(self, database_setup: DatabaseSetup, session: Session):
"""Test that the database is dropped correctly"""
assert database_setup.drop_database() is True

with pytest.raises(OperationalError):
session.execute(f"SELECT * FROM {Customer.__tablename__}")

assert database_setup.drop_database() is False

def test_truncate_all_tables(self, database_setup: DatabaseSetup, session: Session, delivery_table: str):
"""Test that all tables are truncated correctly"""

assert session.execute(f"SELECT * FROM {Customer.__tablename__}").rowcount == 1
assert session.execute(f'SELECT * FROM "{delivery_table}"').rowcount == 1
session.commit()

database_setup.truncate()

assert session.execute(f"SELECT * FROM {Customer.__tablename__}").rowcount == 0
assert session.execute(f'SELECT * FROM "{delivery_table}"').rowcount == 0

def test_truncate_custom_tables(self, database_setup: DatabaseSetup, session: Session, delivery_table: str, standalone_table: str):
"""Test that only specified tables are truncated correctly"""

assert database_session.execute(f"SELECT * FROM {Customer.__tablename__}").rowcount == 1
assert database_session.execute(f'SELECT * FROM "{TableToTruncate.__tablename__}"').rowcount == 1
database_session.commit()
assert session.execute(f"SELECT * FROM {Customer.__tablename__}").rowcount == 1
assert session.execute(f'SELECT * FROM "{delivery_table}"').rowcount == 1
session.commit()

setup.truncate(tables=[TableToTruncate])
database_setup.truncate(tables=[Customer])

assert database_session.execute(f"SELECT * FROM {Customer.__tablename__}").rowcount == 1
assert database_session.execute(f'SELECT * FROM "{TableToTruncate.__tablename__}"').rowcount == 0
assert session.execute(f"SELECT * FROM {Customer.__tablename__}").rowcount == 0
assert session.execute(f'SELECT * FROM "{delivery_table}"').rowcount == 0
assert session.execute(f'SELECT * FROM "{standalone_table}"').rowcount == 1