Skip to content

ENH: is_lazy_array and is_writeable_array to return False on non-arrays #237

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 40 additions & 18 deletions array_api_compat/common/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import inspect
import warnings

def _is_jax_zero_gradient_array(x):
def _is_jax_zero_gradient_array(x: object) -> bool:
"""Return True if `x` is a zero-gradient array.

These arrays are a design quirk of Jax that may one day be removed.
Expand All @@ -32,7 +32,8 @@ def _is_jax_zero_gradient_array(x):

return isinstance(x, np.ndarray) and x.dtype == jax.float0

def is_numpy_array(x):

def is_numpy_array(x: object) -> bool:
"""
Return True if `x` is a NumPy array.

Expand Down Expand Up @@ -63,7 +64,8 @@ def is_numpy_array(x):
return (isinstance(x, (np.ndarray, np.generic))
and not _is_jax_zero_gradient_array(x))

def is_cupy_array(x):

def is_cupy_array(x: object) -> bool:
"""
Return True if `x` is a CuPy array.

Expand Down Expand Up @@ -93,7 +95,8 @@ def is_cupy_array(x):
# TODO: Should we reject ndarray subclasses?
return isinstance(x, cp.ndarray)

def is_torch_array(x):

def is_torch_array(x: object) -> bool:
"""
Return True if `x` is a PyTorch tensor.

Expand All @@ -120,7 +123,8 @@ def is_torch_array(x):
# TODO: Should we reject ndarray subclasses?
return isinstance(x, torch.Tensor)

def is_ndonnx_array(x):

def is_ndonnx_array(x: object) -> bool:
"""
Return True if `x` is a ndonnx Array.

Expand All @@ -147,7 +151,8 @@ def is_ndonnx_array(x):

return isinstance(x, ndx.Array)

def is_dask_array(x):

def is_dask_array(x: object) -> bool:
"""
Return True if `x` is a dask.array Array.

Expand All @@ -174,7 +179,8 @@ def is_dask_array(x):

return isinstance(x, dask.array.Array)

def is_jax_array(x):

def is_jax_array(x: object) -> bool:
"""
Return True if `x` is a JAX array.

Expand Down Expand Up @@ -202,6 +208,7 @@ def is_jax_array(x):

return isinstance(x, jax.Array) or _is_jax_zero_gradient_array(x)


def is_pydata_sparse_array(x) -> bool:
"""
Return True if `x` is an array from the `sparse` package.
Expand Down Expand Up @@ -231,7 +238,8 @@ def is_pydata_sparse_array(x) -> bool:
# TODO: Account for other backends.
return isinstance(x, sparse.SparseArray)

def is_array_api_obj(x):

def is_array_api_obj(x: object) -> bool:
"""
Return True if `x` is an array API compatible array object.

Expand All @@ -254,10 +262,12 @@ def is_array_api_obj(x):
or is_pydata_sparse_array(x) \
or hasattr(x, '__array_namespace__')

def _compat_module_name():

def _compat_module_name() -> str:
assert __name__.endswith('.common._helpers')
return __name__.removesuffix('.common._helpers')


def is_numpy_namespace(xp) -> bool:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sphinx gave me grief over ModuleType, so I left it out of scope for this PR.

"""
Returns True if `xp` is a NumPy namespace.
Expand All @@ -278,6 +288,7 @@ def is_numpy_namespace(xp) -> bool:
"""
return xp.__name__ in {'numpy', _compat_module_name() + '.numpy'}


def is_cupy_namespace(xp) -> bool:
"""
Returns True if `xp` is a CuPy namespace.
Expand All @@ -298,6 +309,7 @@ def is_cupy_namespace(xp) -> bool:
"""
return xp.__name__ in {'cupy', _compat_module_name() + '.cupy'}


def is_torch_namespace(xp) -> bool:
"""
Returns True if `xp` is a PyTorch namespace.
Expand All @@ -319,7 +331,7 @@ def is_torch_namespace(xp) -> bool:
return xp.__name__ in {'torch', _compat_module_name() + '.torch'}


