Skip to content

Commit b9df88e

Browse files
committed
Fix typing for wiring marker
1 parent 99489af commit b9df88e

File tree

3 files changed

+96
-31
lines changed

3 files changed

+96
-31
lines changed

src/dependency_injector/containers.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ class WiringConfiguration:
4141
class Container:
4242
provider_type: Type[Provider] = Provider
4343
providers: Dict[str, Provider]
44-
dependencies: Dict[str, Provider]
44+
dependencies: Dict[str, Provider[Any]]
4545
overridden: Tuple[Provider]
4646
wiring_config: WiringConfiguration
4747
auto_load_config: bool = True

src/dependency_injector/wiring.py

Lines changed: 59 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,14 @@
88
import sys
99
from types import ModuleType
1010
from typing import (
11+
TYPE_CHECKING,
1112
Any,
1213
Callable,
1314
Dict,
14-
Generic,
1515
Iterable,
1616
Iterator,
1717
Optional,
18+
Protocol,
1819
Set,
1920
Tuple,
2021
Type,
@@ -23,6 +24,7 @@
2324
cast,
2425
)
2526

27+
from typing_extensions import Self
2628

2729
# Hotfix, see: https://github.com/ets-labs/python-dependency-injector/issues/362
2830
if sys.version_info >= (3, 9):
@@ -66,7 +68,6 @@ def get_origin(tp):
6668

6769
from . import providers
6870

69-
7071
__all__ = (
7172
"wire",
7273
"unwire",
@@ -89,7 +90,11 @@ def get_origin(tp):
8990

9091
T = TypeVar("T")
9192
F = TypeVar("F", bound=Callable[..., Any])
92-
Container = Any
93+
94+
if TYPE_CHECKING:
95+
from .containers import Container
96+
else:
97+
Container = Any
9398

9499

95100
class PatchedRegistry:
@@ -777,15 +782,15 @@ class RequiredModifier(Modifier):
777782
def __init__(self) -> None:
778783
self.type_modifier = None
779784

780-
def as_int(self) -> "RequiredModifier":
785+
def as_int(self) -> Self:
781786
self.type_modifier = TypeModifier(int)
782787
return self
783788

784-
def as_float(self) -> "RequiredModifier":
789+
def as_float(self) -> Self:
785790
self.type_modifier = TypeModifier(float)
786791
return self
787792

788-
def as_(self, type_: Type) -> "RequiredModifier":
793+
def as_(self, type_: Type) -> Self:
789794
self.type_modifier = TypeModifier(type_)
790795
return self
791796

@@ -833,15 +838,15 @@ class ProvidedInstance(Modifier):
833838
def __init__(self) -> None:
834839
self.segments = []
835840

836-
def __getattr__(self, item):
841+
def __getattr__(self, item: str) -> Self:
837842
self.segments.append((self.TYPE_ATTRIBUTE, item))
838843
return self
839844

840-
def __getitem__(self, item):
845+
def __getitem__(self, item) -> Self:
841846
self.segments.append((self.TYPE_ITEM, item))
842847
return self
843848

844-
def call(self):
849+
def call(self) -> Self:
845850
self.segments.append((self.TYPE_CALL, None))
846851
return self
847852

@@ -866,36 +871,56 @@ def provided() -> ProvidedInstance:
866871
return ProvidedInstance()
867872

868873

869-
class _Marker(Generic[T]):
874+
MarkerItem = Union[
875+
str,
876+
providers.Provider[Any],
877+
Tuple[str, TypeModifier],
878+
Type[Container],
879+
"_Marker",
880+
]
870881

871-
__IS_MARKER__ = True
872882

873-
def __init__(
874-
self,
875-
provider: Union[providers.Provider, Container, str],
876-
modifier: Optional[Modifier] = None,
877-
) -> None:
878-
if _is_declarative_container(provider):
879-
provider = provider.__self__
880-
self.provider = provider
881-
self.modifier = modifier
883+
if TYPE_CHECKING:
882884

883-
def __class_getitem__(cls, item) -> T:
884-
if isinstance(item, tuple):
885-
return cls(*item)
886-
return cls(item)
885+
class _Marker(Protocol):
886+
__IS_MARKER__: bool
887887

888-
def __call__(self) -> T:
889-
return self
888+
def __call__(self) -> Self: ...
889+
def __getattr__(self, item: str) -> Self: ...
890+
def __getitem__(self, item: Any) -> Any: ...
891+
892+
Provide: _Marker
893+
Provider: _Marker
894+
Closing: _Marker
895+
else:
890896

897+
class _Marker:
891898

892-
class Provide(_Marker): ...
899+
__IS_MARKER__ = True
893900

901+
def __init__(
902+
self,
903+
provider: Union[providers.Provider, Container, str],
904+
modifier: Optional[Modifier] = None,
905+
) -> None:
906+
if _is_declarative_container(provider):
907+
provider = provider.__self__
908+
self.provider = provider
909+
self.modifier = modifier
894910

895-
class Provider(_Marker): ...
911+
def __class_getitem__(cls, item: MarkerItem) -> Self:
912+
if isinstance(item, tuple):
913+
return cls(*item)
914+
return cls(item)
896915

916+
def __call__(self) -> Self:
917+
return self
897918

898-
class Closing(_Marker): ...
919+
class Provide(_Marker): ...
920+
921+
class Provider(_Marker): ...
922+
923+
class Closing(_Marker): ...
899924

900925

901926
class AutoLoader:
@@ -998,8 +1023,8 @@ def is_loader_installed() -> bool:
9981023
_loader = AutoLoader()
9991024

10001025
# Optimizations
1001-
from ._cwiring import _sync_inject # noqa
10021026
from ._cwiring import _async_inject # noqa
1027+
from ._cwiring import _sync_inject # noqa
10031028

10041029

10051030
# Wiring uses the following Python wrapper because there is
@@ -1028,13 +1053,17 @@ def _patched(*args, **kwargs):
10281053
patched.injections,
10291054
patched.closing,
10301055
)
1056+
10311057
return cast(F, _patched)
10321058

