Skip to content

Commit 512861f

Browse files
committed
ENH: is_lazy_array()
1 parent beac55b commit 512861f

File tree

3 files changed

+127
-2
lines changed

3 files changed

+127
-2
lines changed

array_api_compat/common/_helpers.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -819,6 +819,63 @@ def is_writeable_array(x) -> bool:
819819
return True
820820

821821

822+
def is_lazy_array(x) -> bool:
823+
"""Return True if x is potentially a future or it may be otherwise impossible or
824+
expensive to eagerly read its contents, regardless of their size, e.g. by
825+
calling ``bool(x)`` or ``float(x)``.
826+
827+
Return True otherwise; e.g. ``bool(x)`` etc. is guaranteed to succeed and to be
828+
cheap as long as the array is the right dtype.
829+
830+
Note
831+
----
832+
This function errs on the side of caution for array types that may or may not be
833+
lazy, e.g. JAX arrays, by always returning True for them.
834+
"""
835+
if (
836+
is_numpy_array(x)
837+
or is_cupy_array(x)
838+
or is_torch_array(x)
839+
or is_pydata_sparse_array(x)
840+
):
841+
return False
842+
843+
# **JAX note:** while it is possible to determine if you're inside or outside
844+
# jax.jit by testing the subclass of a jax.Array object, as well as testing bool()
845+
# as we do below for unknown arrays, this is not recommended by JAX best practices.
846+
847+
# **Dask note:** Dask eagerly computes the graph on __bool__, __float__, and so on.
848+
# This behaviour, while impossible to change without breaking backwards
849+
# compatibility, is highly detrimental to performance as the whole graph will end
850+
# up being computed multiple times.
851+
852+
if is_jax_array(x) or is_dask_array(x) or is_ndonnx_array(x):
853+
return True
854+
855+
# Unknown Array API compatible object. Note that this test may have dire consequences
856+
# in terms of performance, e.g. for a lazy object that eagerly computes the graph
857+
# on __bool__ (dask is one such example, which however is special-cased above).
858+
859+
# Select a single point of the array
860+
s = size(x)
861+
if math.isnan(s):
862+
return True
863+
xp = array_namespace(x)
864+
if s > 1:
865+
x = xp.reshape(x, (-1,))[0]
866+
# Cast to dtype=bool and deal with size 0 arrays
867+
x = xp.any(x)
868+
869+
try:
870+
bool(x)
871+
return False
872+
# The Array API standard dictactes that __bool__ should raise TypeError if the
873+
# output cannot be defined.
874+
# Here we're more lenient and also allow for e.g. NotImplementedError.
875+
except Exception:
876+
return True
877+
878+
822879
__all__ = [
823880
"array_namespace",
824881
"device",
@@ -840,6 +897,7 @@ def is_writeable_array(x) -> bool:
840897
"is_pydata_sparse_array",
841898
"is_pydata_sparse_namespace",
842899
"is_writeable_array",
900+
"is_lazy_array",
843901
"size",
844902
"to_device",
845903
]

docs/helper-functions.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ yet.
5252
.. autofunction:: is_pydata_sparse_array
5353
.. autofunction:: is_ndonnx_array
5454
.. autofunction:: is_writeable_array
55+
.. autofunction:: is_lazy_array
5556
.. autofunction:: is_numpy_namespace
5657
.. autofunction:: is_cupy_namespace
5758
.. autofunction:: is_torch_namespace

tests/test_common.py

Lines changed: 68 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55
is_dask_namespace, is_jax_namespace, is_pydata_sparse_namespace,
66
)
77

8-
from array_api_compat import device, is_array_api_obj, is_writeable_array, to_device
9-
8+
from array_api_compat import (
9+
device, is_array_api_obj, is_lazy_array, is_writeable_array, to_device
10+
)
1011
from ._helpers import import_, wrapped_libraries, all_libraries
1112

1213
import pytest
@@ -92,6 +93,70 @@ def test_is_writeable_array_numpy():
9293
assert not is_writeable_array(x)
9394

9495

96+
@pytest.mark.parametrize("library", all_libraries)
97+
def test_is_lazy_array(library):
98+
lib = import_(library)
99+
x = lib.asarray([1, 2, 3])
100+
assert isinstance(is_lazy_array(x), bool)
101+
102+
103+
@pytest.mark.parametrize("array", [
104+
[], [1, 2], 1, 0, float("nan"), [[1, 2], [3, 4]]
105+
])
106+
def test_is_lazy_array_unknown(array, monkeypatch):
107+
"""Test is_lazy_array() on an unknown Array API compliant object"""
108+
xp = import_("jax.numpy")
109+
import array_api_compat.common._helpers
110+
import jax
111+
112+
x = xp.asarray(array)
113+
# Prevent is_lazy_array() from special-casing JAX
114+
monkeypatch.setattr(
115+
array_api_compat.common._helpers,
116+
"is_jax_array",
117+
lambda x: False,
118+
)
119+
120+
assert not is_lazy_array(x) # Eager JAX
121+
assert jax.jit(is_lazy_array)(x) # Jitted (lazy) JAX
122+
123+
124+
def test_is_lazy_array_unknown_dask(monkeypatch):
125+
"""Test is_lazy_array() on an unknown Array API compliant object which
126+
- may or may not raise an arbitrary exception on bool()
127+
- may or may not have NaN in its shape
128+
"""
129+
da = import_("dask.array", wrapper=True)
130+
import array_api_compat.common._helpers
131+
132+
x = da.arange(10)
133+
y = x[x > 5]
134+
assert np.isnan(y.size)
135+
136+
def do_not_run(_):
137+
raise AssertionError("do_not_run")
138+
139+
z = x.map_blocks(do_not_run, dtype=x.dtype)
140+
with pytest.raises(AssertionError, match="do_not_run"):
141+
z.compute()
142+
143+
# Prevent is_lazy_array() from special-casing Dask
144+
monkeypatch.setattr(
145+
array_api_compat.common._helpers,
146+
"is_dask_array",
147+
lambda x: False,
148+
)
149+
monkeypatch.setattr(
150+
array_api_compat.common._helpers,
151+
"array_namespace",
152+
lambda x: da,
153+
)
154+
155+
assert not is_lazy_array(x) # Eagerly computes on bool()
156+
assert is_lazy_array(y) # NaN size
157+
assert is_lazy_array(z) # bool() raises AssertionError
158+
159+
95160
@pytest.mark.parametrize("library", all_libraries)
96161
def test_device(library):
97162
xp = import_(library, wrapper=True)
@@ -149,6 +214,7 @@ def test_asarray_cross_library(source_library, target_library, request):
149214

150215
assert is_tgt_type(b), f"Expected {b} to be a {tgt_lib.ndarray}, but was {type(b)}"
151216

217+
152218
@pytest.mark.parametrize("library", wrapped_libraries)
153219
def test_asarray_copy(library):
154220
# Note, we have this test here because the test suite currently doesn't

0 commit comments

Comments
 (0)