def is_ndonnx_namespace(xp):
def is_ndonnx_namespace(xp) -> bool:
"""
Returns True if `xp` is an NDONNX namespace.

Expand All @@ -337,7 +349,8 @@ def is_ndonnx_namespace(xp):
"""
return xp.__name__ == 'ndonnx'

def is_dask_namespace(xp):

def is_dask_namespace(xp) -> bool:
"""
Returns True if `xp` is a Dask namespace.

Expand All @@ -357,7 +370,8 @@ def is_dask_namespace(xp):
"""
return xp.__name__ in {'dask.array', _compat_module_name() + '.dask.array'}

def is_jax_namespace(xp):

def is_jax_namespace(xp) -> bool:
"""
Returns True if `xp` is a JAX namespace.

Expand All @@ -378,7 +392,8 @@ def is_jax_namespace(xp):
"""
return xp.__name__ in {'jax.numpy', 'jax.experimental.array_api'}

def is_pydata_sparse_namespace(xp):

def is_pydata_sparse_namespace(xp) -> bool:
"""
Returns True if `xp` is a pydata/sparse namespace.

Expand All @@ -396,7 +411,8 @@ def is_pydata_sparse_namespace(xp):
"""
return xp.__name__ == 'sparse'

def is_array_api_strict_namespace(xp):

def is_array_api_strict_namespace(xp) -> bool:
"""
Returns True if `xp` is an array-api-strict namespace.

Expand All @@ -414,13 +430,15 @@ def is_array_api_strict_namespace(xp):
"""
return xp.__name__ == 'array_api_strict'

def _check_api_version(api_version):

def _check_api_version(api_version: str) -> None:
if api_version in ['2021.12', '2022.12']:
warnings.warn(f"The {api_version} version of the array API specification was requested but the returned namespace is actually version 2023.12")
elif api_version is not None and api_version not in ['2021.12', '2022.12',
'2023.12']:
raise ValueError("Only the 2023.12 version of the array API specification is currently supported")


def array_namespace(*xs, api_version=None, use_compat=None):
"""
Get the array API compatible namespace for the arrays `xs`.
Expand Down Expand Up @@ -808,9 +826,10 @@ def size(x: Array) -> int | None:
return None if math.isnan(out) else out


def is_writeable_array(x) -> bool:
def is_writeable_array(x: object) -> bool:
"""
Return False if ``x.__setitem__`` is expected to raise; True otherwise.
Return False if `x` is not an array API compatible object.

Warning
-------
Expand All @@ -821,10 +840,10 @@ def is_writeable_array(x) -> bool:
return x.flags.writeable
if is_jax_array(x) or is_pydata_sparse_array(x):
return False
return True
return is_array_api_obj(x)


def is_lazy_array(x) -> bool:
def is_lazy_array(x: object) -> bool:
"""Return True if x is potentially a future or it may be otherwise impossible or
expensive to eagerly read its contents, regardless of their size, e.g. by
calling ``bool(x)`` or ``float(x)``.
Expand Down Expand Up @@ -857,6 +876,9 @@ def is_lazy_array(x) -> bool:
if is_jax_array(x) or is_dask_array(x) or is_ndonnx_array(x):
return True

if not is_array_api_obj(x):
return False

# Unknown Array API compatible object. Note that this test may have dire consequences
# in terms of performance, e.g. for a lazy object that eagerly computes the graph
# on __bool__ (dask is one such example, which however is special-cased above).
Expand Down
21 changes: 21 additions & 0 deletions tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,27 @@ def __bool__(self):
assert is_lazy_array(x)


@pytest.mark.parametrize(
'func',
list(is_array_functions.values())
+ ["is_array_api_obj", "is_lazy_array", "is_writeable_array"]
)
def test_is_array_any_object(func):
"""Test that is_*_array functions return False and don't raise on non-array objects
"""
func = globals()[func]

# These objects are missing attributes such as __name__
assert not func(object())
assert not func(None)
assert not func(1)

class C:
pass

assert not func(C())


@pytest.mark.parametrize("library", all_libraries)
def test_device(library):
xp = import_(library, wrapper=True)
Expand Down
Loading