Skip to content

Commit 744fa92

Browse files
authored
QOL & test improvements (#5)
* improve database setup & tests * improve session manager & tests * improve integration tests * add type checking for session manager database uri * adjust readme * python3.9 compliant types
1 parent 9ebb32d commit 744fa92

14 files changed

+341
-588
lines changed

README.md

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,6 @@ if __name__ == '__main__':
6464
uvicorn.run(app, host='0.0.0.0', port=8080)
6565
```
6666

67-
*See [tests/integration/example/app.py](tests/integration/example/app.py)
68-
6967
## Example for pytest
7068

7169
**conftest.py**

database_setup_tools/session_manager.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,24 @@
1-
import sqlalchemy as sqla
21
import threading
2+
from functools import cached_property
3+
from typing import Iterator, Optional
4+
5+
import sqlalchemy as sqla
36
from sqlalchemy.engine import Engine
47
from sqlalchemy.orm import sessionmaker
58
from sqlalchemy.orm.scoping import ScopedSession, scoped_session
6-
from typing import Iterator
79

810

911
class SessionManager:
1012
""" Manages engines, sessions and connection pools. Thread-safe singleton """
11-
_instance = None
13+
_instances = []
1214
_lock = threading.Lock()
1315

1416
def __new__(cls, *args, **kwargs):
15-
if not cls._instance:
17+
if not cls._get_cached_instance(args, kwargs):
1618
with cls._lock:
17-
if not cls._instance:
18-
cls._instance = super(SessionManager, cls).__new__(cls)
19-
return cls._instance
19+
if not cls._get_cached_instance(args, kwargs):
20+
cls._instances.append((super(cls, cls).__new__(cls), (args, kwargs)))
21+
return cls._get_cached_instance(args, kwargs)
2022

2123
def __init__(self, database_uri: str, **kwargs):
2224
""" Session Manager constructor
@@ -32,17 +34,20 @@ def __init__(self, database_uri: str, **kwargs):
3234
max_overflow (int): The maximum number of connections to the database
3335
pre_ping (bool): Whether to ping the database before each connection, may fix connection issues
3436
"""
37+
if not isinstance(database_uri, str):
38+
raise TypeError("database_uri must be a string")
39+
3540
self._database_uri = database_uri
3641
self._engine = self._get_engine(**kwargs)
3742
self._session_factory = sessionmaker(self.engine)
3843
self._Session = scoped_session(self._session_factory)
3944

40-
@property
45+
@cached_property
4146
def database_uri(self) -> str:
4247
""" Getter for the database URI """
4348
return self._database_uri
4449

45-
@property
50+
@cached_property
4651
def engine(self) -> Engine:
4752
""" Getter for the engine """
4853
return self._engine
@@ -55,3 +60,11 @@ def get_session(self) -> Iterator[ScopedSession]:
5560
def _get_engine(self, **kwargs) -> Engine:
5661
""" Provides a database engine """
5762
return sqla.create_engine(self.database_uri, **kwargs)
63+
64+
@classmethod
65+
def _get_cached_instance(cls, args: tuple, kwargs: dict) -> Optional[object]:
66+
""" Provides a cached instance of the SessionManager class if existing """
67+
for instance, arguments in cls._instances:
68+
if arguments == (args, kwargs):
69+
return instance
70+
return None

database_setup_tools/setup.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import threading
2+
from typing import Optional
23

34
import sqlalchemy_utils
45
from sqlalchemy import MetaData
@@ -8,15 +9,15 @@
89

910
class DatabaseSetup:
1011
""" Create the database and the tables if not done yet """
11-
_instance = None
12+
_instances = []
1213
_lock = threading.Lock()
1314

1415
def __new__(cls, *args, **kwargs):
15-
if not cls._instance:
16+
if not cls._get_cached_instance(args, kwargs):
1617
with cls._lock:
17-
if not cls._instance:
18-
cls._instance = super(DatabaseSetup, cls).__new__(cls)
19-
return cls._instance
18+
if not cls._get_cached_instance(args, kwargs):
19+
cls._instances.append((super(cls, cls).__new__(cls), (args, kwargs)))
20+
return cls._get_cached_instance(args, kwargs)
2021

2122
def __init__(self, model_metadata: MetaData, database_uri: str):
2223
""" Set up a database based on its URI and metadata. Will not overwrite existing data.
@@ -26,6 +27,12 @@ def __init__(self, model_metadata: MetaData, database_uri: str):
2627
database_uri (str): The URI of the database to create the tables for
2728
2829
"""
30+
if not isinstance(model_metadata, MetaData):
31+
raise TypeError("model_metadata must be a MetaData")
32+
33+
if not isinstance(database_uri, str):
34+
raise TypeError("database_uri must be a string")
35+
2936
self._model_metadata = model_metadata
3037
self._database_uri = database_uri
3138
self.create_database()
@@ -64,3 +71,11 @@ def create_database(self):
6471
sqlalchemy_utils.create_database(self.database_uri)
6572
session_manager = SessionManager(self.database_uri)
6673
self.model_metadata.create_all(session_manager.engine)
74+
75+
@classmethod
76+
def _get_cached_instance(cls, args: tuple, kwargs: dict) -> Optional[object]:
77+
""" Provides a cached instance of the SessionManager class if existing """
78+
for instance, arguments in cls._instances:
79+
if arguments == (args, kwargs):
80+
return instance
81+
return None

0 commit comments

Comments
 (0)