4
4
5
5
from functools import wraps
6
6
from inspect import signature
7
+ from typing import TYPE_CHECKING
7
8
8
- def get_xp (xp ):
9
+ __all__ = ["get_xp" ]
10
+
11
+ if TYPE_CHECKING :
12
+ from collections .abc import Callable
13
+ from types import ModuleType
14
+ from typing import TypeVar
15
+
16
+ _T = TypeVar ("_T" )
17
+
18
+
19
+ def get_xp (xp : "ModuleType" ) -> "Callable[[Callable[..., _T]], Callable[..., _T]]" :
9
20
"""
10
21
Decorator to automatically replace xp with the corresponding array module.
11
22
@@ -22,14 +33,14 @@ def func(x, /, xp, kwarg=None):
22
33
23
34
"""
24
35
25
- def inner (f ) :
36
+ def inner (f : "Callable[..., _T]" , / ) -> "Callable[..., _T]" :
26
37
@wraps (f )
27
- def wrapped_f (* args , ** kwargs ) :
38
+ def wrapped_f (* args : object , ** kwargs : object ) -> object :
28
39
return f (* args , xp = xp , ** kwargs )
29
40
30
41
sig = signature (f )
31
42
new_sig = sig .replace (
32
- parameters = [sig . parameters [ i ] for i in sig .parameters if i != "xp" ]
43
+ parameters = [par for i , par in sig .parameters . items () if i != "xp" ]
33
44
)
34
45
35
46
if wrapped_f .__doc__ is None :
@@ -40,7 +51,7 @@ def wrapped_f(*args, **kwargs):
40
51
specification for more details.
41
52
42
53
"""
43
- wrapped_f .__signature__ = new_sig
44
- return wrapped_f
54
+ wrapped_f .__signature__ = new_sig # pyright: ignore[reportAttributeAccessIssue]
55
+ return wrapped_f # pyright: ignore[reportReturnType]
45
56
46
57
return inner
0 commit comments