10
10
from typing import TYPE_CHECKING
11
11
12
12
if TYPE_CHECKING :
13
+ from types import ModuleType
13
14
from typing import Optional , Union , Any
14
15
from ._typing import Array , Device
15
16
18
19
import inspect
19
20
import warnings
20
21
21
- def _is_jax_zero_gradient_array (x ) :
22
+ def _is_jax_zero_gradient_array (x : object ) -> bool :
22
23
"""Return True if `x` is a zero-gradient array.
23
24
24
25
These arrays are a design quirk of Jax that may one day be removed.
@@ -32,7 +33,8 @@ def _is_jax_zero_gradient_array(x):
32
33
33
34
return isinstance (x , np .ndarray ) and x .dtype == jax .float0
34
35
35
- def is_numpy_array (x ):
36
+
37
+ def is_numpy_array (x : object ) -> bool :
36
38
"""
37
39
Return True if `x` is a NumPy array.
38
40
@@ -63,7 +65,8 @@ def is_numpy_array(x):
63
65
return (isinstance (x , (np .ndarray , np .generic ))
64
66
and not _is_jax_zero_gradient_array (x ))
65
67
66
- def is_cupy_array (x ):
68
+
69
+ def is_cupy_array (x : object ) -> bool :
67
70
"""
68
71
Return True if `x` is a CuPy array.
69
72
@@ -93,7 +96,8 @@ def is_cupy_array(x):
93
96
# TODO: Should we reject ndarray subclasses?
94
97
return isinstance (x , cp .ndarray )
95
98
96
- def is_torch_array (x ):
99
+
100
+ def is_torch_array (x : object ) -> bool :
97
101
"""
98
102
Return True if `x` is a PyTorch tensor.
99
103
@@ -120,7 +124,8 @@ def is_torch_array(x):
120
124
# TODO: Should we reject ndarray subclasses?
121
125
return isinstance (x , torch .Tensor )
122
126
123
- def is_ndonnx_array (x ):
127
+
128
+ def is_ndonnx_array (x : object ) -> bool :
124
129
"""
125
130
Return True if `x` is a ndonnx Array.
126
131
@@ -147,7 +152,8 @@ def is_ndonnx_array(x):
147
152
148
153
return isinstance (x , ndx .Array )
149
154
150
- def is_dask_array (x ):
155
+
156
+ def is_dask_array (x : object ) -> bool :
151
157
"""
152
158
Return True if `x` is a dask.array Array.
153
159
@@ -174,7 +180,8 @@ def is_dask_array(x):
174
180
175
181
return isinstance (x , dask .array .Array )
176
182
177
- def is_jax_array (x ):
183
+
184
+ def is_jax_array (x : object ) -> bool :
178
185
"""
179
186
Return True if `x` is a JAX array.
180
187
@@ -202,6 +209,7 @@ def is_jax_array(x):
202
209
203
210
return isinstance (x , jax .Array ) or _is_jax_zero_gradient_array (x )
204
211
212
+
205
213
def is_pydata_sparse_array (x ) -> bool :
206
214
"""
207
215
Return True if `x` is an array from the `sparse` package.
@@ -231,7 +239,8 @@ def is_pydata_sparse_array(x) -> bool:
231
239
# TODO: Account for other backends.
232
240
return isinstance (x , sparse .SparseArray )
233
241
234
- def is_array_api_obj (x ):
242
+
243
+ def is_array_api_obj (x : object ) -> bool :
235
244
"""
236
245
Return True if `x` is an array API compatible array object.
237
246
@@ -254,11 +263,13 @@ def is_array_api_obj(x):
254
263
or is_pydata_sparse_array (x ) \
255
264
or hasattr (x , '__array_namespace__' )
256
265
257
- def _compat_module_name ():
266
+
267
+ def _compat_module_name () -> str :
258
268
assert __name__ .endswith ('.common._helpers' )
259
269
return __name__ .removesuffix ('.common._helpers' )
260
270
261
- def is_numpy_namespace (xp ) -> bool :
271
+
272
+ def is_numpy_namespace (xp : ModuleType ) -> bool :
262
273
"""
263
274
Returns True if `xp` is a NumPy namespace.
264
275
@@ -278,7 +289,8 @@ def is_numpy_namespace(xp) -> bool:
278
289
"""
279
290
return xp .__name__ in {'numpy' , _compat_module_name () + '.numpy' }
280
291
281
- def is_cupy_namespace (xp ) -> bool :
292
+
293
+ def is_cupy_namespace (xp : ModuleType ) -> bool :
282
294
"""
283
295
Returns True if `xp` is a CuPy namespace.
284
296
@@ -298,7 +310,8 @@ def is_cupy_namespace(xp) -> bool:
298
310
"""
299
311
return xp .__name__ in {'cupy' , _compat_module_name () + '.cupy' }
300
312
301
- def is_torch_namespace (xp ) -> bool :
313
+
314
+ def is_torch_namespace (xp : ModuleType ) -> bool :
302
315
"""
303
316
Returns True if `xp` is a PyTorch namespace.
304
317
@@ -319,7 +332,7 @@ def is_torch_namespace(xp) -> bool:
319
332
return xp .__name__ in {'torch' , _compat_module_name () + '.torch' }
320
333
321
334
322
- def is_ndonnx_namespace (xp ) :
335
+ def is_ndonnx_namespace (xp : ModuleType ) -> bool :
323
336
"""
324
337
Returns True if `xp` is an NDONNX namespace.
325
338
@@ -337,7 +350,8 @@ def is_ndonnx_namespace(xp):
337
350
"""
338
351
return xp .__name__ == 'ndonnx'
339
352
340
- def is_dask_namespace (xp ):
353
+
354
+ def is_dask_namespace (xp : ModuleType ) -> bool :
341
355
"""
342
356
Returns True if `xp` is a Dask namespace.
343
357
@@ -357,7 +371,8 @@ def is_dask_namespace(xp):
357
371
"""
358
372
return xp .__name__ in {'dask.array' , _compat_module_name () + '.dask.array' }
359
373
360
- def is_jax_namespace (xp ):
374
+
375
+ def is_jax_namespace (xp : ModuleType ) -> bool :
361
376
"""
362
377
Returns True if `xp` is a JAX namespace.
363
378
@@ -378,7 +393,8 @@ def is_jax_namespace(xp):
378
393
"""
379
394
return xp .__name__ in {'jax.numpy' , 'jax.experimental.array_api' }
380
395
381
- def is_pydata_sparse_namespace (xp ):
396
+
397
+ def is_pydata_sparse_namespace (xp : ModuleType ) -> bool :
382
398
"""
383
399
Returns True if `xp` is a pydata/sparse namespace.
384
400
@@ -396,7 +412,8 @@ def is_pydata_sparse_namespace(xp):
396
412
"""
397
413
return xp .__name__ == 'sparse'
398
414
399
- def is_array_api_strict_namespace (xp ):
415
+
416
+ def is_array_api_strict_namespace (xp : ModuleType ) -> bool :
400
417
"""
401
418
Returns True if `xp` is an array-api-strict namespace.
402
419
@@ -414,13 +431,15 @@ def is_array_api_strict_namespace(xp):
414
431
"""
415
432
return xp .__name__ == 'array_api_strict'
416
433
417
- def _check_api_version (api_version ):
434
+
435
+ def _check_api_version (api_version : str ) -> None :
418
436
if api_version in ['2021.12' , '2022.12' ]:
419
437
warnings .warn (f"The { api_version } version of the array API specification was requested but the returned namespace is actually version 2023.12" )
420
438
elif api_version is not None and api_version not in ['2021.12' , '2022.12' ,
421
439
'2023.12' ]:
422
440
raise ValueError ("Only the 2023.12 version of the array API specification is currently supported" )
423
441
442
+
424
443
def array_namespace (* xs , api_version = None , use_compat = None ):
425
444
"""
426
445
Get the array API compatible namespace for the arrays `xs`.
@@ -808,9 +827,10 @@ def size(x: Array) -> int | None:
808
827
return None if math .isnan (out ) else out
809
828
810
829
811
- def is_writeable_array (x ) -> bool :
830
+ def is_writeable_array (x : object ) -> bool :
812
831
"""
813
832
Return False if ``x.__setitem__`` is expected to raise; True otherwise.
833
+ Return False if `x` is not an array API compatible object.
814
834
815
835
Warning
816
836
-------
@@ -821,10 +841,10 @@ def is_writeable_array(x) -> bool:
821
841
return x .flags .writeable
822
842
if is_jax_array (x ) or is_pydata_sparse_array (x ):
823
843
return False
824
- return True
844
+ return is_array_api_obj ( x )
825
845
826
846
827
- def is_lazy_array (x ) -> bool :
847
+ def is_lazy_array (x : object ) -> bool :
828
848
"""Return True if x is potentially a future or it may be otherwise impossible or
829
849
expensive to eagerly read its contents, regardless of their size, e.g. by
830
850
calling ``bool(x)`` or ``float(x)``.
@@ -857,6 +877,9 @@ def is_lazy_array(x) -> bool:
857
877
if is_jax_array (x ) or is_dask_array (x ) or is_ndonnx_array (x ):
858
878
return True
859
879
880
+ if not is_array_api_obj (x ):
881
+ return False
882
+
860
883
# Unknown Array API compatible object. Note that this test may have dire consequences
861
884
# in terms of performance, e.g. for a lazy object that eagerly computes the graph
862
885
# on __bool__ (dask is one such example, which however is special-cased above).
0 commit comments