18
18
import inspect
19
19
import warnings
20
20
21
- def _is_jax_zero_gradient_array (x ) :
21
+ def _is_jax_zero_gradient_array (x : object ) -> bool :
22
22
"""Return True if `x` is a zero-gradient array.
23
23
24
24
These arrays are a design quirk of Jax that may one day be removed.
@@ -32,7 +32,8 @@ def _is_jax_zero_gradient_array(x):
32
32
33
33
return isinstance (x , np .ndarray ) and x .dtype == jax .float0
34
34
35
- def is_numpy_array (x ):
35
+
36
+ def is_numpy_array (x : object ) -> bool :
36
37
"""
37
38
Return True if `x` is a NumPy array.
38
39
@@ -63,7 +64,8 @@ def is_numpy_array(x):
63
64
return (isinstance (x , (np .ndarray , np .generic ))
64
65
and not _is_jax_zero_gradient_array (x ))
65
66
66
- def is_cupy_array (x ):
67
+
68
+ def is_cupy_array (x : object ) -> bool :
67
69
"""
68
70
Return True if `x` is a CuPy array.
69
71
@@ -93,7 +95,8 @@ def is_cupy_array(x):
93
95
# TODO: Should we reject ndarray subclasses?
94
96
return isinstance (x , cp .ndarray )
95
97
96
- def is_torch_array (x ):
98
+
99
+ def is_torch_array (x : object ) -> bool :
97
100
"""
98
101
Return True if `x` is a PyTorch tensor.
99
102
@@ -120,7 +123,8 @@ def is_torch_array(x):
120
123
# TODO: Should we reject ndarray subclasses?
121
124
return isinstance (x , torch .Tensor )
122
125
123
- def is_ndonnx_array (x ):
126
+
127
+ def is_ndonnx_array (x : object ) -> bool :
124
128
"""
125
129
Return True if `x` is a ndonnx Array.
126
130
@@ -147,7 +151,8 @@ def is_ndonnx_array(x):
147
151
148
152
return isinstance (x , ndx .Array )
149
153
150
- def is_dask_array (x ):
154
+
155
+ def is_dask_array (x : object ) -> bool :
151
156
"""
152
157
Return True if `x` is a dask.array Array.
153
158
@@ -174,7 +179,8 @@ def is_dask_array(x):
174
179
175
180
return isinstance (x , dask .array .Array )
176
181
177
- def is_jax_array (x ):
182
+
183
+ def is_jax_array (x : object ) -> bool :
178
184
"""
179
185
Return True if `x` is a JAX array.
180
186
@@ -202,6 +208,7 @@ def is_jax_array(x):
202
208
203
209
return isinstance (x , jax .Array ) or _is_jax_zero_gradient_array (x )
204
210
211
+
205
212
def is_pydata_sparse_array (x ) -> bool :
206
213
"""
207
214
Return True if `x` is an array from the `sparse` package.
@@ -231,7 +238,8 @@ def is_pydata_sparse_array(x) -> bool:
231
238
# TODO: Account for other backends.
232
239
return isinstance (x , sparse .SparseArray )
233
240
234
- def is_array_api_obj (x ):
241
+
242
+ def is_array_api_obj (x : object ) -> bool :
235
243
"""
236
244
Return True if `x` is an array API compatible array object.
237
245
@@ -254,10 +262,12 @@ def is_array_api_obj(x):
254
262
or is_pydata_sparse_array (x ) \
255
263
or hasattr (x , '__array_namespace__' )
256
264
257
- def _compat_module_name ():
265
+
266
+ def _compat_module_name () -> str :
258
267
assert __name__ .endswith ('.common._helpers' )
259
268
return __name__ .removesuffix ('.common._helpers' )
260
269
270
+
261
271
def is_numpy_namespace (xp ) -> bool :
262
272
"""
263
273
Returns True if `xp` is a NumPy namespace.
@@ -278,6 +288,7 @@ def is_numpy_namespace(xp) -> bool:
278
288
"""
279
289
return xp .__name__ in {'numpy' , _compat_module_name () + '.numpy' }
280
290
291
+
281
292
def is_cupy_namespace (xp ) -> bool :
282
293
"""
283
294
Returns True if `xp` is a CuPy namespace.
@@ -298,6 +309,7 @@ def is_cupy_namespace(xp) -> bool:
298
309
"""
299
310
return xp .__name__ in {'cupy' , _compat_module_name () + '.cupy' }
300
311
312
+
301
313
def is_torch_namespace (xp ) -> bool :
302
314
"""
303
315
Returns True if `xp` is a PyTorch namespace.
@@ -319,7 +331,7 @@ def is_torch_namespace(xp) -> bool:
319
331
return xp .__name__ in {'torch' , _compat_module_name () + '.torch' }
320
332
321
333
322
- def is_ndonnx_namespace (xp ):
334
+ def is_ndonnx_namespace (xp ) -> bool :
323
335
"""
324
336
Returns True if `xp` is an NDONNX namespace.
325
337
@@ -337,7 +349,8 @@ def is_ndonnx_namespace(xp):
337
349
"""
338
350
return xp .__name__ == 'ndonnx'
339
351
340
- def is_dask_namespace (xp ):
352
+
353
+ def is_dask_namespace (xp ) -> bool :
341
354
"""
342
355
Returns True if `xp` is a Dask namespace.
343
356
@@ -357,7 +370,8 @@ def is_dask_namespace(xp):
357
370
"""
358
371
return xp .__name__ in {'dask.array' , _compat_module_name () + '.dask.array' }
359
372
360
- def is_jax_namespace (xp ):
373
+
374
+ def is_jax_namespace (xp ) -> bool :
361
375
"""
362
376
Returns True if `xp` is a JAX namespace.
363
377
@@ -378,7 +392,8 @@ def is_jax_namespace(xp):
378
392
"""
379
393
return xp .__name__ in {'jax.numpy' , 'jax.experimental.array_api' }
380
394
381
- def is_pydata_sparse_namespace (xp ):
395
+
396
+ def is_pydata_sparse_namespace (xp ) -> bool :
382
397
"""
383
398
Returns True if `xp` is a pydata/sparse namespace.
384
399
@@ -396,7 +411,8 @@ def is_pydata_sparse_namespace(xp):
396
411
"""
397
412
return xp .__name__ == 'sparse'
398
413
399
- def is_array_api_strict_namespace (xp ):
414
+
415
+ def is_array_api_strict_namespace (xp ) -> bool :
400
416
"""
401
417
Returns True if `xp` is an array-api-strict namespace.
402
418
@@ -414,13 +430,15 @@ def is_array_api_strict_namespace(xp):
414
430
"""
415
431
return xp .__name__ == 'array_api_strict'
416
432
417
- def _check_api_version (api_version ):
433
+
434
+ def _check_api_version (api_version : str ) -> None :
418
435
if api_version in ['2021.12' , '2022.12' ]:
419
436
warnings .warn (f"The { api_version } version of the array API specification was requested but the returned namespace is actually version 2023.12" )
420
437
elif api_version is not None and api_version not in ['2021.12' , '2022.12' ,
421
438
'2023.12' ]:
422
439
raise ValueError ("Only the 2023.12 version of the array API specification is currently supported" )
423
440
441
+
424
442
def array_namespace (* xs , api_version = None , use_compat = None ):
425
443
"""
426
444
Get the array API compatible namespace for the arrays `xs`.
@@ -808,9 +826,10 @@ def size(x: Array) -> int | None:
808
826
return None if math .isnan (out ) else out
809
827
810
828
811
- def is_writeable_array (x ) -> bool :
829
+ def is_writeable_array (x : object ) -> bool :
812
830
"""
813
831
Return False if ``x.__setitem__`` is expected to raise; True otherwise.
832
+ Return False if `x` is not an array API compatible object.
814
833
815
834
Warning
816
835
-------
@@ -821,10 +840,10 @@ def is_writeable_array(x) -> bool:
821
840
return x .flags .writeable
822
841
if is_jax_array (x ) or is_pydata_sparse_array (x ):
823
842
return False
824
- return True
843
+ return is_array_api_obj ( x )
825
844
826
845
827
- def is_lazy_array (x ) -> bool :
846
+ def is_lazy_array (x : object ) -> bool :
828
847
"""Return True if x is potentially a future or it may be otherwise impossible or
829
848
expensive to eagerly read its contents, regardless of their size, e.g. by
830
849
calling ``bool(x)`` or ``float(x)``.
@@ -857,6 +876,9 @@ def is_lazy_array(x) -> bool:
857
876
if is_jax_array (x ) or is_dask_array (x ) or is_ndonnx_array (x ):
858
877
return True
859
878
879
+ if not is_array_api_obj (x ):
880
+ return False
881
+
860
882
# Unknown Array API compatible object. Note that this test may have dire consequences
861
883
# in terms of performance, e.g. for a lazy object that eagerly computes the graph
862
884
# on __bool__ (dask is one such example, which however is special-cased above).
0 commit comments