diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 886bbbc..4816507 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -27,6 +27,11 @@ jobs: with: python-version: ${{ matrix.python-version }} + - name: Setup Test Infrastructure + run: | + docker-compose -f tests/docker-compose.yaml up -d + shell: bash + - name: Install poetry run: pip install poetry shell: bash @@ -38,3 +43,9 @@ jobs: - name: Test with pytest run: poetry run pytest shell: bash + + - name: Teardown Test Infrastructure + if: always() + run: | + docker-compose -f tests/docker-compose.yaml down -v + shell: bash diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..a70ddf9 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,7 @@ +{ + "python.testing.pytestArgs": [ + "." + ], + "python.testing.unittestEnabled": false, + "python.testing.pytestEnabled": true +} diff --git a/README.md b/README.md index c7924d0..89af4d4 100644 --- a/README.md +++ b/README.md @@ -9,12 +9,14 @@ pip install database-setup-tools ``` ## Features + - **Database creation on app startup** - Thread-safe database **session manager** - Opinionated towards `FastAPI` and `SQLModel` but feasible with any other framework or pure `sqlalchemy` - Easily use a local database in your tests ## Planned features + - Database migrations with `Alembic` ## Example @@ -67,27 +69,40 @@ if __name__ == '__main__': ## Example for pytest **conftest.py** + ```python database_setup = DatabaseSetup(model_metadata=model_metadata, database_uri=DATABASE_URI) + def pytest_sessionstart(session): database_setup.drop_database() database_setup.create_database() ``` **test_users.py** + ```python session_manager = SessionManager(database_uri=DATABASE_URI) + @pytest.fixture def session(): - with session_manager.get_session() as session: - yield session + with session_manager.get_session() as session: + yield session + def test_create_user(session: Session): - user = User(name='Test User') - session.add(user) - session.commit() - assert session.query(User).count() == 1 - assert session.query(User).first().name == 'Test User' -``` \ No newline at end of file + user = User(name='Test User') + session.add(user) + session.commit() + assert session.query(User).count() == 1 + assert session.query(User).first().name == 'Test User' +``` + +## Development + +### Testing + +1. Spin up databases for local integration tests: `docker-compose -f tests/docker-compose.yaml up -d` +1. Create virtual environment & install dependencies: `poetry install` +1. Run tests: `poetry run pytest` diff --git a/database_setup_tools/__init__.py b/database_setup_tools/__init__.py index 8dbe02e..2e41e56 100644 --- a/database_setup_tools/__init__.py +++ b/database_setup_tools/__init__.py @@ -1 +1,4 @@ __version__='1.0.1' + +from .session_manager import SessionManager +from .setup import DatabaseSetup diff --git a/database_setup_tools/session_manager.py b/database_setup_tools/session_manager.py index 90adf33..e60647a 100644 --- a/database_setup_tools/session_manager.py +++ b/database_setup_tools/session_manager.py @@ -1,11 +1,11 @@ import threading from functools import cached_property -from typing import Iterator, Optional +from typing import Generator, Optional -import sqlalchemy as sqla +from sqlalchemy import create_engine from sqlalchemy.engine import Engine -from sqlalchemy.orm import sessionmaker -from sqlalchemy.orm.scoping import ScopedSession, scoped_session +from sqlalchemy.orm import Session, sessionmaker +from sqlalchemy.orm.scoping import scoped_session class SessionManager: @@ -47,19 +47,19 @@ def database_uri(self) -> str: """ Getter for the database URI """ return self._database_uri - @cached_property + @property def engine(self) -> Engine: """ Getter for the engine """ return self._engine - def get_session(self) -> Iterator[ScopedSession]: + def get_session(self) -> Generator[Session, None, None]: """ Provides a (thread safe) scoped session that is wrapped in a context manager """ with self._Session() as session: yield session def _get_engine(self, **kwargs) -> Engine: """ Provides a database engine """ - return sqla.create_engine(self.database_uri, **kwargs) + return create_engine(self.database_uri, **kwargs) @classmethod def _get_cached_instance(cls, args: tuple, kwargs: dict) -> Optional[object]: diff --git a/database_setup_tools/setup.py b/database_setup_tools/setup.py index cbb40c3..20f3cdf 100644 --- a/database_setup_tools/setup.py +++ b/database_setup_tools/setup.py @@ -66,11 +66,14 @@ def drop_database(self) -> bool: return True return False - def create_database(self): + def create_database(self) -> bool: """ Create the database and the tables if not done yet """ - sqlalchemy_utils.create_database(self.database_uri) - session_manager = SessionManager(self.database_uri) - self.model_metadata.create_all(session_manager.engine) + if not sqlalchemy_utils.database_exists(self.database_uri): + sqlalchemy_utils.create_database(self.database_uri) + session_manager = SessionManager(self.database_uri) + self.model_metadata.create_all(session_manager.engine) + return True + return False @classmethod def _get_cached_instance(cls, args: tuple, kwargs: dict) -> Optional[object]: diff --git a/poetry.lock b/poetry.lock index 897bd94..a4ed2b1 100644 --- a/poetry.lock +++ b/poetry.lock @@ -52,21 +52,6 @@ files = [ {file = "certifi-2022.12.7.tar.gz", hash = "sha256:35824b4c3a97115964b408844d64aa14db1cc518f6562e8d7261699d1350a9e3"}, ] -[[package]] -name = "click" -version = "8.1.3" -description = "Composable command line interface toolkit" -category = "dev" -optional = false -python-versions = ">=3.7" -files = [ - {file = "click-8.1.3-py3-none-any.whl", hash = "sha256:bb4d8133cb15a609f44e8213d9b391b0809795062913b383c62be0ee95b1db48"}, - {file = "click-8.1.3.tar.gz", hash = "sha256:7682dc8afb30297001674575ea00d1814d808d6a36af415a82bd481d37ba7b8e"}, -] - -[package.dependencies] -colorama = {version = "*", markers = "platform_system == \"Windows\""} - [[package]] name = "colorama" version = "0.4.6" @@ -161,28 +146,6 @@ files = [ [package.extras] test = ["pytest (>=6)"] -[[package]] -name = "fastapi" -version = "0.87.0" -description = "FastAPI framework, high performance, easy to learn, fast to code, ready for production" -category = "dev" -optional = false -python-versions = ">=3.7" -files = [ - {file = "fastapi-0.87.0-py3-none-any.whl", hash = "sha256:254453a2e22f64e2a1b4e1d8baf67d239e55b6c8165c079d25746a5220c81bb4"}, - {file = "fastapi-0.87.0.tar.gz", hash = "sha256:07032e53df9a57165047b4f38731c38bdcc3be5493220471015e2b4b51b486a4"}, -] - -[package.dependencies] -pydantic = ">=1.6.2,<1.7 || >1.7,<1.7.1 || >1.7.1,<1.7.2 || >1.7.2,<1.7.3 || >1.7.3,<1.8 || >1.8,<1.8.1 || >1.8.1,<2.0.0" -starlette = "0.21.0" - -[package.extras] -all = ["email-validator (>=1.1.1)", "httpx (>=0.23.0)", "itsdangerous (>=1.1.0)", "jinja2 (>=2.11.2)", "orjson (>=3.2.1)", "python-multipart (>=0.0.5)", "pyyaml (>=5.3.1)", "ujson (>=4.0.1,!=4.0.2,!=4.1.0,!=4.2.0,!=4.3.0,!=5.0.0,!=5.1.0)", "uvicorn[standard] (>=0.12.0)"] -dev = ["pre-commit (>=2.17.0,<3.0.0)", "ruff (==0.0.114)", "uvicorn[standard] (>=0.12.0,<0.19.0)"] -doc = ["mdx-include (>=1.4.1,<2.0.0)", "mkdocs (>=1.1.2,<2.0.0)", "mkdocs-markdownextradata-plugin (>=0.1.7,<0.3.0)", "mkdocs-material (>=8.1.4,<9.0.0)", "pyyaml (>=5.3.1,<7.0.0)", "typer[all] (>=0.6.1,<0.7.0)"] -test = ["anyio[trio] (>=3.2.1,<4.0.0)", "black (==22.8.0)", "coverage[toml] (>=6.5.0,<7.0)", "databases[sqlite] (>=0.3.2,<0.7.0)", "email-validator (>=1.1.1,<2.0.0)", "flask (>=1.1.2,<3.0.0)", "httpx (>=0.23.0,<0.24.0)", "isort (>=5.0.6,<6.0.0)", "mypy (==0.982)", "orjson (>=3.2.1,<4.0.0)", "passlib[bcrypt] (>=1.7.2,<2.0.0)", "peewee (>=3.13.3,<4.0.0)", "pytest (>=7.1.3,<8.0.0)", "python-jose[cryptography] (>=3.3.0,<4.0.0)", "python-multipart (>=0.0.5,<0.0.6)", "pyyaml (>=5.3.1,<7.0.0)", "ruff (==0.0.114)", "sqlalchemy (>=1.3.18,<=1.4.41)", "types-orjson (==3.6.2)", "types-ujson (==5.5.0)", "ujson (>=4.0.1,!=4.0.2,!=4.1.0,!=4.2.0,!=4.3.0,!=5.0.0,!=5.1.0,<6.0.0)"] - [[package]] name = "greenlet" version = "2.0.2" @@ -728,25 +691,6 @@ pydantic = ">=1.8.2,<2.0.0" SQLAlchemy = ">=1.4.17,<=1.4.41" sqlalchemy2-stubs = "*" -[[package]] -name = "starlette" -version = "0.21.0" -description = "The little ASGI library that shines." -category = "dev" -optional = false -python-versions = ">=3.7" -files = [ - {file = "starlette-0.21.0-py3-none-any.whl", hash = "sha256:0efc058261bbcddeca93cad577efd36d0c8a317e44376bcfc0e097a2b3dc24a7"}, - {file = "starlette-0.21.0.tar.gz", hash = "sha256:b1b52305ee8f7cfc48cde383496f7c11ab897cd7112b33d998b1317dc8ef9027"}, -] - -[package.dependencies] -anyio = ">=3.4.0,<5" -typing-extensions = {version = ">=3.10.0", markers = "python_version < \"3.10\""} - -[package.extras] -full = ["httpx (>=0.22.0)", "itsdangerous", "jinja2", "python-multipart", "pyyaml"] - [[package]] name = "tomli" version = "2.0.1" @@ -771,26 +715,7 @@ files = [ {file = "typing_extensions-4.5.0.tar.gz", hash = "sha256:5cb5f4a79139d699607b3ef622a1dedafa84e115ab0024e0d9c044a9479ca7cb"}, ] -[[package]] -name = "uvicorn" -version = "0.20.0" -description = "The lightning-fast ASGI server." -category = "dev" -optional = false -python-versions = ">=3.7" -files = [ - {file = "uvicorn-0.20.0-py3-none-any.whl", hash = "sha256:c3ed1598a5668208723f2bb49336f4509424ad198d6ab2615b7783db58d919fd"}, - {file = "uvicorn-0.20.0.tar.gz", hash = "sha256:a4e12017b940247f836bc90b72e725d7dfd0c8ed1c51eb365f5ba30d9f5127d8"}, -] - -[package.dependencies] -click = ">=7.0" -h11 = ">=0.8" - -[package.extras] -standard = ["colorama (>=0.4)", "httptools (>=0.5.0)", "python-dotenv (>=0.13)", "pyyaml (>=5.1)", "uvloop (>=0.14.0,!=0.15.0,!=0.15.1)", "watchfiles (>=0.13)", "websockets (>=10.4)"] - [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "92118087bdb8b423f0a088a6593c5b2b4adce9c9ccacff29b35328972f53463a" +content-hash = "535594047b88cfe7b73b84264c32100936f51af1e47688b23631525d07912142" diff --git a/pyproject.toml b/pyproject.toml index 05bb30b..8bae030 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,8 +13,6 @@ sqlalchemy = "^1.4.41" sqlalchemy-utils = "^0.38.3" [tool.poetry.dev-dependencies] -fastapi = "0.87.0" -uvicorn = "0.20.0" sqlmodel = "0.0.8" psycopg2-binary = "^2.9.5" pytest-cov = "^4.0.0" diff --git a/tests/docker-compose.yaml b/tests/docker-compose.yaml new file mode 100644 index 0000000..07f5038 --- /dev/null +++ b/tests/docker-compose.yaml @@ -0,0 +1,12 @@ +version: "3.9" + +services: + postgresql-database: + image: postgres:14 + container_name: "database-setup-tools-test-postgres-database" + ports: + - "5432:5432" + environment: + POSTGRES_USER: "postgres" + POSTGRES_PASSWORD: "postgres" + POSTGRES_DB: "postgres" diff --git a/tests/integration/database_config.py b/tests/integration/database_config.py new file mode 100644 index 0000000..627149a --- /dev/null +++ b/tests/integration/database_config.py @@ -0,0 +1,7 @@ +SQLITE_DATABASE_URI = "sqlite:///test.db" +POSTGRESQL_DATABASE_URI = "postgresql+psycopg2://postgres:postgres@localhost:5432/test" + +DATABASE_URIS = [ + SQLITE_DATABASE_URI, + POSTGRESQL_DATABASE_URI, +] diff --git a/tests/integration/test_database_integration.py b/tests/integration/test_database_integration.py index 40a2772..4db1b26 100644 --- a/tests/integration/test_database_integration.py +++ b/tests/integration/test_database_integration.py @@ -1,13 +1,16 @@ +from typing import Iterator + import pytest from sqlalchemy.exc import OperationalError from sqlalchemy.orm.scoping import ScopedSession from database_setup_tools.session_manager import SessionManager from database_setup_tools.setup import DatabaseSetup -from tests.sample_model import model_metadata, User +from tests.integration.database_config import DATABASE_URIS +from tests.sample_model import User, model_metadata -@pytest.mark.parametrize('database_uri', ["sqlite:///test.db"]) +@pytest.mark.parametrize('database_uri', DATABASE_URIS) class TestDatabaseIntegration: @pytest.fixture @@ -17,20 +20,21 @@ def database_setup(self, database_uri: str) -> DatabaseSetup: setup.drop_database() @pytest.fixture - def database_session(self, database_uri: str) -> ScopedSession: + def database_session(self, database_uri: str) -> Iterator[ScopedSession]: """ Get a database session """ session_manager = SessionManager(database_uri) return next(session_manager.get_session()) def test_create_database_and_tables(self, database_setup: DatabaseSetup, database_session: ScopedSession): """ Test that the tables are created correctly """ - database_setup.create_database() - # noinspection SqlInjection,SqlDialectInspection - test_query = database_session.execute(f'SELECT * FROM {User.__tablename__}') + database_session.execute(f'SELECT * FROM {User.__tablename__}') - assert test_query.cursor.description[0][0] == 'id' - assert test_query.cursor.description[1][0] == 'name' + def test_create_database_multiple_times(self, database_setup: DatabaseSetup, database_session: ScopedSession): + """ Test that creating the database multiple times does not cause problems """ + database_setup.create_database() + # noinspection SqlInjection,SqlDialectInspection + database_session.execute(f'SELECT * FROM {User.__tablename__}') def test_drop_database(self, database_setup: DatabaseSetup, database_session: ScopedSession): """ Test that the database is dropped correctly """ diff --git a/tests/integration/test_fastapi_integration.py b/tests/integration/test_fastapi_integration.py deleted file mode 100644 index af78700..0000000 --- a/tests/integration/test_fastapi_integration.py +++ /dev/null @@ -1,63 +0,0 @@ -from random import randint - -import pytest -from fastapi import FastAPI, Depends -from sqlmodel import Session -from starlette.testclient import TestClient - -from database_setup_tools.session_manager import SessionManager -from database_setup_tools.setup import DatabaseSetup -from tests.sample_model import model_metadata, User - - -@pytest.mark.parametrize('database_uri', ["sqlite:///test.db"]) -class TestIntegrationDatabaseSetup: - - @pytest.fixture - def database_setup(self, database_uri: str) -> DatabaseSetup: - setup = DatabaseSetup(model_metadata=model_metadata, database_uri=database_uri) - yield setup - setup.drop_database() - - @pytest.fixture - def session_manager(self, database_uri: str) -> SessionManager: - return SessionManager(database_uri=database_uri) - - @pytest.fixture - def fastapi_app(self, session_manager: SessionManager) -> FastAPI: - app = FastAPI() - - @app.post('/users/', response_model=User) - def add_random_user(session: Session = Depends(session_manager.get_session)): - """ Endpoint to add a user with a random name """ - user = User(name=f'User {randint(0, 100)}') - session.add(user) - session.commit() - return user - - @app.get('/users/', response_model=list[User]) - def get_all_users(session: Session = Depends(session_manager.get_session)): - """ Endpoint to get all users """ - return session.query(User).all() - - return app - - @pytest.fixture - def test_client(self, fastapi_app: FastAPI) -> TestClient: - return TestClient(fastapi_app) - - @pytest.fixture(scope="function", autouse=True) - def setup(self, database_setup: DatabaseSetup, ): - database_setup.drop_database() - database_setup.create_database() - - def test_get_all_users(self, test_client: TestClient): - response = test_client.get('/users/') - assert response.status_code == 200 - assert response.json() == [] - - def test_add_random_user(self, test_client: TestClient): - response = test_client.post('/users/') - assert response.status_code == 200 - - assert test_client.get('/users/').json() == [response.json()] diff --git a/tests/sample_model.py b/tests/sample_model.py index 131b8fa..feeb0b8 100644 --- a/tests/sample_model.py +++ b/tests/sample_model.py @@ -3,6 +3,7 @@ class User(SQLModel, table=True): """ User model """ + id: int = Field(index=True, primary_key=True) name: str diff --git a/tests/unit/test_setup.py b/tests/unit/test_setup.py index 72be3c7..93fe61c 100644 --- a/tests/unit/test_setup.py +++ b/tests/unit/test_setup.py @@ -55,10 +55,20 @@ def test_model_metadata(self, database_setup: DatabaseSetup): def test_create_database(self, database_setup: DatabaseSetup, database_uri: str, when: Callable, expect: Callable): unstub() # remove stub for create_database method + when(sqlalchemy_utils).database_exists(database_uri).thenReturn(False) expect(sqlalchemy_utils, times=1).create_database(database_uri) - expect(database_setup.model_metadata, times=1).create_all(SessionManager(database_uri).engine) + expect(database_setup.model_metadata, times=1).create_all(...) # can't check engine argument here because a new one is created for each access + + assert database_setup.create_database() is True + + def test_create_database_skip_if_exists(self, database_setup: DatabaseSetup, database_uri: str, when: Callable, expect: Callable): + unstub() # remove stub for create_database method + + when(sqlalchemy_utils).database_exists(database_uri).thenReturn(True) + expect(sqlalchemy_utils, times=0).create_database(database_uri) + expect(database_setup.model_metadata, times=0).create_all(SessionManager(database_uri).engine) - database_setup.create_database() + assert database_setup.create_database() is False def test_drop_database_success(self, database_setup: DatabaseSetup, database_uri: str, when: Callable, expect: Callable): when(sqlalchemy_utils).database_exists(database_uri).thenReturn(True)