diff --git a/docs/providers/resource.rst b/docs/providers/resource.rst index 918dfa66..cf935f3a 100644 --- a/docs/providers/resource.rst +++ b/docs/providers/resource.rst @@ -61,11 +61,12 @@ When you call ``.shutdown()`` method on a resource provider, it will remove the if any, and switch to uninitialized state. Some of resource initializer types support specifying custom resource shutdown. -Resource provider supports 3 types of initializers: +Resource provider supports 4 types of initializers: - Function -- Generator -- Subclass of ``resources.Resource`` +- Context Manager +- Generator (legacy) +- Subclass of ``resources.Resource`` (legacy) Function initializer -------------------- @@ -103,8 +104,44 @@ you configure global resource: Function initializer does not provide a way to specify custom resource shutdown. -Generator initializer ---------------------- +Context Manager initializer +--------------------------- + +This is an extension to the Function initializer. Resource provider automatically detects if the initializer returns a +context manager and uses it to manage the resource lifecycle. + +.. code-block:: python + + from dependency_injector import containers, providers + + class DatabaseConnection: + def __init__(self, host, port, user, password): + self.host = host + self.port = port + self.user = user + self.password = password + + def __enter__(self): + print(f"Connecting to {self.host}:{self.port} as {self.user}") + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + print("Closing connection") + + + class Container(containers.DeclarativeContainer): + + config = providers.Configuration() + db = providers.Resource( + DatabaseConnection, + host=config.db.host, + port=config.db.port, + user=config.db.user, + password=config.db.password, + ) + +Generator initializer (legacy) +------------------------------ Resource provider can use 2-step generators: @@ -154,8 +191,13 @@ object is not mandatory. You can leave ``yield`` statement empty: argument2=..., ) -Subclass initializer --------------------- +.. note:: + + Generator initializers are automatically wrapped with ``contextmanager`` or ``asynccontextmanager`` decorator when + provided to a ``Resource`` provider. + +Subclass initializer (legacy) +----------------------------- You can create resource initializer by implementing a subclass of the ``resources.Resource``: @@ -263,10 +305,11 @@ Asynchronous function initializer: argument2=..., ) -Asynchronous generator initializer: +Asynchronous Context Manager initializer: .. code-block:: python + @asynccontextmanager async def init_async_resource(argument1=..., argument2=...): connection = await connect() yield connection diff --git a/examples/providers/resource.py b/examples/providers/resource.py index 2079a929..c712468a 100644 --- a/examples/providers/resource.py +++ b/examples/providers/resource.py @@ -3,10 +3,12 @@ import sys import logging from concurrent.futures import ThreadPoolExecutor +from contextlib import contextmanager from dependency_injector import containers, providers +@contextmanager def init_thread_pool(max_workers: int): thread_pool = ThreadPoolExecutor(max_workers=max_workers) yield thread_pool diff --git a/src/dependency_injector/providers.pyx b/src/dependency_injector/providers.pyx index d276903b..43e49d7e 100644 --- a/src/dependency_injector/providers.pyx +++ b/src/dependency_injector/providers.pyx @@ -15,8 +15,11 @@ import re import sys import threading import warnings +from asyncio import ensure_future from configparser import ConfigParser as IniConfigParser +from contextlib import asynccontextmanager, contextmanager from contextvars import ContextVar +from inspect import isasyncgenfunction, isgeneratorfunction try: from inspect import _is_coroutine_mark as _is_coroutine_marker @@ -3598,6 +3601,17 @@ cdef class Dict(Provider): return __provide_keyword_args(kwargs, self._kwargs, self._kwargs_len, self._async_mode) +@cython.no_gc +cdef class NullAwaitable: + def __next__(self): + raise StopIteration from None + + def __await__(self): + return self + + +cdef NullAwaitable NULL_AWAITABLE = NullAwaitable() + cdef class Resource(Provider): """Resource provider provides a component with initialization and shutdown.""" @@ -3653,6 +3667,12 @@ cdef class Resource(Provider): def set_provides(self, provides): """Set provider provides.""" provides = _resolve_string_import(provides) + + if isasyncgenfunction(provides): + provides = asynccontextmanager(provides) + elif isgeneratorfunction(provides): + provides = contextmanager(provides) + self._provides = provides return self @@ -3753,28 +3773,21 @@ cdef class Resource(Provider): """Shutdown resource.""" if not self._initialized: if self._async_mode == ASYNC_MODE_ENABLED: - result = asyncio.Future() - result.set_result(None) - return result + return NULL_AWAITABLE return if self._shutdowner: - try: - shutdown = self._shutdowner(self._resource) - except StopIteration: - pass - else: - if inspect.isawaitable(shutdown): - return self._create_shutdown_future(shutdown) + future = self._shutdowner(None, None, None) + + if __is_future_or_coroutine(future): + return ensure_future(self._shutdown_async(future)) self._resource = None self._initialized = False self._shutdowner = None if self._async_mode == ASYNC_MODE_ENABLED: - result = asyncio.Future() - result.set_result(None) - return result + return NULL_AWAITABLE @property def related(self): @@ -3784,165 +3797,75 @@ cdef class Resource(Provider): yield from filter(is_provider, self.kwargs.values()) yield from super().related + async def _shutdown_async(self, future) -> None: + try: + await future + finally: + self._resource = None + self._initialized = False + self._shutdowner = None + + async def _handle_async_cm(self, obj) -> None: + try: + self._resource = resource = await obj.__aenter__() + self._shutdowner = obj.__aexit__ + return resource + except: + self._initialized = False + raise + + async def _provide_async(self, future) -> None: + try: + obj = await future + + if hasattr(obj, '__aenter__') and hasattr(obj, '__aexit__'): + self._resource = await obj.__aenter__() + self._shutdowner = obj.__aexit__ + elif hasattr(obj, '__enter__') and hasattr(obj, '__exit__'): + self._resource = obj.__enter__() + self._shutdowner = obj.__exit__ + else: + self._resource = obj + self._shutdowner = None + + return self._resource + except: + self._initialized = False + raise + cpdef object _provide(self, tuple args, dict kwargs): if self._initialized: return self._resource - if self._is_resource_subclass(self._provides): - initializer = self._provides() - self._resource = __call( - initializer.init, - args, - self._args, - self._args_len, - kwargs, - self._kwargs, - self._kwargs_len, - self._async_mode, - ) - self._shutdowner = initializer.shutdown - elif self._is_async_resource_subclass(self._provides): - initializer = self._provides() - async_init = __call( - initializer.init, - args, - self._args, - self._args_len, - kwargs, - self._kwargs, - self._kwargs_len, - self._async_mode, - ) - self._initialized = True - return self._create_init_future(async_init, initializer.shutdown) - elif inspect.isgeneratorfunction(self._provides): - initializer = __call( - self._provides, - args, - self._args, - self._args_len, - kwargs, - self._kwargs, - self._kwargs_len, - self._async_mode, - ) - self._resource = next(initializer) - self._shutdowner = initializer.send - elif iscoroutinefunction(self._provides): - initializer = __call( - self._provides, - args, - self._args, - self._args_len, - kwargs, - self._kwargs, - self._kwargs_len, - self._async_mode, - ) + obj = __call( + self._provides, + args, + self._args, + self._args_len, + kwargs, + self._kwargs, + self._kwargs_len, + self._async_mode, + ) + + if __is_future_or_coroutine(obj): self._initialized = True - return self._create_init_future(initializer) - elif isasyncgenfunction(self._provides): - initializer = __call( - self._provides, - args, - self._args, - self._args_len, - kwargs, - self._kwargs, - self._kwargs_len, - self._async_mode, - ) + self._resource = resource = ensure_future(self._provide_async(obj)) + return resource + elif hasattr(obj, '__enter__') and hasattr(obj, '__exit__'): + self._resource = obj.__enter__() + self._shutdowner = obj.__exit__ + elif hasattr(obj, '__aenter__') and hasattr(obj, '__aexit__'): self._initialized = True - return self._create_async_gen_init_future(initializer) - elif callable(self._provides): - self._resource = __call( - self._provides, - args, - self._args, - self._args_len, - kwargs, - self._kwargs, - self._kwargs_len, - self._async_mode, - ) + self._resource = resource = ensure_future(self._handle_async_cm(obj)) + return resource else: - raise Error("Unknown type of resource initializer") + self._resource = obj + self._shutdowner = None self._initialized = True return self._resource - def _create_init_future(self, future, shutdowner=None): - callback = self._async_init_callback - if shutdowner: - callback = functools.partial(callback, shutdowner=shutdowner) - - future = asyncio.ensure_future(future) - future.add_done_callback(callback) - self._resource = future - - return future - - def _create_async_gen_init_future(self, initializer): - if inspect.isasyncgen(initializer): - return self._create_init_future(initializer.__anext__(), initializer.asend) - - future = asyncio.Future() - - create_initializer = asyncio.ensure_future(initializer) - create_initializer.add_done_callback(functools.partial(self._async_create_gen_callback, future)) - self._resource = future - - return future - - def _async_init_callback(self, initializer, shutdowner=None): - try: - resource = initializer.result() - except Exception: - self._initialized = False - else: - self._resource = resource - self._shutdowner = shutdowner - - def _async_create_gen_callback(self, future, initializer_future): - initializer = initializer_future.result() - init_future = self._create_init_future(initializer.__anext__(), initializer.asend) - init_future.add_done_callback(functools.partial(self._async_trigger_result, future)) - - def _async_trigger_result(self, future, future_result): - future.set_result(future_result.result()) - - def _create_shutdown_future(self, shutdown_future): - future = asyncio.Future() - shutdown_future = asyncio.ensure_future(shutdown_future) - shutdown_future.add_done_callback(functools.partial(self._async_shutdown_callback, future)) - return future - - def _async_shutdown_callback(self, future_result, shutdowner): - try: - shutdowner.result() - except StopAsyncIteration: - pass - - self._resource = None - self._initialized = False - self._shutdowner = None - - future_result.set_result(None) - - @staticmethod - def _is_resource_subclass(instance): - if not isinstance(instance, type): - return - from . import resources - return issubclass(instance, resources.Resource) - - @staticmethod - def _is_async_resource_subclass(instance): - if not isinstance(instance, type): - return - from . import resources - return issubclass(instance, resources.AsyncResource) - cdef class Container(Provider): """Container provider provides an instance of declarative container. @@ -4993,14 +4916,6 @@ def iscoroutinefunction(obj): return False -def isasyncgenfunction(obj): - """Check if object is an asynchronous generator function.""" - try: - return inspect.isasyncgenfunction(obj) - except AttributeError: - return False - - def _resolve_string_import(provides): if provides is None: return provides diff --git a/src/dependency_injector/resources.py b/src/dependency_injector/resources.py index 7d71d4d8..8722af22 100644 --- a/src/dependency_injector/resources.py +++ b/src/dependency_injector/resources.py @@ -1,23 +1,54 @@ """Resources module.""" -import abc -from typing import TypeVar, Generic, Optional - +from abc import ABCMeta, abstractmethod +from typing import Any, ClassVar, Generic, Optional, Tuple, TypeVar T = TypeVar("T") -class Resource(Generic[T], metaclass=abc.ABCMeta): +class Resource(Generic[T], metaclass=ABCMeta): + __slots__: ClassVar[Tuple[str, ...]] = ("args", "kwargs", "obj") + + obj: Optional[T] + + def __init__(self, *args: Any, **kwargs: Any) -> None: + self.args = args + self.kwargs = kwargs + self.obj = None - @abc.abstractmethod - def init(self, *args, **kwargs) -> Optional[T]: ... + @abstractmethod + def init(self, *args: Any, **kwargs: Any) -> Optional[T]: ... def shutdown(self, resource: Optional[T]) -> None: ... + def __enter__(self) -> Optional[T]: + self.obj = obj = self.init(*self.args, **self.kwargs) + return obj + + def __exit__(self, *exc_info: Any) -> None: + self.shutdown(self.obj) + self.obj = None + -class AsyncResource(Generic[T], metaclass=abc.ABCMeta): +class AsyncResource(Generic[T], metaclass=ABCMeta): + __slots__: ClassVar[Tuple[str, ...]] = ("args", "kwargs", "obj") - @abc.abstractmethod - async def init(self, *args, **kwargs) -> Optional[T]: ... + obj: Optional[T] + + def __init__(self, *args: Any, **kwargs: Any) -> None: + self.args = args + self.kwargs = kwargs + self.obj = None + + @abstractmethod + async def init(self, *args: Any, **kwargs: Any) -> Optional[T]: ... async def shutdown(self, resource: Optional[T]) -> None: ... + + async def __aenter__(self) -> Optional[T]: + self.obj = obj = await self.init(*self.args, **self.kwargs) + return obj + + async def __aexit__(self, *exc_info: Any) -> None: + await self.shutdown(self.obj) + self.obj = None diff --git a/tests/unit/providers/resource/test_async_resource_py35.py b/tests/unit/providers/resource/test_async_resource_py35.py index 1ca950a8..6458584d 100644 --- a/tests/unit/providers/resource/test_async_resource_py35.py +++ b/tests/unit/providers/resource/test_async_resource_py35.py @@ -2,12 +2,13 @@ import asyncio import inspect -import sys +from contextlib import asynccontextmanager from typing import Any -from dependency_injector import containers, providers, resources from pytest import mark, raises +from dependency_injector import containers, providers, resources + @mark.asyncio async def test_init_async_function(): @@ -70,6 +71,46 @@ async def _init(): assert _init.shutdown_counter == 2 +@mark.asyncio +async def test_init_async_context_manager() -> None: + resource = object() + + init_counter = 0 + shutdown_counter = 0 + + @asynccontextmanager + async def _init(): + nonlocal init_counter, shutdown_counter + + await asyncio.sleep(0.001) + init_counter += 1 + + yield resource + + await asyncio.sleep(0.001) + shutdown_counter += 1 + + provider = providers.Resource(_init) + + result1 = await provider() + assert result1 is resource + assert init_counter == 1 + assert shutdown_counter == 0 + + await provider.shutdown() + assert init_counter == 1 + assert shutdown_counter == 1 + + result2 = await provider() + assert result2 is resource + assert init_counter == 2 + assert shutdown_counter == 1 + + await provider.shutdown() + assert init_counter == 2 + assert shutdown_counter == 2 + + @mark.asyncio async def test_init_async_class(): resource = object() diff --git a/tests/unit/providers/resource/test_resource_py35.py b/tests/unit/providers/resource/test_resource_py35.py index 9b906bd7..842d8ba6 100644 --- a/tests/unit/providers/resource/test_resource_py35.py +++ b/tests/unit/providers/resource/test_resource_py35.py @@ -2,10 +2,12 @@ import decimal import sys +from contextlib import contextmanager from typing import Any -from dependency_injector import containers, providers, resources, errors -from pytest import raises, mark +from pytest import mark, raises + +from dependency_injector import containers, errors, providers, resources def init_fn(*args, **kwargs): @@ -123,6 +125,41 @@ def _init(): assert _init.shutdown_counter == 2 +def test_init_context_manager() -> None: + init_counter, shutdown_counter = 0, 0 + + @contextmanager + def _init(): + nonlocal init_counter, shutdown_counter + + init_counter += 1 + yield + shutdown_counter += 1 + + init_counter = 0 + shutdown_counter = 0 + + provider = providers.Resource(_init) + + result1 = provider() + assert result1 is None + assert init_counter == 1 + assert shutdown_counter == 0 + + provider.shutdown() + assert init_counter == 1 + assert shutdown_counter == 1 + + result2 = provider() + assert result2 is None + assert init_counter == 2 + assert shutdown_counter == 1 + + provider.shutdown() + assert init_counter == 2 + assert shutdown_counter == 2 + + def test_init_class(): class TestResource(resources.Resource): init_counter = 0 @@ -190,7 +227,7 @@ def init(self): def test_init_not_callable(): provider = providers.Resource(1) - with raises(errors.Error): + with raises(TypeError, match=r"object is not callable"): provider.init()