diff --git a/.gitignore b/.gitignore index 7afe6c5..4d4aa8a 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,7 @@ __pycache__/ # IDE/OS .idea .DS_Store +.vscode/PythonImportHelper-v2-Completion.json test.db test.db-journal diff --git a/database_setup_tools/session_manager.py b/database_setup_tools/session_manager.py index 903bfa7..c642863 100644 --- a/database_setup_tools/session_manager.py +++ b/database_setup_tools/session_manager.py @@ -1,6 +1,6 @@ import threading from functools import cached_property -from typing import Generator, Optional +from typing import Any, Generator, Optional from sqlalchemy import create_engine from sqlalchemy.engine import Engine @@ -21,14 +21,14 @@ def __new__(cls, *args, **kwargs): cls._instances.append((super(cls, cls).__new__(cls), (args, kwargs))) return cls._get_cached_instance(args, kwargs) - def __init__(self, database_uri: str, **kwargs): + def __init__(self, database_uri: str, **engine_options: dict[str, Any]): """Session Manager constructor Args: database_uri (str): The URI of the database to manage sessions for Keyword Args: - **kwargs: Keyword arguments to pass to the engine + **engine_options: Keyword arguments to pass to the engine postgresql: pool_size (int): The maximum number of connections to the database @@ -39,7 +39,7 @@ def __init__(self, database_uri: str, **kwargs): raise TypeError("database_uri must be a string") self._database_uri = database_uri - self._engine = self._get_engine(**kwargs) + self._engine = self._get_engine(**engine_options) self._session_factory = sessionmaker(self.engine) self._Session = scoped_session(self._session_factory) diff --git a/database_setup_tools/setup.py b/database_setup_tools/setup.py index e4ba194..d255d67 100644 --- a/database_setup_tools/setup.py +++ b/database_setup_tools/setup.py @@ -1,5 +1,5 @@ import threading -from typing import List, Optional +from typing import Any, List, Optional, Union import sqlalchemy_utils from sqlalchemy import MetaData, Table @@ -21,13 +21,15 @@ def __new__(cls, *args, **kwargs): cls._instances.append((super(cls, cls).__new__(cls), (args, kwargs))) return cls._get_cached_instance(args, kwargs) - def __init__(self, model_metadata: MetaData, database_uri: str): + def __init__(self, model_metadata: MetaData, database_uri: str, **engine_options: dict[str, Any]): """Set up a database based on its URI and metadata. Will not overwrite existing data. Args: model_metadata (Metadata): The metadata of the models to create the tables for database_uri (str): The URI of the database to create the tables for + Keyword Args: + **engine_options: Keyword arguments to pass to the engine """ if not isinstance(model_metadata, MetaData): raise TypeError("model_metadata must be a MetaData") @@ -37,6 +39,7 @@ def __init__(self, model_metadata: MetaData, database_uri: str): self._model_metadata = model_metadata self._database_uri = database_uri + self._engine_options = engine_options self.create_database() @property @@ -55,7 +58,7 @@ def session_manager(self) -> SessionManager: Returns: SessionManager: The session manager """ - return SessionManager(database_uri=self.database_uri) + return SessionManager(database_uri=self.database_uri, **self._engine_options) @property def database_uri(self) -> str: @@ -85,7 +88,7 @@ def create_database(self) -> bool: return True return False - def truncate(self, tables: Optional[List[SQLModel | SQLModelMetaclass]] = None): + def truncate(self, tables: Optional[List[Union[SQLModel, SQLModelMetaclass]]] = None): """Truncate all tables in the database""" tables_to_truncate: List[Table] = self.model_metadata.sorted_tables if tables is not None: diff --git a/tests/docker-compose.yaml b/tests/docker-compose.yaml index 7fb8a11..11fa818 100644 --- a/tests/docker-compose.yaml +++ b/tests/docker-compose.yaml @@ -4,6 +4,10 @@ services: postgresql-database: image: postgres:15 container_name: "database-setup-tools-test-postgres-database" + command: > + -c ssl=on + -c ssl_cert_file=/etc/ssl/certs/ssl-cert-snakeoil.pem + -c ssl_key_file=/etc/ssl/private/ssl-cert-snakeoil.key ports: - "5432:5432" environment: diff --git a/tests/integration/database_config.py b/tests/integration/database_config.py index 32b624a..6a2dbfe 100644 --- a/tests/integration/database_config.py +++ b/tests/integration/database_config.py @@ -1,5 +1,4 @@ -POSTGRESQL_DATABASE_URI = "postgresql+psycopg2://postgres:postgres@localhost:5432/test" - DATABASE_URIS = [ - POSTGRESQL_DATABASE_URI, + "postgresql+psycopg2://postgres:postgres@localhost:5432/test", # PostgreSQL + "postgresql+psycopg2://postgres:postgres@localhost:5432/test?sslmode=require", # PostgreSQL with SSL ] diff --git a/tests/integration/test_database_integration.py b/tests/integration/test_database_integration.py index 6637dc4..728e8fe 100644 --- a/tests/integration/test_database_integration.py +++ b/tests/integration/test_database_integration.py @@ -2,10 +2,8 @@ import pytest from sqlalchemy.exc import OperationalError -from sqlalchemy.orm.scoping import ScopedSession -from sqlmodel import Field, SQLModel +from sqlalchemy.orm import Session -from database_setup_tools.session_manager import SessionManager from database_setup_tools.setup import DatabaseSetup from tests.integration.database_config import DATABASE_URIS from tests.sample_model import Customer, model_metadata @@ -13,6 +11,10 @@ @pytest.mark.parametrize("database_uri", DATABASE_URIS) class TestDatabaseIntegration: + # + # Fixtures + # + @pytest.fixture def database_setup(self, database_uri: str) -> DatabaseSetup: setup = DatabaseSetup(model_metadata=model_metadata, database_uri=database_uri) @@ -22,31 +24,12 @@ def database_setup(self, database_uri: str) -> DatabaseSetup: setup.drop_database() @pytest.fixture - def database_session(self, database_setup: DatabaseSetup) -> Iterator[ScopedSession]: + def session(self, database_setup: DatabaseSetup) -> Iterator[Session]: """Get a database session""" return next(database_setup.session_manager.get_session()) - def test_create_database_and_tables(self, database_setup: DatabaseSetup, database_session: ScopedSession): - """Test that the tables are created correctly""" - database_session.execute(f"SELECT * FROM {Customer.__tablename__}") - - def test_create_database_multiple_times(self, database_setup: DatabaseSetup, database_session: ScopedSession): - """Test that creating the database multiple times does not cause problems""" - database_setup.create_database() - database_session.execute(f"SELECT * FROM {Customer.__tablename__}") - - def test_drop_database(self, database_setup: DatabaseSetup, database_session: ScopedSession): - """Test that the database is dropped correctly""" - assert database_setup.drop_database() is True - - with pytest.raises(OperationalError): - database_session.execute(f"SELECT * FROM {Customer.__tablename__}") - - assert database_setup.drop_database() is False - - def test_truncate_all_tables(self, database_setup: DatabaseSetup, database_session: ScopedSession): - """Test that all tables are truncated correctly""" - + @pytest.fixture + def delivery_table(self, session: Session) -> str: setup_statements = [ f"""CREATE TABLE delivery ( id INTEGER, @@ -63,44 +46,68 @@ def test_truncate_all_tables(self, database_setup: DatabaseSetup, database_sessi f"INSERT INTO \"{Customer.__tablename__}\" VALUES (1, 'John Doe')", "INSERT INTO \"delivery\" VALUES (1, 'Delivery 1', 1)", ] + for statement in setup_statements: - database_session.execute(statement) + session.execute(statement) - assert database_session.execute(f"SELECT * FROM {Customer.__tablename__}").rowcount == 1 - assert database_session.execute(f'SELECT * FROM "delivery"').rowcount == 1 - database_session.commit() + return "delivery" - database_setup.truncate() + @pytest.fixture + def standalone_table(self, session: Session) -> str: + setup_statements = [ + "CREATE TABLE standalone (id INTEGER, PRIMARY KEY(id))", + 'SELECT * FROM "standalone"', + 'INSERT INTO "standalone" VALUES (1)', + ] - assert database_session.execute(f"SELECT * FROM {Customer.__tablename__}").rowcount == 0 - assert database_session.execute(f'SELECT * FROM "delivery"').rowcount == 0 + for statement in setup_statements: + session.execute(statement) - def test_truncate_custom_tables(self, database_uri: str): - """Test that only specified tables are truncated correctly""" + return "standalone" - class TableToTruncate(SQLModel, table=True): - id: int = Field(index=True, primary_key=True) - name: str + # + # Tests + # - setup = DatabaseSetup(model_metadata=model_metadata, database_uri=database_uri) - setup.drop_database() - setup.create_database() - database_session = next(setup.session_manager.get_session()) + def test_create_database_and_tables(self, database_setup: DatabaseSetup, session: Session): + """Test that the tables are created correctly""" + session.execute(f"SELECT * FROM {Customer.__tablename__}") - setup_statements = [ - f'SELECT * FROM "{Customer.__tablename__}"', - f'SELECT * FROM "{TableToTruncate.__tablename__}"', - f"INSERT INTO \"{Customer.__tablename__}\" VALUES (1, 'John Doe')", - f"INSERT INTO \"{TableToTruncate.__tablename__}\" VALUES (1, 'Test')", - ] - for statement in setup_statements: - database_session.execute(statement) + def test_create_database_multiple_times(self, database_setup: DatabaseSetup, session: Session): + """Test that creating the database multiple times does not cause problems""" + database_setup.create_database() + session.execute(f"SELECT * FROM {Customer.__tablename__}") + + def test_drop_database(self, database_setup: DatabaseSetup, session: Session): + """Test that the database is dropped correctly""" + assert database_setup.drop_database() is True + + with pytest.raises(OperationalError): + session.execute(f"SELECT * FROM {Customer.__tablename__}") + + assert database_setup.drop_database() is False + + def test_truncate_all_tables(self, database_setup: DatabaseSetup, session: Session, delivery_table: str): + """Test that all tables are truncated correctly""" + + assert session.execute(f"SELECT * FROM {Customer.__tablename__}").rowcount == 1 + assert session.execute(f'SELECT * FROM "{delivery_table}"').rowcount == 1 + session.commit() + + database_setup.truncate() + + assert session.execute(f"SELECT * FROM {Customer.__tablename__}").rowcount == 0 + assert session.execute(f'SELECT * FROM "{delivery_table}"').rowcount == 0 + + def test_truncate_custom_tables(self, database_setup: DatabaseSetup, session: Session, delivery_table: str, standalone_table: str): + """Test that only specified tables are truncated correctly""" - assert database_session.execute(f"SELECT * FROM {Customer.__tablename__}").rowcount == 1 - assert database_session.execute(f'SELECT * FROM "{TableToTruncate.__tablename__}"').rowcount == 1 - database_session.commit() + assert session.execute(f"SELECT * FROM {Customer.__tablename__}").rowcount == 1 + assert session.execute(f'SELECT * FROM "{delivery_table}"').rowcount == 1 + session.commit() - setup.truncate(tables=[TableToTruncate]) + database_setup.truncate(tables=[Customer]) - assert database_session.execute(f"SELECT * FROM {Customer.__tablename__}").rowcount == 1 - assert database_session.execute(f'SELECT * FROM "{TableToTruncate.__tablename__}"').rowcount == 0 + assert session.execute(f"SELECT * FROM {Customer.__tablename__}").rowcount == 0 + assert session.execute(f'SELECT * FROM "{delivery_table}"').rowcount == 0 + assert session.execute(f'SELECT * FROM "{standalone_table}"').rowcount == 1