Skip to content

Introduce wiring inspect filter #412

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 27, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions docs/main/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,16 @@ follows `Semantic versioning`_

Development version
-------------------
- Introduce wiring inspect filter to filter out ``flask.request`` and other local proxy objects
from the inspection.
See issue: `#408 <https://github.com/ets-labs/python-dependency-injector/issues/408>`_.
Many thanks to `@bvanfleet <https://github.com/bvanfleet>`_ for reporting the issue and
help in finding the root cause.
- Add ``boto3`` example.
- Add tests for ``.as_float()`` modifier usage with wiring.
- Make refactoring of wiring module and tests.
See PR # `#406 <https://github.com/ets-labs/python-dependency-injector/issues/406>`_.
Thanks to `@withshubh <https://github.com/withshubh>`_ for the contribution:
- Refactor unnecessary ``else`` / ``elif`` in ``wiring`` module when ``if`` block has a
return statement.
- Remove unused imports in tests.
- Use literal syntax to create data structure in tests.
- Add integration with a static analysis tool `DeepSource <https://deepsource.io/>`_.
Expand Down
54 changes: 45 additions & 9 deletions src/dependency_injector/wiring.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,21 @@ class GenericMeta(type):


try:
from fastapi.params import Depends as FastAPIDepends
fastapi_installed = True
import fastapi.params
except ImportError:
fastapi_installed = False
fastapi = None


try:
import starlette.requests
except ImportError:
starlette = None


try:
import werkzeug.local
except ImportError:
werkzeug = None


from . import providers
Expand Down Expand Up @@ -111,20 +122,21 @@ def resolve_provider(
) -> Optional[providers.Provider]:
if isinstance(provider, providers.Delegate):
return self._resolve_delegate(provider)
if isinstance(provider, (
elif isinstance(provider, (
providers.ProvidedInstance,
providers.AttributeGetter,
providers.ItemGetter,
providers.MethodCaller,
)):
return self._resolve_provided_instance(provider)
if isinstance(provider, providers.ConfigurationOption):
elif isinstance(provider, providers.ConfigurationOption):
return self._resolve_config_option(provider)
if isinstance(provider, providers.TypedConfigurationOption):
elif isinstance(provider, providers.TypedConfigurationOption):
return self._resolve_config_option(provider.option, as_=provider.provides)
if isinstance(provider, str):
elif isinstance(provider, str):
return self._resolve_string_id(provider, modifier)
return self._resolve_provider(provider)
else:
return self._resolve_provider(provider)

def _resolve_string_id(
self,
Expand Down Expand Up @@ -247,6 +259,28 @@ def _create_providers_map(
return providers_map


class InspectFilter:

def is_excluded(self, instance: object) -> bool:
if self._is_werkzeug_local_proxy(instance):
return True
elif self._is_starlette_request_cls(instance):
return True
else:
return False

def _is_werkzeug_local_proxy(self, instance: object) -> bool:
return werkzeug and isinstance(instance, werkzeug.local.LocalProxy)

def _is_starlette_request_cls(self, instance: object) -> bool:
return starlette \
and isinstance(instance, type) \
and issubclass(instance, starlette.requests.Request)


inspect_filter = InspectFilter()


def wire( # noqa: C901
container: Container,
*,
Expand All @@ -268,6 +302,8 @@ def wire( # noqa: C901

for module in modules:
for name, member in inspect.getmembers(module):
if inspect_filter.is_excluded(member):
continue
if inspect.isfunction(member):
_patch_fn(module, name, member, providers_map)
elif inspect.isclass(member):
Expand Down Expand Up @@ -530,7 +566,7 @@ def _is_fastapi_default_arg_injection(injection, kwargs):


def _is_fastapi_depends(param: Any) -> bool:
return fastapi_installed and isinstance(param, FastAPIDepends)
return fastapi and isinstance(param, fastapi.params.Depends)


def _is_patched(fn):
Expand Down
34 changes: 34 additions & 0 deletions tests/unit/samples/wiringflask/web.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import sys

from flask import Flask, jsonify, request, current_app, session, g
from flask import _request_ctx_stack, _app_ctx_stack
from dependency_injector import containers, providers
from dependency_injector.wiring import inject, Provide

# This is here for testing wiring bypasses these objects without crashing
request, current_app, session, g # noqa
_request_ctx_stack, _app_ctx_stack # noqa


class Service:
def process(self) -> str:
return 'Ok'


class Container(containers.DeclarativeContainer):

service = providers.Factory(Service)


app = Flask(__name__)


@app.route('/')
@inject
def index(service: Service = Provide[Container.service]):
result = service.process()
return jsonify({'result': result})


container = Container()
container.wire(modules=[sys.modules[__name__]])
33 changes: 33 additions & 0 deletions tests/unit/wiring/test_wiringflask_py36.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import unittest

# Runtime import to avoid syntax errors in samples on Python < 3.5 and reach top-dir
import os
_TOP_DIR = os.path.abspath(
os.path.sep.join((
os.path.dirname(__file__),
'../',
)),
)
_SAMPLES_DIR = os.path.abspath(
os.path.sep.join((
os.path.dirname(__file__),
'../samples/',
)),
)
import sys
sys.path.append(_TOP_DIR)
sys.path.append(_SAMPLES_DIR)

from wiringflask import web


class WiringFlaskTest(unittest.TestCase):

def test(self):
client = web.app.test_client()

with web.app.app_context():
response = client.get('/')

self.assertEqual(response.status_code, 200)
self.assertEqual(response.data, b'{"result":"Ok"}\n')