10331059

10341060
if sys.version_info >= (3, 10):
1061+
10351062
def _get_annotations(obj: Any) -> Dict[str, Any]:
10361063
return inspect.get_annotations(obj)
1064+
10371065
else:
1066+
10381067
def _get_annotations(obj: Any) -> Dict[str, Any]:
10391068
return getattr(obj, "__annotations__", {})
10401069

tests/typing/wiring.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
from typing import Iterator
2+
3+
from typing_extensions import Annotated
4+
5+
from dependency_injector.containers import DeclarativeContainer
6+
from dependency_injector.providers import Object, Resource
7+
from dependency_injector.wiring import Closing, Provide, required
8+
9+
10+
def _resource() -> Iterator[int]:
11+
yield 1
12+
13+
14+
class Container(DeclarativeContainer):
15+
value = Object(1)
16+
res = Resource(_resource)
17+
18+
19+
def default_by_ref(value: int = Provide[Container.value]) -> None: ...
20+
def default_by_string(value: int = Provide["value"]) -> None: ...
21+
def default_by_string_with_modifier(
22+
value: int = Provide["value", required().as_int()]
23+
) -> None: ...
24+
def default_container(container: Container = Provide[Container]) -> None: ...
25+
def default_with_closing(value: int = Closing[Provide[Container.res]]) -> None: ...
26+
def annotated_by_ref(value: Annotated[int, Provide[Container.value]]) -> None: ...
27+
def annotated_by_string(value: Annotated[int, Provide["value"]]) -> None: ...
28+
def annotated_by_string_with_modifier(
29+
value: Annotated[int, Provide["value", required().as_int()]],
30+
) -> None: ...
31+
def annotated_container(
32+
container: Annotated[Container, Provide[Container]],
33+
) -> None: ...
34+
def annotated_with_closing(
35+
value: Annotated[int, Closing[Provide[Container.res]]],
36+
) -> None: ...

0 commit comments

Comments
 (0)