Skip to content

Commit b2e692e

Browse files
committed
ENH: is_lazy_array and is_writeable_array to return False on non-arrays
1 parent 9442237 commit b2e692e

File tree

2 files changed

+65
-21
lines changed

2 files changed

+65
-21
lines changed

array_api_compat/common/_helpers.py

Lines changed: 44 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from typing import TYPE_CHECKING
1111

1212
if TYPE_CHECKING:
13+
from types import ModuleType
1314
from typing import Optional, Union, Any
1415
from ._typing import Array, Device
1516

@@ -18,7 +19,7 @@
1819
import inspect
1920
import warnings
2021

21-
def _is_jax_zero_gradient_array(x):
22+
def _is_jax_zero_gradient_array(x: object) -> bool:
2223
"""Return True if `x` is a zero-gradient array.
2324
2425
These arrays are a design quirk of Jax that may one day be removed.
@@ -32,7 +33,8 @@ def _is_jax_zero_gradient_array(x):
3233

3334
return isinstance(x, np.ndarray) and x.dtype == jax.float0
3435

35-
def is_numpy_array(x):
36+
37+
def is_numpy_array(x: object) -> bool:
3638
"""
3739
Return True if `x` is a NumPy array.
3840
@@ -63,7 +65,8 @@ def is_numpy_array(x):
6365
return (isinstance(x, (np.ndarray, np.generic))
6466
and not _is_jax_zero_gradient_array(x))
6567

66-
def is_cupy_array(x):
68+
69+
def is_cupy_array(x: object) -> bool:
6770
"""
6871
Return True if `x` is a CuPy array.
6972
@@ -93,7 +96,8 @@ def is_cupy_array(x):
9396
# TODO: Should we reject ndarray subclasses?
9497
return isinstance(x, cp.ndarray)
9598

96-
def is_torch_array(x):
99+
100+
def is_torch_array(x: object) -> bool:
97101
"""
98102
Return True if `x` is a PyTorch tensor.
99103
@@ -120,7 +124,8 @@ def is_torch_array(x):
120124
# TODO: Should we reject ndarray subclasses?
121125
return isinstance(x, torch.Tensor)
122126

123-
def is_ndonnx_array(x):
127+
128+
def is_ndonnx_array(x: object) -> bool:
124129
"""
125130
Return True if `x` is a ndonnx Array.
126131
@@ -147,7 +152,8 @@ def is_ndonnx_array(x):
147152

148153
return isinstance(x, ndx.Array)
149154

150-
def is_dask_array(x):
155+
156+
def is_dask_array(x: object) -> bool:
151157
"""
152158
Return True if `x` is a dask.array Array.
153159
@@ -174,7 +180,8 @@ def is_dask_array(x):
174180

175181
return isinstance(x, dask.array.Array)
176182

177-
def is_jax_array(x):
183+
184+
def is_jax_array(x: object) -> bool:
178185
"""
179186
Return True if `x` is a JAX array.
180187
@@ -202,6 +209,7 @@ def is_jax_array(x):
202209

203210
return isinstance(x, jax.Array) or _is_jax_zero_gradient_array(x)
204211

212+
205213
def is_pydata_sparse_array(x) -> bool:
206214
"""
207215
Return True if `x` is an array from the `sparse` package.
@@ -231,7 +239,8 @@ def is_pydata_sparse_array(x) -> bool:
231239
# TODO: Account for other backends.
232240
return isinstance(x, sparse.SparseArray)
233241

234-
def is_array_api_obj(x):
242+
243+
def is_array_api_obj(x: object) -> bool:
235244
"""
236245
Return True if `x` is an array API compatible array object.
237246
@@ -254,11 +263,13 @@ def is_array_api_obj(x):
254263
or is_pydata_sparse_array(x) \
255264
or hasattr(x, '__array_namespace__')
256265

257-
def _compat_module_name():
266+
267+
def _compat_module_name() -> str:
258268
assert __name__.endswith('.common._helpers')
259269
return __name__.removesuffix('.common._helpers')
260270

261-
def is_numpy_namespace(xp) -> bool:
271+
272+
def is_numpy_namespace(xp: ModuleType) -> bool:
262273
"""
263274
Returns True if `xp` is a NumPy namespace.
264275
@@ -278,7 +289,8 @@ def is_numpy_namespace(xp) -> bool:
278289
"""
279290
return xp.__name__ in {'numpy', _compat_module_name() + '.numpy'}
280291

281-
def is_cupy_namespace(xp) -> bool:
292+
293+
def is_cupy_namespace(xp: ModuleType) -> bool:
282294
"""
283295
Returns True if `xp` is a CuPy namespace.
284296
@@ -298,7 +310,8 @@ def is_cupy_namespace(xp) -> bool:
298310
"""
299311
return xp.__name__ in {'cupy', _compat_module_name() + '.cupy'}
300312

301-
def is_torch_namespace(xp) -> bool:
313+
314+
def is_torch_namespace(xp: ModuleType) -> bool:
302315
"""
303316
Returns True if `xp` is a PyTorch namespace.
304317
@@ -319,7 +332,7 @@ def is_torch_namespace(xp) -> bool:
319332
return xp.__name__ in {'torch', _compat_module_name() + '.torch'}
320333

