Skip to content

Commit 633f60a

Browse files
committed
Add support for typing.Annotated
1 parent cc2304e commit 633f60a

File tree

7 files changed

+84
-11
lines changed

7 files changed

+84
-11
lines changed

examples/wiring/example.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from dependency_injector import containers, providers
44
from dependency_injector.wiring import Provide, inject
5+
from typing import Annotated
56

67

78
class Service:
@@ -12,12 +13,18 @@ class Container(containers.DeclarativeContainer):
1213

1314
service = providers.Factory(Service)
1415

15-
16+
# You can place marker on parameter default value
1617
@inject
1718
def main(service: Service = Provide[Container.service]) -> None:
1819
...
1920

2021

22+
# Also, you can place marker with typing.Annotated
23+
@inject
24+
def main_with_annotated(service: Annotated[Service, Provide[Container.service]]) -> None:
25+
...
26+
27+
2128
if __name__ == "__main__":
2229
container = Container()
2330
container.wire(modules=[__name__])

requirements-dev.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,5 +16,6 @@ numpy
1616
scipy
1717
boto3
1818
mypy_boto3_s3
19+
typing_extensions
1920

2021
-r requirements-ext.txt

src/dependency_injector/wiring.py

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,20 @@ class GenericMeta(type):
3636
else:
3737
GenericAlias = None
3838

39+
if sys.version_info >= (3, 9):
40+
from typing import Annotated, get_args, get_origin
41+
else:
42+
try:
43+
from typing_extensions import Annotated, get_args, get_origin
44+
except ImportError:
45+
Annotated = object()
46+
47+
# For preventing NameError. Never executes
48+
def get_args(hint):
49+
return ()
50+
51+
def get_origin(tp):
52+
return None
3953

4054
try:
4155
import fastapi.params
@@ -548,6 +562,24 @@ def _unpatch_attribute(patched: PatchedAttribute) -> None:
548562
setattr(patched.member, patched.name, patched.marker)
549563

550564

565+
def _extract_marker(parameter: inspect.Parameter) -> Optional["_Marker"]:
566+
if get_origin(parameter.annotation) is Annotated:
567+
marker = get_args(parameter.annotation)[1]
568+
else:
569+
marker = parameter.default
570+
571+
if not isinstance(marker, _Marker) and not _is_fastapi_depends(marker):
572+
return None
573+
574+
if _is_fastapi_depends(marker):
575+
marker = marker.dependency
576+
577+
if not isinstance(marker, _Marker):
578+
return None
579+
580+
return marker
581+
582+
551583
def _fetch_reference_injections( # noqa: C901
552584
fn: Callable[..., Any],
553585
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
@@ -573,17 +605,10 @@ def _fetch_reference_injections( # noqa: C901
573605
injections = {}
574606
closing = {}
575607
for parameter_name, parameter in signature.parameters.items():
576-
if not isinstance(parameter.default, _Marker) \
577-
and not _is_fastapi_depends(parameter.default):
578-
continue
608+
marker = _extract_marker(parameter)
579609

580-
marker = parameter.default
581-
582-
if _is_fastapi_depends(marker):
583-
marker = marker.dependency
584-
585-
if not isinstance(marker, _Marker):
586-
continue
610+
if marker is None:
611+
continue
587612

588613
if isinstance(marker, Closing):
589614
marker = marker.provider

tests/unit/samples/wiringfastapi/web.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import sys
22

3+
from typing_extensions import Annotated
4+
35
from fastapi import FastAPI, Depends
46
from fastapi import Request # See: https://github.com/ets-labs/python-dependency-injector/issues/398
57
from fastapi.security import HTTPBasic, HTTPBasicCredentials
@@ -27,6 +29,11 @@ async def index(service: Service = Depends(Provide[Container.service])):
2729
result = await service.process()
2830
return {"result": result}
2931

32+
@app.api_route('/annotated')
33+
@inject
34+
async def annotated(service: Annotated[Service, Depends(Provide[Container.service])]):
35+
result = await service.process()
36+
return {'result': result}
3037

3138
@app.get("/auth")
3239
@inject

tests/unit/samples/wiringflask/web.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing_extensions import Annotated
2+
13
from flask import Flask, jsonify, request, current_app, session, g
24
from flask import _request_ctx_stack, _app_ctx_stack
35
from dependency_injector import containers, providers
@@ -28,5 +30,12 @@ def index(service: Service = Provide[Container.service]):
2830
return jsonify({"result": result})
2931

3032

33+
@app.route("/annotated")
34+
@inject
35+
def annotated(service: Annotated[Service, Provide[Container.service]]):
36+
result = service.process()
37+
return jsonify({'result': result})
38+
39+
3140
container = Container()
3241
container.wire(modules=[__name__])

tests/unit/wiring/test_fastapi_py36.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,20 @@ async def process(self):
3636
assert response.json() == {"result": "Foo"}
3737

3838

39+
@mark.asyncio
40+
async def test_depends_with_annotated(async_client: AsyncClient):
41+
class ServiceMock:
42+
async def process(self):
43+
return "Foo"
44+
45+
with web.container.service.override(ServiceMock()):
46+
response = await async_client.get("/")
47+
48+
assert response.status_code == 200
49+
assert response.json() == {"result": "Foo"}
50+
51+
52+
3953
@mark.asyncio
4054
async def test_depends_injection(async_client: AsyncClient):
4155
response = await async_client.get("/auth", auth=("john_smith", "secret"))

tests/unit/wiring/test_flask_py36.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,13 @@ def test_wiring_with_flask():
2929

3030
assert response.status_code == 200
3131
assert json.loads(response.data) == {"result": "OK"}
32+
33+
34+
def test_wiring_with_annotated():
35+
client = web.app.test_client()
36+
37+
with web.app.app_context():
38+
response = client.get("/annotated")
39+
40+
assert response.status_code == 200
41+
assert json.loads(response.data) == {"result": "OK"}

0 commit comments

Comments
 (0)