1
+ # pyright: reportPrivateUsage=false
1
2
from __future__ import annotations
2
3
3
- from typing import Optional , Union
4
+ from builtins import bool as py_bool
5
+ from typing import TYPE_CHECKING , cast
6
+
7
+ import numpy as np
4
8
5
9
from .._internal import get_xp
6
10
from ..common import _aliases , _helpers
7
11
from ..common ._typing import NestedSequence , SupportsBufferProtocol
8
12
from ._info import __array_namespace_info__
9
13
from ._typing import Array , Device , DType
10
14
11
- import numpy as np
15
+ if TYPE_CHECKING :
16
+ from typing import Any , Literal , TypeAlias
17
+
18
+ from typing_extensions import Buffer , TypeIs
19
+
20
+ _Copy : TypeAlias = py_bool | Literal [2 ] | np ._CopyMode
12
21
13
22
bool = np .bool_
14
23
65
74
iinfo = get_xp (np )(_aliases .iinfo )
66
75
67
76
68
- def _supports_buffer_protocol (obj ):
77
+ def _supports_buffer_protocol (obj : object ) -> TypeIs [ Buffer ]: # pyright: ignore[reportUnusedFunction]
69
78
try :
70
- memoryview (obj )
79
+ memoryview (obj ) # pyright: ignore[reportArgumentType]
71
80
except TypeError :
72
81
return False
73
82
return True
@@ -78,18 +87,13 @@ def _supports_buffer_protocol(obj):
78
87
# complicated enough that it's easier to define it separately for each module
79
88
# rather than trying to combine everything into one function in common/
80
89
def asarray (
81
- obj : (
82
- Array
83
- | bool | int | float | complex
84
- | NestedSequence [bool | int | float | complex ]
85
- | SupportsBufferProtocol
86
- ),
90
+ obj : Array | complex | NestedSequence [complex ] | SupportsBufferProtocol ,
87
91
/ ,
88
92
* ,
89
- dtype : Optional [ DType ] = None ,
90
- device : Optional [ Device ] = None ,
91
- copy : Optional [ Union [ bool , np . _CopyMode ]] = None ,
92
- ** kwargs ,
93
+ dtype : DType | None = None ,
94
+ device : Device | None = None ,
95
+ copy : _Copy | None = None ,
96
+ ** kwargs : Any ,
93
97
) -> Array :
94
98
"""
95
99
Array API compatibility wrapper for asarray().
@@ -106,51 +110,70 @@ def asarray(
106
110
elif copy is True :
107
111
copy = np ._CopyMode .ALWAYS
108
112
109
- return np .array (obj , copy = copy , dtype = dtype , ** kwargs )
113
+ return np .array (obj , copy = copy , dtype = dtype , ** kwargs ) # pyright: ignore
110
114
111
115
112
116
def astype (
113
117
x : Array ,
114
118
dtype : DType ,
115
119
/ ,
116
120
* ,
117
- copy : bool = True ,
118
- device : Optional [ Device ] = None ,
121
+ copy : py_bool = True ,
122
+ device : Device | None = None ,
119
123
) -> Array :
120
124
_helpers ._check_device (np , device )
121
125
return x .astype (dtype = dtype , copy = copy )
122
126
123
127
124
128
# count_nonzero returns a python int for axis=None and keepdims=False
125
129
# https://github.com/numpy/numpy/issues/17562
126
- def count_nonzero (x : Array , axis = None , keepdims = False ) -> Array :
127
- result = np .count_nonzero (x , axis = axis , keepdims = keepdims )
130
+ def count_nonzero (
131
+ x : Array ,
132
+ axis : int | tuple [int , ...] | None = None ,
133
+ keepdims : py_bool = False ,
134
+ ) -> Array :
135
+ result = cast ("Any" , np .count_nonzero (x , axis = axis , keepdims = keepdims )) # pyright: ignore
128
136
if axis is None and not keepdims :
129
137
return np .asarray (result )
130
138
return result
131
139
132
140
133
141
# These functions are completely new here. If the library already has them
134
142
# (i.e., numpy 2.0), use the library version instead of our wrapper.
135
- if hasattr (np , ' vecdot' ):
143
+ if hasattr (np , " vecdot" ):
136
144
vecdot = np .vecdot
137
145
else :
138
146
vecdot = get_xp (np )(_aliases .vecdot )
139
147
140
- if hasattr (np , ' isdtype' ):
148
+ if hasattr (np , " isdtype" ):
141
149
isdtype = np .isdtype
142
150
else :
143
151
isdtype = get_xp (np )(_aliases .isdtype )
144
152
145
- if hasattr (np , ' unstack' ):
153
+ if hasattr (np , " unstack" ):
146
154
unstack = np .unstack
147
155
else :
148
156
unstack = get_xp (np )(_aliases .unstack )
149
157
150
- __all__ = _aliases .__all__ + ['__array_namespace_info__' , 'asarray' , 'astype' ,
151
- 'acos' , 'acosh' , 'asin' , 'asinh' , 'atan' ,
152
- 'atan2' , 'atanh' , 'bitwise_left_shift' ,
153
- 'bitwise_invert' , 'bitwise_right_shift' ,
154
- 'bool' , 'concat' , 'count_nonzero' , 'pow' ]
155
-
156
- _all_ignore = ['np' , 'get_xp' ]
158
+ __all__ = [
159
+ "__array_namespace_info__" ,
160
+ "asarray" ,
161
+ "astype" ,
162
+ "acos" ,
163
+ "acosh" ,
164
+ "asin" ,
165
+ "asinh" ,
166
+ "atan" ,
167
+ "atan2" ,
168
+ "atanh" ,
169
+ "bitwise_left_shift" ,
170
+ "bitwise_invert" ,
171
+ "bitwise_right_shift" ,
172
+ "bool" ,
173
+ "concat" ,
174
+ "count_nonzero" ,
175
+ "pow" ,
176
+ ]
177
+ __all__ += _aliases .__all__
178
+
179
+ _all_ignore = ["np" , "get_xp" ]
0 commit comments