Skip to content

Commit 18870dc

Browse files
committed
TYP: fix typing errors in numpy._aliases
1 parent 9643256 commit 18870dc

File tree

1 file changed

+52
-29
lines changed

1 file changed

+52
-29
lines changed

array_api_compat/numpy/_aliases.py

Lines changed: 52 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,23 @@
1+
# pyright: reportPrivateUsage=false
12
from __future__ import annotations
23

3-
from typing import Optional, Union
4+
from builtins import bool as py_bool
5+
from typing import TYPE_CHECKING, cast
6+
7+
import numpy as np
48

59
from .._internal import get_xp
610
from ..common import _aliases, _helpers
711
from ..common._typing import NestedSequence, SupportsBufferProtocol
812
from ._info import __array_namespace_info__
913
from ._typing import Array, Device, DType
1014

11-
import numpy as np
15+
if TYPE_CHECKING:
16+
from typing import Any, Literal, TypeAlias
17+
18+
from typing_extensions import Buffer, TypeIs
19+
20+
_Copy: TypeAlias = py_bool | Literal[2] | np._CopyMode
1221

1322
bool = np.bool_
1423

@@ -65,9 +74,9 @@
6574
iinfo = get_xp(np)(_aliases.iinfo)
6675

6776

68-
def _supports_buffer_protocol(obj):
77+
def _supports_buffer_protocol(obj: object) -> TypeIs[Buffer]: # pyright: ignore[reportUnusedFunction]
6978
try:
70-
memoryview(obj)
79+
memoryview(obj) # pyright: ignore[reportArgumentType]
7180
except TypeError:
7281
return False
7382
return True
@@ -78,18 +87,13 @@ def _supports_buffer_protocol(obj):
7887
# complicated enough that it's easier to define it separately for each module
7988
# rather than trying to combine everything into one function in common/
8089
def asarray(
81-
obj: (
82-
Array
83-
| bool | int | float | complex
84-
| NestedSequence[bool | int | float | complex]
85-
| SupportsBufferProtocol
86-
),
90+
obj: Array | complex | NestedSequence[complex] | SupportsBufferProtocol,
8791
/,
8892
*,
89-
dtype: Optional[DType] = None,
90-
device: Optional[Device] = None,
91-
copy: Optional[Union[bool, np._CopyMode]] = None,
92-
**kwargs,
93+
dtype: DType | None = None,
94+
device: Device | None = None,
95+
copy: _Copy | None = None,
96+
**kwargs: Any,
9397
) -> Array:
9498
"""
9599
Array API compatibility wrapper for asarray().
@@ -106,51 +110,70 @@ def asarray(
106110
elif copy is True:
107111
copy = np._CopyMode.ALWAYS
108112

109-
return np.array(obj, copy=copy, dtype=dtype, **kwargs)
113+
return np.array(obj, copy=copy, dtype=dtype, **kwargs) # pyright: ignore
110114

111115

112116
def astype(
113117
x: Array,
114118
dtype: DType,
115119
/,
116120
*,
117-
copy: bool = True,
118-
device: Optional[Device] = None,
121+
copy: py_bool = True,
122+
device: Device | None = None,
119123
) -> Array:
120124
_helpers._check_device(np, device)
121125
return x.astype(dtype=dtype, copy=copy)
122126

123127

124128
# count_nonzero returns a python int for axis=None and keepdims=False
125129
# https://github.com/numpy/numpy/issues/17562
126-
def count_nonzero(x: Array, axis=None, keepdims=False) -> Array:
127-
result = np.count_nonzero(x, axis=axis, keepdims=keepdims)
130+
def count_nonzero(
131+
x: Array,
132+
axis: int | tuple[int, ...] | None = None,
133+
keepdims: py_bool = False,
134+
) -> Array:
135+
result = cast("Any", np.count_nonzero(x, axis=axis, keepdims=keepdims)) # pyright: ignore
128136
if axis is None and not keepdims:
129137
return np.asarray(result)
130138
return result
131139

132140

133141
# These functions are completely new here. If the library already has them
134142
# (i.e., numpy 2.0), use the library version instead of our wrapper.
135-
if hasattr(np, 'vecdot'):
143+
if hasattr(np, "vecdot"):
136144
vecdot = np.vecdot
137145
else:
138146
vecdot = get_xp(np)(_aliases.vecdot)
139147

140-
if hasattr(np, 'isdtype'):
148+
if hasattr(np, "isdtype"):
141149
isdtype = np.isdtype
142150
else:
143151
isdtype = get_xp(np)(_aliases.isdtype)
144152

145-
if hasattr(np, 'unstack'):
153+
if hasattr(np, "unstack"):
146154
unstack = np.unstack
147155
else:
148156
unstack = get_xp(np)(_aliases.unstack)
149157

150-
__all__ = _aliases.__all__ + ['__array_namespace_info__', 'asarray', 'astype',
151-
'acos', 'acosh', 'asin', 'asinh', 'atan',
152-
'atan2', 'atanh', 'bitwise_left_shift',
153-
'bitwise_invert', 'bitwise_right_shift',
154-
'bool', 'concat', 'count_nonzero', 'pow']
155-
156-
_all_ignore = ['np', 'get_xp']
158+
__all__ = [
159+
"__array_namespace_info__",
160+
"asarray",
161+
"astype",
162+
"acos",
163+
"acosh",
164+
"asin",
165+
"asinh",
166+
"atan",
167+
"atan2",
168+
"atanh",
169+
"bitwise_left_shift",
170+
"bitwise_invert",
171+
"bitwise_right_shift",
172+
"bool",
173+
"concat",
174+
"count_nonzero",
175+
"pow",
176+
]
177+
__all__ += _aliases.__all__
178+
179+
_all_ignore = ["np", "get_xp"]

0 commit comments

Comments
 (0)