diff --git a/examples/wiring/example.py b/examples/wiring/example.py index 4221ab13..0e32b192 100644 --- a/examples/wiring/example.py +++ b/examples/wiring/example.py @@ -2,10 +2,10 @@ from dependency_injector import containers, providers from dependency_injector.wiring import Provide, inject +from typing import Annotated -class Service: - ... +class Service: ... class Container(containers.DeclarativeContainer): @@ -13,9 +13,16 @@ class Container(containers.DeclarativeContainer): service = providers.Factory(Service) +# You can place marker on parameter default value @inject -def main(service: Service = Provide[Container.service]) -> None: - ... +def main(service: Service = Provide[Container.service]) -> None: ... + + +# Also, you can place marker with typing.Annotated +@inject +def main_with_annotated( + service: Annotated[Service, Provide[Container.service]] +) -> None: ... if __name__ == "__main__": diff --git a/requirements-dev.txt b/requirements-dev.txt index bc533741..0d759d4e 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -18,5 +18,6 @@ numpy scipy boto3 mypy_boto3_s3 +typing_extensions -r requirements-ext.txt diff --git a/src/dependency_injector/wiring.py b/src/dependency_injector/wiring.py index b5247c9e..67d56f29 100644 --- a/src/dependency_injector/wiring.py +++ b/src/dependency_injector/wiring.py @@ -37,6 +37,21 @@ class GenericMeta(type): ... else: GenericAlias = None +if sys.version_info >= (3, 9): + from typing import Annotated, get_args, get_origin +else: + try: + from typing_extensions import Annotated, get_args, get_origin + except ImportError: + Annotated = object() + + # For preventing NameError. Never executes + def get_args(hint): + return () + + def get_origin(tp): + return None + try: import fastapi.params @@ -572,6 +587,24 @@ def _unpatch_attribute(patched: PatchedAttribute) -> None: setattr(patched.member, patched.name, patched.marker) +def _extract_marker(parameter: inspect.Parameter) -> Optional["_Marker"]: + if get_origin(parameter.annotation) is Annotated: + marker = get_args(parameter.annotation)[1] + else: + marker = parameter.default + + if not isinstance(marker, _Marker) and not _is_fastapi_depends(marker): + return None + + if _is_fastapi_depends(marker): + marker = marker.dependency + + if not isinstance(marker, _Marker): + return None + + return marker + + def _fetch_reference_injections( # noqa: C901 fn: Callable[..., Any], ) -> Tuple[Dict[str, Any], Dict[str, Any]]: @@ -596,18 +629,10 @@ def _fetch_reference_injections( # noqa: C901 injections = {} closing = {} for parameter_name, parameter in signature.parameters.items(): - if not isinstance(parameter.default, _Marker) and not _is_fastapi_depends( - parameter.default - ): - continue + marker = _extract_marker(parameter) - marker = parameter.default - - if _is_fastapi_depends(marker): - marker = marker.dependency - - if not isinstance(marker, _Marker): - continue + if marker is None: + continue if isinstance(marker, Closing): marker = marker.provider diff --git a/tests/unit/samples/wiringfastapi/web.py b/tests/unit/samples/wiringfastapi/web.py index 3cee5450..c1ed5102 100644 --- a/tests/unit/samples/wiringfastapi/web.py +++ b/tests/unit/samples/wiringfastapi/web.py @@ -1,7 +1,11 @@ import sys +from typing_extensions import Annotated + from fastapi import FastAPI, Depends -from fastapi import Request # See: https://github.com/ets-labs/python-dependency-injector/issues/398 +from fastapi import ( + Request, +) # See: https://github.com/ets-labs/python-dependency-injector/issues/398 from fastapi.security import HTTPBasic, HTTPBasicCredentials from dependency_injector import containers, providers from dependency_injector.wiring import inject, Provide @@ -28,11 +32,16 @@ async def index(service: Service = Depends(Provide[Container.service])): return {"result": result} +@app.api_route("/annotated") +@inject +async def annotated(service: Annotated[Service, Depends(Provide[Container.service])]): + result = await service.process() + return {"result": result} + + @app.get("/auth") @inject -def read_current_user( - credentials: HTTPBasicCredentials = Depends(security) -): +def read_current_user(credentials: HTTPBasicCredentials = Depends(security)): return {"username": credentials.username, "password": credentials.password} diff --git a/tests/unit/samples/wiringflask/web.py b/tests/unit/samples/wiringflask/web.py index f273d8aa..8bb44494 100644 --- a/tests/unit/samples/wiringflask/web.py +++ b/tests/unit/samples/wiringflask/web.py @@ -1,3 +1,5 @@ +from typing_extensions import Annotated + from flask import Flask, jsonify, request, current_app, session, g from dependency_injector import containers, providers from dependency_injector.wiring import inject, Provide @@ -26,5 +28,12 @@ def index(service: Service = Provide[Container.service]): return jsonify({"result": result}) +@app.route("/annotated") +@inject +def annotated(service: Annotated[Service, Provide[Container.service]]): + result = service.process() + return jsonify({"result": result}) + + container = Container() container.wire(modules=[__name__]) diff --git a/tests/unit/wiring/test_fastapi_py36.py b/tests/unit/wiring/test_fastapi_py36.py index 1e9ff584..491c991c 100644 --- a/tests/unit/wiring/test_fastapi_py36.py +++ b/tests/unit/wiring/test_fastapi_py36.py @@ -4,13 +4,17 @@ # Runtime import to avoid syntax errors in samples on Python < 3.5 and reach top-dir import os + _SAMPLES_DIR = os.path.abspath( - os.path.sep.join(( - os.path.dirname(__file__), - "../samples/", - )), + os.path.sep.join( + ( + os.path.dirname(__file__), + "../samples/", + ) + ), ) import sys + sys.path.append(_SAMPLES_DIR) @@ -37,6 +41,19 @@ async def process(self): assert response.json() == {"result": "Foo"} +@mark.asyncio +async def test_depends_with_annotated(async_client: AsyncClient): + class ServiceMock: + async def process(self): + return "Foo" + + with web.container.service.override(ServiceMock()): + response = await async_client.get("/") + + assert response.status_code == 200 + assert response.json() == {"result": "Foo"} + + @mark.asyncio async def test_depends_injection(async_client: AsyncClient): response = await async_client.get("/auth", auth=("john_smith", "secret")) diff --git a/tests/unit/wiring/test_flask_py36.py b/tests/unit/wiring/test_flask_py36.py index 751f04d8..97420275 100644 --- a/tests/unit/wiring/test_flask_py36.py +++ b/tests/unit/wiring/test_flask_py36.py @@ -2,19 +2,25 @@ # Runtime import to avoid syntax errors in samples on Python < 3.5 and reach top-dir import os + _TOP_DIR = os.path.abspath( - os.path.sep.join(( - os.path.dirname(__file__), - "../", - )), + os.path.sep.join( + ( + os.path.dirname(__file__), + "../", + ) + ), ) _SAMPLES_DIR = os.path.abspath( - os.path.sep.join(( - os.path.dirname(__file__), - "../samples/", - )), + os.path.sep.join( + ( + os.path.dirname(__file__), + "../samples/", + ) + ), ) import sys + sys.path.append(_TOP_DIR) sys.path.append(_SAMPLES_DIR) @@ -29,3 +35,13 @@ def test_wiring_with_flask(): assert response.status_code == 200 assert json.loads(response.data) == {"result": "OK"} + + +def test_wiring_with_annotated(): + client = web.app.test_client() + + with web.app.app_context(): + response = client.get("/annotated") + + assert response.status_code == 200 + assert json.loads(response.data) == {"result": "OK"}