321334

322-
def is_ndonnx_namespace(xp):
335+
def is_ndonnx_namespace(xp: ModuleType) -> bool:
323336
"""
324337
Returns True if `xp` is an NDONNX namespace.
325338
@@ -337,7 +350,8 @@ def is_ndonnx_namespace(xp):
337350
"""
338351
return xp.__name__ == 'ndonnx'
339352

340-
def is_dask_namespace(xp):
353+
354+
def is_dask_namespace(xp: ModuleType) -> bool:
341355
"""
342356
Returns True if `xp` is a Dask namespace.
343357
@@ -357,7 +371,8 @@ def is_dask_namespace(xp):
357371
"""
358372
return xp.__name__ in {'dask.array', _compat_module_name() + '.dask.array'}
359373

360-
def is_jax_namespace(xp):
374+
375+
def is_jax_namespace(xp: ModuleType) -> bool:
361376
"""
362377
Returns True if `xp` is a JAX namespace.
363378
@@ -378,7 +393,8 @@ def is_jax_namespace(xp):
378393
"""
379394
return xp.__name__ in {'jax.numpy', 'jax.experimental.array_api'}
380395

381-
def is_pydata_sparse_namespace(xp):
396+
397+
def is_pydata_sparse_namespace(xp: ModuleType) -> bool:
382398
"""
383399
Returns True if `xp` is a pydata/sparse namespace.
384400
@@ -396,7 +412,8 @@ def is_pydata_sparse_namespace(xp):
396412
"""
397413
return xp.__name__ == 'sparse'
398414

399-
def is_array_api_strict_namespace(xp):
415+
416+
def is_array_api_strict_namespace(xp: ModuleType) -> bool:
400417
"""
401418
Returns True if `xp` is an array-api-strict namespace.
402419
@@ -414,13 +431,15 @@ def is_array_api_strict_namespace(xp):
414431
"""
415432
return xp.__name__ == 'array_api_strict'
416433

417-
def _check_api_version(api_version):
434+
435+
def _check_api_version(api_version: str) -> None:
418436
if api_version in ['2021.12', '2022.12']:
419437
warnings.warn(f"The {api_version} version of the array API specification was requested but the returned namespace is actually version 2023.12")
420438
elif api_version is not None and api_version not in ['2021.12', '2022.12',
421439
'2023.12']:
422440
raise ValueError("Only the 2023.12 version of the array API specification is currently supported")
423441

442+
424443
def array_namespace(*xs, api_version=None, use_compat=None):
425444
"""
426445
Get the array API compatible namespace for the arrays `xs`.
@@ -808,9 +827,10 @@ def size(x: Array) -> int | None:
808827
return None if math.isnan(out) else out
809828

810829

811-
def is_writeable_array(x) -> bool:
830+
def is_writeable_array(x: object) -> bool:
812831
"""
813832
Return False if ``x.__setitem__`` is expected to raise; True otherwise.
833+
Return False if `x` is not an array API compatible object.
814834
815835
Warning
816836
-------
@@ -821,10 +841,10 @@ def is_writeable_array(x) -> bool:
821841
return x.flags.writeable
822842
if is_jax_array(x) or is_pydata_sparse_array(x):
823843
return False
824-
return True
844+
return is_array_api_obj(x)
825845

826846

827-
def is_lazy_array(x) -> bool:
847+
def is_lazy_array(x: object) -> bool:
828848
"""Return True if x is potentially a future or it may be otherwise impossible or
829849
expensive to eagerly read its contents, regardless of their size, e.g. by
830850
calling ``bool(x)`` or ``float(x)``.
@@ -857,6 +877,9 @@ def is_lazy_array(x) -> bool:
857877
if is_jax_array(x) or is_dask_array(x) or is_ndonnx_array(x):
858878
return True
859879

880+
if not is_array_api_obj(x):
881+
return False
882+
860883
# Unknown Array API compatible object. Note that this test may have dire consequences
861884
# in terms of performance, e.g. for a lazy object that eagerly computes the graph
862885
# on __bool__ (dask is one such example, which however is special-cased above).

tests/test_common.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,27 @@ def __bool__(self):
156156
assert is_lazy_array(x)
157157

158158

159+
@pytest.mark.parametrize(
160+
'func',
161+
list(is_array_functions.values())
162+
+ ["is_array_api_obj", "is_lazy_array", "is_writeable_array"]
163+
)
164+
def test_is_array_any_object(func):
165+
"""Test that is_*_array functions return False and don't raise on non-array objects
166+
"""
167+
func = globals()[func]
168+
169+
# These objects are missing attributes such as __name__
170+
assert not func(object())
171+
assert not func(None)
172+
assert not func(1)
173+
174+
class C:
175+
pass
176+
177+
assert not func(C())
178+
179+
159180
@pytest.mark.parametrize("library", all_libraries)
160181
def test_device(library):
161182
xp = import_(library, wrapper=True)

0 commit comments

Comments
 (0)