Skip to content

Commit b9b0206

Browse files
committed
TYP: annotate _internal.get_xp (and curse at ParamSpec for being so useless)
1 parent 9194c5c commit b9b0206

File tree

1 file changed

+17
-6
lines changed

1 file changed

+17
-6
lines changed

array_api_compat/_internal.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,19 @@
44

55
from functools import wraps
66
from inspect import signature
7+
from typing import TYPE_CHECKING
78

8-
def get_xp(xp):
9+
__all__ = ["get_xp"]
10+
11+
if TYPE_CHECKING:
12+
from collections.abc import Callable
13+
from types import ModuleType
14+
from typing import TypeVar
15+
16+
_T = TypeVar("_T")
17+
18+
19+
def get_xp(xp: "ModuleType") -> "Callable[[Callable[..., _T]], Callable[..., _T]]":
920
"""
1021
Decorator to automatically replace xp with the corresponding array module.
1122
@@ -22,14 +33,14 @@ def func(x, /, xp, kwarg=None):
2233
2334
"""
2435

25-
def inner(f):
36+
def inner(f: "Callable[..., _T]", /) -> "Callable[..., _T]":
2637
@wraps(f)
27-
def wrapped_f(*args, **kwargs):
38+
def wrapped_f(*args: object, **kwargs: object) -> object:
2839
return f(*args, xp=xp, **kwargs)
2940

3041
sig = signature(f)
3142
new_sig = sig.replace(
32-
parameters=[sig.parameters[i] for i in sig.parameters if i != "xp"]
43+
parameters=[par for i, par in sig.parameters.items() if i != "xp"]
3344
)
3445

3546
if wrapped_f.__doc__ is None:
@@ -40,7 +51,7 @@ def wrapped_f(*args, **kwargs):
4051
specification for more details.
4152
4253
"""
43-
wrapped_f.__signature__ = new_sig
44-
return wrapped_f
54+
wrapped_f.__signature__ = new_sig # pyright: ignore[reportAttributeAccessIssue]
55+
return wrapped_f # pyright: ignore[reportReturnType]
4556

4657
return inner

0 commit comments

Comments
 (0)