diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index 9762d0a0..f9f39230 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -18,7 +18,7 @@ import inspect import warnings -def _is_jax_zero_gradient_array(x): +def _is_jax_zero_gradient_array(x: object) -> bool: """Return True if `x` is a zero-gradient array. These arrays are a design quirk of Jax that may one day be removed. @@ -32,7 +32,8 @@ def _is_jax_zero_gradient_array(x): return isinstance(x, np.ndarray) and x.dtype == jax.float0 -def is_numpy_array(x): + +def is_numpy_array(x: object) -> bool: """ Return True if `x` is a NumPy array. @@ -63,7 +64,8 @@ def is_numpy_array(x): return (isinstance(x, (np.ndarray, np.generic)) and not _is_jax_zero_gradient_array(x)) -def is_cupy_array(x): + +def is_cupy_array(x: object) -> bool: """ Return True if `x` is a CuPy array. @@ -93,7 +95,8 @@ def is_cupy_array(x): # TODO: Should we reject ndarray subclasses? return isinstance(x, cp.ndarray) -def is_torch_array(x): + +def is_torch_array(x: object) -> bool: """ Return True if `x` is a PyTorch tensor. @@ -120,7 +123,8 @@ def is_torch_array(x): # TODO: Should we reject ndarray subclasses? return isinstance(x, torch.Tensor) -def is_ndonnx_array(x): + +def is_ndonnx_array(x: object) -> bool: """ Return True if `x` is a ndonnx Array. @@ -147,7 +151,8 @@ def is_ndonnx_array(x): return isinstance(x, ndx.Array) -def is_dask_array(x): + +def is_dask_array(x: object) -> bool: """ Return True if `x` is a dask.array Array. @@ -174,7 +179,8 @@ def is_dask_array(x): return isinstance(x, dask.array.Array) -def is_jax_array(x): + +def is_jax_array(x: object) -> bool: """ Return True if `x` is a JAX array. @@ -202,6 +208,7 @@ def is_jax_array(x): return isinstance(x, jax.Array) or _is_jax_zero_gradient_array(x) + def is_pydata_sparse_array(x) -> bool: """ Return True if `x` is an array from the `sparse` package. @@ -231,7 +238,8 @@ def is_pydata_sparse_array(x) -> bool: # TODO: Account for other backends. return isinstance(x, sparse.SparseArray) -def is_array_api_obj(x): + +def is_array_api_obj(x: object) -> bool: """ Return True if `x` is an array API compatible array object. @@ -254,10 +262,12 @@ def is_array_api_obj(x): or is_pydata_sparse_array(x) \ or hasattr(x, '__array_namespace__') -def _compat_module_name(): + +def _compat_module_name() -> str: assert __name__.endswith('.common._helpers') return __name__.removesuffix('.common._helpers') + def is_numpy_namespace(xp) -> bool: """ Returns True if `xp` is a NumPy namespace. @@ -278,6 +288,7 @@ def is_numpy_namespace(xp) -> bool: """ return xp.__name__ in {'numpy', _compat_module_name() + '.numpy'} + def is_cupy_namespace(xp) -> bool: """ Returns True if `xp` is a CuPy namespace. @@ -298,6 +309,7 @@ def is_cupy_namespace(xp) -> bool: """ return xp.__name__ in {'cupy', _compat_module_name() + '.cupy'} + def is_torch_namespace(xp) -> bool: """ Returns True if `xp` is a PyTorch namespace. @@ -319,7 +331,7 @@ def is_torch_namespace(xp) -> bool: return xp.__name__ in {'torch', _compat_module_name() + '.torch'} -def is_ndonnx_namespace(xp): +def is_ndonnx_namespace(xp) -> bool: """ Returns True if `xp` is an NDONNX namespace. @@ -337,7 +349,8 @@ def is_ndonnx_namespace(xp): """ return xp.__name__ == 'ndonnx' -def is_dask_namespace(xp): + +def is_dask_namespace(xp) -> bool: """ Returns True if `xp` is a Dask namespace. @@ -357,7 +370,8 @@ def is_dask_namespace(xp): """ return xp.__name__ in {'dask.array', _compat_module_name() + '.dask.array'} -def is_jax_namespace(xp): + +def is_jax_namespace(xp) -> bool: """ Returns True if `xp` is a JAX namespace. @@ -378,7 +392,8 @@ def is_jax_namespace(xp): """ return xp.__name__ in {'jax.numpy', 'jax.experimental.array_api'} -def is_pydata_sparse_namespace(xp): + +def is_pydata_sparse_namespace(xp) -> bool: """ Returns True if `xp` is a pydata/sparse namespace. @@ -396,7 +411,8 @@ def is_pydata_sparse_namespace(xp): """ return xp.__name__ == 'sparse' -def is_array_api_strict_namespace(xp): + +def is_array_api_strict_namespace(xp) -> bool: """ Returns True if `xp` is an array-api-strict namespace. @@ -414,13 +430,15 @@ def is_array_api_strict_namespace(xp): """ return xp.__name__ == 'array_api_strict' -def _check_api_version(api_version): + +def _check_api_version(api_version: str) -> None: if api_version in ['2021.12', '2022.12']: warnings.warn(f"The {api_version} version of the array API specification was requested but the returned namespace is actually version 2023.12") elif api_version is not None and api_version not in ['2021.12', '2022.12', '2023.12']: raise ValueError("Only the 2023.12 version of the array API specification is currently supported") + def array_namespace(*xs, api_version=None, use_compat=None): """ Get the array API compatible namespace for the arrays `xs`. @@ -808,9 +826,10 @@ def size(x: Array) -> int | None: return None if math.isnan(out) else out -def is_writeable_array(x) -> bool: +def is_writeable_array(x: object) -> bool: """ Return False if ``x.__setitem__`` is expected to raise; True otherwise. + Return False if `x` is not an array API compatible object. Warning ------- @@ -821,10 +840,10 @@ def is_writeable_array(x) -> bool: return x.flags.writeable if is_jax_array(x) or is_pydata_sparse_array(x): return False - return True + return is_array_api_obj(x) -def is_lazy_array(x) -> bool: +def is_lazy_array(x: object) -> bool: """Return True if x is potentially a future or it may be otherwise impossible or expensive to eagerly read its contents, regardless of their size, e.g. by calling ``bool(x)`` or ``float(x)``. @@ -857,6 +876,9 @@ def is_lazy_array(x) -> bool: if is_jax_array(x) or is_dask_array(x) or is_ndonnx_array(x): return True + if not is_array_api_obj(x): + return False + # Unknown Array API compatible object. Note that this test may have dire consequences # in terms of performance, e.g. for a lazy object that eagerly computes the graph # on __bool__ (dask is one such example, which however is special-cased above). diff --git a/tests/test_common.py b/tests/test_common.py index 07afaddb..e702e4a9 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -156,6 +156,27 @@ def __bool__(self): assert is_lazy_array(x) +@pytest.mark.parametrize( + 'func', + list(is_array_functions.values()) + + ["is_array_api_obj", "is_lazy_array", "is_writeable_array"] +) +def test_is_array_any_object(func): + """Test that is_*_array functions return False and don't raise on non-array objects + """ + func = globals()[func] + + # These objects are missing attributes such as __name__ + assert not func(object()) + assert not func(None) + assert not func(1) + + class C: + pass + + assert not func(C()) + + @pytest.mark.parametrize("library", all_libraries) def test_device(library): xp = import_(library, wrapper=True)