20
20
TypeVar ,
21
21
Type ,
22
22
Union ,
23
+ Set ,
23
24
cast ,
24
25
)
25
26
@@ -82,22 +83,53 @@ class GenericMeta(type):
82
83
Container = Any
83
84
84
85
85
- class Registry :
86
+ class PatchedRegistry :
86
87
87
88
def __init__ (self ):
88
- self ._storage = set ()
89
+ self ._callables : Set [Callable [..., Any ]] = set ()
90
+ self ._attributes : Set [PatchedAttribute ] = set ()
89
91
90
- def add (self , patched : Callable [..., Any ]) -> None :
91
- self ._storage .add (patched )
92
+ def add_callable (self , patched : Callable [..., Any ]) -> None :
93
+ self ._callables .add (patched )
92
94
93
- def get_from_module (self , module : ModuleType ) -> Iterator [Callable [..., Any ]]:
94
- for patched in self ._storage :
95
+ def get_callables_from_module (self , module : ModuleType ) -> Iterator [Callable [..., Any ]]:
96
+ for patched in self ._callables :
95
97
if patched .__module__ != module .__name__ :
96
98
continue
97
99
yield patched
98
100
101
+ def add_attribute (self , patched : 'PatchedAttribute' ):
102
+ self ._attributes .add (patched )
99
103
100
- _patched_registry = Registry ()
104
+ def get_attributes_from_module (self , module : ModuleType ) -> Iterator ['PatchedAttribute' ]:
105
+ for attribute in self ._attributes :
106
+ if not attribute .is_in_module (module ):
107
+ continue
108
+ yield attribute
109
+
110
+ def clear_module_attributes (self , module : ModuleType ):
111
+ for attribute in self ._attributes .copy ():
112
+ if not attribute .is_in_module (module ):
113
+ continue
114
+ self ._attributes .remove (attribute )
115
+
116
+
117
+ class PatchedAttribute :
118
+
119
+ def __init__ (self , member : Any , name : str , marker : '_Marker' ):
120
+ self .member = member
121
+ self .name = name
122
+ self .marker = marker
123
+
124
+ @property
125
+ def module_name (self ) -> str :
126
+ if isinstance (self .member , ModuleType ):
127
+ return self .member .__name__
128
+ else :
129
+ return self .member .__module__
130
+
131
+ def is_in_module (self , module : ModuleType ) -> bool :
132
+ return self .module_name == module .__name__
101
133
102
134
103
135
class ProvidersMap :
@@ -278,9 +310,6 @@ def _is_starlette_request_cls(self, instance: object) -> bool:
278
310
and issubclass (instance , starlette .requests .Request )
279
311
280
312
281
- inspect_filter = InspectFilter ()
282
-
283
-
284
313
def wire ( # noqa: C901
285
314
container : Container ,
286
315
* ,
@@ -301,20 +330,27 @@ def wire( # noqa: C901
301
330
providers_map = ProvidersMap (container )
302
331
303
332
for module in modules :
304
- for name , member in inspect .getmembers (module ):
305
- if inspect_filter .is_excluded (member ):
333
+ for member_name , member in inspect .getmembers (module ):
334
+ if _inspect_filter .is_excluded (member ):
306
335
continue
307
- if inspect .isfunction (member ):
308
- _patch_fn (module , name , member , providers_map )
309
- elif inspect .isclass (member ):
310
- for method_name , method in inspect .getmembers (member , _is_method ):
311
- _patch_method (member , method_name , method , providers_map )
312
336
313
- for patched in _patched_registry .get_from_module (module ):
337
+ if _is_marker (member ):
338
+ _patch_attribute (module , member_name , member , providers_map )
339
+ elif inspect .isfunction (member ):
340
+ _patch_fn (module , member_name , member , providers_map )
341
+ elif inspect .isclass (member ):
342
+ cls = member
343
+ for cls_member_name , cls_member in inspect .getmembers (cls ):
344
+ if _is_marker (cls_member ):
345
+ _patch_attribute (cls , cls_member_name , cls_member , providers_map )
346
+ elif _is_method (cls_member ):
347
+ _patch_method (cls , cls_member_name , cls_member , providers_map )
348
+
349
+ for patched in _patched_registry .get_callables_from_module (module ):
314
350
_bind_injections (patched , providers_map )
315
351
316
352
317
- def unwire (
353
+ def unwire ( # noqa: C901
318
354
* ,
319
355
modules : Optional [Iterable [ModuleType ]] = None ,
320
356
packages : Optional [Iterable [ModuleType ]] = None ,
@@ -335,15 +371,19 @@ def unwire(
335
371
for method_name , method in inspect .getmembers (member , inspect .isfunction ):
336
372
_unpatch (member , method_name , method )
337
373
338
- for patched in _patched_registry .get_from_module (module ):
374
+ for patched in _patched_registry .get_callables_from_module (module ):
339
375
_unbind_injections (patched )
340
376
377
+ for patched_attribute in _patched_registry .get_attributes_from_module (module ):
378
+ _unpatch_attribute (patched_attribute )
379
+ _patched_registry .clear_module_attributes (module )
380
+
341
381
342
382
def inject (fn : F ) -> F :
343
383
"""Decorate callable with injecting decorator."""
344
384
reference_injections , reference_closing = _fetch_reference_injections (fn )
345
385
patched = _get_patched (fn , reference_injections , reference_closing )
346
- _patched_registry .add (patched )
386
+ _patched_registry .add_callable (patched )
347
387
return cast (F , patched )
348
388
349
389
@@ -358,7 +398,7 @@ def _patch_fn(
358
398
if not reference_injections :
359
399
return
360
400
fn = _get_patched (fn , reference_injections , reference_closing )
361
- _patched_registry .add (fn )
401
+ _patched_registry .add_callable (fn )
362
402
363
403
_bind_injections (fn , providers_map )
364
404
@@ -384,7 +424,7 @@ def _patch_method(
384
424
if not reference_injections :
385
425
return
386
426
fn = _get_patched (fn , reference_injections , reference_closing )
387
- _patched_registry .add (fn )
427
+ _patched_registry .add_callable (fn )
388
428
389
429
_bind_injections (fn , providers_map )
390
430
@@ -411,6 +451,31 @@ def _unpatch(
411
451
_unbind_injections (fn )
412
452
413
453
454
+ def _patch_attribute (
455
+ member : Any ,
456
+ name : str ,
457
+ marker : '_Marker' ,
458
+ providers_map : ProvidersMap ,
459
+ ) -> None :
460
+ provider = providers_map .resolve_provider (marker .provider , marker .modifier )
461
+ if provider is None :
462
+ return
463
+
464
+ _patched_registry .add_attribute (PatchedAttribute (member , name , marker ))
465
+
466
+ if isinstance (marker , Provide ):
467
+ instance = provider ()
468
+ setattr (member , name , instance )
469
+ elif isinstance (marker , Provider ):
470
+ setattr (member , name , provider )
471
+ else :
472
+ raise Exception (f'Unknown type of marker { marker } ' )
473
+
474
+
475
+ def _unpatch_attribute (patched : PatchedAttribute ) -> None :
476
+ setattr (patched .member , patched .name , patched .marker )
477
+
478
+
414
479
def _fetch_reference_injections (
415
480
fn : Callable [..., Any ],
416
481
) -> Tuple [Dict [str , Any ], Dict [str , Any ]]:
@@ -484,6 +549,10 @@ def _is_method(member):
484
549
return inspect .ismethod (member ) or inspect .isfunction (member )
485
550
486
551
552
+ def _is_marker (member ):
553
+ return isinstance (member , _Marker )
554
+
555
+
487
556
def _get_patched (fn , reference_injections , reference_closing ):
488
557
if inspect .iscoroutinefunction (fn ):
489
558
patched = _get_async_patched (fn )
@@ -825,9 +894,6 @@ def uninstall(self):
825
894
importlib .invalidate_caches ()
826
895
827
896
828
- _loader = AutoLoader ()
829
-
830
-
831
897
def register_loader_containers (* containers : Container ) -> None :
832
898
"""Register containers in auto-wiring module loader."""
833
899
_loader .register_containers (* containers )
@@ -851,3 +917,8 @@ def uninstall_loader() -> None:
851
917
def is_loader_installed () -> bool :
852
918
"""Check if auto-wiring module loader hook is installed."""
853
919
return _loader .installed
920
+
921
+
922
+ _patched_registry = PatchedRegistry ()
923
+ _inspect_filter = InspectFilter ()
924
+ _loader = AutoLoader ()
0 commit comments