Skip to content

Commit 0c1b7dd

Browse files
committed
ENH: is_lazy_array and is_writeable_array to return False on non-arrays
1 parent 9442237 commit 0c1b7dd

File tree

2 files changed

+61
-18
lines changed

2 files changed

+61
-18
lines changed

array_api_compat/common/_helpers.py

Lines changed: 40 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import inspect
1919
import warnings
2020

21-
def _is_jax_zero_gradient_array(x):
21+
def _is_jax_zero_gradient_array(x: object) -> bool:
2222
"""Return True if `x` is a zero-gradient array.
2323
2424
These arrays are a design quirk of Jax that may one day be removed.
@@ -32,7 +32,8 @@ def _is_jax_zero_gradient_array(x):
3232

3333
return isinstance(x, np.ndarray) and x.dtype == jax.float0
3434

35-
def is_numpy_array(x):
35+
36+
def is_numpy_array(x: object) -> bool:
3637
"""
3738
Return True if `x` is a NumPy array.
3839
@@ -63,7 +64,8 @@ def is_numpy_array(x):
6364
return (isinstance(x, (np.ndarray, np.generic))
6465
and not _is_jax_zero_gradient_array(x))
6566

66-
def is_cupy_array(x):
67+
68+
def is_cupy_array(x: object) -> bool:
6769
"""
6870
Return True if `x` is a CuPy array.
6971
@@ -93,7 +95,8 @@ def is_cupy_array(x):
9395
# TODO: Should we reject ndarray subclasses?
9496
return isinstance(x, cp.ndarray)
9597

96-
def is_torch_array(x):
98+
99+
def is_torch_array(x: object) -> bool:
97100
"""
98101
Return True if `x` is a PyTorch tensor.
99102
@@ -120,7 +123,8 @@ def is_torch_array(x):
120123
# TODO: Should we reject ndarray subclasses?
121124
return isinstance(x, torch.Tensor)
122125

123-
def is_ndonnx_array(x):
126+
127+
def is_ndonnx_array(x: object) -> bool:
124128
"""
125129
Return True if `x` is a ndonnx Array.
126130
@@ -147,7 +151,8 @@ def is_ndonnx_array(x):
147151

148152
return isinstance(x, ndx.Array)
149153

150-
def is_dask_array(x):
154+
155+
def is_dask_array(x: object) -> bool:
151156
"""
152157
Return True if `x` is a dask.array Array.
153158
@@ -174,7 +179,8 @@ def is_dask_array(x):
174179

175180
return isinstance(x, dask.array.Array)
176181

177-
def is_jax_array(x):
182+
183+
def is_jax_array(x: object) -> bool:
178184
"""
179185
Return True if `x` is a JAX array.
180186
@@ -202,6 +208,7 @@ def is_jax_array(x):
202208

203209
return isinstance(x, jax.Array) or _is_jax_zero_gradient_array(x)
204210

211+
205212
def is_pydata_sparse_array(x) -> bool:
206213
"""
207214
Return True if `x` is an array from the `sparse` package.
@@ -231,7 +238,8 @@ def is_pydata_sparse_array(x) -> bool:
231238
# TODO: Account for other backends.
232239
return isinstance(x, sparse.SparseArray)
233240

234-
def is_array_api_obj(x):
241+
242+
def is_array_api_obj(x: object) -> bool:
235243
"""
236244
Return True if `x` is an array API compatible array object.
237245
@@ -254,10 +262,12 @@ def is_array_api_obj(x):
254262
or is_pydata_sparse_array(x) \
255263
or hasattr(x, '__array_namespace__')
256264

257-
def _compat_module_name():
265+
266+
def _compat_module_name() -> str:
258267
assert __name__.endswith('.common._helpers')
259268
return __name__.removesuffix('.common._helpers')
260269

270+
261271
def is_numpy_namespace(xp) -> bool:
262272
"""
263273
Returns True if `xp` is a NumPy namespace.
@@ -278,6 +288,7 @@ def is_numpy_namespace(xp) -> bool:
278288
"""
279289
return xp.__name__ in {'numpy', _compat_module_name() + '.numpy'}
280290

291+
281292
def is_cupy_namespace(xp) -> bool:
282293
"""
283294
Returns True if `xp` is a CuPy namespace.
@@ -298,6 +309,7 @@ def is_cupy_namespace(xp) -> bool:
298309
"""
299310
return xp.__name__ in {'cupy', _compat_module_name() + '.cupy'}
300311

312+
301313
def is_torch_namespace(xp) -> bool:
302314
"""
303315
Returns True if `xp` is a PyTorch namespace.
@@ -319,7 +331,7 @@ def is_torch_namespace(xp) -> bool:
319331
return xp.__name__ in {'torch', _compat_module_name() + '.torch'}
320332

321333

