8
8
import sys
9
9
from types import ModuleType
10
10
from typing import (
11
+ TYPE_CHECKING ,
11
12
Any ,
12
13
Callable ,
13
14
Dict ,
14
- Generic ,
15
15
Iterable ,
16
16
Iterator ,
17
17
Optional ,
18
+ Protocol ,
18
19
Set ,
19
20
Tuple ,
20
21
Type ,
23
24
cast ,
24
25
)
25
26
27
+ from typing_extensions import Self
26
28
27
29
# Hotfix, see: https://github.com/ets-labs/python-dependency-injector/issues/362
28
30
if sys .version_info >= (3 , 9 ):
@@ -66,7 +68,6 @@ def get_origin(tp):
66
68
67
69
from . import providers
68
70
69
-
70
71
__all__ = (
71
72
"wire" ,
72
73
"unwire" ,
@@ -89,7 +90,11 @@ def get_origin(tp):
89
90
90
91
T = TypeVar ("T" )
91
92
F = TypeVar ("F" , bound = Callable [..., Any ])
92
- Container = Any
93
+
94
+ if TYPE_CHECKING :
95
+ from .containers import Container
96
+ else :
97
+ Container = Any
93
98
94
99
95
100
class PatchedRegistry :
@@ -777,15 +782,15 @@ class RequiredModifier(Modifier):
777
782
def __init__ (self ) -> None :
778
783
self .type_modifier = None
779
784
780
- def as_int (self ) -> "RequiredModifier" :
785
+ def as_int (self ) -> Self :
781
786
self .type_modifier = TypeModifier (int )
782
787
return self
783
788
784
- def as_float (self ) -> "RequiredModifier" :
789
+ def as_float (self ) -> Self :
785
790
self .type_modifier = TypeModifier (float )
786
791
return self
787
792
788
- def as_ (self , type_ : Type ) -> "RequiredModifier" :
793
+ def as_ (self , type_ : Type ) -> Self :
789
794
self .type_modifier = TypeModifier (type_ )
790
795
return self
791
796
@@ -833,15 +838,15 @@ class ProvidedInstance(Modifier):
833
838
def __init__ (self ) -> None :
834
839
self .segments = []
835
840
836
- def __getattr__ (self , item ) :
841
+ def __getattr__ (self , item : str ) -> Self :
837
842
self .segments .append ((self .TYPE_ATTRIBUTE , item ))
838
843
return self
839
844
840
- def __getitem__ (self , item ):
845
+ def __getitem__ (self , item ) -> Self :
841
846
self .segments .append ((self .TYPE_ITEM , item ))
842
847
return self
843
848
844
- def call (self ):
849
+ def call (self ) -> Self :
845
850
self .segments .append ((self .TYPE_CALL , None ))
846
851
return self
847
852
@@ -866,36 +871,56 @@ def provided() -> ProvidedInstance:
866
871
return ProvidedInstance ()
867
872
868
873
869
- class _Marker (Generic [T ]):
874
+ MarkerItem = Union [
875
+ str ,
876
+ providers .Provider [Any ],
877
+ Tuple [str , TypeModifier ],
878
+ Type [Container ],
879
+ "_Marker" ,
880
+ ]
870
881
871
- __IS_MARKER__ = True
872
882
873
- def __init__ (
874
- self ,
875
- provider : Union [providers .Provider , Container , str ],
876
- modifier : Optional [Modifier ] = None ,
877
- ) -> None :
878
- if _is_declarative_container (provider ):
879
- provider = provider .__self__
880
- self .provider = provider
881
- self .modifier = modifier
883
+ if TYPE_CHECKING :
882
884
883
- def __class_getitem__ (cls , item ) -> T :
884
- if isinstance (item , tuple ):
885
- return cls (* item )
886
- return cls (item )
885
+ class _Marker (Protocol ):
886
+ __IS_MARKER__ : bool
887
887
888
- def __call__ (self ) -> T :
889
- return self
888
+ def __call__ (self ) -> Self : ...
889
+ def __getattr__ (self , item : str ) -> Self : ...
890
+ def __getitem__ (self , item : Any ) -> Any : ...
891
+
892
+ Provide : _Marker
893
+ Provider : _Marker
894
+ Closing : _Marker
895
+ else :
890
896
897
+ class _Marker :
891
898
892
- class Provide ( _Marker ): ...
899
+ __IS_MARKER__ = True
893
900
901
+ def __init__ (
902
+ self ,
903
+ provider : Union [providers .Provider , Container , str ],
904
+ modifier : Optional [Modifier ] = None ,
905
+ ) -> None :
906
+ if _is_declarative_container (provider ):
907
+ provider = provider .__self__
908
+ self .provider = provider
909
+ self .modifier = modifier
894
910
895
- class Provider (_Marker ): ...
911
+ def __class_getitem__ (cls , item : MarkerItem ) -> Self :
912
+ if isinstance (item , tuple ):
913
+ return cls (* item )
914
+ return cls (item )
896
915
916
+ def __call__ (self ) -> Self :
917
+ return self
897
918
898
- class Closing (_Marker ): ...
919
+ class Provide (_Marker ): ...
920
+
921
+ class Provider (_Marker ): ...
922
+
923
+ class Closing (_Marker ): ...
899
924
900
925
901
926
class AutoLoader :
@@ -998,8 +1023,8 @@ def is_loader_installed() -> bool:
998
1023
_loader = AutoLoader ()
999
1024
1000
1025
# Optimizations
1001
- from ._cwiring import _sync_inject # noqa
1002
1026
from ._cwiring import _async_inject # noqa
1027
+ from ._cwiring import _sync_inject # noqa
1003
1028
1004
1029
1005
1030
# Wiring uses the following Python wrapper because there is
@@ -1028,13 +1053,17 @@ def _patched(*args, **kwargs):
1028
1053
patched .injections ,
1029
1054
patched .closing ,
1030
1055
)
1056
+
1031
1057
return cast (F , _patched )
1032
1058
1033
1059
1034
1060
if sys .version_info >= (3 , 10 ):
1061
+
1035
1062
def _get_annotations (obj : Any ) -> Dict [str , Any ]:
1036
1063
return inspect .get_annotations (obj )
1064
+
1037
1065
else :
1066
+
1038
1067
def _get_annotations (obj : Any ) -> Dict [str , Any ]:
1039
1068
return getattr (obj , "__annotations__" , {})
1040
1069
0 commit comments