1
- import sqlalchemy as sqla
2
1
import threading
2
+ from functools import cached_property
3
+ from typing import Iterator , Optional
4
+
5
+ import sqlalchemy as sqla
3
6
from sqlalchemy .engine import Engine
4
7
from sqlalchemy .orm import sessionmaker
5
8
from sqlalchemy .orm .scoping import ScopedSession , scoped_session
6
- from typing import Iterator
7
9
8
10
9
11
class SessionManager :
10
12
""" Manages engines, sessions and connection pools. Thread-safe singleton """
11
- _instance = None
13
+ _instances = []
12
14
_lock = threading .Lock ()
13
15
14
16
def __new__ (cls , * args , ** kwargs ):
15
- if not cls ._instance :
17
+ if not cls ._get_cached_instance ( args , kwargs ) :
16
18
with cls ._lock :
17
- if not cls ._instance :
18
- cls ._instance = super (SessionManager , cls ).__new__ (cls )
19
- return cls ._instance
19
+ if not cls ._get_cached_instance ( args , kwargs ) :
20
+ cls ._instances . append (( super (cls , cls ).__new__ (cls ), ( args , kwargs )) )
21
+ return cls ._get_cached_instance ( args , kwargs )
20
22
21
23
def __init__ (self , database_uri : str , ** kwargs ):
22
24
""" Session Manager constructor
@@ -32,17 +34,20 @@ def __init__(self, database_uri: str, **kwargs):
32
34
max_overflow (int): The maximum number of connections to the database
33
35
pre_ping (bool): Whether to ping the database before each connection, may fix connection issues
34
36
"""
37
+ if not isinstance (database_uri , str ):
38
+ raise TypeError ("database_uri must be a string" )
39
+
35
40
self ._database_uri = database_uri
36
41
self ._engine = self ._get_engine (** kwargs )
37
42
self ._session_factory = sessionmaker (self .engine )
38
43
self ._Session = scoped_session (self ._session_factory )
39
44
40
- @property
45
+ @cached_property
41
46
def database_uri (self ) -> str :
42
47
""" Getter for the database URI """
43
48
return self ._database_uri
44
49
45
- @property
50
+ @cached_property
46
51
def engine (self ) -> Engine :
47
52
""" Getter for the engine """
48
53
return self ._engine
@@ -55,3 +60,11 @@ def get_session(self) -> Iterator[ScopedSession]:
55
60
def _get_engine (self , ** kwargs ) -> Engine :
56
61
""" Provides a database engine """
57
62
return sqla .create_engine (self .database_uri , ** kwargs )
63
+
64
+ @classmethod
65
+ def _get_cached_instance (cls , args : tuple , kwargs : dict ) -> Optional [object ]:
66
+ """ Provides a cached instance of the SessionManager class if existing """
67
+ for instance , arguments in cls ._instances :
68
+ if arguments == (args , kwargs ):
69
+ return instance
70
+ return None
0 commit comments