322-
def is_ndonnx_namespace(xp):
334+
def is_ndonnx_namespace(xp) -> bool:
323335
"""
324336
Returns True if `xp` is an NDONNX namespace.
325337
@@ -337,7 +349,8 @@ def is_ndonnx_namespace(xp):
337349
"""
338350
return xp.__name__ == 'ndonnx'
339351

340-
def is_dask_namespace(xp):
352+
353+
def is_dask_namespace(xp) -> bool:
341354
"""
342355
Returns True if `xp` is a Dask namespace.
343356
@@ -357,7 +370,8 @@ def is_dask_namespace(xp):
357370
"""
358371
return xp.__name__ in {'dask.array', _compat_module_name() + '.dask.array'}
359372

360-
def is_jax_namespace(xp):
373+
374+
def is_jax_namespace(xp) -> bool:
361375
"""
362376
Returns True if `xp` is a JAX namespace.
363377
@@ -378,7 +392,8 @@ def is_jax_namespace(xp):
378392
"""
379393
return xp.__name__ in {'jax.numpy', 'jax.experimental.array_api'}
380394

381-
def is_pydata_sparse_namespace(xp):
395+
396+
def is_pydata_sparse_namespace(xp) -> bool:
382397
"""
383398
Returns True if `xp` is a pydata/sparse namespace.
384399
@@ -396,7 +411,8 @@ def is_pydata_sparse_namespace(xp):
396411
"""
397412
return xp.__name__ == 'sparse'
398413

399-
def is_array_api_strict_namespace(xp):
414+
415+
def is_array_api_strict_namespace(xp) -> bool:
400416
"""
401417
Returns True if `xp` is an array-api-strict namespace.
402418
@@ -414,13 +430,15 @@ def is_array_api_strict_namespace(xp):
414430
"""
415431
return xp.__name__ == 'array_api_strict'
416432

417-
def _check_api_version(api_version):
433+
434+
def _check_api_version(api_version: str) -> None:
418435
if api_version in ['2021.12', '2022.12']:
419436
warnings.warn(f"The {api_version} version of the array API specification was requested but the returned namespace is actually version 2023.12")
420437
elif api_version is not None and api_version not in ['2021.12', '2022.12',
421438
'2023.12']:
422439
raise ValueError("Only the 2023.12 version of the array API specification is currently supported")
423440

441+
424442
def array_namespace(*xs, api_version=None, use_compat=None):
425443
"""
426444
Get the array API compatible namespace for the arrays `xs`.
@@ -808,9 +826,10 @@ def size(x: Array) -> int | None:
808826
return None if math.isnan(out) else out
809827

810828

811-
def is_writeable_array(x) -> bool:
829+
def is_writeable_array(x: object) -> bool:
812830
"""
813831
Return False if ``x.__setitem__`` is expected to raise; True otherwise.
832+
Return False if `x` is not an array API compatible object.
814833
815834
Warning
816835
-------
@@ -821,10 +840,10 @@ def is_writeable_array(x) -> bool:
821840
return x.flags.writeable
822841
if is_jax_array(x) or is_pydata_sparse_array(x):
823842
return False
824-
return True
843+
return is_array_api_obj(x)
825844

826845

827-
def is_lazy_array(x) -> bool:
846+
def is_lazy_array(x: object) -> bool:
828847
"""Return True if x is potentially a future or it may be otherwise impossible or
829848
expensive to eagerly read its contents, regardless of their size, e.g. by
830849
calling ``bool(x)`` or ``float(x)``.
@@ -857,6 +876,9 @@ def is_lazy_array(x) -> bool:
857876
if is_jax_array(x) or is_dask_array(x) or is_ndonnx_array(x):
858877
return True
859878

879+
if not is_array_api_obj(x):
880+
return False
881+
860882
# Unknown Array API compatible object. Note that this test may have dire consequences
861883
# in terms of performance, e.g. for a lazy object that eagerly computes the graph
862884
# on __bool__ (dask is one such example, which however is special-cased above).

tests/test_common.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,27 @@ def __bool__(self):
156156
assert is_lazy_array(x)
157157

158158

159+
@pytest.mark.parametrize(
160+
'func',
161+
list(is_array_functions.values())
162+
+ ["is_array_api_obj", "is_lazy_array", "is_writeable_array"]
163+
)
164+
def test_is_array_any_object(func):
165+
"""Test that is_*_array functions return False and don't raise on non-array objects
166+
"""
167+
func = globals()[func]
168+
169+
# These objects are missing attributes such as __name__
170+
assert not func(object())
171+
assert not func(None)
172+
assert not func(1)
173+
174+
class C:
175+
pass
176+
177+
assert not func(C())
178+
179+
159180
@pytest.mark.parametrize("library", all_libraries)
160181
def test_device(library):
161182
xp = import_(library, wrapper=True)

0 commit comments

Comments
 (0)