diff --git a/src/dependency_injector/_cwiring.pyi b/src/dependency_injector/_cwiring.pyi index e7ff12f4..c779b8c4 100644 --- a/src/dependency_injector/_cwiring.pyi +++ b/src/dependency_injector/_cwiring.pyi @@ -1,23 +1,18 @@ -from typing import Any, Awaitable, Callable, Dict, Tuple, TypeVar +from typing import Any, Dict from .providers import Provider -T = TypeVar("T") +class DependencyResolver: + def __init__( + self, + kwargs: Dict[str, Any], + injections: Dict[str, Provider[Any]], + closings: Dict[str, Provider[Any]], + /, + ) -> None: ... + def __enter__(self) -> Dict[str, Any]: ... + def __exit__(self, *exc_info: Any) -> None: ... + async def __aenter__(self) -> Dict[str, Any]: ... + async def __aexit__(self, *exc_info: Any) -> None: ... -def _sync_inject( - fn: Callable[..., T], - args: Tuple[Any, ...], - kwargs: Dict[str, Any], - injections: Dict[str, Provider[Any]], - closings: Dict[str, Provider[Any]], - /, -) -> T: ... -async def _async_inject( - fn: Callable[..., Awaitable[T]], - args: Tuple[Any, ...], - kwargs: Dict[str, Any], - injections: Dict[str, Provider[Any]], - closings: Dict[str, Provider[Any]], - /, -) -> T: ... def _isawaitable(instance: Any) -> bool: ... diff --git a/src/dependency_injector/_cwiring.pyx b/src/dependency_injector/_cwiring.pyx index 84a5485f..3e2775c7 100644 --- a/src/dependency_injector/_cwiring.pyx +++ b/src/dependency_injector/_cwiring.pyx @@ -1,83 +1,110 @@ """Wiring optimizations module.""" -import asyncio -import collections.abc -import inspect -import types +from asyncio import gather +from collections.abc import Awaitable +from inspect import CO_ITERABLE_COROUTINE +from types import CoroutineType, GeneratorType +from .providers cimport Provider, Resource, NULL_AWAITABLE from .wiring import _Marker -from .providers cimport Provider, Resource +cimport cython -def _sync_inject(object fn, tuple args, dict kwargs, dict injections, dict closings, /): - cdef object result +@cython.internal +@cython.no_gc +cdef class KWPair: + cdef str name + cdef object value + + def __cinit__(self, str name, object value, /): + self.name = name + self.value = value + + +cdef inline bint _is_injectable(dict kwargs, str name): + return name not in kwargs or isinstance(kwargs[name], _Marker) + + +cdef class DependencyResolver: + cdef dict kwargs cdef dict to_inject - cdef object arg_key - cdef Provider provider + cdef dict injections + cdef dict closings - to_inject = kwargs.copy() - for arg_key, provider in injections.items(): - if arg_key not in kwargs or isinstance(kwargs[arg_key], _Marker): - to_inject[arg_key] = provider() + def __init__(self, dict kwargs, dict injections, dict closings, /): + self.kwargs = kwargs + self.to_inject = kwargs.copy() + self.injections = injections + self.closings = closings - result = fn(*args, **to_inject) + async def _await_injection(self, kw_pair: KWPair, /) -> None: + self.to_inject[kw_pair.name] = await kw_pair.value - if closings: - for arg_key, provider in closings.items(): - if arg_key in kwargs and not isinstance(kwargs[arg_key], _Marker): - continue - if not isinstance(provider, Resource): - continue - provider.shutdown() + cdef object _await_injections(self, to_await: list): + return gather(*map(self._await_injection, to_await)) - return result + cdef void _handle_injections_sync(self): + cdef Provider provider + for name, provider in self.injections.items(): + if _is_injectable(self.kwargs, name): + self.to_inject[name] = provider() -async def _async_inject(object fn, tuple args, dict kwargs, dict injections, dict closings, /): - cdef object result - cdef dict to_inject - cdef list to_inject_await = [] - cdef list to_close_await = [] - cdef object arg_key - cdef Provider provider - - to_inject = kwargs.copy() - for arg_key, provider in injections.items(): - if arg_key not in kwargs or isinstance(kwargs[arg_key], _Marker): - provide = provider() - if provider.is_async_mode_enabled(): - to_inject_await.append((arg_key, provide)) - elif _isawaitable(provide): - to_inject_await.append((arg_key, provide)) - else: - to_inject[arg_key] = provide - - if to_inject_await: - async_to_inject = await asyncio.gather(*(provide for _, provide in to_inject_await)) - for provide, (injection, _) in zip(async_to_inject, to_inject_await): - to_inject[injection] = provide - - result = await fn(*args, **to_inject) - - if closings: - for arg_key, provider in closings.items(): - if arg_key in kwargs and isinstance(kwargs[arg_key], _Marker): - continue - if not isinstance(provider, Resource): - continue - shutdown = provider.shutdown() - if _isawaitable(shutdown): - to_close_await.append(shutdown) - - await asyncio.gather(*to_close_await) - - return result + cdef list _handle_injections_async(self): + cdef list to_await = [] + cdef Provider provider + + for name, provider in self.injections.items(): + if _is_injectable(self.kwargs, name): + provide = provider() + + if provider.is_async_mode_enabled() or _isawaitable(provide): + to_await.append(KWPair(name, provide)) + else: + self.to_inject[name] = provide + + return to_await + + cdef void _handle_closings_sync(self): + cdef Provider provider + + for name, provider in self.closings.items(): + if _is_injectable(self.kwargs, name) and isinstance(provider, Resource): + provider.shutdown() + + cdef list _handle_closings_async(self): + cdef list to_await = [] + cdef Provider provider + + for name, provider in self.closings.items(): + if _is_injectable(self.kwargs, name) and isinstance(provider, Resource): + if _isawaitable(shutdown := provider.shutdown()): + to_await.append(shutdown) + + return to_await + + def __enter__(self): + self._handle_injections_sync() + return self.to_inject + + def __exit__(self, *_): + self._handle_closings_sync() + + async def __aenter__(self): + if to_await := self._handle_injections_async(): + await self._await_injections(to_await) + return self.to_inject + + def __aexit__(self, *_): + if to_await := self._handle_closings_async(): + return gather(*to_await) + return NULL_AWAITABLE cdef bint _isawaitable(object instance): """Return true if object can be passed to an ``await`` expression.""" - return (isinstance(instance, types.CoroutineType) or - isinstance(instance, types.GeneratorType) and - bool(instance.gi_code.co_flags & inspect.CO_ITERABLE_COROUTINE) or - isinstance(instance, collections.abc.Awaitable)) + return (isinstance(instance, CoroutineType) or + isinstance(instance, GeneratorType) and + bool(instance.gi_code.co_flags & CO_ITERABLE_COROUTINE) or + isinstance(instance, Awaitable)) diff --git a/src/dependency_injector/wiring.py b/src/dependency_injector/wiring.py index 9c976c2d..aadf2cdc 100644 --- a/src/dependency_injector/wiring.py +++ b/src/dependency_injector/wiring.py @@ -10,6 +10,7 @@ from typing import ( TYPE_CHECKING, Any, + AsyncIterator, Callable, Dict, Iterable, @@ -720,6 +721,8 @@ def _get_patched( if inspect.iscoroutinefunction(fn): patched = _get_async_patched(fn, patched_object) + elif inspect.isasyncgenfunction(fn): + patched = _get_async_gen_patched(fn, patched_object) else: patched = _get_sync_patched(fn, patched_object) @@ -1035,36 +1038,41 @@ def is_loader_installed() -> bool: _loader = AutoLoader() # Optimizations -from ._cwiring import _async_inject # noqa -from ._cwiring import _sync_inject # noqa +from ._cwiring import DependencyResolver # noqa: E402 # Wiring uses the following Python wrapper because there is # no possibility to compile a first-type citizen coroutine in Cython. def _get_async_patched(fn: F, patched: PatchedCallable) -> F: @functools.wraps(fn) - async def _patched(*args, **kwargs): - return await _async_inject( - fn, - args, - kwargs, - patched.injections, - patched.closing, - ) + async def _patched(*args: Any, **raw_kwargs: Any) -> Any: + resolver = DependencyResolver(raw_kwargs, patched.injections, patched.closing) + + async with resolver as kwargs: + return await fn(*args, **kwargs) + + return cast(F, _patched) + + +def _get_async_gen_patched(fn: F, patched: PatchedCallable) -> F: + @functools.wraps(fn) + async def _patched(*args: Any, **raw_kwargs: Any) -> AsyncIterator[Any]: + resolver = DependencyResolver(raw_kwargs, patched.injections, patched.closing) + + async with resolver as kwargs: + async for obj in fn(*args, **kwargs): + yield obj return cast(F, _patched) def _get_sync_patched(fn: F, patched: PatchedCallable) -> F: @functools.wraps(fn) - def _patched(*args, **kwargs): - return _sync_inject( - fn, - args, - kwargs, - patched.injections, - patched.closing, - ) + def _patched(*args: Any, **raw_kwargs: Any) -> Any: + resolver = DependencyResolver(raw_kwargs, patched.injections, patched.closing) + + with resolver as kwargs: + return fn(*args, **kwargs) return cast(F, _patched) diff --git a/tests/unit/samples/wiring/asyncinjections.py b/tests/unit/samples/wiring/asyncinjections.py index 204300e3..e0861017 100644 --- a/tests/unit/samples/wiring/asyncinjections.py +++ b/tests/unit/samples/wiring/asyncinjections.py @@ -1,7 +1,9 @@ import asyncio +from typing_extensions import Annotated + from dependency_injector import containers, providers -from dependency_injector.wiring import inject, Provide, Closing +from dependency_injector.wiring import Closing, Provide, inject class TestResource: @@ -42,6 +44,15 @@ async def async_injection( return resource1, resource2 +@inject +async def async_generator_injection( + resource1: object = Provide[Container.resource1], + resource2: object = Closing[Provide[Container.resource2]], +): + yield resource1 + yield resource2 + + @inject async def async_injection_with_closing( resource1: object = Closing[Provide[Container.resource1]], diff --git a/tests/unit/wiring/provider_ids/test_async_injections_py36.py b/tests/unit/wiring/provider_ids/test_async_injections_py36.py index f17f19c7..70f9eb17 100644 --- a/tests/unit/wiring/provider_ids/test_async_injections_py36.py +++ b/tests/unit/wiring/provider_ids/test_async_injections_py36.py @@ -32,6 +32,23 @@ async def test_async_injections(): assert asyncinjections.resource2.shutdown_counter == 0 +@mark.asyncio +async def test_async_generator_injections() -> None: + resources = [] + + async for resource in asyncinjections.async_generator_injection(): + resources.append(resource) + + assert len(resources) == 2 + assert resources[0] is asyncinjections.resource1 + assert asyncinjections.resource1.init_counter == 1 + assert asyncinjections.resource1.shutdown_counter == 0 + + assert resources[1] is asyncinjections.resource2 + assert asyncinjections.resource2.init_counter == 1 + assert asyncinjections.resource2.shutdown_counter == 1 + + @mark.asyncio async def test_async_injections_with_closing(): resource1, resource2 = await asyncinjections.async_injection_with_closing()